【PyTorch】单对象分割项目

对象分割是在图像中找到目标对象的边界的过程。单目标分割的重点是自动勾勒出图像中一个目标对象的边界。对象边界通常由二进制掩码定义。

通过二进制掩码,可以在图像上覆盖轮廓以勾勒出对象边界。例如以下图片描绘了胎儿的超声图像、胎儿头部的二进制掩码以及覆盖在超声图像上的胎儿头部的图像分割:

目录

准备数据集

创建自定义数据集

划分数据集

创建数据加载器

搭建模型

定义损失函数

定义优化器

训练和评估模型


准备数据集

使用胎儿头围数据集Automated measurement of fetal head circumference,在怀孕期间,超声成像用于测量胎儿头围,监测胎儿的生长。数据集包含标准平面的二维(2D)超声图像。Automated measurement of fetal head circumferenceFor more information about this dataset go to: https://hc18.grand-challenge.org/https://zenodo.org/record/1322001#.XcX1jk9KhhE

import os
path2train="./data/training_set/"imgsList=[pp for pp in os.listdir(path2train) if "Annotation" not in pp]
anntsList=[pp for pp in os.listdir(path2train) if "Annotation" in pp]
print("number of images:", len(imgsList))
print("number of annotations:", len(anntsList))import numpy as np
np.random.seed(2024)
rndImgs=np.random.choice(imgsList,4)
rndImgsimport matplotlib.pylab as plt
from PIL import Image
from scipy import ndimage as ndi
from skimage.segmentation import mark_boundaries
from torchvision.transforms.functional import to_tensor, to_pil_image
import torchdef show_img_mask(img, mask):if torch.is_tensor(img):img=to_pil_image(img)mask=to_pil_image(mask)img_mask=mark_boundaries(np.array(img), np.array(mask),outline_color=(0,1,0),color=(0,1,0))plt.imshow(img_mask)
for fn in rndImgs:path2img = os.path.join(path2train, fn)path2annt= path2img.replace(".png", "_Annotation.png")img = Image.open(path2img)annt_edges = Image.open(path2annt)mask = ndi.binary_fill_holes(annt_edges)        plt.figure()plt.subplot(1, 3, 1) plt.imshow(img, cmap="gray")plt.subplot(1, 3, 2) plt.imshow(mask, cmap="gray")plt.subplot(1, 3, 3) show_img_mask(img, mask)

plt.figure()
plt.subplot(1, 3, 1) 
plt.imshow(img, cmap="gray")
plt.axis('off')plt.subplot(1, 3, 2) 
plt.imshow(mask, cmap="gray")
plt.axis('off')    plt.subplot(1, 3, 3) 
show_img_mask(img, mask)
plt.axis('off')

# conda install conda-forge/label/cf202003::albumentations
from albumentations import (HorizontalFlip,VerticalFlip,    Compose,Resize,
)h,w=128,192
transform_train = Compose([ Resize(h,w), HorizontalFlip(p=0.5), VerticalFlip(p=0.5), ])transform_val = Resize(h,w)

创建自定义数据集

from torch.utils.data import Dataset
from PIL import Image
from torchvision.transforms.functional import to_tensor, to_pil_imageclass fetal_dataset(Dataset):def __init__(self, path2data, transform=None):      imgsList=[pp for pp in os.listdir(path2data) if "Annotation" not in pp]anntsList=[pp for pp in os.listdir(path2train) if "Annotation" in pp]self.path2imgs = [os.path.join(path2data, fn) for fn in imgsList] self.path2annts= [p2i.replace(".png", "_Annotation.png") for p2i in self.path2imgs]self.transform = transformdef __len__(self):return len(self.path2imgs)def __getitem__(self, idx):path2img = self.path2imgs[idx]image = Image.open(path2img)path2annt = self.path2annts[idx]annt_edges = Image.open(path2annt)mask = ndi.binary_fill_holes(annt_edges)        image= np.array(image)mask=mask.astype("uint8")        if self.transform:augmented = self.transform(image=image, mask=mask)image = augmented['image']mask = augmented['mask']            image= to_tensor(image)            mask=255*to_tensor(mask)            return image, mask
fetal_ds1=fetal_dataset(path2train, transform=transform_train)
fetal_ds2=fetal_dataset(path2train, transform=transform_val)
img,mask=fetal_ds1[0]
print(img.shape, img.type(),torch.max(img))
print(mask.shape, mask.type(),torch.max(mask))show_img_mask(img, mask)

