机器学习 - Kaggle项目实践(6)Dogs vs. Cats Redux: Kernels Edition 猫狗二分类

Dogs vs. Cats Redux: Kernels Edition | Kaggle

任务:给定猫狗图像数据集 进行二分类。

Cats or Dogs - using CNN with Transfer Learning | Kaggle(参考)

Cats or Dogs | Kaggle (我的kaggle)

本文介绍了使用ResNet50预训练模型进行猫狗图像分类的完整流程。

数据预处理、模型构建、训练评估和预测输出,展示了迁移学习在图像分类任务中的高效应用。

首先从Kaggle数据集解压图片并处理数据,将文件名转换为独热标签(猫[0,1],狗[1,0])。

然后构建ResNet50模型,移除原始分类层并改为二分类softmax输出,使用ImageNet预训练权重初始化。

模型在训练集上训练20个epoch后,在验证集上准确率达到98%以上。

1. zip 图片提取与 文件名标签提取

从zip文件提取出 train 地址列表和 test 地址列表

import zipfile
import oswith zipfile.ZipFile('/kaggle/input/dogs-vs-cats-redux-kernels-edition/train.zip', 'r') as z:z.extractall('.') # 将ZIP文件中的所有内容解压到当前目录train_image_list = z.namelist() # 获取名称列表train_image_list = os.listdir("./train/") # 进一步解压with zipfile.ZipFile('/kaggle/input/dogs-vs-cats-redux-kernels-edition/test.zip', 'r') as z:z.extractall('.')test_image_list = z.namelist()test_image_list = os.listdir("./test/")print(train_image_list[0],test_image_list[0]) # 文件名 train格式 类别+数字  test只有数字

把train文件夹地址和图像文件名列表,拼凑出完整的地址;

cv2读取出图片;文件名提取出标签 二分类概率 猫为[0,1] 狗为[1,0]

from random import shuffle
from tqdm import tqdm
import cv2
import numpy as np
import pandas as pdRANDOM_STATE = 2018
IMG_SIZE = 224
def process_data(data_image_list, DATA_FOLDER, isTrain):data_df = []for img in tqdm(data_image_list):if(isTrain):label = [1,0] if img.split('.')[0] == 'cat' else [0,1] # 根据文件名 转换独热标签else:label = img.split('.')[0]path = os.path.join(DATA_FOLDER,img) # 拼接为完整路径img = cv2.imread(path,cv2.IMREAD_COLOR) # 读取img = cv2.resize(img, (IMG_SIZE,IMG_SIZE)) # 设定大小data_df.append([np.array(img),np.array(label)]) # 拼在一起返回shuffle(data_df) # 打乱return data_dftrain = process_data(train_image_list, './train/', True)
test = process_data(test_image_list, './test/', False)

2. EDA 图片探索 训练集图片展示

展示 5*5 张训练集图片和测试集图片

def show_images(data, isTest=False):f, ax = plt.subplots(5,5, figsize=(15,15))for i,data in enumerate(data[:25]):img_data,img_num = data[0],data[1]label = np.argmax(img_num) # 独热向量 [0,1] 为狗 转换为文字标签if label == 1: str_label='Dog'elif label == 0: str_label='Cat'if(isTest):str_label="None"ax[i//5, i%5].imshow(img_data)ax[i//5, i%5].axis('off')ax[i//5, i%5].set_title("Label: {}".format(str_label))plt.show()show_images(train)
show_images(test,True)

3. 建立模型 ResNet

 残差神经网络ResNet预训练参数 迁移学习

移除原始ResNet50最后的1000类分类层,改为softmax 激活函数二分类

使用在ImageNet上预训练的权重(好的初始化快速收敛)允许训练微调

from tensorflow.keras.applications import ResNet50
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Densemodel = Sequential()
model.add(ResNet50(include_top=False, # 移除原始ResNet50最后的1000类分类层pooling='max', # 在卷积特征上添加全局最大池化,将特征图转换为向量weights='imagenet' # 使用在ImageNet上预训练的权重
))
model.add(Dense(2, activation='softmax')) # softmax 激活函数二分类model.layers[0].trainable = True # 允许训练微调
model.compile(optimizer='sgd', loss='categorical_crossentropy', metrics=['accuracy'])
model.summary()

4. 准备数据并训练

X = np.array([data[0] for data in train]).reshape(-1,IMG_SIZE,IMG_SIZE,3)
y = np.array([data[1] for data in train])
from sklearn.model_selection import train_test_split
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.25, random_state=RANDOM_STATE)
train_model = model.fit(X_train, y_train, batch_size=64, epochs=20, verbose=1, validation_data=(X_val, y_val))

verbose=1 训练进度展示

5. 预测+评估

