深度学习与遥感入门(六)|轻量化 MobileNetV2 高光谱分类

系列回顾:
(一)CNN 基础:高光谱图像分类可视化全流程
(二)HybridNet(CNN+Transformer):提升全局感受野
(三)GCN 入门实战:基于光谱 KNN 的图卷积分类与全图预测
(四)空间–光谱联合构图的 GCN:RBF 边权 + 自环 + 早停,得到更稳更自然的全图分类结果
(五)GAT & 构图消融 + 分块全图预测:更稳更快的高光谱图分类(PyTorch Geometric 实战)
合集链接:https://mp.weixin.qq.com/mp/appmsgalbum?__biz=MzkwMTE0MjI4NQ==&action=getalbum&album_id=4007114522736459789#wechat_redirect
本篇(六)聚焦“数据泄露”,采用仅训练集像素拟合 StandardScaler+PCA,并在全图预测中共享同一变换空间;模型选用轻量化 MobileNetV2深度可分离卷积,在显存友好的坐标批推理下实现全图预测

0. 前言:PCA 与高光谱分类中的“数据泄露”

  • 什么是泄露? 训练阶段直接/间接使用了测试数据统计信息(均值、方差、主成分方向等)。
  • 怎么产生? 在整图上 fit 标准化与 PCA,然后再切训练/测试或直接做全图分类。
  • 为什么常见? 历史习惯、样本少时稳定性考虑、对比研究图省事、实现方便。
  • 影响大吗? 小数据集上常为 0.1%~1% 的 OA 差异;但在类分布差异大训练样本极少时,差距可达数个百分点。真实部署场景绝不允许整图 fit

本文做法先分层抽样得到训练/测试索引仅用训练像素 fit 标准化与 PCA用该变换对整图 transform训练与预测均在同一(训练集拟合得到的)特征空间,从源头避免泄露。

1. 任务要点

  1. 严格无泄露预处理:只用训练像素拟合 StandardScaler+PCA;全图在同一变换空间中变换。
  2. 轻量模型:用 MobileNetV2 的深度可分离卷积(3 段)+ GAP + FC。
  3. 全图预测显存友好:按坐标批收集 patch → 堆成 batch → 前向推理。
  4. 评估classification_report、混淆矩阵、OA;可视化支持 Windows 阻塞显示。

2. 方法详解

2.1 严格无泄露的 PCA 流程

  • 先划分后拟合:对有标签像素做分层抽样得到训练/测试索引;仅训练像素拟合 StandardScalerPCA
  • 全图共享空间:将整图 (H×W×Bands) 用训练集拟合的变换进行标准化与降维,得到 (H×W×PCA_DIM)
  • 提取 patch:在 PCA 空间内按坐标提取 (PATCH_SIZE×PATCH_SIZE×PCA_DIM) 的 patch 作为输入。

这样做的关键测试像素从未参与统计,评估更可信。

2.2 轻量化 MobileNetV2(HSI 版)

  • Depthwise Separable Conv:逐通道 3×3 深度卷积 + 1×1 点卷积,大幅降参与算力需求。
  • 网络骨干:3 段深度可分离卷积 → GAP(自适应全局平均池化)→ 全连接输出。
  • 输入通道:这里输入为 PCA 后的通道数(例如 30),以二维 patch 形式输入(C×H×W)。

2.3 全图预测策略(坐标批)

  • 坐标遍历:生成所有像素坐标。
  • 反射填充:边界像素也能提取完整 patch。
  • 批量收集:按 batch_size 组装 patch → 前向 → 填回 pred_map
  • 显存稳定:避免一次性张量过大导致溢出。

3. 代码逐段 + 解释

下面先按逻辑分段展示与解释;最末提供“一键可跑脚本(整合版)”,复制后仅需修改数据路径即可运行。

3.1 全局与可视化设置

import os, time, numpy as np, scipy.io as sio
import torch, torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
from sklearn.model_selection import train_test_split
import matplotlib# Windows 下 TkAgg 更稳;Linux/服务器用 Agg(无显示)
if os.name == 'nt':matplotlib.use('TkAgg')
else:matplotlib.use('Agg')import matplotlib.pyplot as plt
import seaborn as snsmatplotlib.rcParams['font.family'] = 'SimHei'
matplotlib.rcParams['axes.unicode_minus'] = False
plt.rcParams['figure.dpi'] = 120
sns.set_theme(context="notebook", style="whitegrid", font="SimHei")torch.backends.cudnn.benchmark = True

3.2 随机种子

