信号处理学习——文献精读与code复现之TFN——嵌入时频变换的可解释神经网络(下)

书接上文:

信号处理学习——文献精读与code复现之TFN——嵌入时频变换的可解释神经网络(上)-CSDN博客

接下来是重要的代码复现!!!GitHub - ChenQian0618/TFN: this is the open code of paper entitled "TFN: An Interpretable Neural Network With Time Frequency Transform Embedded for Intelligent Fault Diagnosis".



一. 准备工作

因为我的论文中所使用的数据的样本量是2048,而不是TFN文献中的1024,所以有些地方需要调整一下。

先看TFN-main\Models\BackboneCNN.py中的代码,查看是否需要调整。(不需要)

再看TFN-main\Models\TFconvlayer.py中的代码,也同样的(不需要)。

具体可去github上查看作者大大们的完整代码。

模块是否依赖输入长度?原因
TFconv_* 类中的 forward❌ 不依赖它接受的输入是 [B, C, L],不限制 L(你的2048没问题)
weightforward()❌ 不依赖只与 kernel_sizesuperparams 有关,与你输入的信号长度无关
AdaptiveMaxPool1d in CNN❌ 不依赖输入长度自动调整为固定输出维度,兼容任意长度
T = torch.arange(...)❌ 但与 kernel_size 相关这个是内部卷积核构造,与输入数据 2048 无关

二. 设置

参考文献中的部分设置,分类损失函数采用交叉熵,训练优化器选用Adam,动量参数设为0.9,初始学习率为0.001,总训练周期为50次。总共重复10次平均实验。

三.Code

1.数据划分

这里用到的数据集是比CWRU要稍微难识别一些的变速轴承信号数据集,加拿大渥太华数据集。

import scipy.io as sio
import numpy as np
import random
import os# 定义基础路径
base_path = 'D:/0A_Gotoyourdream/00_BOSS_WHQ/A_Code/A_Data/'# 定义各类别对应的mat文件
file_mapping = {'H': 'H-B-1.mat','I': 'I-B-1.mat','B': 'B-B-1.mat','O': 'O-B-2.mat','C': 'C-B-2.mat'
}# 定义每个类别需要抽取的数量
sample_limit = {'H': 200,'I': 200,'B': 200,'O': 200,'C': 200
}# 保存最终数据
X_list = []
y_list = []# 固定参数
fs = 200000
window_size = 2048
step_size = int(fs * 0.015)  # 步长 0.015秒# 类别编码
#label_mapping = {'H': 0, 'I': 1, 'B': 3, 'O': 2, 'C': 4}  # 注意和你之前保持一致label_mapping = {'H': 0, 'I': 1, 'O': 2, 'B': 3, 'C': 4}
inverse_mapping = {v: k for k, v in label_mapping.items()}
labels = [inverse_mapping[i] for i in range(len(inverse_mapping))]
# 再替换这些缩写为全名
label_fullnames = {'H': 'Health','I': 'F_Inner','O': 'F_Outer','B': 'F_Ball','C': 'F_Combined'
}
labels = [label_fullnames[c] for c in labels]# 创建保存目录(可选)
output_dir = os.path.join(base_path, "ClassBD-Processed_Samples")
os.makedirs(output_dir, exist_ok=True)# 遍历每一类数据
for label_name, file_name in file_mapping.items():print(f"正在处理类别 {label_name}...")mat_path = os.path.join(base_path, file_name)dataset = sio.loadmat(mat_path)# 提取振动信号并去直流分量vib_data = np.array(dataset["Channel_1"].flatten().tolist()[:fs * 10])vib_data = vib_data - np.mean(vib_data)# 滑窗切分样本vib_samples = []start = 0while start + window_size <= len(vib_data):sample = vib_data[start:start + window_size].astype(np.float32)  # 降低内存占用vib_samples.append(sample)start += step_sizevib_samples = np.array(vib_samples)print(f"共切分得到 {vib_samples.shape[0]} 个样本")# 抽样if vib_samples.shape[0] < sample_limit[label_name]:raise ValueError(f"类别 {label_name} 样本不足(仅 {vib_samples.shape[0]}),无法抽取 {sample_limit[label_name]} 个")selected_indices = random.sample(range(vib_samples.shape[0]), sample_limit[label_name])selected_X = vib_samples[selected_indices]selected_y = np.full(sample_limit[label_name], label_mapping[label_name], dtype=np.int64)# 保存save_path_X = os.path.join(output_dir, f"X_{label_name}.mat")save_path_y = os.path.join(output_dir, f"y_{label_name}.mat")sio.savemat(save_path_X, {'X': selected_X})sio.savemat(save_path_y, {'y': selected_y})print(f"已保存类别 {label_name} 的数据:{save_path_X}, {save_path_y}")

