day53

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
from sklearn.preprocessing import MinMaxScaler
from sklearn.datasets import load_iris
import warnings

# 忽略不必要的警告信息
warnings.filterwarnings("ignore")

# --------------------------
# 1. 配置训练参数与设备
# --------------------------

# 潜在空间维度(生成器的输入维度)
latent_dim = 10  
# 训练总轮数(GAN通常需要较多迭代才能收敛)
train_epochs = 10000  
# 批次大小(根据数据集规模调整)
batch_size = 32  
# 学习率(控制参数更新幅度)
learning_rate = 0.0002  
# Adam优化器的动量参数(影响收敛稳定性)
beta1 = 0.5  

# 自动选择运算设备(优先GPU,没有则用CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"当前使用设备: {device}")

# --------------------------
# 2. 数据加载与预处理
# --------------------------

# 加载鸢尾花数据集
iris_dataset = load_iris()
# 提取特征数据和标签
features = iris_dataset.data
labels = iris_dataset.target

# 只选取Setosa类别(标签为0)的数据进行训练
setosa_features = features[labels == 0]

# 将数据缩放到[-1, 1]区间(配合生成器的Tanh输出激活)
scaler = MinMaxScaler(feature_range=(-1, 1))
scaled_features = scaler.fit_transform(setosa_features)

# 转换为PyTorch张量并创建数据加载器
# 注意:必须转为float类型才能与模型参数兼容
data_tensor = torch.from_numpy(scaled_features).float()
dataset = TensorDataset(data_tensor)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# 打印数据基本信息
print(f"训练样本数量: {len(scaled_features)}")
print(f"特征维度: {scaled_features.shape[1]}")  # 鸢尾花数据集固定为4维特征

# --------------------------
# 3. 定义生成器和判别器
# --------------------------

class Generator(nn.Module):
    """生成器:将随机噪声转换为模拟的鸢尾花特征数据"""
    def __init__(self):
        super(Generator, self).__init__()
        # 简单的全连接网络结构
        self.net = nn.Sequential(
            nn.Linear(latent_dim, 16),  # 从潜在空间映射到16维
            nn.ReLU(),  # 激活函数增加非线性
            nn.Linear(16, 32),  # 进一步映射到32维
            nn.ReLU(),
            nn.Linear(32, 4),  # 输出4维特征(与真实数据一致)
            nn.Tanh()  # 确保输出在[-1, 1]范围内
        )
    
    def forward(self, x):
        # 前向传播:输入噪声,输出生成的数据
        return self.net(x)

class Discriminator(nn.Module):
    """判别器:区分输入数据是真实样本还是生成器伪造的"""
    def __init__(self):
        super(Discriminator, self).__init__()
        # 简单的全连接网络结构
        self.net = nn.Sequential(
            nn.Linear(4, 32),  # 输入4维特征
            nn.LeakyReLU(0.2),  # LeakyReLU避免梯度消失问题
            nn.Linear(32, 16),  # 压缩到16维
            nn.LeakyReLU(0.2),
            nn.Linear(16, 1),  # 输出单个概率值
            nn.Sigmoid()  # 将输出压缩到[0,1](表示真实数据的概率)
        )
    
    def forward(self, x):
        # 前向传播:输入数据,输出判断概率
        return self.net(x)

# 初始化模型并移动到运算设备
generator = Generator().to(device)
discriminator = Discriminator().to(device)

# 打印模型结构
print("\n生成器结构:")
print(generator)
print("\n判别器结构:")
print(discriminator)

# --------------------------
# 4. 配置训练组件
# --------------------------

# 定义损失函数(二元交叉熵,适合二分类问题)
criterion = nn.BCELoss()

# 定义优化器(分别优化生成器和判别器)
gen_optimizer = optim.Adam(generator.parameters(), lr=learning_rate, betas=(beta1, 0.999))
dis_optimizer = optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(beta1, 0.999))

# --------------------------
# 5. 开始训练
# --------------------------