def set_seeds(seed=42):import randomrandom.seed(seed)np.random.seed(seed)torch.manual_seed(seed)torch.cuda.manual_seed_all(seed)

固定随机性,保证复现。

3.3 轻量化网络

class DepthwiseSeparableConv(nn.Module):def __init__(self, in_ch, out_ch, stride=1):super().__init__()self.depthwise = nn.Conv2d(in_ch, in_ch, 3, stride=stride, padding=1, groups=in_ch, bias=False)self.pointwise = nn.Conv2d(in_ch, out_ch, 1, bias=False)self.bn = nn.BatchNorm2d(out_ch)self.act = nn.ReLU6(inplace=True)def forward(self, x):x = self.depthwise(x)x = self.pointwise(x)x = self.bn(x)return self.act(x)class MobileNetV2_HSI(nn.Module):def __init__(self, in_ch, num_classes, width_mult=1.0):super().__init__()c1, c2, c3 = int(32 * width_mult), int(64 * width_mult), int(128 * width_mult)self.layer1 = DepthwiseSeparableConv(in_ch, c1)self.layer2 = DepthwiseSeparableConv(c1, c2)self.layer3 = DepthwiseSeparableConv(c2, c3)self.gap = nn.AdaptiveAvgPool2d(1)self.fc = nn.Linear(c3, num_classes)def forward(self, x):x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.gap(x).flatten(1)return self.fc(x)

结构极简但实用:3 段深度可分离卷积 + GAP + FC,对 HSI 的小样本/低算力场景友好。

3.4 数据集与 Patch 封装

class HSIPatchDataset(Dataset):def __init__(self, patches, labels):# patches: (N, H, W, C) → 张量 (N, C, H, W)self.X = torch.tensor(patches, dtype=torch.float32).permute(0, 3, 1, 2)self.y = torch.tensor(labels, dtype=torch.long)def __len__(self): return len(self.y)def __getitem__(self, idx): return self.X[idx], self.y[idx]

3.5 全图预测(坐标批推理)

@torch.inference_mode()
def predict_full_image_by_coords(model, X_img_pca, patch_size, device,batch_size=2048, title="全图预测(坐标批推理)"):model.eval()H, W, C = X_img_pca.shapem = patch_size // 2padded = np.pad(X_img_pca, ((m, m), (m, m), (0, 0)), mode='reflect')coords = np.mgrid[0:H, 0:W].reshape(2, -1).Tpred_map = np.zeros((H, W), dtype=np.int32)t0 = time.time()for i in range(0, len(coords), batch_size):batch_coords = coords[i:i + batch_size]patches = np.empty((len(batch_coords), patch_size, patch_size, C), dtype=np.float32)for k, (r, c) in enumerate(batch_coords):patches[k] = padded[r:r + patch_size, c:c + patch_size, :]tensor = torch.from_numpy(patches).permute(0, 3, 1, 2).to(device)preds = model(tensor).argmax(dim=1).cpu().numpy() + 1  # +1 便于和 GT 对齐for (r, c), p in zip(batch_coords, preds):pred_map[r, c] = pprint(f"全图预测耗时:{time.time() - t0:.2f} 秒")# 可视化(阻塞显示,避免“最后的图没有显示”)try:plt.figure(figsize=(10, 7.5))cmap = matplotlib.colormaps.get_cmap('tab20')vmin, vmax = pred_map.min(), pred_map.max()if vmin == vmax: vmin, vmax = 0, 1im = plt.imshow(pred_map, cmap=cmap, interpolation='nearest', vmin=vmin, vmax=vmax)cbar = plt.colorbar(im, shrink=0.85); cbar.set_label('预测类别', rotation=90)plt.title(title, fontsize=14, weight='bold'); plt.axis('off'); plt.tight_layout()print("尝试显示全图预测结果...")plt.show(block=True)except Exception as e:print(f"显示全图预测结果时出错: {e}")try:plt.savefig("prediction_map.png", bbox_inches='tight')print("已保存为 prediction_map.png")except Exception as se:print(f"保存失败: {se}")return pred_map

3.6 主流程(数据→划分→无泄露预处理→训练→评估→全图预测)

下面是主函数的关键片段(末尾附完整可运行脚本):