划分数据集

按照8:2的比例划分训练数据集和验证数据集

from sklearn.model_selection import ShuffleSplitsss = ShuffleSplit(n_splits=1, test_size=0.2, random_state=0)
indices=range(len(fetal_ds1))
for train_index, val_index in sss.split(indices):print(len(train_index))print("-"*10)print(len(val_index))

from torch.utils.data import Subsettrain_ds=Subset(fetal_ds1,train_index)
print(len(train_ds))
val_ds=Subset(fetal_ds2,val_index)
print(len(val_ds))

展示训练数据集示例图像 

plt.figure(figsize=(5,5))
for img,mask in train_ds:show_img_mask(img,mask)break

展示验证数据集示例图像 

plt.figure(figsize=(5,5))
for img,mask in val_ds:show_img_mask(img,mask)break

创建数据加载器

from torch.utils.data import DataLoader
train_dl = DataLoader(train_ds, batch_size=8, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=16, shuffle=False) for img_b, mask_b in train_dl:print(img_b.shape,img_b.dtype)print(mask_b.shape, mask_b.dtype)breakfor img_b, mask_b in val_dl:print(img_b.shape,img_b.dtype)print(mask_b.shape, mask_b.dtype)breaktorch.max(img_b)

搭建模型

基于编码器-解码器模型encoder–decoder model搭建分割任务模型

import torch.nn as nn
import torch.nn.functional as Fclass SegNet(nn.Module):def __init__(self, params):super(SegNet, self).__init__()C_in, H_in, W_in=params["input_shape"]init_f=params["initial_filters"] num_outputs=params["num_outputs"] self.conv1 = nn.Conv2d(C_in, init_f, kernel_size=3,stride=1,padding=1)self.conv2 = nn.Conv2d(init_f, 2*init_f, kernel_size=3,stride=1,padding=1)self.conv3 = nn.Conv2d(2*init_f, 4*init_f, kernel_size=3,padding=1)self.conv4 = nn.Conv2d(4*init_f, 8*init_f, kernel_size=3,padding=1)self.conv5 = nn.Conv2d(8*init_f, 16*init_f, kernel_size=3,padding=1)self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)self.conv_up1 = nn.Conv2d(16*init_f, 8*init_f, kernel_size=3,padding=1)self.conv_up2 = nn.Conv2d(8*init_f, 4*init_f, kernel_size=3,padding=1)self.conv_up3 = nn.Conv2d(4*init_f, 2*init_f, kernel_size=3,padding=1)self.conv_up4 = nn.Conv2d(2*init_f, init_f, kernel_size=3,padding=1)self.conv_out = nn.Conv2d(init_f, num_outputs , kernel_size=3,padding=1)    def forward(self, x):x = F.relu(self.conv1(x))x = F.max_pool2d(x, 2, 2)x = F.relu(self.conv2(x))x = F.max_pool2d(x, 2, 2)x = F.relu(self.conv3(x))x = F.max_pool2d(x, 2, 2)x = F.relu(self.conv4(x))x = F.max_pool2d(x, 2, 2)x = F.relu(self.conv5(x))x=self.upsample(x)x = F.relu(self.conv_up1(x))x=self.upsample(x)x = F.relu(self.conv_up2(x))x=self.upsample(x)x = F.relu(self.conv_up3(x))x=self.upsample(x)x = F.relu(self.conv_up4(x))x = self.conv_out(x)return x params_model={"input_shape": (1,h,w),"initial_filters": 16, "num_outputs": 1,}model = SegNet(params_model)import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model=model.to(device)

打印模型结构

print(model)

获取模型摘要 

from torchsummary import summary
summary(model, input_size=(1, h, w))

定义损失函数

def dice_loss(pred, target, smooth = 1e-5):intersection = (pred * target).sum(dim=(2,3))union= pred.sum(dim=(2,3)) + target.sum(dim=(2,3)) dice= 2.0 * (intersection + smooth) / (union+ smooth)    loss = 1.0 - dicereturn loss.sum(), dice.sum()import torch.nn.functional as Fdef loss_func(pred, target):bce = F.binary_cross_entropy_with_logits(pred, target,  reduction='sum')pred= torch.sigmoid(pred)dlv, _ = dice_loss(pred, target)loss = bce  + dlvreturn lossfor img_v,mask_v in val_dl:mask_v= mask_v[8:]breakfor img_t,mask_t in train_dl:breakprint(dice_loss(mask_v,mask_v))
loss_func(mask_v,torch.zeros_like(mask_v))

