杰瑞科技汇

神经网络MATLAB教程怎么学?

MATLAB 神经网络教程:从入门到实践

MATLAB 提供了强大的神经网络工具箱,它极大地简化了神经网络的设计、训练、仿真和应用过程,本教程将围绕以下几个核心部分展开:

神经网络MATLAB教程怎么学?-图1
(图片来源网络,侵删)
  1. 核心概念简介:了解神经网络的基本构成。
  2. 工具箱概览:认识 MATLAB 中神经网络的核心函数和 App。
  3. 基础教程:一步步构建你的第一个神经网络
    • 教程 1:拟合简单的一维函数
    • 教程 2:解决经典分类问题 - 鸢尾花
  4. 进阶应用:图像分类
  5. 高级主题与最佳实践
  6. 学习资源推荐

核心概念简介

在开始编码前,我们先快速回顾一下神经网络的基本术语,这有助于你理解后续的操作。

  • 神经元:神经网络的基本单元,接收输入,进行加权求和,通过激活函数产生输出。
  • :神经元的集合,神经网络通常由三层构成:
    • 输入层:接收原始数据。
    • 隐藏层:进行特征提取和转换,可以有多层。
    • 输出层:产生最终的预测结果。
  • 权重 和偏置:连接神经元的参数,神经网络的“知识”就存储在这些参数中,训练的过程就是不断调整它们。
  • 激活函数:为神经元引入非线性,使得网络能够学习复杂的模式,常用 sigmoid, tanh, ReLU
  • 损失函数:衡量模型预测值与真实值之间差距的函数,训练的目标是最小化损失函数
  • 反向传播:训练神经网络的核心算法,通过计算损失函数对各参数的梯度,来更新权重和偏置。
  • 优化器:决定如何根据梯度来更新参数的算法,常用 SGD, Adam, RMSprop

工具箱概览

MATLAB 神经网络工具箱主要通过两种方式使用:

A. 图形用户界面 App - 最适合初学者

MATLAB 提供了一个名为 Deep Network Designer 的可视化 App,让你可以通过拖拽的方式构建和训练网络,无需编写大量代码。

  • 如何打开:在 MATLAB 命令窗口输入 deepNetworkDesigner 并回车。
  • 优点
    • 直观,所见即所得。
    • 自动生成训练代码,方便学习和修改。
    • 集成了数据导入、网络构建、训练、分析和部署的全流程。
  • 适用场景:快速原型设计、教学、以及不熟悉代码的用户。

B. 命令行函数 - 最灵活和强大

这是专业用户和开发者最常用的方式,它提供了对网络构建和训练的完全控制权,核心函数包括:

神经网络MATLAB教程怎么学?-图2
(图片来源网络,侵删)
  • layer:定义网络层。
    • featureInputLayer: 输入层,指定输入数据的大小。
    • fullyConnectedLayer: 全连接层。
    • convolution2dLayer: 二维卷积层(用于图像)。
    • reluLayer: ReLU 激活层。
    • maxPooling2dLayer: 二维最大池化层。
    • fullyConnectedLayer: 输出层,指定类别数。
    • classificationLayer: 分类问题的输出层。
    • regressionLayer: 回归问题的输出层。
  • layerGraph:将多个层组合成一个有向无环图,构建复杂的网络结构。
  • trainingOptions:配置训练参数,如优化器 ('adam')、学习率 ('InitialLearnRate')、最大训练轮数 ('MaxEpochs')、训练数据划分 ('ValidationData') 等。
  • trainNetwork:核心训练函数,接收网络结构、训练数据和训练选项,开始训练过程。

基础教程:一步步构建你的第一个神经网络

我们将从两个经典的机器学习问题开始:函数拟合(回归)和分类。

教程 1:拟合简单的一维函数 (回归问题)

目标:训练一个神经网络来拟合 y = sin(x) 函数。

步骤

准备数据