print("\n--- 训练开始 ---")
for epoch in range(train_epochs):
    # 遍历数据加载器中的每一批次
    for batch_idx, (real_data,) in enumerate(data_loader):
        # 将真实数据移动到运算设备
        real_data = real_data.to(device)
        current_batch_size = real_data.size(0)  # 获取当前批次的实际样本数(最后一批可能不满)
        
        # 创建标签:真实数据标为1,生成数据标为0
        real_labels = torch.ones(current_batch_size, 1).to(device)
        fake_labels = torch.zeros(current_batch_size, 1).to(device)
        
        # --------------------
        # 训练判别器
        # --------------------
        dis_optimizer.zero_grad()  # 清空判别器的梯度缓存
        
        # 1. 用真实数据训练
        real_output = discriminator(real_data)
        # 计算真实数据的损失(希望判别器能认出真实数据)
        loss_real = criterion(real_output, real_labels)
        
        # 2. 用生成的数据训练
        # 生成随机噪声(作为生成器的输入)
        noise = torch.randn(current_batch_size, latent_dim).to(device)
        # 生成假数据,并阻断梯度流向生成器(避免影响生成器参数)
        fake_data = generator(noise).detach()
        fake_output = discriminator(fake_data)
        # 计算假数据的损失(希望判别器能认出假数据)
        loss_fake = criterion(fake_output, fake_labels)
        
        # 总损失反向传播并更新判别器参数
        dis_loss = loss_real + loss_fake
        dis_loss.backward()
        dis_optimizer.step()
        
        # --------------------
        # 训练生成器
        # --------------------
        gen_optimizer.zero_grad()  # 清空生成器的梯度缓存
        
        # 重新生成假数据(这次需要计算生成器的梯度)
        noise = torch.randn(current_batch_size, latent_dim).to(device)
        fake_data = generator(noise)
        fake_output = discriminator(fake_data)
        
        # 生成器的损失:希望判别器把假数据当成真的(所以标签用real_labels)
        gen_loss = criterion(fake_output, real_labels)
        gen_loss.backward()
        gen_optimizer.step()
    
    # 每1000轮打印一次训练状态
    if (epoch + 1) % 1000 == 0:
        print(
            f"轮次 [{epoch+1}/{train_epochs}], "
            f"判别器损失: {dis_loss.item():.4f}, "
            f"生成器损失: {gen_loss.item():.4f}"
        )

print("\n--- 训练完成 ---")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/pingmian/87696.shtml
繁体地址,请注明出处:http://hk.pswp.cn/pingmian/87696.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

c++ python 共享内存

一、目的 是为了c来读取并解码传递给python,Python做测试非常方便,c 和 python之间必须定好协议,整体使用c 来解码,共享内存传递给python 二、主类 主类,串联decoder,注意decoder并没有直接在显存里面穿…

react函数组件的props,ref,state。

react开发我们会把页面分为一个个组件,组件是独立而且可复用的重复代码片段。具体来说组件可以是一个按钮,一个输入框。react组件有两种定义方法,一种是函数组件,一种是类组件。我们这里说一下函数组件之间父子之间如何传递props参…

基于ARM+FPGA实现的BISS-C协议解决方案,适用于高精度光栅位移传感器等

模块简介 本资源提供了专为FPGA设计的BISS-C接口协议发送模块源码。BISS-C模式作为一种高速、同步的串行通信协议,广泛应用于高精度光栅位移传感器的数据传输中,特别适用于需要精确位置信息的应用场景。此模式遵循主从架构,其中FPGA作为主控制…

spring中@Transactional注解和事务的实战理解附代码

文章目录 前言一、事务是什么?二、事务的特性2.1隔离性2.2事务的隔离级别 三、Transactional注解Transactional注解简介基本用法常用属性配置事务传播行为事务隔离级别异常处理与回滚性能优化建议 四、 事务不生效的可能原因方法访问权限非public自调用问题异常被捕…

替代进口SCA7606【智芯微】国产高精度电流传感器 工业新能源电网专用

SCA7606(智芯微)产品解析与推广文案一、产品概述SCA7606 是 智芯微电子(ZXMICRO) 推出的一款 高精度数字隔离式电流传感器芯片,采用 霍尔效应数字输出 技术,专为 工业控制、新能源、智能电网 等领域的电流检…

Java 与 Vue 全栈开发:“一课一得“ 学习笔记系统实战

一、项目背景与核心价值 "一课一得" 是一个面向学习者的笔记管理平台,旨在帮助用户系统化记录、整理和回顾学习内容。项目采用前后端分离架构:前端基于 Vue.js 构建交互式界面,后端使用 Java Spring Boot 实现业务逻辑&#xff0c…

百度文心大模型 4.5 开源深度测评:技术架构、部署实战与生态协同全解析

声明:本文只做实际测评,并非广告 1.前言 2025 年 6 月 30 日,百度做出一项重大举措,将文心大模型 4.5 系列正式开源,并选择国内领先的开源平台 GitCode 作为首发平台。该模型也是百度在2025年3月16日发布的自研的新一…

力扣_链表_python版本

