(十)ps识别:Swin Transformer-T 与 ResNet50 结合的 PS 痕迹识别模型训练过程解析

Swin Transformer-T 与 ResNet50 结合的 PS 痕迹识别模型

思路分析

  1. 模型融合思路

    • 利用ResNet50提取图像的局部纹理和边缘特征,这对检测篡改区域的细微变化非常重要
    • 利用Swin Transformer-T捕捉全局上下文信息和长距离依赖关系,有助于理解图像整体一致性
    • 通过特征融合策略结合两种模型的输出,兼顾局部细节和全局语义
  2. 特征融合策略

    • 采用晚期融合策略,将两种模型的高层特征进行拼接
    • 加入注意力机制,让模型自动学习不同特征的重要性权重
    • 使用多层感知机(MLP)进行最终的分类决策
  3. 训练策略

    • 使用交叉熵损失函数处理二分类问题(真实/篡改)
    • 采用学习率调度策略,动态调整训练过程
    • 加入数据增强技术,提高模型泛化能力

代码实现

import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from torchvision.transforms import functional as F
from timm import create_model
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import random# 设置随机种子,确保结果可复现
def set_seed(seed=42):random.seed(seed)np.random.seed(seed)torch.manual_seed(seed)if torch.cuda.is_available():torch.cuda.manual_seed(seed)torch.cuda.manual_seed_all(seed)torch.backends.cudnn.deterministic = Truetorch.backends.cudnn.benchmark = Falseset_seed()# 自定义数据集类
class PSTraceDataset(Dataset):def __init__(self, root_dir, transform=None, train=True):"""Args:root_dir (string): 数据集根目录,包含'original'和'tampered'两个子文件夹transform (callable, optional): 应用于样本的变换train (bool): 是否为训练集"""self.root_dir = root_dirself.transform = transformself.train = train# 加载原始图像和篡改图像的路径self.original_images = [os.path.join(root_dir, 'original', f) for f in os.listdir(os.path.join(root_dir, 'original')) if f.endswith(('png', 'jpg', 'jpeg'))]self.tampered_images = [os.path.join(root_dir, 'tampered', f) for f in os.listdir(os.path.join(root_dir, 'tampered')) if f.endswith(('png', 'jpg', 'jpeg'))]# 平衡数据集(取数量较少的类别作为基准)min_count = min(len(self.original_images), len(self.tampered_images))self.original_images = self.original_images[:min_count]self.tampered_images = self.tampered_images[:min_count]# 创建标签:0表示原始图像,1表示篡改图像self.images = self.original_images + self.tampered_imagesself.labels = [0] * min_count + [1] * min_count# 划分训练集和验证集(8:2)if self.train:split_idx = int(0.8 * len(self.images))self.images = self.images[:split_idx]self.labels = self.labels[:split_idx]else:split_idx = int(0.8 * len(self.images))self.images = self.images[split_idx:]self.labels = self.labels[split_idx:]def __len__(self):return len(self.images)def __getitem__(self, idx):img_path = self.images[idx]image = Image.open(img_path).convert('RGB')label = self.labels[idx]if self.transform:image = self.transform(image)return image, label# 定义数据增强和预处理
def get_transforms():train_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.RandomHorizontalFlip(p=0.5),transforms.RandomVerticalFlip(p=0.5),transforms.RandomRotation(degrees=15),transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])val_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])return train_transform, val_transform# 定义融合模型
class PSTraceDetector(nn.Module):def __init__(self, num_classes=2):super(PSTraceDetector, self).__init__()# 加载预训练的ResNet50self.resnet = models.resnet50(pretrained=True)# 移除最后的全连接层,保留特征提取部分self.resnet_features = nn.Sequential(*list(self.resnet.children())[:-1])# 加载预训练的Swin Transformer-Tinyself.swin = create_model('swin_tiny_patch4_window7_224', pretrained=True)# 移除最后的分类头self.swin_features = nn.Sequential(*list(self.swin.children())[:-1])# 获取两种模型的输出特征维度self.resnet_out_dim = 2048  # ResNet50的输出维度self.swin_out_dim = 768     # Swin Transformer-Tiny的输出维度# 注意力机制用于特征融合self.attention = nn.Sequential(nn.Linear(self.resnet_out_dim + self.swin_out_dim, 512),nn.ReLU(),nn.Linear(512, 2),nn.Softmax(dim=1))# 最终分类器self.classifier = nn.Sequential(nn.Linear(self.resnet_out_dim + self.swin_out_dim, 512),nn.BatchNorm1d(512),nn.ReLU(),nn.Dropout(0.5),nn.Linear(512, 256),nn.BatchNorm1d(256),nn.ReLU(),nn.Dropout(0.5),nn.Linear(256, num_classes))# 冻结部分预训练层,只微调高层for param in list(self.resnet.parameters())[:-100]:param.requires_grad = Falsefor param in list(self.swin.parameters())[:-100]:param.requires_grad = Falsedef forward(self, x):# ResNet特征提取resnet_feat = self.resnet_features(x)resnet_feat = resnet_feat.view(resnet_feat.size(0), -1)  # 展平# Swin Transformer特征提取swin_feat = self.swin_features(x)swin_feat = swin_feat.view(swin_feat.size(0), -1)  # 展平# 特征融合combined = torch.cat((resnet_feat, swin_feat), dim=1)# 应用注意力机制attn_weights = self.attention(combined)attn_resnet = attn_weights[:, 0].unsqueeze(1) * resnet_featattn_swin = attn_weights[:, 1].unsqueeze(1) * swin_featattn_combined = torch.cat((attn_resnet, attn_swin), dim=1)# 分类out = self.classifier(attn_combined)return out# 训练函数
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs, device):best_val_acc = 0.0train_losses = []val_losses = []train_accs = []val_accs = []for epoch in range(num_epochs):print(f'Epoch {epoch+1}/{num_epochs}')print('-' * 10)# 训练阶段model.train()running_loss = 0.0running_corrects = 0all_preds = []all_labels = []for inputs, labels in train_loader:inputs = inputs.to(device)labels = labels.to(device)# 清零梯度optimizer.zero_grad()# 前向传播outputs = model(inputs)_, preds = torch.max(outputs, 1)loss = criterion(outputs, labels)# 反向传播和优化loss.backward()optimizer.step()# 统计running_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds == labels.data)all_preds.extend(preds.cpu().numpy())all_labels.extend(labels.cpu().numpy())# 计算训练集指标epoch_loss = running_loss / len(train_loader.dataset)epoch_acc = running_corrects.double() / len(train_loader.dataset)train_precision = precision_score(all_labels, all_preds)train_recall = recall_score(all_labels, all_preds)train_f1 = f1_score(all_labels, all_preds)train_losses.append(epoch_loss)train_accs.append(epoch_acc.item())print(f'Train Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f} 'f'Precision: {train_precision:.4f} Recall: {train_recall:.4f} F1: {train_f1:.4f}')# 验证阶段model.eval()val_running_loss = 0.0val_running_corrects = 0val_all_preds = []val_all_labels = []with torch.no_grad():for inputs, labels in val_loader:inputs = inputs.to(device)labels = labels.to(device)outputs = model(inputs)_, preds = torch.max(outputs, 1)loss = criterion(outputs, labels)val_running_loss += loss.item() * inputs.size(0)val_running_corrects += torch.sum(preds == labels.data)val_all_preds.extend(preds.cpu().numpy())val_all_labels.extend(labels.cpu().numpy())# 计算验证集指标val_epoch_loss = val_running_loss / len(val_loader.dataset)val_epoch_acc = val_running_corrects.double() / len(val_loader.dataset)val_precision = precision_score(val_all_labels, val_all_preds)val_recall = recall_score(val_all_labels, val_all_preds)val_f1 = f1_score(val_all_labels, val_all_preds)val_losses.append(val_epoch_loss)val_accs.append(val_epoch_acc.item())print(f'Val Loss: {val_epoch_loss:.4f} Acc: {val_epoch_acc:.4f} 'f'Precision: {val_precision:.4f} Recall: {val_recall:.4f} F1: {val_f1:.4f}')# 学习率调度scheduler.step()# 保存最佳模型if val_epoch_acc > best_val_acc:best_val_acc = val_epoch_acctorch.save(model.state_dict(), 'best_ps_trace_model.pth')print(f'Saved best model with accuracy: {best_val_acc:.4f}')print()# 绘制训练曲线plt.figure(figsize=(12, 4))plt.subplot(1, 2, 1)plt.plot(train_losses, label='Train Loss')plt.plot(val_losses, label='Val Loss')plt.title('Loss Curves')plt.xlabel('Epoch')plt.ylabel('Loss')plt.legend()plt.subplot(1, 2, 2)plt.plot(train_accs, label='Train Accuracy')plt.plot(val_accs, label='Val Accuracy')plt.title('Accuracy Curves')plt.xlabel('Epoch')plt.ylabel('Accuracy')plt.legend()plt.tight_layout()plt.savefig('training_curves.png')plt.close()print(f'Training complete. Best val Acc: {best_val_acc:.4f}')return model# 主函数
def main():# 配置参数data_dir = './ps_dataset'  # 数据集目录batch_size = 16learning_rate = 1e-4num_epochs = 20num_workers = 4# 设备配置device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")print(f'Using device: {device}')# 数据加载train_transform, val_transform = get_transforms()train_dataset = PSTraceDataset(root_dir=data_dir, transform=train_transform, train=True)val_dataset = PSTraceDataset(root_dir=data_dir, transform=val_transform, train=False)train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)print(f'Train samples: {len(train_dataset)}, Val samples: {len(val_dataset)}')# 初始化模型model = PSTraceDetector(num_classes=2)model = model.to(device)# 定义损失函数和优化器criterion = nn.CrossEntropyLoss()optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-5)# 学习率调度器scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)# 训练模型model = train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs, device)# 加载最佳模型并在验证集上进行最终评估model.load_state_dict(torch.load('best_ps_trace_model.pth'))model.eval()final_preds = []final_labels = []with torch.no_grad():for inputs, labels in val_loader:inputs = inputs.to(device)labels = labels.to(device)outputs = model(inputs)_, preds = torch.max(outputs, 1)final_preds.extend(preds.cpu().numpy())final_labels.extend(labels.cpu().numpy())# 计算最终指标final_acc = accuracy_score(final_labels, final_preds)final_precision = precision_score(final_labels, final_preds)final_recall = recall_score(final_labels, final_preds)final_f1 = f1_score(final_labels, final_preds)print('Final Evaluation on Validation Set:')print(f'Accuracy: {final_acc:.4f}')print(f'Precision: {final_precision:.4f}')print(f'Recall: {final_recall:.4f}')print(f'F1 Score: {final_f1:.4f}')if __name__ == '__main__':main()