神经网络MATLAB教程怎么学?-图3
(图片来源网络,侵删)
% 生成数据
x = (0:0.01:2*pi)';
T = sin(x); % 目标值
% 划分训练集和测试集
% 这里我们简单地将一部分数据作为训练,一部分作为验证
percentValidation = 20; % 20% 的数据用于验证
idx = randperm(size(x,1));
numValidation = round(percentValidation/100*size(x,1));
idxTrain = idx(1:end-numValidation);
idxValidation = idx(end-numValidation+1:end);
xTrain = x(idxTrain,:);
TTrain = T(idxTrain,:);
xValidation = x(idxValidation,:);
TValidation = T(idxValidation,:);

定义网络结构 对于回归问题,我们需要一个回归输出层。

% 定义网络层
layers = [
    featureInputLayer(1) % 输入层,1个特征
    fullyConnectedLayer(10) % 1个全连接层,10个神经元
    reluLayer % ReLU激活函数
    fullyConnectedLayer(10) % 第2个全连接层
    reluLayer
    fullyConnectedLayer(1) % 输出层,1个神经元
    regressionLayer % 回归问题的输出层
];

配置训练选项

options = trainingOptions('adam', ...
    'InitialLearnRate', 0.005, ...
    'MaxEpochs', 200, ...
    'MiniBatchSize', 128, ...
    'Shuffle', 'every-epoch', ...
    'Plots', 'training-progress', ... % 显示训练过程图
    'ValidationData', {xValidation, TValidation}, ...
    'ValidationFrequency', 30);

训练网络

net = trainNetwork(xTrain, TTrain, layers, options);

运行后,MATLAB 会自动弹出一个窗口,实时显示损失函数和准确率的变化。

进行预测和评估

% 用训练好的网络进行预测
yPred = predict(net, x);
% 绘制结果
figure;
plot(x, T, 'b-', 'LineWidth', 2); hold on;
plot(x, yPred, 'r--', 'LineWidth', 2);
xlabel('x');
ylabel('y');
legend('真实值 sin(x)', '网络预测值');'神经网络函数拟合结果');
grid on;

你会看到红色的虚线(预测值)非常紧密地跟随蓝色的实线(真实值),说明网络训练成功。


教程 2:解决经典分类问题 - 鸢尾花

目标:根据鸢尾花的花萼和花瓣的长度、宽度,将其分为三个不同的品种。

步骤

加载和准备数据 MATLAB 自带了鸢尾花数据集。

% 加载数据
load fisheriris
% 特征 (4个: 花萼长度/宽度, 花瓣长度/宽度)
X = meas; 
% 标签 (3个品种: setosa, versicolor, virginica)
Y = categorical(species);
% 划分训练集和测试集
cv = cvpartition(Y, 'HoldOut', 0.2); % 80%训练, 20%测试
XTrain = X(cv.training,:);
YTrain = Y(cv.training);
XTest = X(cv.test,:);
YTest = Y(cv.test);

定义网络结构 对于分类问题,我们需要一个分类输出层。

% 定义网络层
layers = [
    featureInputLayer(4) % 输入层,4个特征
    fullyConnectedLayer(10) % 1个全连接层
    reluLayer
    fullyConnectedLayer(10) % 第2个全连接层
    reluLayer
    fullyConnectedLayer(3) % 输出层,3个类别
    classificationLayer % 分类问题的输出层
];

配置训练选项

options = trainingOptions('adam', ...
    'InitialLearnRate', 0.001, ...
    'MaxEpochs', 100, ...
    'MiniBatchSize', 20, ...
    'Shuffle', 'every-epoch', ...
    'Plots', 'training-progress', ...
    'ValidationData', {XTest, YTest}, ...
    'ValidationFrequency', 5);

训练网络

net = trainNetwork(XTrain, YTrain, layers, options);

进行预测和评估

% 预测测试集
YPred = classify(net, XTest);
% 计算准确率
accuracy = sum(YPred == YTest) / numel(YTest);
fprintf('测试集准确率: %.2f%%\n', accuracy * 100);
% 混淆矩阵
figure;
confusionchart(YTest, YPred);