一、206. 反转链表代码: class Solution:def reverseList(self, head):dummy ListNode()cur headwhile cur:last cur.nextcur.next dummy.nextdummy.next curcur lastreturn dummy.next二、92. 反转链表 IIclass Solution:def reverseBetween(self, head: Opt…

[netty5: WebSocketProtocolHandler]-源码分析

在阅读这篇文章前,推荐先阅读:[netty5: MessageToMessageCodec & MessageToMessageEncoder & MessageToMessageDecoder]-源码分析 WebSocketProtocolHandler WebSocketProtocolHandler 是 WebSocket 处理的基础抽象类,负责管理 Web…

[2025CVPR]一种新颖的视觉与记忆双适配器(Visual and Memory Dual Adapter, VMDA)

引言 多模态目标跟踪(Multi-modal Object Tracking)旨在通过结合RGB模态与其他辅助模态(如热红外、深度、事件数据)来增强可见光传感器的感知能力,尤其在复杂场景下显著提升跟踪鲁棒性。然而,现有方法在频…

理想汽车6月交付36279辆 第二季度共交付111074辆

理想汽车-W(02015)发布公告,2025年6月,理想汽车交付新车36279辆,第二季度共交付111074辆。截至2025年6月30日,理想汽车历史累计交付量为133.78万辆。 在成立十周年之际,理想汽车已连续两年成为人民币20万元以上中高端市…

MobileNets: 高效的卷积神经网络用于移动视觉应用

摘要 我们提出了一类高效的模型,称为MobileNets,专门用于移动和嵌入式视觉应用。MobileNets基于一种简化的架构,利用深度可分离卷积构建轻量级的深度神经网络。我们引入了两个简单的全局超参数,能够有效地在延迟和准确性之间进行…

SDP服务发现协议:动态查询设备能力的底层逻辑(面试深度解析)

SDP的底层逻辑揭示了物联网设备交互的本质——先建立认知,再开展协作。 一、SDP 核心知识点高频考点解析 1.1 SDP 的定位与作用 考点:SDP 在蓝牙协议栈中的位置及核心功能 解析:SDP(Service Discovery Protocol,服务发现协议)位于蓝牙协议栈的中间层,依赖 L2CAP 协议传…

CppCon 2018 学习:GIT, CMAKE, CONAN

提到的: “THE MOST COMMON C TOOLSET” VERSION CONTROL SYSTEM BUILDING PACKAGE MANAGEMENT 这些是 C 项目开发中最核心的工具链组成部分。下面我将逐一解释每部分的作用、常见工具,以及它们如何协同构建现代 C 项目。 1. VERSION CONTROL SYSTEM&am…

使用tensorflow的线性回归的例子(五)

我们使用Iris数据,Sepal length为y值而Petal width为x值。import matplotlib.pyplot as pltimport numpy as npimport tensorflow as tffrom sklearn import datasetsfrom tensorflow.python.framework import opsops.reset_default_graph()# Load the data# iris.d…

虚幻基础:动作——蒙太奇

能帮到你的话,就给个赞吧 😘 文章目录 动作——蒙太奇如果动作被打断,则后续的动画通知不会执行 动作——蒙太奇 如果动作被打断,则后续的动画通知不会执行

[工具系列] 开源的 API 调试工具 Postwoman

介绍 随着 Web 应用的复杂性增加,API 测试已成为开发中不可或缺的一部分,无论是前端还是后端开发,确保 API 正常运行至关重要。 Postman 长期以来是开发者进行 API 测试的首选工具,但是很多基本功能都需要登陆才能使用&#xff…

【力扣 简单 C】746. 使用最小花费爬楼梯

目录 题目 解法一 题目 解法一 int min(int a, int b) {return a < b ? a : b; }int minCostClimbingStairs(int* cost, int costSize) {const int n costSize; // 楼顶&#xff0c;第n阶// 爬到第n阶的最小花费 // 爬到第n-1阶的最小花费从第n-1阶爬上第n阶的花费…

python+django开发带auth接口

pythondjango开发带auth接口 # coding utf-8 import base64 from django.contrib import auth as django_authfrom django.core.exceptions import ObjectDoesNotExist from django.http import JsonResponsefrom sign.models import Eventdef user_auth(request):"&quo…

RBAC权限模型如何让API访问控制既安全又灵活?

url: /posts/9f01e838545ae8d34016c759ef461423/ title: RBAC权限模型如何让API访问控制既安全又灵活? date: 2025-07-01T04:52:07+08:00 lastmod: 2025-07-01T04:52:07+08:00 author: cmdragon summary: RBAC权限模型通过用户、角色和权限的关联实现访问控制,核心组件包括用…