将 tensorflow keras 训练数据集转换为 Yolo 训练数据集

以 https://www.kaggle.com/datasets/vipoooool/new-plant-diseases-dataset 为例

1.  图像分类数据集文件结构 (例如用于 yolov11n-cls.pt 训练)

import os
import csv
import random
from PIL import Image
from sklearn.model_selection import train_test_split
import shutil# ====================== 配置参数 ======================
# 从 Kaggle Hub 下载植物病害数据集
# https://www.kaggle.com/datasets/vipoooool/new-plant-diseases-dataset
import kagglehub
tf_download_path = kagglehub.dataset_download("vipoooool/new-plant-diseases-dataset")
print("Path to dataset files:", tf_download_path)
# 定义数据集路径
tf_dataset_path = f"{tf_download_path}/New Plant Diseases Dataset(Augmented)/New Plant Diseases Dataset(Augmented)"INPUT_DATA_DIR = tf_dataset_path  # 输入数据集路径(解压后的根目录)
OUTPUT_YOLO_DIR = "./runs/traindata/yolo/yolo_plant_diseases_classify"        # 输出YOLO数据集路径
if os.path.exists(OUTPUT_YOLO_DIR):shutil.rmtree(OUTPUT_YOLO_DIR)
os.makedirs(OUTPUT_YOLO_DIR, exist_ok=True)TRAIN_SIZE = 0.8                                 # 训练集比例
IMAGE_EXTENSIONS = [".JPG", ".jpg", ".jpeg", ".png"]     # 支持的图像扩展名# ====================== 类别映射(需根据实际数据集调整) ======================
# 从原数据集的类别名称生成映射(示例:假设病害类别为文件夹名)
def get_class_mapping(data_dir):class_names = []for folder in os.listdir(data_dir):folder_path = os.path.join(data_dir, folder)if os.path.isdir(folder_path) and not folder.startswith("."):class_names.append(folder)class_names.sort()  # 按字母序排序,确保类别编号固定return {cls: idx for idx, cls in enumerate(class_names)}# ====================== 划分数据集并保存 ======================
def save_dataset(annotations, class_map, output_dir, train_size=0.8):# 划分训练集和验证集random.shuffle(annotations)split_idx = int(len(annotations) * train_size)train_data = annotations[:split_idx]val_data = annotations[split_idx:]# 创建目录结构os.makedirs(os.path.join(output_dir, "train"), exist_ok=True)os.makedirs(os.path.join(output_dir, "val"), exist_ok=True)for cls in class_map.keys():os.makedirs(os.path.join(output_dir, "train", cls), exist_ok=True)os.makedirs(os.path.join(output_dir, "val", cls), exist_ok=True)# 保存训练集for data in train_data:img_path = data["image_path"]cls = data["class_name"]try:shutil.copy2(img_path, os.path.join(output_dir, "train", cls))print(f"图像 {img_path} 复制到训练集 {cls} 类成功")except Exception as e:print(f"图像 {img_path} 复制到训练集 {cls} 类失败,错误信息: {e}")# 保存验证集for data in val_data:img_path = data["image_path"]cls = data["class_name"]try:shutil.copy2(img_path, os.path.join(output_dir, "val", cls))print(f"图像 {img_path} 复制到验证集 {cls} 类成功")except Exception as e:print(f"图像 {img_path} 复制到验证集 {cls} 类失败,错误信息: {e}")# 生成类别名文件(classes.names)with open(os.path.join(output_dir, "classes.names"), "w") as f:for cls in class_map.keys():f.write(f"{cls}\n")# 生成数据集配置文件(dataset.yaml)yaml_path = os.path.join(output_dir, "dataset.yaml")with open(yaml_path, "w") as f:f.write(f"path: {output_dir}\n")  # 数据集根路径f.write(f"train: train\n")  # 训练集路径(相对于path)f.write(f"val: val\n")      # 验证集路径# f.write(f"test: images/test\n")   # 测试集路径(如果有)f.write(f"nc: {len(class_map)}\n")  # 类别数# 修改 names 字段输出格式class_names = list(class_map.keys())f.write(f"names: {class_names}\n")return train_data, val_data# ====================== 主函数 ======================
if __name__ == "__main__":# 1. 检查输入路径是否存在if not os.path.exists(INPUT_DATA_DIR):raise FileNotFoundError(f"请先下载数据集并解压到路径:{INPUT_DATA_DIR}")# 2. 获取类别映射(假设图像按类别存放在子文件夹中)class_map = get_class_mapping(os.path.join(INPUT_DATA_DIR, "train"))  # 假设训练集图像在train子文件夹中,每个子文件夹为一个类别# 3. 解析标注(仅按文件夹分类)annotations = []for cls, idx in class_map.items():cls_dir = os.path.join(INPUT_DATA_DIR, "train", cls)  # 假设类别文件夹路径为train/类别名for img_file in os.listdir(cls_dir):if any(img_file.lower().endswith(ext) for ext in IMAGE_EXTENSIONS):img_path = os.path.join(cls_dir, img_file)annotations.append({"image_path": img_path,"class_name": cls})# 4. 保存为YOLO格式train_data, val_data = save_dataset(annotations, class_map, OUTPUT_YOLO_DIR, train_size=TRAIN_SIZE)print(f"✅ 转换完成!YOLO数据集已保存至:{OUTPUT_YOLO_DIR}")print(f"类别数:{len(class_map)},训练集样本数:{len(train_data)},验证集样本数:{len(val_data)}")