2. 存储为dataloder

import os
import scipy.io as sio
import numpy as np
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader# ========== 1. 读取四类数据 ==========
base_path = "D:/0A_Gotoyourdream/00_BOSS_WHQ/A_Code/A_Data/ClassBD-Processed_Samples"def load_data(label):X = sio.loadmat(os.path.join(base_path, f"X_{label}.mat"))["X"]y = sio.loadmat(os.path.join(base_path, f"y_{label}.mat"))["y"].flatten()return X.astype(np.float32), y.astype(np.int64)X_H, y_H = load_data("H")
X_I, y_I = load_data("I")
X_B, y_B = load_data("B")
X_O, y_O = load_data("O")
X_C, y_C = load_data("C")# ========== 2. 合并数据 + reshape ==========
X_all = np.concatenate([X_H, X_I, X_B, X_O, X_C], axis=0)
y_all = np.concatenate([y_H, y_I, y_B, y_O, y_C], axis=0)
X_all = X_all[:, np.newaxis, :]  # (N, 1, 200000)# ========== 3. 划分训练/测试集 ==========
X_train, X_test, y_train, y_test = train_test_split(X_all, y_all, test_size=0.4, stratify=y_all, random_state=42)# ========== 4. DataLoader ==========
train_dataset = TensorDataset(torch.tensor(X_train), torch.tensor(y_train))
test_dataset = TensorDataset(torch.tensor(X_test), torch.tensor(y_test))train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

3. 定义模型及一些设置

需要注意的部分在代码的注释中有写

from Models.TFN import TFN_STTF  # 你也可以换成 TFN_Chirplet、TFN_Morlet
model = TFN_STTF(in_channels=1, out_channels=5, kernel_size=15)  # out_channels = 类别数device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

4. 训练及测试

# 开始训练
for epoch in range(1, 51):model.train()running_loss = 0.0correct = 0total = 0for inputs, labels in train_loader:inputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()# TFN模型支持返回多个输出(output, _, _)outputs = model(inputs)if isinstance(outputs, tuple):outputs = outputs[0]loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels).sum().item()#scheduler.step()train_acc = correct / total * 100# 测试集评估model.eval()correct_test = 0total_test = 0with torch.no_grad():for inputs, labels in test_loader:inputs, labels = inputs.to(device), labels.to(device)outputs = model(inputs)if isinstance(outputs, tuple):outputs = outputs[0]_, predicted = outputs.max(1)total_test += labels.size(0)correct_test += predicted.eq(labels).sum().item()test_acc = correct_test / total_test * 100print(f"Epoch {epoch:03d}: Loss={running_loss:.4f}, Train Acc={train_acc:.2f}%, Test Acc={test_acc:.2f}%")

四. 结果

代码复现成功!!!

接着后面就是拿来做对比实验啦~~~

(感恩大佬们提供github代码!!!)

 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/diannao/89398.shtml
繁体地址,请注明出处:http://hk.pswp.cn/diannao/89398.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

线上故障排查:签单合同提交报错分析-对接e签宝

