机器学习 - Kaggle项目实践(3)Digit Recognizer 手写数字识别

Digit Recognizer | Kaggle 题面

Digit Recognizer-CNN | Kaggle 下面代码的kaggle版本

使用CNN进行手写数字识别

学习到了网络搭建手法+学习率退火+数据增广 提高训练效果。

使用混淆矩阵 以及对分类出错概率最大的例子单独拎出来分析。

最终以99.546%正确率 排在 86/1035

1. 数据准备

拆成X和Y 输出每个数字 train中的数据量

import pandas as pd
import seaborn as sns
train = pd.read_csv("/kaggle/input/digit-recognizer/train.csv")
test = pd.read_csv("/kaggle/input/digit-recognizer/test.csv")
Y_train = train["label"]
X_train = train.drop(labels = ["label"],axis = 1) 
del train g = sns.countplot(x=Y_train)
Y_train.value_counts()

1.确认数据干净 → 没有缺失值

print(X_train.isnull().any().describe()) # 下面代表所有列 都是False
print(test.isnull().any().describe())

2.灰度归一化(除以255) → 提升训练速度和稳定性。

reshape → 转换成 CNN 所需的 (height, width, channel) 格式。  

原来是(样本数, 784) -> (样本数, 高度, 宽度, 通道数)  样本数, 28, 28, 1

X_train, test = X_train / 255.0, test / 255.0
X_train = X_train.values.reshape(-1, 28, 28, 1)
test = test.values.reshape(-1, 28, 28, 1)

3.One-Hot 编码 只有一个位置为1的标签 → 后续 softmax 输出10个类别的概率。

from tensorflow.keras.utils import to_categorical
Y_train = to_categorical(Y_train, num_classes=10)

4.划分验证集 → 用于调参、防止过拟合。

from sklearn.model_selection import train_test_split
X_train, X_val, Y_train, Y_val = train_test_split(X_train, Y_train, test_size=0.1, random_state=2)

5.可视化一个样本 → 检查数据是否正确加载。

import matplotlib.pyplot as plt
plt.imshow(X_train[0][:,:,0], cmap='gray')
plt.title("Example Image")
plt.show()

2. CNN模型建构

卷积块:2个Conv(5 * 5)+ BN + 池化      ~     2个Conv(3 * 3) + BN + 池化

Dropout舍弃一些神经元防止过拟合。

import numpy as np
import tensorflow as tffrom tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D as MaxPool2D
from tensorflow.keras.layers import Dense, Dropout, Flatten, BatchNormalization
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ReduceLROnPlateau
from tensorflow.keras.preprocessing.image import ImageDataGeneratormodel = Sequential()# 卷积块 1
model.add(Conv2D(32, (5,5), padding='same', activation='relu', input_shape=(28,28,1)))
model.add(BatchNormalization())   # 批归一化
model.add(Conv2D(32, (5,5), padding='same', activation='relu'))
model.add(BatchNormalization())
model.add(MaxPool2D(pool_size=(2,2)))
model.add(Dropout(0.25))# 卷积块 2
model.add(Conv2D(64, (3,3), padding='same', activation='relu'))
model.add(BatchNormalization())
model.add(Conv2D(64, (3,3), padding='same', activation='relu'))
model.add(BatchNormalization())
model.add(MaxPool2D(pool_size=(2,2), strides=(2,2)))
model.add(Dropout(0.25))

输出层:Flatten + Dense + Softmax

model.add(Flatten())
model.add(Dense(256, activation='relu'))
model.add(BatchNormalization())
model.add(Dropout(0.5))
model.add(Dense(10, activation='softmax'))

model.compile     Adam优化器 + 交叉熵损失 + accuracy评估

optimizer = Adam(learning_rate=0.001)model.compile(optimizer=optimizer,loss='categorical_crossentropy',metrics=['accuracy'])

学习率退火 若 3 个 epoch 内没提升,就把 lr 减半

learning_rate_reduction = ReduceLROnPlateau(monitor='val_accuracy', patience=3, verbose=1,factor=0.5, min_lr=1e-5
)

数据增广(ImageDataGenerator) 轻微旋转、缩放、平移等扩充数据

datagen = ImageDataGenerator(rotation_range=10,zoom_range=0.1,width_shift_range=0.1,height_shift_range=0.1
)

训练 设置 epoch  batch_size  steps