train的时候,使用的文件夹

2. 目标检测数据集文件结构 (例如用于 yolo11n.pt 训练)

import os
import csv
import random
from PIL import Image
from sklearn.model_selection import train_test_split
import shutil# ====================== 配置参数 ======================
# 从 Kaggle Hub 下载植物病害数据集
# https://www.kaggle.com/datasets/vipoooool/new-plant-diseases-dataset
import kagglehub
tf_download_path = kagglehub.dataset_download("vipoooool/new-plant-diseases-dataset")
print("Path to dataset files:", tf_download_path)
# 定义数据集路径
tf_dataset_path = f"{tf_download_path}/New Plant Diseases Dataset(Augmented)/New Plant Diseases Dataset(Augmented)"INPUT_DATA_DIR = tf_dataset_path  # 输入数据集路径(解压后的根目录)
OUTPUT_YOLO_DIR = "./traindata/yolo/yolo_plant_diseases"        # 输出YOLO数据集路径
if os.path.exists(OUTPUT_YOLO_DIR):shutil.rmtree(OUTPUT_YOLO_DIR)
os.makedirs(OUTPUT_YOLO_DIR, exist_ok=True)TRAIN_SIZE = 0.8                                 # 训练集比例
IMAGE_EXTENSIONS = [".JPG", ".jpg", ".jpeg", ".png"]     # 支持的图像扩展名# ====================== 类别映射(需根据实际数据集调整) ======================
# 从原数据集的类别名称生成映射(示例:假设病害类别为文件夹名)
def get_class_mapping(data_dir):class_names = []for folder in os.listdir(data_dir):folder_path = os.path.join(data_dir, folder)if os.path.isdir(folder_path) and not folder.startswith("."):class_names.append(folder)class_names.sort()  # 按字母序排序,确保类别编号固定return {cls: idx for idx, cls in enumerate(class_names)}# ====================== 解析CSV标注(假设标注在CSV中) ======================
def parse_csv_annotations(csv_path, class_map, image_dir):annotations = []with open(csv_path, "r", encoding="utf-8") as f:reader = csv.DictReader(f)for row in reader:image_name = row["image_path"]class_name = row["disease_class"]  # 需与CSV中的类别列名一致x_min = float(row["x_min"])y_min = float(row["y_min"])x_max = float(row["x_max"])y_max = float(row["y_max"])# 检查图像是否存在image_path = os.path.join(image_dir, image_name)if not os.path.exists(image_path):continue# 获取图像尺寸with Image.open(image_path) as img:img_width, img_height = img.size# 转换为YOLO坐标center_x = (x_min + x_max) / 2 / img_widthcenter_y = (y_min + y_max) / 2 / img_heightwidth = (x_max - x_min) / img_widthheight = (y_max - y_min) / img_heightannotations.append({"image_path": image_path,"class_id": class_map[class_name],"bbox": (center_x, center_y, width, height)})return annotations# ====================== 划分数据集并保存 ======================
def save_dataset(annotations, class_map, output_dir, train_size=0.8):# 划分训练集和验证集random.shuffle(annotations)split_idx = int(len(annotations) * train_size)train_data = annotations[:split_idx]val_data = annotations[split_idx:]# 创建目录结构os.makedirs(os.path.join(output_dir, "images/train"), exist_ok=True)os.makedirs(os.path.join(output_dir, "images/val"), exist_ok=True)os.makedirs(os.path.join(output_dir, "labels/train"), exist_ok=True)os.makedirs(os.path.join(output_dir, "labels/val"), exist_ok=True)# 保存训练集for data in train_data:img_path = data["image_path"]lbl_path = os.path.join(output_dir, "labels/train",os.path.splitext(os.path.basename(img_path))[0] + ".txt")# 复制图像try:shutil.copy2(img_path, os.path.join(output_dir, 'images/train'))print(f"图像 {img_path} 复制到训练集成功")except Exception as e:print(f"图像 {img_path} 复制到训练集失败,错误信息: {e}")# 保存标注with open(lbl_path, "w") as f:f.write(f"{data['class_id']} {' '.join(map(str, data['bbox']))}\n")# 保存验证集for data in val_data:img_path = data["image_path"]lbl_path = os.path.join(output_dir, "labels/val",os.path.splitext(os.path.basename(img_path))[0] + ".txt")# 复制图像try:shutil.copy2(img_path, os.path.join(output_dir, 'images/val'))print(f"图像 {img_path} 复制到验证集成功")except Exception as e:print(f"图像 {img_path} 复制到验证集失败,错误信息: {e}")# 保存标注with open(lbl_path, "w") as f:f.write(f"{data['class_id']} {' '.join(map(str, data['bbox']))}\n")# 生成类别名文件(classes.names)with open(os.path.join(output_dir, "classes.names"), "w") as f:for cls in class_map.keys():f.write(f"{cls}\n")# 生成数据集配置文件(dataset.yaml)yaml_path = os.path.join(output_dir, "dataset.yaml")with open(yaml_path, "w") as f:f.write(f"path: {output_dir}\n")  # 数据集根路径f.write(f"train: images/train\n")  # 训练集路径(相对于path)f.write(f"val: images/val\n")      # 验证集路径# f.write(f"test: images/test\n")   # 测试集路径(如果有)f.write(f"nc: {len(class_map)}\n")  # 类别数f.write("names:\n")for idx, cls in enumerate(class_map.keys()):f.write(f"  {idx}: {cls}\n")return train_data, val_data# ====================== 主函数 ======================
if __name__ == "__main__":# 1. 检查输入路径是否存在if not os.path.exists(INPUT_DATA_DIR):raise FileNotFoundError(f"请先下载数据集并解压到路径:{INPUT_DATA_DIR}")# 2. 获取类别映射(假设图像按类别存放在子文件夹中,无CSV标注时使用此方法)# 若有CSV标注,需手动指定CSV路径和列名,注释掉下方代码并取消注释parse_csv_annotations部分class_map = get_class_mapping(os.path.join(INPUT_DATA_DIR, "train"))  # 假设训练集图像在train子文件夹中,每个子文件夹为一个类别# 3. 解析标注(根据实际情况选择CSV或文件夹分类)# 情况A:无标注,仅按文件夹分类(弱监督,边界框为图像全尺寸)annotations = []for cls, idx in class_map.items():cls_dir = os.path.join(INPUT_DATA_DIR, "train", cls)  # 假设类别文件夹路径为train/类别名for img_file in os.listdir(cls_dir):if any(img_file.lower().endswith(ext) for ext in IMAGE_EXTENSIONS):img_path = os.path.join(cls_dir, img_file)with Image.open(img_path) as img:img_width, img_height = img.size# 边界框为全图(弱监督场景,仅用于分类任务,非检测)annotations.append({"image_path": img_path,"class_id": idx,"bbox": (0.5, 0.5, 1.0, 1.0)  # 全图边界框})# # 情况B:有CSV标注(需取消注释以下代码并调整参数)# CSV_PATH = os.path.join(INPUT_DATA_DIR, "labels.csv")  # CSV标注文件路径# IMAGE_DIR = os.path.join(INPUT_DATA_DIR, "images")     # 图像根目录# class_map = {"Apple Scab": 0, "Black Rot": 1, ...}    # 手动定义类别映射# annotations = parse_csv_annotations(CSV_PATH, class_map, IMAGE_DIR)# 4. 保存为YOLO格式train_data, val_data = save_dataset(annotations, class_map, OUTPUT_YOLO_DIR, train_size=TRAIN_SIZE)print(f"✅ 转换完成!YOLO数据集已保存至:{OUTPUT_YOLO_DIR}")print(f"类别数:{len(class_map)},训练集样本数:{len(train_data)},验证集样本数:{len(val_data)}")