def main():set_seeds(42)device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')print(f"使用设备: {device}")# ---- 路径与超参(按需修改)----DATA_DIR = r"your_path"X_FILE, Y_FILE = "KSC.mat", "KSC_gt.mat"PCA_DIM, PATCH_SIZE, TRAIN_RATIO = 30, 5, 0.30BATCH_SIZE, EPOCHS, LR, WEIGHT_DECAY = 64, 30, 1e-3, 1e-4NUM_WORKERS = 0 if os.name == 'nt' else min(4, os.cpu_count() or 0)PIN_MEMORY = (device.type == 'cuda')PREDICT_BATCH_SIZE = 4096# ---- 读取数据 ----def load_data():X = sio.loadmat(os.path.join(DATA_DIR, X_FILE))Y = sio.loadmat(os.path.join(DATA_DIR, Y_FILE))x_key = [k for k in X.keys() if not k.startswith("__")][0]y_key = [k for k in Y.keys() if not k.startswith("__")][0]return X[x_key], Y[y_key]X_img, Y_img = load_data()h, w, bands = X_img.shapeprint(f"数据尺寸: {h}×{w}, 波段: {bands}")# ---- 有标签索引 + 分层划分 ----labeled_idx_rc = np.array([(i, j) for i in range(h) for j in range(w) if Y_img[i, j] != 0])labels_all = np.array([Y_img[i, j] - 1 for i, j in labeled_idx_rc], dtype=np.int64)num_classes = len(np.unique(labels_all))print(f"有标签样本: {len(labeled_idx_rc)},类别数: {num_classes}")train_ids, test_ids = train_test_split(np.arange(len(labeled_idx_rc)),test_size=1 - TRAIN_RATIO, stratify=labels_all, random_state=42)# ---- 仅训练像素拟合 Scaler+PCA(无泄露)----print("拟合 StandardScaler/PCA(仅训练像素)...")train_pixels = np.array([X_img[i, j] for i, j in labeled_idx_rc[train_ids]], dtype=np.float32)scaler = StandardScaler().fit(train_pixels)pca = PCA(n_components=PCA_DIM, random_state=42).fit(scaler.transform(train_pixels))# 整图进入同一空间(float32)X_pca_img = pca.transform(scaler.transform(X_img.reshape(-1, bands).astype(np.float32))).astype(np.float32)X_pca_img = X_pca_img.reshape(h, w, PCA_DIM)# ---- 提取训练/测试 patch ----def extract_patches(sel_ids):m = PATCH_SIZE // 2padded = np.pad(X_pca_img, ((m, m), (m, m), (0, 0)), mode='reflect')patches = np.empty((len(sel_ids), PATCH_SIZE, PATCH_SIZE, PCA_DIM), dtype=np.float32)labs = np.empty((len(sel_ids),), dtype=np.int64)for n, k in enumerate(sel_ids):i, j = labeled_idx_rc[k]patches[n] = padded[i:i + PATCH_SIZE, j:j + PATCH_SIZE, :]labs[n] = labels_all[k]return patches, labsX_train, y_train = extract_patches(train_ids)X_test, y_test = extract_patches(test_ids)# ---- DataLoader ----train_loader = DataLoader(HSIPatchDataset(X_train, y_train), batch_size=BATCH_SIZE,shuffle=True, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)test_loader = DataLoader(HSIPatchDataset(X_test, y_test), batch_size=BATCH_SIZE,shuffle=False, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)# ---- 模型与优化器 ----model = MobileNetV2_HSI(PCA_DIM, num_classes).to(device)optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)criterion = nn.CrossEntropyLoss()scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=3)# ---- 评估函数 ----@torch.no_grad()def evaluate(loader):model.eval()all_y, all_pred = [], []for xb, yb in loader:xb = xb.to(device)pred = model(xb).argmax(dim=1).cpu().numpy()all_pred.extend(pred); all_y.extend(yb.numpy())return accuracy_score(all_y, all_pred), np.array(all_y), np.array(all_pred)# ---- 训练循环 ----print("开始训练...")best_acc, model_path = 0.0, "best_mnv2_hsi.pth"for epoch in range(1, EPOCHS + 1):model.train(); total_loss = 0.0for xb, yb in train_loader:xb, yb = xb.to(device), yb.to(device)optimizer.zero_grad()loss = criterion(model(xb), yb)loss.backward(); optimizer.step()total_loss += loss.item() * xb.size(0)test_acc, _, _ = evaluate(test_loader)scheduler.step(test_acc)print(f"Epoch {epoch:02d}/{EPOCHS} | 损失: {total_loss/len(train_loader.dataset):.4f} | 测试准确率: {test_acc:.4f}")if test_acc > best_acc:best_acc = test_acctorch.save(model.state_dict(), model_path)print(f"训练完成,最佳测试准确率:{best_acc:.4f}")# ---- 安全加载最佳权重 ----try:state = torch.load(model_path, map_location=device, weights_only=True)except TypeError:state = torch.load(model_path, map_location=device)model.load_state_dict(state)# ---- 测试报告 & 混淆矩阵 ----test_acc, y_true, y_pred = evaluate(test_loader)print("\n测试集分类报告:")print(classification_report(y_true, y_pred, digits=4, zero_division=0))plt.figure(figsize=(10, 7))class_names = [f"类{i + 1}" for i in range(num_classes)]sns.heatmap(confusion_matrix(y_true, y_pred),annot=True, fmt='d', cmap="Blues",xticklabels=class_names, yticklabels=class_names,cbar=False, square=True)plt.xlabel("预测标签"); plt.ylabel("真实标签")plt.title("MobileNetV2 测试集混淆矩阵", fontsize=14, weight='bold')plt.tight_layout(); plt.show(block=True)# ---- 全图预测 ----print("全图预测中(坐标→收集 patch→堆成 batch→前向)...")pred_map = predict_full_image_by_coords(model, X_pca_img, patch_size=PATCH_SIZE, device=device,batch_size=PREDICT_BATCH_SIZE, title="MobileNetV2 全图预测(坐标批推理)")print(f"预测图统计: min={pred_map.min()}, max={pred_map.max()}, mean={pred_map.mean():.3f}")print("完成。")