import torchvisiondef metrics_batch(pred, target):pred= torch.sigmoid(pred)_, metric=dice_loss(pred, target)return metricdef loss_batch(loss_func, output, target, opt=None):   loss = loss_func(output, target)with torch.no_grad():pred= torch.sigmoid(output)_, metric_b=dice_loss(pred, target)if opt is not None:opt.zero_grad()loss.backward()opt.step()return loss.item(), metric_b

定义优化器

from torch import optim
opt = optim.Adam(model.parameters(), lr=3e-4)from torch.optim.lr_scheduler import ReduceLROnPlateau
lr_scheduler = ReduceLROnPlateau(opt, mode='min',factor=0.5, patience=20,verbose=1)def get_lr(opt):for param_group in opt.param_groups:return param_group['lr']current_lr=get_lr(opt)
print('current lr={}'.format(current_lr))

训练和评估模型

def loss_epoch(model,loss_func,dataset_dl,sanity_check=False,opt=None):running_loss=0.0running_metric=0.0len_data=len(dataset_dl.dataset)for xb, yb in dataset_dl:xb=xb.to(device)yb=yb.to(device)output=model(xb)loss_b, metric_b=loss_batch(loss_func, output, yb, opt)running_loss += loss_bif metric_b is not None:running_metric+=metric_bif sanity_check is True:breakloss=running_loss/float(len_data)metric=running_metric/float(len_data)return loss, metric
import copy
def train_val(model, params):num_epochs=params["num_epochs"]loss_func=params["loss_func"]opt=params["optimizer"]train_dl=params["train_dl"]val_dl=params["val_dl"]sanity_check=params["sanity_check"]lr_scheduler=params["lr_scheduler"]path2weights=params["path2weights"]loss_history={"train": [],"val": []}metric_history={"train": [],"val": []}    best_model_wts = copy.deepcopy(model.state_dict())best_loss=float('inf')    for epoch in range(num_epochs):current_lr=get_lr(opt)print('Epoch {}/{}, current lr={}'.format(epoch, num_epochs - 1, current_lr))   model.train()train_loss, train_metric=loss_epoch(model,loss_func,train_dl,sanity_check,opt)loss_history["train"].append(train_loss)metric_history["train"].append(train_metric)model.eval()with torch.no_grad():val_loss, val_metric=loss_epoch(model,loss_func,val_dl,sanity_check)loss_history["val"].append(val_loss)metric_history["val"].append(val_metric)   if val_loss < best_loss:best_loss = val_lossbest_model_wts = copy.deepcopy(model.state_dict())torch.save(model.state_dict(), path2weights)print("Copied best model weights!")lr_scheduler.step(val_loss)if current_lr != get_lr(opt):print("Loading best model weights!")model.load_state_dict(best_model_wts) print("train loss: %.6f, dice: %.2f" %(train_loss,100*train_metric))print("val loss: %.6f, dice: %.2f" %(val_loss,100*val_metric))print("-"*10) model.load_state_dict(best_model_wts)return model, loss_history, metric_history        
opt = optim.Adam(model.parameters(), lr=3e-4)# 定义学习率调度器,当验证集上的损失不再下降时,将学习率降低为原来的0.5倍,等待20个epoch后再次降低学习率
lr_scheduler = ReduceLROnPlateau(opt, mode='min',factor=0.5, patience=20,verbose=1)path2models= "./models/"# 判断path2models路径是否存在,如果不存在则创建该路径
if not os.path.exists(path2models):os.mkdir(path2models)params_train={"num_epochs": 100,"optimizer": opt,"loss_func": loss_func,"train_dl": train_dl,"val_dl": val_dl,"sanity_check": False,"lr_scheduler": lr_scheduler,"path2weights": path2models+"weights.pt",
}model,loss_hist,metric_hist=train_val(model,params_train)

打印训练验证损失