train的时候,使用的yaml文件路径

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.pswp.cn/diannao/86172.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

排序算法-归并排序与快速排序

归并排序与快速排序 快速排序是利用的递归思想:选取一个基准数,把小于基准数的放左边 大于的放右边直到整个序列有序 。快排分割函数 O(lognn), 空间 :没有额外开辟新的数组但是递归树调用函数会占用栈内存 O(logn) 。 归并排序:在递归返回的…

北大开源音频编辑模型PlayDiffusion,可实现音频局部编辑,比传统 AR 模型的效率高出 50 倍!

北大开源了一个音频编辑模型PlayDiffusion,可以实现类似图片修复(inpaint)的局部编辑功能 - 只需修改音频中的特定片段,而无需重新生成整段音频。此外,它还是一个高性能的 TTS 系统,比传统 AR 模型的效率高出 50 倍。 自回归 Tra…

MyBatis————入门

1,配置相关 我们上一期详细讲了一下使用注解来实现操作数据库的方式,我们今天使用xml来实现,有同学可能有疑问,使用注解挺方便呀,为啥还要注解呀,先来说一下注解我感觉挺麻烦的,但是我们后面要…

【推荐算法】推荐算法演进史:从协同过滤到深度强化学习

推荐算法演进史:从协同过滤到深度强化学习 一、传统推荐时代:协同过滤的奠基(1990s-2006)1.1 算法背景:信息爆炸的挑战1.2 核心算法:协同过滤1.3 局限性 二、深度学习黎明:神经网络初探&#xf…