3.7 Windows 入口保护(多进程/显示更稳)

if __name__ == "__main__":try:import multiprocessing as mpmp.set_start_method("spawn", force=True)mp.freeze_support()except Exception:passmain()

4. 结果展示

在这里插入图片描述
在这里插入图片描述
欢迎大家关注下方我的公众获取更多内容!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/pingmian/93193.shtml
繁体地址,请注明出处:http://hk.pswp.cn/pingmian/93193.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

第4节 神经网络从公式简化到卷积神经网络(CNN)的进化之路

🧙 深度学习的"玄学进化史" 从CNN用卷积层池化层处理图片,循环网络RNN如何利用上下文处理序列数据,到注意力机制让Transformer横空出世,现在的大语言模型已经能写能画能决策!每个新技巧都让人惊呼"还能这么玩",难怪说深度学习像玄学——但这玄学,…

最新去水印小程序系统 前端+后端全套源码 多套模版 免授权(源码下载)

最新去水印小程序系统 前端后端全套源码 多套模版 免授权 源码下载:https://download.csdn.net/download/m0_66047725/91669468 更多资源下载:关注我

TCP Socket 编程实战:实现简易英译汉服务

前言:TCP(传输控制协议)是一种面向连接、可靠的流式传输协议,与 UDP 的无连接特性不同,它通过三次握手建立连接、四次挥手断开连接,提供数据确认、重传机制,保证数据有序且完整传输。本文将基于…

CF566C Logistical Questions Solution

Description 给定一棵 nnn 个点的树 TTT,点有点权 aia_iai​,边有边权 www. 定义 dist⁡(u,v)\operatorname{dist}(u,v)dist(u,v) 为 u→vu\to vu→v 的简单路径上的边权和. 找到一个节点 uuu,使得 W∑i1ndist⁡(u,i)32aiW\sum\limits_{i1}^n…

聊天室全栈开发-保姆级教程(Node.js+Websocket+Redis+HTML+CSS)

前言 最近在学习websocket全双工通信,想要做一个联机小游戏,做游戏之前先做一个聊天室练练手。 跟着本篇博客,可以从0搭建一个属于你自己的聊天室。 准备阶段 什么人适合学习本篇文章? 答:前端开发者,有一…

后台管理系统-2-vue3之路由配置和Main组件的初步搭建布局

文章目录1 路由搭建1.1 路由创建(router/index.js)1.2 路由组件(views/Main.vue)1.3 路由引入并注册(main.js)1.4 路由渲染(App.vue)2 element-plus的应用2.1 完整引入并注册(main.js)2.2 示例应用(App.vue)3 ElementPlusIconsVue的应用3.1 图标引入并注册(main.js)3.2 示例应用…

使用 Let’s Encrypt 免费申请泛域名 SSL 证书,并实现自动续期

使用 Let’s Encrypt 免费申请泛域名 SSL 证书,并实现自动续期 目录 使用 Let’s Encrypt 免费申请泛域名 SSL 证书,并实现自动续期 🛠️ 环境准备💡 什么是 Let’s Encrypt?🧠 Let’s Encrypt 证书颁发原…

一键自动化:Kickstart无人值守安装指南

Kickstart文件实现自动安装1. Kickstart文件概述1.1 定义与作用Kickstart文件是Red Hat系Linux发行版(如RHEL、CentOS、Fedora)用于实现自动化安装的配置文件,采用纯文本格式保存。它通过预设安装参数的方式,使系统安装过程无需人…

