从基于大型数据集训练的神经网络中提取层,并基于新数据集进行微调。本例使用ImageNet中的子集进行微调。
This example retrains a SqueezeNet neural network using transfer learning. This network has been trained on over a million images, and can classify images into 1000 object categories (such as keyboard, coffee mug, pencil, and many animals). The network has learned rich feature representations for a wide range of images. The network takes an image as input and outputs a prediction score for each of these classes.
Performing transfer learning and fine-tuning of a pretrained neural network typically requires less data, is much faster, and is easier than training a neural network from scratch.
To adapt a pretrained neural network for a new task, replace the last few layers (the network head) so that it outputs prediction scores for each of the classes for the new task. This diagram outlines the architecture of a neural network that makes predictions for classes, and illustrates how to edit the network so that it outputs predictions for classes.
ImageNet 使用 WordNet 的层级分类体系,每个类别有唯一的 ID。
- 老虎(tiger)
- WordNet ID:
n02129604
- 子类别: 包括孟加拉虎、西伯利亚虎(Indochinese tiger)等。
- WordNet ID:
- 兔子(rabbit)
- WordNet ID:
n02325366
- 子类别: 如家兔(
European rabbit
)、野兔(hare
)等。
- WordNet ID:
- 鸡(chicken)
- WordNet ID:
n01514668
- 子类别: 如母鸡(
hen
)、公鸡(rooster
)、小鸡(chick
)等。
- WordNet ID:
- 老虎:1,300 张图片(不同虎亚种)。
- 兔子:1,300 张图片(含家兔、野兔)。
- 鸡:1,300 张图片(含不同品种、年龄)。
Load Training Data
Create an image datastore. An image datastore enables you to store large collections of image data, including data that does not fit in memory, and efficiently read batches of images when training a neural network. Specify the folder with the extracted images, and indicate that the subfolder names correspond to the image labels.
imds = imageDatastore(digitDatasetPath, ...IncludeSubfolders=true,LabelSource="foldernames");imds.Labels = renamecats(imds.Labels, {'n01514668', 'n02129604','n02325366'}, {'chicken', 'tiger','rabbit'});
numObsPerClass = countEachLabel(imds)
numObsPerClass = Label Count_______ _____chicken 1300 tiger 1300 rabbit 1300
Load Pretrained Network
To adapt a pretrained neural network for a new task, replace the last few layers (the network head) so that it outputs prediction scores for each of the classes for the new task. This diagram outlines the architecture of a neural network that makes predictions for classes, and illustrates how to edit the network so that it outputs predictions for classes.
Load a pretrained SqueezeNet neural network into the workspace by using the imagePretrainedNetwork
function. To return a neural network ready for retraining for the new data, specify the number of classes. When you specify the number of classes, the imagePretrainedNetwork
function adapts the neural network so that it outputs prediction scores for each of the specified number of classes.
You can try other pretrained networks. Deep Learning Toolbox™ provides various pretrained networks that have different sizes, speeds, and accuracies. These additional networks usually require a support package. If the support package for a selected network is not installed, then the function provides a download link. For more information, see Pretrained Deep Neural Networks.
net = imagePretrainedNetwork("squeezenet",NumClasses=numClasses);
inputSize = networkInputSize(net)
The learnable layer in the network head (the last layer with learnable parameters) requires retraining. The layer is usually a fully connected layer, or a convolutional layer, with an output size that matches the number of classes.
The networkHead function, attached to this example as a supporting file, returns the layer and learnable parameter names of the learnable layer in the network head.
[layerName,learnableNames] = networkHead(net)
For transfer learning, you can freeze the weights of earlier layers in the network by setting the learning rates in those layers to 0. During training, the trainnet function does not update the parameters of these frozen layers. Because the function does not compute the gradients of the frozen layers, freezing the weights can significantly speed up network training. For small datasets, freezing the network layers prevents those layers from overfitting to the new dataset.
Freeze the weights of the network, keeping the last learnable layer unfrozen.
net = freezeNetwork(net,LayerNamesToIgnore=layerName);
Prepare Data for Training
The images in the datastore can have different sizes. To automatically resize the training images, use an augmented image datastore.
augImds = augmentedImageDatastore(inputSize(1:2),imds,ColorPreprocessing='gray2rgb');
Specify Training Options
Specify the training options. Choosing among the options requires empirical analysis. To explore different training option configurations by running experiments, you can use the Experiment Manager app.
For this example, use these options:
Train using the Adam optimizer.
Validate the network using the validation data every five iterations. For larger datasets, to prevent validation from slowing down training, increase this value.
Display the training progress in a plot, and monitor the accuracy metric.
Disable the verbose output.
opts = trainingOptions("adam", ...InitialLearnRate=1e-4, ...MaxEpochs=50, ...ValidationData=augImdsVal, ...Verbose=false,...Plots="training-progress", ...MiniBatchSize=128,...Metrics="accuracy");
Train Neural Network
Train the neural network using the trainnet function. For classification, use cross-entropy loss. By default, the trainnet function uses a GPU if one is available. Using a GPU requires a Parallel Computing Toolbox™ license and a supported GPU device. For information on supported devices, see GPU Computing Requirements. Otherwise, the trainnet function uses the CPU. To specify the execution environment, use the ExecutionEnvironment training option.
rng default
net = trainnet(augImds,net,"crossentropy",opts);
没有划分数据集,因为这个例子本身的目的是为了观察CNN的特征变换。
>> summary(net)已初始化: true可学习参数的数量: 724k输入:1 'data' 227×227×3 图像
观察在训练集上的性能。
将预训练的神经网络直接应用于分类问题。要对新图像进行分类,请使用 minibatchpredict。要将预测分类分数转换为标签,请使用scores2label 函数。有关如何使用预训练神经网络进行分类的示例,请参阅使用 GoogLeNet 对图像进行分类。
Ambiguity of Classifications
You can use the softmax activations to calculate the image classifications that are most likely to be incorrect. Define the ambiguity of a classification as the ratio of the second-largest probability to the largest probability. The ambiguity of a classification is between zero (nearly certain classification) and 1 (nearly as likely to be classified to the most likely class as the second class). An ambiguity of near 1 means the network is unsure of the class in which a particular image belongs. This uncertainty might be caused by two classes whose observations appear so similar to the network that it cannot learn the differences between them. Or, a high ambiguity can occur because a particular observation contains elements of more than one class, so the network cannot decide which classification is correct. Note that low ambiguity does not necessarily imply correct classification; even if the network has a high probability for a class, the classification can still be incorrect.
[R,RI] = maxk(softmaxActivations,2,2);
ambiguity = R(:,2)./R(:,1);
Find the most ambiguous images.
[ambiguity,ambiguityIdx] = sort(ambiguity,"descend");
View the most probable classes of the ambiguous images and the true classes.
classList = unique(imds.Labels);
top10Idx = ambiguityIdx(1:10);
top10Ambiguity = ambiguity(1:10);
mostLikely = classList(RI(ambiguityIdx,1));
secondLikely = classList(RI(ambiguityIdx,2));
table(top10Idx,top10Ambiguity,mostLikely(1:10),secondLikely(1:10),imds.Labels(ambiguityIdx(1:10)),...VariableNames=["Image #","Ambiguity","Likeliest","Second","True Class"])
10×5 tableImage # Ambiguity Likeliest Second True Class_______ _________ _________ _______ __________2268 0.99602 chicken tiger tiger 3330 0.99584 tiger rabbit rabbit 104 0.99187 chicken tiger chicken 304 0.98644 rabbit chicken chicken 1163 0.98466 tiger chicken chicken 3071 0.95684 chicken rabbit rabbit 1925 0.95373 rabbit tiger tiger 3006 0.95209 rabbit chicken rabbit 2772 0.93734 chicken rabbit rabbit 3461 0.9258 tiger rabbit rabbit
容易错分的地方就这三坨。原因是这些样本都比较复杂,前景不突出,或者背景复杂,造成特征不明确。