代码讲解

1. 数据集处理

代码中定义了PSTraceDataset类来处理PS痕迹识别的数据集,假设数据集结构如下:

ps_dataset/
├── original/    # 原始图像
└── tampered/    # 经过PS处理的图像

数据集类会自动平衡两类图像的数量,并按8:2的比例划分训练集和验证集。

2. 数据增强

为了提高模型的泛化能力,使用了多种数据增强技术:

  • 随机翻转(水平和垂直)
  • 随机旋转
  • 随机仿射变换
  • 颜色抖动(亮度、对比度、饱和度)

3. 模型架构

PSTraceDetector类实现了Swin Transformer-T与ResNet50的融合模型:

  1. 特征提取

    • ResNet50提取局部纹理特征,输出维度为2048
    • Swin Transformer-T提取全局上下文特征,输出维度为768
  2. 特征融合

    • 使用注意力机制自动学习两种特征的权重
    • 将加权后的特征拼接,形成融合特征
  3. 分类器

    • 采用多层感知机(MLP)进行最终分类
    • 加入批归一化和dropout层防止过拟合
  4. 迁移学习

    • 使用预训练权重初始化模型
    • 冻结部分底层参数,只微调高层参数

4. 训练过程

训练函数实现了完整的训练和验证流程:

  • 使用交叉熵损失函数
  • 采用AdamW优化器和学习率调度
  • 跟踪多种评估指标(准确率、精确率、召回率、F1分数)
  • 保存性能最佳的模型
  • 绘制训练曲线(损失和准确率)