datagen.fit(X_train)# ===== 训练 =====
epochs = 30
batch_size = 86
steps = int(np.ceil(len(X_train) / batch_size)) # 一次训练覆盖所有样本history = model.fit(datagen.flow(X_train, Y_train, batch_size=batch_size), # 数据增强epochs=epochs,steps_per_epoch=steps, validation_data=(X_val, Y_val), # 验证数据callbacks=[learning_rate_reduction], # 回调学习率verbose=2 # 每个 epoch 输出一次
)
440/440 - 144s - 327ms/step - accuracy: 0.9950 - loss: 0.0170 - val_accuracy: 0.9957 - val_loss: 0.0134 - learning_rate: 6.2500e-05
Epoch 30/30Epoch 30: ReduceLROnPlateau reducing learning rate to 3.125000148429535e-05.
440/440 - 145s - 329ms/step - accuracy: 0.9952 - loss: 0.0153 - val_accuracy: 0.9955 - val_loss: 0.0141 - learning_rate: 6.2500e-05

这是训练日志的最后一部分输出 准确率达到 99.5%

3. 模型评估 evaluation

1. history报告中 训练集和验证集的 Loss 和 Accuracy 可视化。

val为验证,验证不比训练差 说明没有太过拟合。

fig, ax = plt.subplots(2,1)# 绘制训练集和验证集的 Loss
ax[0].plot(history.history['loss'], color='b', label="Training loss")
ax[0].plot(history.history['val_loss'], color='r', label="validation loss")
ax[0].legend(loc='best', shadow=True)# 绘制训练集和验证集的 Accuracy
ax[1].plot(history.history['accuracy'], color='b', label="Training accuracy")
ax[1].plot(history.history['val_accuracy'], color='r', label="Validation accuracy")
ax[1].legend(loc='best', shadow=True)

2. 混淆矩阵 confusion_matrix

from sklearn.metrics import confusion_matrix
import itertoolsdef plot_confusion_matrix(cm, classes,normalize=False,title='Confusion matrix',cmap=plt.cm.Blues):plt.imshow(cm, interpolation='nearest', cmap=cmap)plt.title(title)plt.colorbar()tick_marks = np.arange(len(classes))plt.xticks(tick_marks, classes, rotation=45)plt.yticks(tick_marks, classes)if normalize:cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]thresh = cm.max() / 2.for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):plt.text(j, i, cm[i, j],horizontalalignment="center",color="white" if cm[i, j] > thresh else "black")plt.tight_layout()plt.ylabel('True label')plt.xlabel('Predicted label')Y_pred = model.predict(X_val) # 预测值
Y_pred_classes = np.argmax(Y_pred, axis = 1) # 预测对应类别
Y_true = np.argmax(Y_val,axis = 1) # 真实类别confusion_mtx = confusion_matrix(Y_true, Y_pred_classes) 
plot_confusion_matrix(confusion_mtx, classes = range(10)) 

3. 对于分类错误的位置 用预测(错误数字的概率 - 正确标签的概率)

输出预测错误差别最大的图片 是什么样的。

# 找出预测错误的位置 布尔数组
errors = (Y_pred_classes != Y_true )# 提取错误预测的详细数据
Y_pred_classes_errors = Y_pred_classes[errors]
Y_pred_errors = Y_pred[errors]
Y_true_errors = Y_true[errors]
X_val_errors = X_val[errors]# 错误预测的最大概率
Y_pred_errors_prob = np.max(Y_pred_errors,axis = 1)# 正确标签对应的预测概率
true_prob_errors = np.diagonal(np.take(Y_pred_errors, Y_true_errors, axis=1))# 差值:预测的最大概率 - 正确标签概率
delta_pred_true_errors = Y_pred_errors_prob - true_prob_errors# 找出概率差最大的错误
most_important_errors = np.argsort(delta_pred_true_errors)[-6:]def display_errors(errors_index,img_errors,pred_errors, obs_errors):""" This function shows 6 images with their predicted and real labels"""n = 0nrows = 2ncols = 3fig, ax = plt.subplots(nrows,ncols,sharex=True,sharey=True)for row in range(nrows):for col in range(ncols):error = errors_index[n]ax[row,col].imshow((img_errors[error]).reshape((28,28)))ax[row,col].set_title("Predicted label :{}\nTrue label :{}".format(pred_errors[error],obs_errors[error]))n += 1display_errors(most_important_errors, X_val_errors, Y_pred_classes_errors, Y_true_errors)

