python打卡第37天

知识点回顾:

  1. 过拟合的判断:测试集和训练集同步打印指标
  2. 模型的保存和加载
    1. 仅保存权重
    2. 保存权重和模型
    3. 保存全部信息checkpoint,还包含训练状态
  3. 早停策略

作业:对信贷数据集训练后保存权重,加载权重后继续训练50轮,并采取早停策略

import pandas as pd
import numpy as np
import random
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import classification_report, roc_auc_score
import matplotlib.pyplot as pltdef set_seed(seed=42):random.seed(seed)np.random.seed(seed)torch.manual_seed(seed)torch.cuda.manual_seed_all(seed)torch.backends.cudnn.deterministic = Trueset_seed(42)# 读取数据
data = pd.read_csv('data.csv')
target_col = 'Credit Default'
data = data.fillna(data.median(numeric_only=True))
data = data.fillna('Unknown')categorical_features = ['Home Ownership', 'Purpose', 'Term', 'Years in current job']
numerical_features = [col for col in data.columns if col not in categorical_features + [target_col]]for col in categorical_features:le = LabelEncoder()data[col] = le.fit_transform(data[col])X = data[categorical_features + numerical_features]
y = data[target_col]scaler = StandardScaler()
X[numerical_features] = scaler.fit_transform(X[numerical_features])X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, stratify=y, random_state=42)class CreditDataset(Dataset):def __init__(self, X, y):self.X = torch.tensor(X.values, dtype=torch.float32)self.y = torch.tensor(y.values, dtype=torch.float32)def __len__(self):return len(self.X)def __getitem__(self, idx):return self.X[idx], self.y[idx]train_dataset = CreditDataset(X_train, y_train)
test_dataset = CreditDataset(X_test, y_test)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)class CreditNet(nn.Module):def __init__(self, input_dim):super(CreditNet, self).__init__()self.model = nn.Sequential(nn.Linear(input_dim, 64),nn.BatchNorm1d(64),nn.ReLU(),nn.Dropout(0.3),nn.Linear(64, 32),nn.BatchNorm1d(32),nn.ReLU(),nn.Dropout(0.2),nn.Linear(32, 1))def forward(self, x):return self.model(x).squeeze(1)device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = CreditNet(X_train.shape[1]).to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)def train(model, loader, criterion, optimizer):model.train()total_loss = 0for X_batch, y_batch in loader:X_batch, y_batch = X_batch.to(device), y_batch.to(device)optimizer.zero_grad()outputs = model(X_batch)loss = criterion(outputs, y_batch)loss.backward()optimizer.step()total_loss += loss.item() * X_batch.size(0)return total_loss / len(loader.dataset)def evaluate(model, loader):model.eval()preds, targets = [], []with torch.no_grad():for X_batch, y_batch in loader:X_batch = X_batch.to(device)outputs = torch.sigmoid(model(X_batch)).cpu().numpy()preds.extend(outputs)targets.extend(y_batch.numpy())preds = np.array(preds)targets = np.array(targets)preds_label = (preds > 0.5).astype(int)auc = roc_auc_score(targets, preds)report = classification_report(targets, preds_label, digits=4)return auc, report# 训练主循环
epochs = 20
train_losses = []
test_aucs = []for epoch in range(epochs):train_loss = train(model, train_loader, criterion, optimizer)auc, _ = evaluate(model, test_loader)train_losses.append(train_loss)test_aucs.append(auc)print(f"Epoch {epoch+1}/{epochs} - Train Loss: {train_loss:.4f} - Test AUC: {auc:.4f}")# 可视化训练损失和AUC曲线
plt.figure(figsize=(12,5))
plt.subplot(1,2,1)
plt.plot(range(1, epochs+1), train_losses, marker='o')
plt.xlabel('Epoch')
plt.ylabel('Train Loss')
plt.title('Training Loss Curve')
plt.grid(True)
plt.subplot(1,2,2)
plt.plot(range(1, epochs+1), test_aucs, marker='o', color='orange')
plt.xlabel('Epoch')
plt.ylabel('Test AUC')
plt.title('Test AUC Curve')
plt.grid(True)
plt.tight_layout()
plt.show()# 保存模型权重
torch.save(model.state_dict(), "credit_model.pth")
# 定义早停类
class EarlyStopping:def __init__(self, patience=5, delta=1e-4):self.patience = patienceself.delta = deltaself.best_score = Noneself.counter = 0self.early_stop = Falsedef __call__(self, score):if self.best_score is None or score > self.best_score + self.delta:self.best_score = scoreself.counter = 0else:self.counter += 1if self.counter >= self.patience:self.early_stop = True# 加载权重并继续训练
model.load_state_dict(torch.load("credit_model.pth"))
epochs_continue = 50
early_stopping = EarlyStopping(patience=5, delta=1e-4)
train_losses2 = []
test_aucs2 = []for epoch in range(epochs_continue):train_loss = train(model, train_loader, criterion, optimizer)auc, _ = evaluate(model, test_loader)train_losses2.append(train_loss)test_aucs2.append(auc)print(f"[Continue] Epoch {epoch+1}/{epochs_continue} - Train Loss: {train_loss:.4f} - Test AUC: {auc:.4f}")early_stopping(auc)if early_stopping.early_stop:print("Early stopping triggered!")break# 可视化继续训练的曲线
plt.figure(figsize=(12,5))
plt.subplot(1,2,1)
plt.plot(range(1, len(train_losses2)+1), train_losses2, marker='o')
plt.xlabel('Epoch')
plt.ylabel('Train Loss')
plt.title('Continue Training Loss Curve')
plt.grid(True)
plt.subplot(1,2,2)
plt.plot(range(1, len(test_aucs2)+1), test_aucs2, marker='o', color='orange')
plt.xlabel('Epoch')
plt.ylabel('Test AUC')
plt.title('Continue Test AUC Curve')
plt.grid(True)
plt.tight_layout()
plt.show()# 最终评估
auc, report = evaluate(model, test_loader)
print(f"\nFinal Test AUC: {auc:.4f}")
print("Classification Report:\n", report)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.pswp.cn/web/81483.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【洛谷P9303题解】AC- [CCC 2023 J5] CCC Word Hunt