num_epochs=params_train["num_epochs"]plt.title("Train-Val Loss")
plt.plot(range(1,num_epochs+1),loss_hist["train"],label="train")
plt.plot(range(1,num_epochs+1),loss_hist["val"],label="val")
plt.ylabel("Loss")
plt.xlabel("Training Epochs")
plt.legend()
plt.show()

 

打印训练验证精度

# plot accuracy progress
plt.title("Train-Val Accuracy")
plt.plot(range(1,num_epochs+1),metric_hist["train"],label="train")
plt.plot(range(1,num_epochs+1),metric_hist["val"],label="val")
plt.ylabel("Accuracy")
plt.xlabel("Training Epochs")
plt.legend()
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/news/919847.shtml
繁体地址,请注明出处:http://hk.pswp.cn/news/919847.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

esp dl

放下了好多年 又回到了dl 该忘的也忘的差不多了 其实没啥复杂的 只是不习惯 熟悉而已 好吧 现代的人工智能体 还是存在着很大的问题 眼睛 耳朵 思考 虽然功能是正常的 但距离&#xff02;真正&#xff02;(&#xff09;意思上的独立意识个体 还是差别很大 再等个几十年 看看…

基于django/python的服装销售系统平台/服装购物系统/基于django/python的服装商城

基于django/python的服装销售系统平台/服装购物系统/基于django/python的服装商城

详解ThreadLocal<HttpServletRequest> requestThreadLocal

public static ThreadLocal<HttpServletRequest> requestThreadLocal ThreadLocal.withInitial(() -> null);一、代码逐部分详解 1. public static public&#xff1a;表示这个变量是公开的&#xff0c;其他类可以访问。static&#xff1a;表示这是类变量&#xff0c…

Vue2 响应式系统设计原理与实现

文章目录Vue2 响应式系统设计原理与实现Vue2 响应式系统设计原理与实现 Vue2 的响应式原理主要基于以下几点&#xff1a; 使用 Object.defineProperty () 方法对数据对象的属性进行劫持 当数据发生变化时&#xff0c;通知依赖该数据的视图进行更新 实现一个发布 - 订阅模式&a…

探索 JUC:Java 并发编程的神奇世界

探索 JUC&#xff1a;Java 并发编程的神奇世界 在 Java 编程领域&#xff0c;随着多核处理器的普及和应用场景复杂度的提升&#xff0c;并发编程变得愈发重要。Java 并发包&#xff08;JUC&#xff0c;Java.util.concurrent&#xff09;就像是一座宝藏库&#xff0c;为开发者提…

selenium采集数据怎么应对反爬机制?

selenium是一个非常强大的浏览器自动化工具&#xff0c;通过操作浏览器来抓取动态网页内容&#xff0c;可以很好的处理JavaScript和AJAX加载的网页。 它能支持像点击按钮、悬停元素、填写表单等各种自动化操作&#xff0c;所以很适合自动化测试和数据采集。 selenium与各种主流…

指定文件夹上的压缩图像格式tiff转换为 jpg 批量脚本

文章大纲 背景简介 代码 背景简介 随着数字成像技术在科研、医学影像和遥感等领域的广泛应用,多页TIFF(Tag Image File Format)文件因其支持多维数据存储和高位深特性,成为存储序列图像、显微镜切片或卫星遥感数据的首选格式。然而在实际应用中,这类文件存在以下显著痛点…

Docker 部署 MySQL 8.0 完整指南:从拉取镜像到配置远程访问

目录前言一、拉取镜像二、查看镜像三、运行容器命令参数说明&#xff1a;四、查看运行容器五、进入容器内部六、修改 MySQL 配置1. 创建配置文件2. 配置内容七、重启 MySQL 服务八、设置 Docker 启动时自动启动 MySQL九、再次重启 MySQL十、授权远程访问1. 进入容器内部2. 登录…

IntelliJ IDEA 常用快捷键笔记(Windows)

前言&#xff1a;特别标注的快捷键&#xff08;Windows&#xff09;快捷键功能说明Ctrl Alt M将选中代码提取成方法Ctrl Alt T包裹选中代码块&#xff08;try/catch、if、for 等&#xff09;Ctrl H查看类的继承层次Alt 7打开项目结构面板Ctrl F12打开当前文件结构视图Ct…

疏老师-python训练营-Day54Inception网络及其思考