在企业管理系统中&#xff0c;合同生成与签署环节至关重要&#xff0c;尤其是在使用第三方平台进行电子签署时。本文将通过实际的报错信息&#xff0c;分析如何进行线上故障排查&#xff0c;解决合同生成过程中出现的问题。 #### 1. 错误描述 在尝试生成合同并提交至电子签署…

知攻善防靶机 Linux easy溯源

知攻善防 【护网训练-Linux】应急响应靶场-Easy溯源 小张是个刚入门的程序猿&#xff0c;在公司开发产品的时候突然被叫去应急&#xff0c;小张心想"早知道简历上不写会应急了"&#xff0c;于是call了运维小王的电话&#xff0c;小王说"你面试的时候不是说会应急…

原神八分屏角色展示页面(纯前端html,学习交流)

原神八分屏角色展示页面 - 一个精美的前端交互项目 项目简介 这是一个基于原神游戏角色制作的八分屏展示页面&#xff0c;采用纯前端技术实现&#xff0c;包含了丰富的动画效果、音频交互和视觉设计。项目展示了一些热门原神角色&#xff0c;每个角色都有独立的介绍页面和专属…

华为认证二选一:物联网 VS 人工智能,你的赛道在哪里?

一篇不讲情怀只讲干货的科普指南 一、华为物联网 & 人工智能到底在搞什么&#xff1f; 华为物联网&#xff08;IoT&#xff09; 的核心是 “万物互联”。 通过传感器、通信技术&#xff08;如NB-IoT/5G&#xff09;、云计算平台&#xff08;如OceanConnect&#xff09;&…

CloudLens for PolarDB:解锁数据库性能优化与智能运维的终极指南

随着企业数据规模的爆炸式增长,数据库性能管理已成为技术团队的关键挑战。本文深入探讨如何利用CloudLens for PolarDB实现高级监控、智能诊断和自动化运维,帮助您构建一个自我修复、高效运行的数据库环境。 引言:数据库监控的演进 在云原生时代,传统的数据库监控方式已不…

MySQL中TINYINT/INT/BIGINT的典型应用场景及实例

以下是MySQL中TINYINT/INT/BIGINT的典型应用场景及实例说明&#xff1a; 一、TINYINT&#xff08;1字节&#xff09; 1.状态标识 -- 用户激活状态&#xff08;0未激活/1已激活&#xff09; ALTER TABLE users ADD is_active TINYINT(1) DEFAULT 0; 适用于布尔值存储和状态码…

YOLOv13:最新的YOLO目标检测算法

[2506.17733] YOLOv13: Real-Time Object Detection with Hypergraph-Enhanced Adaptive Visual Perception Github: https://github.com/iMoonLab/yolov13 YOLOv13&#xff1a;利用超图增强型自适应视觉感知进行实时物体检测 主要的创新点提出了HyperACE机制、FullPAD范式、轻…

【深入浅出:计算流体力学(CFD)基础与核心原理--从NS方程到工业仿真实践】

关键词&#xff1a;#CFD、#Navier-Stokes方程、#有限体积法、#湍流模型、#网格收敛性、#工业仿真验证 一、CFD是什么&#xff1f;为何重要&#xff1f; 计算流体力学&#xff08;Computational Fluid Dynamics, CFD&#xff09; 是通过数值方法求解流体流动控制方程&#xff0…

qt常用控件--04

文章目录 qt常用控件labelLCD NumberProgressBar结语 很高兴和大家见面&#xff0c;给生活加点impetus&#xff01;&#xff01;开启今天的编程之路&#xff01;&#xff01; 今天我们进一步c11中常见的新增表达 作者&#xff1a;٩( ‘ω’ )و260 我的专栏&#xff1a;qt&am…

Redmine:一款基于Web的开源项目管理软件

Redmine 是一款基于 Ruby on Rails 框架开发的开源、跨平台、基于 Web 的项目管理、问题跟踪和文档协作软件。 Redmine 官方网站自身就是基于它构建的一个 Web 应用。 功能特性 Redmine 的主要特点和功能包括&#xff1a; 多项目管理&#xff1a; Redmine 可以同时管理多个项…