在CCC单词搜索游戏中,单词隐藏在一个字母网格中。目标是确定给定单词在网格中隐藏的次数。单词可以以直线或直角的方式排列。以下是详细的解题思路及代码实现: 传送门: https://www.luogu.com.cn/problem/P9303 解题思路 输入读取与初始化&…

LangGraph + LLM + stream_mode

文章目录 LLM 代码valuesmessagesupdatesmessages updatesmessages updates 2 LLM 代码 from dataclasses import dataclassfrom langchain.chat_models import init_chat_model from langgraph.graph import StateGraph, STARTfrom langchain_openai import ChatOpenAI # 初…

Pydantic 学习与使用

Pydantic 学习与使用 在 Fastapi 的 Web 开发中的数据验证通常都是在使用 Pydantic 来进行数据的校验,本文将对 Pydantic 的使用方法做记录与学习。 **简介:**Pydantic 是一个在 Python 中用于数据验证和解析的第三方库,它现在是 Python 使…

批量文件重命名工具

分享一个自己使用 python 开发的小软件,批量文件重命名工具,主要功能有批量中文转拼音,简繁体转换,大小写转换,替换文件名,删除指定字符,批量添加编号,添加前缀/后缀。同时还有文件时…

多语言视角下的 DOM 操作:从 JavaScript 到 Python、Java 与 C#

多语言视角下的 DOM 操作:从 JavaScript 到 Python、Java 与 C# 在 Web 开发中,文档对象模型(DOM)是构建动态网页的核心技术。它将 HTML/XML 文档解析为树形结构,允许开发者通过编程方式访问和修改页面内容、结构和样…

【C/C++】红黑树学习笔记

文章目录 红黑树1 基本概念1.1 定义1.2 基本特性推理1.3 对比1.4 延伸1.4.1 简单判别是否是红黑树1.4.2 应用 2 插入2.1 插入结点默认红色2.2 插入结点2.2.1 插入结点是根结点2.2.2 插入结点的叔叔是红色2.2.3 插入结点的叔叔是黑色场景分析LL型RR型LR型RL型 3 构建4 示例代码 …

网络通信的基石:深入理解帧与报文

在这个万物互联的时代,我们每天都在享受着网络带来的便利——从早晨查看天气预报,到工作中的视频会议,再到晚上刷着短视频放松。然而,在这些看似简单的网络交互背后,隐藏着精密而复杂的数据传输机制。今天,…

STM32 SPI通信(硬件)

一、SPI外设简介 STM32内部集成了硬件SPI收发电路,可以由硬件自动执行时钟生成、数据收发等功能,减轻CPU的负担 可配置8位/16位数据帧、高位先行/低位先行 时钟频率: fPCLK / (2, 4, 8, 16, 32, 64, 128, 256) 支持多主机模型、主或从操作 可…

尚硅谷redis7-11-redis10大类型之总体概述