浙大疏锦行 DAY54 一、 inception网络介绍 今天我们介绍inception&#xff0c;也就是GoogleNet 传统计算机视觉的发展史 从上面的链接&#xff0c;可以看到其实inceptionnet是在resnet之前的&#xff0c;那为什么我今天才说呢&#xff1f;因为他要引出我们后面的特征融合和…

LeetCode第3304题 - 找出第 K 个字符 I

题目 解答 class Solution {public char kthCharacter(int k) {int n 0;int v 1;while (v < k) {v << 1;n;}String target kthCharacterString(n);return target.charAt(k - 1);}public String kthCharacterString(int n) {if (n 0) {return "a";}Str…

Codeforces Round 1043 (Div. 3) D-F 题解

D. From 1 to Infinity 题意 有一个无限长的序列&#xff0c;是把所有正整数按次序拼接&#xff1a;123456789101112131415...\texttt{123456789101112131415...}123456789101112131415...。求这个序列前 k(k≤1015)k(k\le 10^{15})k(k≤1015) 位的数位和。 思路 二分出第 …

【C语言16天强化训练】从基础入门到进阶:Day 7

&#x1f525;个人主页&#xff1a;艾莉丝努力练剑 ❄专栏传送门&#xff1a;《C语言》、《数据结构与算法》、C语言刷题12天IO强训、LeetCode代码强化刷题、洛谷刷题、C/C基础知识知识强化补充、C/C干货分享&学习过程记录 &#x1f349;学习方向&#xff1a;C/C方向学习者…

【AI基础:神经网络】16、神经网络的生理学根基:从人脑结构到AI架构,揭秘道法自然的智能密码

“道法自然,久藏玄冥”——人工神经网络(ANN)的崛起并非偶然,而是对自然界最精妙的智能系统——人脑——的深度模仿与抽象。从单个神经元的信号处理到大脑皮层的层级组织,从突触可塑性的学习机制到全脑并行计算的高效能效,生物大脑的“玄冥”智慧为AI提供了源源不断的灵感…

容器安全实践(一):概念篇 - 从“想当然”到“真相”

在容器化技术日益普及的今天&#xff0c;许多开发者和运维人员都将应用部署在 Docker 或 Kubernetes 中。然而&#xff0c;一个普遍存在的误解是&#xff1a;“容器是完全隔离的&#xff0c;所以它是安全的。” 如果你也有同样的想法&#xff0c;那么你需要重新审视容器安全了。…

腾讯开源WeKnora:新一代文档理解与检索框架

引言&#xff1a;文档智能处理的新范式 在数字化时代&#xff0c;企业和个人每天都面临着海量文档的处理需求&#xff0c;从产品手册到学术论文&#xff0c;从合同条款到医疗报告&#xff0c;非结构化文档的高效处理一直是技术痛点。2025年8月&#xff0c;腾讯正式开源了基于大…

C++之list类的代码及其逻辑详解 (中)

接下来我会依照前面所说的一些接口以及list的结构来进行讲解。1. list_node的结构1.1 list_node结构体list由于其结构为双向循环链表&#xff0c;所以我们在这里要这么初始化_next&#xff1a;指向链表中下一个节点的指针_prev&#xff1a;指向链表中上一个节点的指针_val&…

新能源汽车热管理仿真:蒙特卡洛助力神经网络训练

研究背景在新能源汽车的热管理仿真研究中&#xff0c;神经网络训练技术常被应用于系统降阶建模。通过这一方法&#xff0c;可以构建出高效准确的代理模型&#xff0c;进而用于控制策略的优化、系统性能的预测与评估&#xff0c;以及实时仿真等任务&#xff0c;有效提升开发效率…

第十九讲:C++11第一部分

目录 1、C11简介 2、列表初始化 2.1、{}初始化 2.2、initializer_list 2.2.1、成员函数 2.2.2、应用 3、变量类型推导 3.1、auto 3.2、decltype 3.3、nullptr 4、范围for 5、智能指针 6、STL的一些变化 7、右值引用和移动语义 7.1、右值引用 7.2、右值与左值引…

书写本体论视域下的文字学理论重构

在符号学与哲学的交叉领域&#xff0c;文字学&#xff08;Grammatologie&#xff09;作为一门颠覆性学科始终处于理论风暴的中心。自德里达1967年发表《论文字学》以来&#xff0c;传统语言学中"语音中心主义"的霸权地位遭遇根本性动摇&#xff0c;文字不再被视为语言…