Ubuntu系统:22.04
python版本:3.9
安装依赖库:
pip install tensorflow==2.13 matplotlib numpy -i https://mirrors.aliyun.com/pypi/simple
代码实现:
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
import numpy as np
import matplotlib.pyplot as plt# 加载MNIST数据集
mnist = tf.keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()# 数据预处理
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32') / 255
test_images = test_images.reshape(test_images.shape[0], 28, 28, 1).astype('float32') / 255# 构建CNN模型
model = Sequential()
model.add(Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=(28, 28, 1)))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(64, kernel_size=(3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dense(10, activation='softmax'))# 编译模型
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])# 训练模型
history = model.fit(train_images, train_labels,batch_size=128,epochs=5,verbose=1,validation_data=(test_images, test_labels))# 评估模型
test_loss, test_acc = model.evaluate(test_images, test_labels, verbose=0)
print(f"\n测试准确率: {test_acc:.4f}")# 保存模型
model.save('mnist_cnn_model.keras')
print("模型已保存为 mnist_cnn_model.keras")# 可视化训练过程
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='训练准确率')
plt.plot(history.history['val_accuracy'], label='验证准确率')
plt.title('模型准确率')
plt.ylabel('准确率')
plt.xlabel('训练轮次')
plt.legend()plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='训练损失')
plt.plot(history.history['val_loss'], label='验证损失')
plt.title('模型损失')
plt.ylabel('损失')
plt.xlabel('训练轮次')
plt.legend()plt.tight_layout()
plt.savefig('training_history.png')
print("训练过程图表已保存为 training_history.png")# 测试预测
sample_idx = np.random.randint(0, len(test_images))
sample_image = test_images[sample_idx].reshape(1, 28, 28, 1)
prediction = model.predict(sample_image, verbose=0)plt.figure(figsize=(5, 3))
plt.imshow(test_images[sample_idx].reshape(28, 28), cmap='gray')
plt.title(f"真实标签: {test_labels[sample_idx]}\n预测结果: {np.argmax(prediction)}")
plt.axis('off')
plt.savefig('sample_prediction.png')
print(f"样本预测图已保存为 sample_prediction.png\n真实标签: {test_labels[sample_idx]},预测结果: {np.argmax(prediction)}")