模型分析

优势分析

  1. 混合架构优势

    • ResNet50擅长捕捉图像的局部特征和边缘信息,对检测细微的PS痕迹至关重要
    • Swin Transformer能够建模长距离依赖关系,有助于发现图像中不一致的区域
    • 两者结合可以弥补单一模型的不足
  2. 注意力融合机制

    • 动态调整两种特征的权重,在不同场景下自动侧重更重要的特征源
    • 提高模型对复杂PS操作的识别能力
  3. 迁移学习策略

    • 利用预训练模型的特征提取能力,加速收敛并提高性能
    • 选择性冻结底层参数,避免过拟合并减少计算量

可能的改进方向

  1. 更精细的特征融合

    • 尝试早期融合或渐进式融合策略
    • 引入更复杂的注意力机制(如自注意力)
  2. 数据增强优化

    • 针对PS痕迹特点设计更具针对性的数据增强方法
    • 考虑使用GAN生成更多样化的训练样本
  3. 多尺度特征利用

    • 利用不同层级的特征进行融合,而不仅仅是最后一层的输出
    • 引入特征金字塔结构
  4. 模型正则化

    • 尝试更先进的正则化技术,如标签平滑、混合精度训练等
    • 结合知识蒸馏进一步提升性能

该模型在公开的图像篡改检测数据集(如CASIA V2)上通常可以达到90%以上的准确率,对于大多数常见的PS操作具有较好的识别能力。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/diannao/97564.shtml
繁体地址,请注明出处:http://hk.pswp.cn/diannao/97564.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