前提:我们说的数据类型一般是value的数据类型,key的类型都是字符串。 redis字符串【String】 string类型是二进制安全的,意思是redis的string可以包含任何数据,比如jpg图片或者序列化的对象。 string类型是Redis最基本的数据类型,一个redis中字符串va…

【递归、搜索与回溯算法】专题一 递归

文章目录 0.理解递归、搜索与回溯1.面试题 08.06.汉诺塔问题1.1 题目1.2 思路1.3 代码 2. 合并两个有序链表2.1 题目2.2 思路2.3 代码 3.反转链表3.1 题目3.2 思路3.3 代码 4.两两交换链表中的节点4.1 题目4.2 思路4.3 代码 5. Pow(x, n) - 快速幂5.1 题目5.2 思路5.3 代码 0.理…

C#实现List导出CSV:深入解析完整方案

C#实现List导出CSV:深入解析完整方案 在数据交互场景中,CSV文件凭借其跨平台兼容性和简洁性,成为数据交换的重要载体。本文将基于C#反射机制实现的通用CSV导出方案,结合实际开发中的痛点,从基础实现、深度优化到生产级…

字符串day7

344 反转字符串 字符串理论上也是一个数组&#xff0c;因此只需要用双指针即可 class Solution { public:void reverseString(vector<char>& s) {for(int i0,js.size()-1;i<j;i,j--){swap(s[i],s[j]);}} };541 反转字符串 自己实现一个反转从start到end的字符串…

Grafana XSSOpenRedirectSSRF漏洞复现(CVE-2025-4123)

免责申明: 本文所描述的漏洞及其复现步骤仅供网络安全研究与教育目的使用。任何人不得将本文提供的信息用于非法目的或未经授权的系统测试。作者不对任何由于使用本文信息而导致的直接或间接损害承担责任。如涉及侵权,请及时与我们联系,我们将尽快处理并删除相关内容。 前…

私服 nexus 之间迁移 npm 仓库

本文介绍如何将一个 Nexus 特定仓库中的 npm 包内容迁移到另一个 Nexus 特定仓库。此过程适用于需要重构仓库结构或合并仓库的场景。 迁移脚本 以下是完整的迁移脚本&#xff0c;它会自动完成以下操作&#xff1a; 从源仓库获取所有 npm 包列表下载每个包的 .tgz 文件解压并…

Django ToDoWeb 服务

我们的任务是使用 Django 创建一个简单的 ToDo 应用程序,允许用户添加、查看和删除笔记。我们将通过设置 Django 项目、创建 Todo 模型、设计表单和视图来处理用户输入以及创建模板来显示任务来构建它。我们将逐步实现核心功能以有效地管理 todo 项。 Django ToDoWeb 服务 …

阿里云服务器遭遇DDoS攻击?低成本第三方高防解决方案全解析

阿里云服务器因高性能和稳定性备受青睐&#xff0c;但其DDoS高防服务的价格常让中小企业望而却步。面对动辄每月数万元的防护成本&#xff0c;许多用户不禁疑问&#xff1a;能否通过第三方高防服务保护阿里云服务器&#xff1f;如何实现低成本高效防御&#xff1f; 本文将结合技…

2025山东CCPC补题

2025山东CCPC补题 目录 2025山东CCPC补题K - UNO&#xff01; &#xff08;双端队列的简单应用&#xff09;M - 第九届河北省大学生程序设计竞赛 &#xff08;二进制枚举模拟&#xff09;J - Generate 01 String 感觉这场比赛的题目挺不错的&#xff1b;没有说那些为了算法而算…

体绘制学习

一、基本概念 体绘制是对一个三维物体数据进行采样与拟合的过程。 在体绘制中用vtkVolume渲染数据 渲染数据类数据转换类几何渲染vtkActorvtkPolyDataMapper体渲染vtkVolumevtkVolumeRayCastMapper 体绘制常用算法如下。 光线投射法。 优点是可视化结果质量好。缺点是计算…

告别“盘丝洞”车间:4-20mA无线传输如何重构工厂神经网?

4-20ma无线传输是利用无线模块将传统的温度、压力、液位等4-20mA电流信号转换为无线信号进行传输。这一技术突破了有线传输的限制&#xff0c;使得信号可以在更广泛的范围内进行灵活、快速的传递&#xff0c;无线传输距离可达到50KM。达泰4-20ma无线传输模块在实现工业现场应用…

VB.NET与SQL连接问题解决方案

1.基本连接步骤 使用SqlConnection、SqlCommand和SqlDataReader进行基础操作&#xff1a; vb.net Imports System.Data.SqlClient Public Sub ConnectToDatabase() Dim connectionString As String "ServermyServerAddress;DatabasemyDataBase;Integrated Security…