Java基于SpringBoot的校园闲置物品交易系统,附源码+文档说明

博主介绍:✌Java老徐、7年大厂程序员经历。全网粉丝12w、csdn博客专家、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ 🍅文末获取源码联系🍅 👇🏻 精彩专栏推荐订阅👇&…

Ajax Systems公司的核心产品有哪些?

Ajax Systems 是一家专注于家庭安全和智能系统的公司,其核心产品如下3: 入侵保护设备:如 MotionCam Outdoor 无线室外运动探测器,配备内置摄像头和两个红外传感器,可通过预装电池运行长达三年,能在 15 米距…

64、js 中require和import有何区别?

在 JavaScript 中,require 和 import 都是用于模块导入的语法,但它们属于不同的模块系统,具有显著的区别: 1. 模块系统不同 require 属于 CommonJS 模块系统(Node.js 默认使用)。 语法:const…

Java+Access综合测评系统源码分享:含论文、开题报告、任务书全套资料

JAVAaccess综合测评系统毕业设计 一、系统概述 本系统采用Java Swing开发前端界面,结合Access数据库实现数据存储,专为教育机构打造的综合测评解决方案。系统包含学生管理、题库管理、在线测评、成绩分析四大核心模块,实现了测评流程的全自…

【python】RGB to YUV and YUV to RGB