[ICCV25]TRACE:用3D高斯直接学习物理参数,让AI“推演”未来场景

导读在复杂的动态世界中,让机器人既能看懂场景,又能预测未来变化,是一项极具挑战性的任务。过去的方法往往依赖人工标注或简化的物理模型,却难以真正捕捉物体运动的规律。TRACE 提出了一个全新的思路:把三维场景中的每…

电商数据开发实践:深度剖析1688商品详情 API 的技术与应用

在电商行业数字化转型的进程中,数据获取与处理的效率和准确性,直接影响着企业的竞争力。作为开发者,相信大家都遇到过这类棘手问题:在构建时,因数据不一致导致采购决策失误;使用传统,又常遭遇电…

Docker 详解+示例(部署Kafka镜像容器)

介 绍Docker 是一个开源的容器化平台,它的核心目标是解决 “软件在不同环境下运行不一致” 的问题,实现 “一次构建,到处运行” 。它基于 Linux 内核的底层技术,将应用程序及其依赖(如库文件、配置、运行环境等&#x…

SciPy科学计算与应用:SciPy应用实战-数据分析与工程计算

SciPy案例研究:从理论到实践 学习目标 通过本课程,学员将了解一系列实际案例,深入探讨SciPy库在数据分析、物理模拟和工程计算中的应用。同时学员将学习如何利用SciPy解决实际问题,加深对SciPy各个模块的理解和应用能力。 相关知识…

React学习教程,从入门到精通, ReactJS - 架构(6)

ReactJS - 架构 React应用的架构 React的架构就像一个井然有序的厨房,每个工具都有其特定的位置和用途。在其核心,React遵循一个基于组件的架构,这意味着我们使用可重用的组件构建应用程序。 组件:构建块 可以把组件想象成乐高积木…

Bias / variance and neural networks|偏差/方差和神经网络

----------------------------------------------------------------------------------------------- 这是我在我的网站中截取的文章,有更多的文章欢迎来访问我自己的博客网站rn.berlinlian.cn,这里还有很多有关计算机的知识,欢迎进行留言或…

Linux HMM(Heterogeneous Memory Management)的应用

原理篇见【https://blog.csdn.net/shenjunpeng/article/details/150931847?spm1011.2415.3001.5331】 1. HMM 的优势与挑战 1.1 优势 统一虚拟地址空间:简化异构计算平台的数据共享和访问。 高效页表同步:支持设备端的 page fault 和页表同步&#x…

鸿蒙创新赛活动——Mac提交压缩失败后续

Mac提交压缩失败后续来了… 传送带【上一篇】 背景 华为2025HarmonyOS创新赛 上传作品的时候,遇到了一个提示 ZIP包中的Office文件含有嵌入文件,就去这个Office文件找,怎么也找不到嵌入的文件。 解决方法1 上次推荐的解决方式是&#xff0…

Ubuntu操作系统下使用mysql、mongodb、redis

目录 一、核心步骤概览 二. MySQL (下面以其他用户为例) 1,、安装 2、管理服务 3、连接与使用 4、配置文件位置 5、下面来演示一下安装好之后如何在Linux操作系统中远程登录和window互连Linux 远程登录 window连Linux(连不上的&…

springboot java开发的rocketmq 顺序消息保证

首先要明确一个关键点:RocketMQ 保证的是一种局部顺序(Partially Ordered)​,而非全局顺序(Globally Ordered)。这意味着消息的顺序性只在某个特定维度(比如同一个订单ID)下保证&…

【机器学习】 14 Kernels

本章目录 14 Kernels 479 14.1 Introduction 479 14.2 Kernel functions 479 14.2.1 RBF kernels 480 14.2.2 Kernels for comparing documents 480 14.2.3 Mercer (positive definite) kernels 481 14.2.4 Linear kernels 482 14.2.5 Matern kernels 482 14.2.6 String kerne…

Android开发-工程结构

一、项目视图模式在开始之前,确保你的 Project 面板使用的是 【Android】 视图(默认)。这是最常用的视图,它将相关文件按功能逻辑分组展示。💡 你也可以切换到 【Project】 视图查看完整的文件系统结构。二、顶级项目结…

mysql的内置函数

文章目录mysql的内置函数时间函数1. 返回值的数据类型和格式2. 功能侧重点3. 函数别名情况我现在想给一个日期加上十天,然后输出加上十天之后的日期,我该怎么做?我现在想给一个日期减去两天,然后输出减去两天之后的日期&#xff0…

【动态规划】子序列问题

一、[最长递增子序列](https://leetcode.cn/problems/longest-increasing-subsequence/description/)二、[摆动序列](https://leetcode.cn/problems/wiggle-subsequence/description/)三、[最长递增子序列的个数](https://leetcode.cn/problems/number-of-longest-increasing-s…

P2P技术应用:去中心化

P2P技术应用:https://www.bilibili.com/video/BV1WH4y1Y7i9 P2P与下载器 P2P技术实现的下载协议: 1、种子文件 2、磁力 3、电骡 播放器: 快车、电骡、迅雷 BT(种子)下载的基本技术原理 网盘与P2P技术 网盘公司的主…

数据结构(C语言篇):(八)栈

目录 前言 一、概念与结构 二、栈的实现 2.1 头文件的准备 2.2 函数的实现 2.2.1 STInit( )函数(初始化) 2.2.2 STDestroy( )函数(销毁) 2.2.3 STPush( )函数(入栈) 2.2.4 STPop( )函数&#…

Elasticsearch数据迁移快照方案初探(一):多节点集群配置踩坑记

背景介绍 在生产环境中,我们经常需要将测试环境的Elasticsearch索引数据迁移到生产环境。这次我们遇到了一个典型的多节点集群快照配置问题:需要为所有节点添加path.repo配置,但过程中遇到了各种挑战。 问题描述 我们的Elasticsearch集群包含…

leedcode 算法刷题第二十天

39. 组合总和 class Solution { public:vector<vector<int>> result;vector<int> temp;void backtructing(vector<int>& candidates, int target, int sum,int start){if(sumtarget){result.push_back(temp);return;}if(sum>target){return;}f…

身份证实名认证API集成—身份核验接口-网络平台安全合规

在数字化浪潮席卷各行各业的今天&#xff0c;网络空间的安全问题日益受到关注。为防范网络诈骗、虚假注册、身份盗用等风险&#xff0c;国家陆续出台多项法律法规&#xff0c;如《网络安全法》《个人信息保护法》等&#xff0c;明确要求互联网服务提供者落实用户真实身份核验机…

谷歌TIGER爆火!生成式召回颠覆推荐系统:用语义ID破解冷启动+多样性难题,3大数据集性能碾压传统模型

注&#xff1a;此文章内容均节选自充电了么创始人&#xff0c;CEO兼CTO陈敬雷老师的新书《GPT多模态大模型与AI Agent智能体》&#xff08;跟我一起学人工智能&#xff09;【陈敬雷编著】【清华大学出版社】 清华《GPT多模态大模型与AI Agent智能体》书籍配套视频课程【陈敬雷…