FPGA FMC 接口

1 FMC 介绍 FMC 接口即 FPGA Mezzanine Card 接口,中文名为 FPGA 中间层板卡接口。以下是对它的详细介绍: 标准起源:2008 年 7 月,美国国家标准协会(ANSI)批准和发布了 VITA 57 FMC 标准。该标准由从 FPGA 供应商到最终用户的公司联盟开发,旨在为位于基板(载卡)上的 …

C++中std::atomic_bool详解和实战示例

std::atomic_bool 是 C 标准库中提供的一种 原子类型&#xff0c;用于在多线程环境下对布尔值进行 线程安全的读写操作&#xff0c;避免使用 std::mutex 带来的性能开销。 1. 基本作用 在多线程环境中&#xff0c;多个线程同时访问一个 bool 类型变量可能会出现 竞态条件&…

深度学习之分类手写数字的网络

面临的问题 定义神经⽹络后&#xff0c;我们回到⼿写识别上来。我们可以把识别⼿写数字问题分成两个⼦问题&#xff1a; 把包含许多数字的图像分成⼀系列单独的图像&#xff0c;每个包含单个数字&#xff1b; 也就是把图像 &#xff0c;分成6个单独的图像 分类单独的数字 我们将…

nginx基本使用 linux(mac下的)

目录结构 编译后会有&#xff1a;conf html logs sbin 四个文件 &#xff08;其他两个是之前下载的安装包&#xff09; conf&#xff1a;配置文件html&#xff1a;页面资源logs&#xff1a;日志sbin&#xff1a;启动文件&#xff0c;nginx主程序 运行后多了文件&#xff1a;&l…

基于大众点评的重庆火锅在线评论数据挖掘分析(情感分析、主题分析、EDA探索性数据分析)

文章目录 有需要本项目的代码或文档以及全部资源&#xff0c;或者部署调试可以私信博主项目介绍数据采集数据预处理EDA探索性数据分析关键词提取算法情感分析LDA主题分析总结每文一语 有需要本项目的代码或文档以及全部资源&#xff0c;或者部署调试可以私信博主 项目介绍 本…

鸿蒙系统(HarmonyOS)应用开发之经典蓝色风格登录页布局、图文验证码

一、项目概述 本项目是一款基于鸿蒙 ArkTS&#xff08;ETS&#xff09;开发的用户登录页面&#xff0c;集成了图文验证码功能&#xff0c;旨在为应用提供安全、便捷的用户身份验证入口。项目采用现代化 UI 设计&#xff0c;兼顾用户体验与安全性&#xff0c;适用于多种需要用户…

0.96寸OLED显示屏 江协科技学习笔记(36个知识点)

1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 32 33 34 35 36

Flutter SnackBar 控件详细介绍

文章目录 Flutter SnackBar 控件详细介绍基本特性基本用法1. 显示简单 SnackBar2. 自定义持续时间 主要属性高级用法1. 带操作的 SnackBar2. 自定义样式3. 浮动式 SnackBar SnackBarAction 属性实际应用场景注意事项完整示例建议 Flutter SnackBar 控件详细介绍 SnackBar 是 F…

【C++】头文件的能力与禁忌

在C中&#xff0c;​头文件&#xff08;.h/.hpp&#xff09;​​ 的主要作用是声明接口和共享代码&#xff0c;但如果不规范使用&#xff0c;会导致编译或链接错误。以下是详细总结&#xff1a; 一、头文件中可以做的事情 1.1 声明 函数声明&#xff08;无需inline&#xff…

腾讯 iOA 零信任产品:安全远程访问的革新者

在当今数字化时代&#xff0c;企业面临着前所未有的挑战与机遇。随着远程办公、多分支运营以及云计算的广泛应用&#xff0c;传统的网络安全架构逐渐暴露出诸多不足。腾讯 iOA 零信任产品凭借其创新的安全理念和强大的功能特性&#xff0c;为企业提供了一种全新的解决方案&…