完整项目流程总结
1. 环境准备与依赖导入
import time import os import numpy as np import pandas as pd import torch import torch.nn as nn import torch.optim as optim import torchvision import torchvision.transforms as transforms from torch.utils.data import DataLoader from torchvision.datasets import CIFAR10 from torchvision.models import resnet18, ResNet18_Weights import wandb from torch.utils.tensorboard import SummaryWriter from sklearn.metrics import * import matplotlib.pyplot as plt
2. 数据准备与增强
# 数据增强变换 transform = transforms.Compose([transforms.RandomRotation(45),transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(p=0.5),transforms.RandomVerticalFlip(p=0.5),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2471, 0.2435, 0.2616)), ]) # 测试集变换 transformtest = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2471, 0.2435, 0.2616)), ]) # 数据集加载 train_dataset = CIFAR10(root=datapath,train=True,download=True,transform=transform, ) train_loader = DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True,num_workers=2, )
3. 模型构建与初始化
# 获取ResNet18模型并调整全连接层 model = resnet18(weights=None) in_features = model.fc.in_features model.fc = nn.Linear(in_features=in_features, out_features=10) # 加载预训练权重(如果有) if os.path.exists(weightpath):weights_default = torch.load(weightpath)weights_default.pop("fc.weight", None)weights_default.pop("fc.bias", None)new_state_dict = model.state_dict()weights_default_process = {k: v for k, v in weights_default.items() if k in new_state_dict}new_state_dict.update(weights_default_process)model.load_state_dict(new_state_dict) model.to(device)
4. 训练过程
# 初始化训练工具 loss_fn = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=lr) # 可视化工具初始化 wandb.init(project="my-qianyi-project", config={...}) write1 = SummaryWriter(log_dir=log_dir) write1.add_graph(model, input_to_model=torch.randn(1, 3, 32, 32).to(device)) # 训练循环 for epoch in range(epochs):model.train()# 训练代码...torch.save(model.state_dict(), weightpath)
5. 验证与评估
# 加载最佳模型进行验证 model.load_state_dict(torch.load(weightpath)) model.eval() # 验证过程 # 保存预测结果到CSV # 生成分类报告和混淆矩阵
6. 模型应
# 加载模型进行推理 def predict_image(image_path):# 图像预处理# 模型预测# 返回结果
7. 模型移植与部署
7.1 模型转换(PyTorch → ONNX/)
python
# 转换为ONNX格式 def convert_to_onnx(model, input_size, onnx_path):model.eval()dummy_input = torch.randn(1, *input_size).to(device)torch.onnx.export(model,dummy_input,onnx_path,export_params=True,opset_version=11,do_constant_folding=True,input_names=['input'],output_names=['output'],dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}})print(f"Model converted to ONNX and saved to {onnx_path}") # 使用示例 convert_to_onnx(model, (3, 32, 32), "model.onnx")
7.2 模型量化(减小模型大小,加速推理)
python
# 动态量化 def quantize_model(model):quantized_model = torch.quantization.quantize_dynamic(model, {nn.Linear}, dtype=torch.qint8)return quantized_model # 使用示例 quantized_model = quantize_model(model) torch.save(quantized_model.state_dict(), "quantized_model.pth")
7.3 减少参数数量
# 简单的权重剪枝 def prune_model(model, pruning_percentage=0.2):parameters_to_prune = []for name, module in model.named_modules():if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):parameters_to_prune.append((module, 'weight'))torch.nn.utils.prune.global_unstructured(parameters_to_prune,pruning_method=torch.nn.utils.prune.L1Unstructured,amount=pruning_percentage,)return model # 使用示例 pruned_model = prune_model(model)
7.4 移动端部署(使用ONNX Runtime)
# 保存为LibTorch格式(C++可用) example = torch.rand(1, 3, 32, 32).to(device) traced_script_module = torch.jit.trace(model, example) traced_script_module.save("model.pt")
7.5 Web部署(使用ONNX.js)
# 首先转换为ONNX,然后使用ONNX.js在浏览器中运行 # 或者使用第三方工具如https://github.com/onnx/tensorflow-onnx
7.6 边缘设备部署(使用TensorRT、OpenVINO等)
# 使用NVIDIA TensorRT优化(需要先转换为ONNX) # 或使用Intel OpenVINO工具包
8. 性能监控与优化
# 模型推理速度测试 def benchmark_model(model, input_size, num_runs=100):model.eval()input_tensor = torch.randn(1, *input_size).to(device)# GPU预热for _ in range(10):_ = model(input_tensor)# 计时start_time = time.time()for _ in range(num_runs):_ = model(input_tensor)end_time = time.time()avg_time = (end_time - start_time) / num_runsfps = 1 / avg_timeprint(f"Average inference time: {avg_time*1000:.2f} ms, FPS: {fps:.2f}")return avg_time, fps # 使用示例 benchmark_model(model, (3, 32, 32))
这个完整的流程涵盖了从数据准备到模型部署的全过程,特别是新增的模型移植部分,提供了将训练好的模型部署到不同平台和设备的方法,这对于实际应用非常重要。