深度解读 Browser-Use:让 AI 驱动浏览器自动化成为可能

目录 一、什么是 Browser-Use? 二、Browser-Use 的核心功能 1. AI 与浏览器的链接桥梁 2. 无代码 / 低代码操作界面 3. 支持多家 LLM 4. 开发体验简洁 可快速上手 三、核心价值与适用场景 四、与 Playwright 的结合使用 五、总结与展望 https://github.com…

React.memo、useMemo 和 React.PureComponent的区别

useMemo 和 React.memo 都是 React 提供的性能优化工具,但它们的作用和使用场景有显著不同。以下是两者的全面对比: 一、核心区别总结特性useMemoReact.memo类型React Hook高阶组件(HOC)作用对象缓存计算结果缓存组件渲染结果优化目标避免重复计算避免不…

Lumerical INTERCONNECT ------ CW Laser 和 OPWM 组成的系统

Lumerical INTERCONNECT ------ CW Laser 和 OPWM 组成的系统 引言 正文 引言 这里我们来简单介绍一下 CW Laser 与 OSA 组成的简单系统结构的仿真。 正文 我们构建一个如下图所示的仿真结构。 我们将 CWL 中的 power 设置为 1 W。 然后直接运行仿真查看结果如下: 虽然 …

想涨薪30%?别只盯着大厂了!转型AI产品经理的3个通用方法,人人都能学!

在AI产品经理刚成为互联网公司香饽饽的时候,刚做产品1年的月月就规划了自己的转型计划,然后用3个月时间成功更换赛道,转战AI产品经理,涨薪30%。 问及她有什么上岸秘诀?她也复盘总结了3个踩坑经验和正确路径&#xff0c…

基于Hadoop的全国农产品批发价格数据分析与可视化与价格预测研究

文章目录有需要本项目的代码或文档以及全部资源,或者部署调试可以私信博主项目介绍每文一语有需要本项目的代码或文档以及全部资源,或者部署调试可以私信博主 项目介绍 随着我国农业数字化进程的加快,农产品批发市场每天都会产生海量的价格…

STM32在使用DMA发送和接收时的模式区别

在STM32的DMA传输中,发送使用DMA_Mode_Normal而接收使用DMA_Mode_Circular的设计基于以下关键差异:1. ‌触发机制的本质区别‌‌发送方向(TX)‌:由USART的‌TXE标志(发送寄存器空)触发‌&#x…

【秋招笔试】2025.08.15饿了么秋招机考-第三题

📌 点击直达笔试专栏 👉《大厂笔试突围》 💻 春秋招笔试突围在线OJ 👉 笔试突围在线刷题 bishipass.com 03. A先生的商贸网络投资 问题描述 A先生是一位精明的商人,他计划在 n n n 个城市之间建立商贸网络。目前有 m m

Socket 套接字的学习--UDP

上次我们大概介绍了一些关于网络的基础知识,这次我们利用编程来深入学习一下一:套接字Socket1.1什么是Socketsocket API 是一层抽象的网络编程接口,适用于各种底层网络协议,如 IPv4、IPv6,. 然而, 各种网络协议的地址格式并不相同。1.2套接字的分类套接字…

AI - MCP 协议(一)

AI应用开发的高级特性——MCP模型上下文协议,打通AI与外部服务的边界。 ************************************************************************************************************** 一、需求分析 当你的AI具备了RAG的能力,具备了调用工具的…

在es中安装kibana

一 安装 1.1 验证访问https的连通性 # 测试 80 端口(HTTP) curl -I -m 5 http://目标IP:端口号 说明: -I:仅获取 HTTP 头部(Head 请求),不下载正文,减少数据传输。 -m 5&#x…

嵌入式开发学习———Linux环境下网络编程学习(二)

UDP服务器客户端搭建UDP服务器代码#include <stdio.h> #include <stdlib.h> #include <string.h> #include <sys/socket.h> #include <netinet/in.h>#define PORT 8080 #define BUFFER_SIZE 1024int main() {int sockfd;char buffer[BUFFER_SIZE…

UVa1465/LA4841 Searchlights

UVa12345 UVa1465/LA4841 Searchlights题目链接题意输入格式输出格式分析AC 代码题目链接 本题是2010年icpc亚洲区域赛杭州赛区的I题 题意 在一个 n 行 m 列&#xff08;n≤100&#xff0c;m≤10 000&#xff09;的网格中有一些探照灯&#xff0c;每个探照灯有一个最大亮度 k&…