模型评估:model.evaluate 评估分数; 验证集真实和预测 分类报告

score = model.evaluate(X_val, y_val, verbose=0) # 评估分数
print('Validation loss:', score[0])
print('Validation accuracy:', score[1])predicted_classes = model.predict_classes(X_val) # 预测
y_true = np.argmax(y_val,axis=1) # 实际from sklearn.metrics import classification_report # 分类报告
print(classification_report(y_true, predicted_classes, target_names=["Cat", "Dog"]))

这三个指标均达到 98%以上

还可以 可视化部分验证集结果(人眼看是否差不多分类正确)

f, ax = plt.subplots(5, 5, figsize=(15, 15))for i, (img_data, _) in enumerate(test[:25]):prediction = model.predict(img_data.reshape(-1, IMG_SIZE, IMG_SIZE, 3))[0]label = 'Dog' if np.argmax(prediction) == 1 else 'Cat'ax[i//5, i%5].imshow(img_data)ax[i//5, i%5].axis('off')ax[i//5, i%5].set_title(f"Predicted: {label}")plt.show()

预测并保存结果

pred_list = []
img_list = []
for img in tqdm(test):data = img[0].reshape(-1,IMG_SIZE,IMG_SIZE,3)pred_list.append(model.predict([data])[0][1])img_list.append(img_idx[1])submission = pd.DataFrame({'id':img_list , 'label':pred_list})
submission.to_csv("submission.csv", index=False)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/bicheng/95170.shtml
繁体地址,请注明出处:http://hk.pswp.cn/bicheng/95170.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基础的汇编指令

目录 1、接上一个csdn特殊功能寄存器 1.1CPSR寄存器 1.2SPSR寄存器 1.3CPSR寄存器的高四位和第四位 ​编辑 2、汇编指令的分类 3、汇编指令的基本格式 4、数据搬移指令(赋值指令) 4.1指令码 4.2指令格式 4.3测试代码 4.5立即数 4.6ldr伪指令 …

Docker实战避坑指南:从入门到精通

摘要:文人结合自身微服务实践,系统梳理从安装适配、镜像拉取,到运行配置、构建优化、多容器编排、数据持久化、监控运维等 Docker 全流程高频踩坑点,给出可落地的解决方案,帮助读者快速规避同类问题并提升容器化效率。…

《Bishop PRML》10.1. Variational Inference(2)理解VAE

通过VAE与AE理解变分分布的变量 如何理解变分推断公式中,Z和X的含义是什么? 知乎 变分自编码器VAE的数学原理。 csdn 变分自编码器(VAE)的数学原理以及实现 Loss functions in Variational Autoencoders (VAEs) 一文解释 VAE+ELBO AE的编码和解码是确定性的。VAE的解码过程…

函数调用中的初始化与赋值——深入理解C++对象的生命周期

技术博客:函数调用中的初始化与赋值——深入理解C对象的生命周期引言在C编程中,理解函数调用过程中参数传递、对象创建和返回值处理的细节对于编写高效且无误的代码至关重要。本文将通过一个具体的例子来探讨函数调用时实参到形参的转换过程,…

矩阵微积分的链式法则(chain rule)

矩阵微积分的链式法则(chain rule)与标量情况一样,用于求复合函数的导数,但由于涉及矩阵和向量的求导,维度匹配和布局约定(numerator-layout vs. denominator-layout)必须格外小心。下面给出常见…

网络编程4-并发服务器、阻塞与非阻塞IO、信号驱动模型、IO多路复用..

一、并发服务器1、单循环服务器(顺序处理) 一次只能处理一个客户端连接,只有当前客户端断开连接后,才能接受新的客户端连接2、多进程/多线程并发服务器while(1) {connfd accept(listenfd);pid fork(); // 或 pthread_cr…

在 WSL2-NVIDIA-Workbench 中安装Anaconda、CUDA 13.0、cuDNN 9.12 及 PyTorch(含完整环境验证)

在 WSL-NVIDIA-Workbench(NVIDIA AI Workbench & Ubuntu 22.04)中 安装 Anaconda、CUDA 13.0、cuDNN 9.12 及 PyTorch 步骤也可参阅: 在WSL2-Ubuntu中安装Anaconda、CUDA13.0、cuDNN9.12及PyTorch(含完整环境验证&#xf…

Shell编程核心入门:参数传递、运算符与流程控制全解析

Shell编程核心入门:参数传递、运算符与流程控制全解析 在Linux/Unix系统中,Shell作为命令解释器和脚本语言,是自动化运维、批量处理任务的核心工具。掌握Shell脚本的参数传递、运算符使用和流程控制,能让你从“手动执行命令”升级…

如何用 Kotlin 在 Android 手机开发一个应用程序获取网络时间

使用 NTP 协议获取网络时间在 build.gradle 文件中添加以下依赖:implementation commons-net:commons-net:3.6创建 NTP 时间获取工具类:import org.apache.commons.net.ntp.NTPUDPClient import org.apache.commons.net.ntp.TimeInfo import java.net.In…

python智慧交通数据分析可视化系统 车流实时检测分析 深度学习 车流量实时检测跟踪 轨迹跟踪 毕业设计✅

博主介绍:✌全网粉丝50W,前互联网大厂软件研发、集结硕博英豪成立软件开发工作室,专注于计算机相关专业项目实战6年之久,累计开发项目作品上万套。凭借丰富的经验与专业实力,已帮助成千上万的学生顺利毕业,…

计算机视觉第一课opencv(四)保姆级教学

目录 简介 一、轮廓检测 1.查找轮廓的API 2.代码分析 2.1.图像二值化处理 2.2轮廓检测 2.3轮廓绘制 2.4轮廓面积计算 2.5轮廓周长计算 2.6筛选特定面积的轮廓 2.7查找最大面积的轮廓 2.8绘制轮廓的外接圆 2.9绘制轮廓的外接矩形 二、轮廓的近似 三、模板匹配 简…

基于Vue2+elementUi实现树形 横向 合并 table不规则表格

1、实现效果 共N行&#xff0c;但是每一列对应的单元格列数固定&#xff0c;行数不固定2、实现方式说明&#xff1a;使用的是vue2 elementUI表格组件 js实现<template><div class"table-container" ><el-table height"100%" :span-metho…

深度学习在计算机视觉中的应用:对象检测

引言 对象检测是计算机视觉领域中的一项基础任务&#xff0c;目标是在图像或视频帧中识别和定位感兴趣的对象。随着深度学习技术的发展&#xff0c;对象检测的准确性和效率都有了显著提升。本文将详细介绍如何使用深度学习进行对象检测&#xff0c;并提供一个实践案例。 环境准…

node.js 安装步骤

在Node.js中安装包通常通过npm(Node Package Manager)来完成,这是Node.js的包管理工具。以下是安装Node.js和通过npm安装包的基本步骤: 1. 安装Node.js 方法一:使用nvm(Node Version Manager) 推荐使用nvm来安装Node.js,因为它允许你安装多个Node.js版本,并轻松地在…

面试-故障案例解析

一、NFS故障&#xff0c;造成系统cpu使用率低而负载极高。故障概述: 公司使用NFS为web节点提供共享存储服务,某一天下午发现web节点CPU使用率低,而负载极高.登录web节点服务器排查发现后段NFS服务器故障. 影响范围: 网站看不到图片了。 处理流程: 通过ssh登录NFS服务…

医疗AI时代的生物医学Go编程:高性能计算与精准医疗的案例分析(一)

摘要: 随着高通量测序、医学影像和电子病历等生物医学数据的爆炸式增长,对高效、可靠、可扩展的计算工具需求日益迫切。Go语言凭借其原生并发模型、卓越的性能、简洁的语法和强大的标准库,在生物医学信息学领域展现出独特优势。本文以“生物医学Go编程探析”为主题,通过三个…

针对 “TCP 连接建立阶段” 的攻击

针对 “TCP 连接建立阶段” 的攻击一、定义二、共性防御思路三、攻击手段3.1、SYN 洪水攻击&#xff08;SYN Flood&#xff09;3.2、Land 攻击&#xff08;Land Attack&#xff09;一、定义 什么是针对 “TCP 连接建立阶段” 的攻击&#xff1f;核心特征是利用 TCP “三次握手…

聊一聊 单体分布式 和 微服务分布式

微服务 与 单体架构对比维度单体架构微服务架构​​架构本质​​一个单一的、功能齐全的应用程序一组​​小型、独立​​的服务集合​​开发​​团队工作在同一个代码库&#xff0c;易产生冲突。技术栈统一。每个服务可以由​​ 独立的小团队 ​​负责&#xff0c;允许使用​​…

【C++八股文】计算机网络篇

网络协议核心知识点详解 TCP头部结构 TCP头部包含多个关键字段&#xff0c;每个字段都有其特定作用&#xff1a; 16位源端口&#xff1a;标识发送方应用程序的端口号16位目的端口&#xff1a;标识接收方应用程序的端口号32位序号&#xff1a;保证数据包有序传输的唯一标识32…

小迪Web自用笔记7

游戏一般不走http https协议&#xff0c;一般的抓包工具抓不到。科来&#xff0c;这个工具是从网卡抓包。你一旦打怪数据就会多起来↓但不是很专业。可以抓到https。wep↑这个西东是全部协议都做流量包&#xff0c;你不知道他是从哪儿来的&#xff0c;他全都抓&#xff08;专业…