你会得到一个很高的准确率(>95%),并通过混淆矩阵直观地看到分类结果。


进阶应用:图像分类

卷积神经网络是处理图像任务的利器,下面是一个简单的手写数字识别示例。

目标:使用 MNIST 数据集训练一个 CNN 来识别手写数字 (0-9)。

步骤

加载数据 使用 digitTrainCellArray 函数加载 MNIST 数据。

% 加载数据
[XTrain, YTrain, XTest, YTest] = digitTrainCellArray();
% 数据需要调整为 4D 格式: [高度, 宽度, 通道数, 样本数]
XTrain = cat(4, XTrain{:});
XTest = cat(4, XTest{:});
% 标签需要转换为 one-hot 编码
YTrain = onehotencode(YTrain, 1);
YTest = onehotencode(YTest, 1);

定义 CNN 结构

layers = [
    imageInputLayer([28 28 1]) % 28x28 像素,1个通道(灰度图)
    % 第一个卷积块
    convolution2dLayer(3, 16, 'Padding', 'same')
    batchNormalizationLayer
    reluLayer
    maxPooling2dLayer(2, 'Stride', 2)
    % 第二个卷积块
    convolution2dLayer(3, 32, 'Padding', 'same')
    batchNormalizationLayer
    reluLayer
    maxPooling2dLayer(2, 'Stride', 2)
    % 全连接层
    fullyConnectedLayer(10)
    softmaxLayer % Softmax激活函数,用于多分类
    classificationLayer
];

配置和训练

options = trainingOptions('sgdm', ...
    'InitialLearnRate', 0.01, ...
    'MaxEpochs', 15, ...
    'MiniBatchSize', 128, ...
    'Shuffle', 'every-epoch', ...
    'Plots', 'training-progress', ...
    'ValidationData', {XTest, YTest});
net = trainNetwork(XTrain, YTrain, layers, options);

预测和评估

YPred = classify(net, XTest);
YTestLabels = onehotdecode(YTest, 1:10, 'Classname');
accuracy = sum(YPred == YTestLabels) / numel(YTestLabels);
fprintf('CNN 测试集准确率: %.2f%%\n', accuracy * 100);

这个 CNN 模型的准确率通常会达到 98% 以上。


高级主题与最佳实践

  • 数据预处理与增强
    • 归一化:将输入数据缩放到 [0, 1][-1, 1] 区间,有助于加速收敛。normalize 函数。
    • 数据增强:通过旋转、平移、缩放、翻转等方式人为扩充训练集,防止过拟合。imageDataAugmenteraugmentedImageDatastore 是关键工具。
  • 迁移学习:利用在大型数据集(如 ImageNet)上预训练好的网络(如 VGG16, ResNet, SqueezeNet),替换掉其最后的几层,用你自己的小数据集进行微调。alexnet, resnet50 等函数可以直接加载预训练模型。
  • 超参数调优:学习率、批次大小、网络层数/神经元数等都是超参数,可以使用 bayesopt (贝叶斯优化) 或 gridsearch (网格搜索) 等方法自动寻找最优组合。
  • 部署:训练好的模型可以通过 exportONNXNetworkexportNetworkToONNX 导出为 ONNX 格式,以便在其他框架中使用;也可以通过 codegen 生成 C/C++ 或 CUDA 代码,用于嵌入式设备或服务器部署。

学习资源推荐

  1. 官方文档 (首选)

  2. 官方示例

    • 在 MATLAB 命令窗口输入 help nnet,或直接在文档中浏览 "Deep Learning Toolbox Examples",有大量从简单到复杂的实例代码。
  3. 视频教程

    MATLAB 官方 YouTube 频道上有大量关于深度学习的视频教程,搜索 "MATLAB Deep Learning"。

  4. 社区支持

希望这份详细的教程能帮助你顺利开启 MATLAB 神经网络的学习之旅!祝你编码愉快!

分享:
扫描分享到社交APP
上一篇
下一篇