发现预测错误的几张图片 本身就很容易误解

4. 预测&提交结果

results = model.predict(test)
results = np.argmax(results,axis = 1)  # 概率转类别
results = pd.Series(results,name="Label")submission = pd.concat([pd.Series(range(1,28001),name = "ImageId"),results],axis = 1)
submission.to_csv("cnn_mnist_datagen.csv",index=False)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/bicheng/93237.shtml
繁体地址,请注明出处:http://hk.pswp.cn/bicheng/93237.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

新手如何高效运营亚马逊跨境电商:从传统SP广告到DeepBI智能策略

"为什么我的广告点击量很高但订单转化率却很低?""如何避免新品期广告预算被大词消耗殆尽?""为什么手动调整关键词和出价总是慢市场半拍?""竞品ASIN投放到底该怎么做才有效?""有没有…

【论文阅读 | CVPR 2024 | UniRGB-IR:通过适配器调优实现可见光-红外语义任务的统一框架】

论文阅读 | CVPR 2024 | UniRGB-IR:通过适配器调优实现可见光-红外语义任务的统一框架​1&&2. 摘要&&引言3.方法3.1 整体架构3.2 多模态特征池3.3 补充特征注入器3.4 适配器调优范式4 实验4.1 RGB-IR 目标检测4.2 RGB-IR 语义分割4.3 RGB-IR 显著目…

Hyperf 百度翻译接口实现方案

保留 HTML/XML 标签结构,仅翻译文本内容,避免破坏富文本格式。采用「HTML 解析 → 文本提取 → 批量翻译 → 回填」的流程。百度翻译集成方案:富文本内容翻译系统 HTML 解析 百度翻译 API 集成 文件结构 app/ ├── Controller/ │ └──…

字节跳动 VeOmni 框架开源:统一多模态训练效率飞跃!

资料来源:火山引擎-开发者社区 多模态时代的训练痛点,终于有了“特效药” 当大模型从单一语言向文本 图像 视频的多模态进化时,算法工程师们的训练流程却陷入了 “碎片化困境”: 当业务要同时迭代 DiT、LLM 与 VLM时&#xff0…

配置docker pull走http代理

之前写了一篇自建Docker镜像加速器服务的博客,需要用到境外服务器作为代理,但是一般可能没有境外服务器,只有http代理,所以如果本地使用想走代理可以用以下方式 临时生效(只对当前终端有效) 设置环境变量…

OpenAI 开源模型 gpt-oss 本地部署详细教程

OpenAI 最近发布了其首个开源的开放权重模型gpt-oss,这在AI圈引起了巨大的轰动。对于广大开发者和AI爱好者来说,这意味着我们终于可以在自己的机器上,完全本地化地运行和探索这款强大的模型了。 本教程将一步一步指导你如何在Windows和Linux…

力扣-5.最长回文子串