文章目录 1、YUV2、YUV vs RGB3、RGB to YUV4、YUV to RGB附录——YUV NV12 vs YUV NV21参考1、YUV YUV 颜色空间,又常被称作 YCbCr 颜色空间,是用于数字电视的颜色空间,在 ITU-R BT.601、BT.709、BT.2020 标准中被明确定义,这三种标准分别针对标清、高清、超高清数字电视…

运行示例程序和一些基本操作

欢迎 ----> 示例 --> 选择sample CTRL B 编译代码 CTRL R 运行exe 项目 中 Shadow build 表示是否 编译生成文件和 源码是否放一块 勾上不在同一个地方 已有项目情况下怎么打开项目 方法一: 左键双击 xxx.pro 方法二: 文件菜单里面 选择打开项目

计算机网络第2章(下):物理层传输介质与核心设备全面解析

目录 一、传输介质1.1 传输介质的分类1.2 导向型传输介质1.2.1 双绞线(Twisted Pair)1.2.2 同轴电缆(Coaxial Cable)1.2.3 光纤(Optical Fiber)1.2.4 以太网对有线传输介质的命名规则 1.3 非导向型传输介质…

PHP文件包含漏洞详解:原理、利用与防御

PHP文件包含漏洞详解:原理、利用与防御 什么是文件包含漏洞? 文件包含漏洞是PHP应用程序中常见的安全问题,当开发者使用包含函数引入文件时,如果传入的文件名参数未经严格校验,攻击者就可能利用这个漏洞读取敏感文件…

5.4.2 Spring Boot整合Redis

本次实战主要围绕Spring Boot与Redis的整合展开,首先创建了一个Spring Boot项目,并配置了Redis的相关属性。接着,定义了三个实体类:Address、Family和Person,分别表示地址、家庭成员和个人信息,并使用Index…

java内存模型JMM

Java 内存模型(Java Memory Model,JMM)定义了 Java 程序中的变量、线程如何和本地内存以及主内存进行交互的规则。它主要涉及到多线程环境下的共享变量可见性、指令重排等问题,是理解并发编程中的关键概念。 核心概念&#xff1a…

配置git命令缩写

以下是 Git 命令缩写的配置方法及常用方案,适用于 Linux/macOS/Windows 系统: 🔧 一、配置方法 1. 命令行设置(推荐) # 基础命令缩写 git config --global alias.st status git config --global alias.co che…

准确--k8s cgroup问题排查

k8s cgroup问题排查 6月 06 17:20:39 k8s-node01 containerd[1515]: time"2025-06-06T17:20:39.42902033408:00" levelerror msg"StartContainer fo r \"46ae0ef9618b96447a1f28fd2229647fe671e8acbcec02c8c46b37051130c8c4\" failed" error&qu…

Go 中 map 的双值检测写法详解

Go 中 map 的双值检测写法详解 在 Go 中,if char, exists : pairs[s[i]]; exists { 是一种利用 Go 语言特性编写的优雅条件语句,用于检测 map 中是否存在某个键。让我们分解解释这种写法: 语法结构解析 if value, ok : mapVariable[key]; …

C# Wkhtmltopdf HTML转PDF碰到的问题

最近碰到一个Html转PDF的需求,看了一下基本上都是需要依赖Wkhtmltopdf,需要在Windows或者linux安装这个可以后使用。找了一下选择了HtmlToPDFCore,这个库是对Wkhtmltopdf.NetCore简单二次封装,这个库的好处就是通过NuGet安装HtmlT…

grafana 批量视图备份及恢复(含数据源)

一、grafana 批量视图备份 import requests import json import urllib3 import osfrom requests.auth import HTTPBasicAuthfilename_folders_map "folders_map.json" type_folder "dash-folder" type_dashboard "dash-db"# Grafana服务器地…

.Net Framework 4/C# 关键字(非常用,持续更新...)

一、is 关键字 is 关键字用于检查对象是否于给定类型兼容,如果兼容将返回 true,如果不兼容则返回 false,在进行类型转换前,可以先使用 is 关键字判断对象是否与指定类型兼容,如果兼容才进行转换,这样的转换是安全的。 例如有:首先创建一个字符串对象,然后将字符串对象隐…