题目链接 5.最长回文子串 class Solution {public String longestPalindrome(String s) {boolean[][] dp new boolean[s.length()][s.length()];int maxLen 0;String str s.substring(0, 1);for (int i 0; i < s.length(); i) {dp[i][i] true;}for (int len 2; len …

Apache Ignite超时管理核心组件解析

这是一个非常关键且设计精巧的 定时任务与超时管理组件 —— GridTimeoutProcessor&#xff0c;它是 Apache Ignite 内核中负责 统一调度和处理所有异步超时事件的核心模块。&#x1f3af; 一、核心职责统一管理所有需要“在某个时间点触发”的任务或超时逻辑。它相当于 Ignite…

DAY 42 Grad-CAM与Hook函数

知识点回顾回调函数lambda函数hook函数的模块钩子和张量钩子Grad-CAM的示例# 定义一个存储梯度的列表 conv_gradients []# 定义反向钩子函数 def backward_hook(module, grad_input, grad_output):# 模块&#xff1a;当前应用钩子的模块# grad_input&#xff1a;模块输入的梯度…

基于 NVIDIA 生态的 Dynamo 风格分布式 LLM 推理架构

网罗开发&#xff08;小红书、快手、视频号同名&#xff09;大家好&#xff0c;我是 展菲&#xff0c;目前在上市企业从事人工智能项目研发管理工作&#xff0c;平时热衷于分享各种编程领域的软硬技能知识以及前沿技术&#xff0c;包括iOS、前端、Harmony OS、Java、Python等方…

《吃透 C++ 类和对象(中):拷贝构造函数与赋值运算符重载深度解析》

&#x1f525;个人主页&#xff1a;草莓熊Lotso &#x1f3ac;作者简介&#xff1a;C研发方向学习者 &#x1f4d6;个人专栏&#xff1a; 《C语言》 《数据结构与算法》《C语言刷题集》《Leetcode刷题指南》 ⭐️人生格言&#xff1a;生活是默默的坚持&#xff0c;毅力是永久的…

Python 环境隔离实战:venv、virtualenv 与 conda 的差异与最佳实践

那天把项目部署到测试环境&#xff0c;结果依赖冲突把服务拉崩了——本地能跑&#xff0c;线上不能跑。折腾半天才发现&#xff1a;我和同事用的不是同一套 site-packages&#xff0c;版本差异导致运行时异常。那一刻我彻底明白&#xff1a;虚拟环境不是可选项&#xff0c;它是…

[ 数据结构 ] 时间和空间复杂度

1.算法效率算法效率分析分为两种 : ①时间效率, ②空间效率 时间效率即为 时间复杂度 , 时间复杂度主要衡量一个算法的运行速度空间效率即为 空间复杂度 , 空间复杂度主要衡量一个算法所需要的额外空间2.时间复杂度2.1 时间复杂度的概念定义 : 再计算机科学中 , 算法的时间复杂…

一,设计模式-单例模式

目的设计单例模式的目的是为了解决两个问题&#xff1a;保证一个类只有一个实例这种需求是需要控制某些资源的共享权限&#xff0c;比如文件资源、数据库资源。为该实例提供一个全局访问节点相较于通过全局变量保存重要的共享对象&#xff0c;通过一个封装的类对象&#xff0c;…

AIStarter修复macOS 15兼容问题:跨平台AI项目管理新体验

AIStarter是全网唯一支持Windows、Mac和Linux的AI管理平台&#xff0c;为开发者提供便捷的AI项目管理体验。近期&#xff0c;熊哥在视频中分享了针对macOS 15系统无法打开AIStarter的修复方案&#xff0c;最新版已完美兼容。本文基于视频内容&#xff0c;详解修复细节与使用技巧…

LabVIEW 纺织检测数据传递

基于 LabVIEW 实现纺织检测系统中上位机&#xff08;PC 机&#xff09;与下位机&#xff08;单片机&#xff09;的串口数据传递&#xff0c;成功应用于煮茧机温度测量系统。通过采用特定硬件架构与软件设计&#xff0c;实现了温度数据的高效采集、传输与分析&#xff0c;操作简…

ECCV-2018《Variational Wasserstein Clustering》

核心思想 该论文提出了一个基于最优传输(optimal transportation) 理论的新型聚类方法&#xff0c;称为变分Wasserstein聚类(Variational Wasserstein Clustering, VWC)。其核心思想有三点&#xff1a;建立最优传输与k-means聚类的联系&#xff1a;作者指出k-means聚类问题本质…

部署 Docker 应用详解(MySQL + Tomcat + Nginx + Redis)

文章目录一、MySQL二、Tomcat三、Nginx四、Redis一、MySQL 搜索 MySQL 镜像下载 MySQL 镜像创建 MySQL 容器 docker run -i -t/d -p 3307:3306 --namec_mysql -v $PWD/conf:/etc/mysql/conf.d -v $PWD/logs:/logs -v $PWD/data:/var/lib/mysql -e MYSQL_ROOT_PASSWORD123456 m…

VR全景导览在大型活动中的应用实践:优化观众体验与现场管理

大型演出赛事往往吸引海量观众&#xff0c;但复杂的场馆环境常带来诸多困扰&#xff1a;如何快速找到座位看台区域&#xff1f;停车位如何规划&#xff1f;附近公交地铁站在哪&#xff1f;这些痛点直接影响观众体验与现场秩序。VR全景技术为解决这些问题提供了有效方案。通过在…

OpenJDK 17 JIT编译器堆栈分析

##堆栈(gdb) bt #0 PhaseOutput::safepoint_poll_table (this0x7fffd0bfb950) at /home/yym/openjdk17/jdk17-master/src/hotspot/share/opto/output.hpp:173 #1 0x00007ffff689634e in PhaseOutput::fill_buffer (this0x7fffd0bfb950, cb0x7fffd0bfb970, blk_starts0x7fffb0…