文章详情

短信预约-IT技能 免费直播动态提醒

请输入下面的图形验证码

提交验证

短信预约提醒成功

pytorch通过自己的数据集训练Unet网络架构

2022-12-08 20:57

关注

在图像分割这个问题上,主要有两个流派:Encoder-Decoder和Dialated Conv。本文介绍的是编解码网络中最为经典的U-Net。随着骨干网路的进化,很多相应衍生出来的网络大多都是对于Unet进行了改进但是本质上的思路还是没有太多的变化。比如结合DenseNet 和Unet的FCDenseNet, Unet++

一、Unet网络介绍

论文:https://arxiv.org/abs/1505.04597v1(2015)

UNet的设计就是应用与医学图像的分割。由于医学影像处理中,数据量较少,本文提出的方法有效提升了使用少量数据集训练检测的效果,提出了处理大尺寸图像的有效方法。

UNet的网络架构继承自FCN,并在此基础上做了些改变。提出了Encoder-Decoder概念,实际上就是FCN那个先卷积再上采样的思想。

上图是Unet的网络结构,从图中可以看出,

结构左边为Encoder,即下采样提取特征的过程。Encoder基本模块为双卷积形式,即输入经过两个

conu 3x3,使用的valid卷积,在代码实现时我们可以增加padding使用same卷积,来适应Skip Architecture。下采样采用的池化层直接缩小2倍。

结构右边是Decoder,即上采样恢复图像尺寸并预测的过程。Decoder一样采用双卷积的形式,其中上采样使用转置卷积实现,每次转置卷积放大2倍。

结构中间copy and crop是一个cat操作,即feature map的通道叠加。

二、VOC训练Unet

2.1 Unet代码实现

根据上面对于Unet网络结构的介绍,可见其结构非常对称简单,代码Unet.py实现如下:

from turtle import forward
import torch.nn as nn
import torch
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.conv(x)
class Unet(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(Unet, self).__init__()
        # Encoder
        self.conv1 = DoubleConv(in_ch, 64)
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = DoubleConv(64, 128)
        self.pool2 = nn.MaxPool2d(2)
        self.conv3 = DoubleConv(128, 256)
        self.pool3 = nn.MaxPool2d(2)
        self.conv4 = DoubleConv(256, 512)
        self.pool4 = nn.MaxPool2d(2)
        self.conv5 = DoubleConv(512, 1024)
        # Decoder
        self.up6 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        self.conv6 = DoubleConv(1024, 512)
        self.up7 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.conv7 = DoubleConv(512, 256)
        self.up8 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.conv8 = DoubleConv(256, 128)
        self.up9 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.conv9 = DoubleConv(128, 64)
        self.output = nn.Conv2d(64, out_ch, 1)
    def forward(self, x):
        conv1 = self.conv1(x)
        pool1 = self.pool1(conv1)
        conv2 = self.conv2(pool1)
        pool2 = self.pool2(conv2)
        conv3 = self.conv3(pool2)
        pool3 = self.pool3(conv3)
        conv4 = self.conv4(pool3)
        pool4 = self.pool4(conv4)
        conv5 = self.conv5(pool4)
        up6 = self.up6(conv5)
        meger6 = torch.cat([up6, conv4], dim=1)
        conv6 = self.conv6(meger6)
        up7 = self.up7(conv6)
        meger7 = torch.cat([up7, conv3], dim=1)
        conv7 = self.conv7(meger7)
        up8 = self.up8(conv7)
        meger8 = torch.cat([up8, conv2], dim=1)
        conv8 = self.conv8(meger8)
        up9 = self.up9(conv8)
        meger9 = torch.cat([up9, conv1], dim=1)
        conv9 = self.conv9(meger9)
        out = self.output(conv9)
        return out
if __name__=="__main__":
    model = Unet(3, 21)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    print(model)

2.2 数据集处理

数据来源于kaggle,下载地址我忘了。包含2个类别,1个车,还有1个背景类,共有5k+的数据,按照比例分为训练集和验证集即可。具体见carnava.py

from PIL import Image
from requests import check_compatibility
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms as T
import numpy as np
import os
import matplotlib.pyplot as plt
class Car(Dataset):
    def __init__(self, root, train=True):
        self.root = root
        self.crop_size = (256, 256)
        self.img_path = os.path.join(root, "train_hq")
        self.label_path = os.path.join(root, "train_masks")
        img_path_list = [os.path.join(self.img_path, im) for im in os.listdir(self.img_path)]
        train_path_list, val_path_list = self._split_data_set(img_path_list)
        if train:
            self.imgs_list = train_path_list
        else:
            self.imgs_list = val_path_list
        normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        self.transforms = T.Compose([
                T.Resize(256),
                T.CenterCrop(256),
                T.ToTensor(),
                normalize
            ])
        self.transforms_val = T.Compose([
            T.Resize(256),
            T.CenterCrop(256)
        ])
        self.color_map = [[0, 0, 0], [255, 255, 255]]
    def __getitem__(self, index: int):
        im_path = self.imgs_list[index]
        image = Image.open(im_path).convert("RGB")
        data = self.transforms(image)
        (filepath, filename) = os.path.split(im_path)
        filename = filename.split('.')[0]
        label = Image.open(self.label_path +"/"+filename+"_mask.gif").convert("RGB")
        label = self.transforms_val(label)
        cm2lb=np.zeros(256**3)
        for i,cm in enumerate(self.color_map):
            cm2lb[(cm[0]*256+cm[1])*256+cm[2]]=i
        image=np.array(label,dtype=np.int64)
        idx=(image[:,:,0]*256+image[:,:,1])*256+image[:,:,2]
        label=np.array(cm2lb[idx],dtype=np.int64)
        label=torch.from_numpy(label).long()
        return data, label
    def label2img(self, label):
        cmap = self.color_map
        cmap = np.array(cmap).astype(np.uint8)
        pred = cmap[label]
        return pred
    def __len__(self):
        return len(self.imgs_list)
    def _split_data_set(self, img_path_list):
        val_path_list = img_path_list[::8]
        train_path_list = []
        for item in img_path_list:
            if item not in val_path_list:
                train_path_list.append(item)
        return train_path_list, val_path_list
if __name__=="__main__":
    root = "../dataset/carvana"
    car_train = Car(root,train=True)
    train_dataloader = DataLoader(car_train, batch_size=8, shuffle=True)
    print(len(car_train))
    print(len(train_dataloader))
    # for data, label in car_train:
    #     print(data.shape)
    #     print(label.shape)
    #     break
    (data, label) = car_train[190]
    label_np = label.data.numpy()
    label_im = car_train.label2img(label_np)
    plt.figure()
    plt.imshow(label_im)
    plt.show()

2.3 训练过程

分割其实就是给每个像素分类而已,所以损失函数依旧是交叉熵函数,正确率为分类正确的像素点个数/全部的像素点个数

import torch
import torch.nn as nn
from torch.utils.data import DataLoader,Dataset
from voc import VOC
from carnava import Car
from unet import Unet
import os
import numpy as np
from torch import optim
import torch.nn as nn
import util
# 计算混淆矩阵
def _fast_hist(label_true, label_pred, n_class):
    mask = (label_true >= 0) & (label_true < n_class)
    hist = np.bincount(
        n_class * label_true[mask].astype(int) +
        label_pred[mask], minlength=n_class ** 2).reshape(n_class, n_class)
    return hist
def label_accuracy_score(label_trues, label_preds, n_class):
    """Returns accuracy score evaluation result.
      - overall accuracy
      - mean accuracy
      - mean IU
    """
    hist = np.zeros((n_class, n_class))
    for lt, lp in zip(label_trues, label_preds):
        hist += _fast_hist(lt.flatten(), lp.flatten(), n_class)
    acc = np.diag(hist).sum() / hist.sum()
    with np.errstate(divide='ignore', invalid='ignore'):
        acc_cls = np.diag(hist) / hist.sum(axis=1)
    acc_cls = np.nanmean(acc_cls)
    with np.errstate(divide='ignore', invalid='ignore'):
        iu = np.diag(hist) / (
            hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)
        )
    mean_iu = np.nanmean(iu)
    freq = hist.sum(axis=1) / hist.sum()
    return acc, acc_cls, mean_iu
out_path = "./out"
if not os.path.exists(out_path):
    os.makedirs(out_path)
log_path = os.path.join(out_path, "result.txt")
if os.path.exists(log_path):
    os.remove(log_path)
model_path = os.path.join(out_path, "best_model.pth")
root = "../dataset/carvana"
epochs = 5
numclasses = 2
train_data = Car(root, train=True)
train_dataloader = DataLoader(train_data, batch_size=16, shuffle=True)
val_data = Car(root, train=False)
val_dataloader = DataLoader(val_data, batch_size=16, shuffle=True)
net = Unet(3, numclasses)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = net.to(device)
optimizer = optim.SGD(net.parameters(), lr=0.01, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss()
def train_model():
    best_score = 0.0
    for e in range(epochs):
        net.train()
        train_loss = 0.0
        label_true = torch.LongTensor()
        label_pred = torch.LongTensor()
        for batch_id, (data, label) in enumerate(train_dataloader):
            data, label = data.to(device), label.to(device)
            output = net(data)
            loss = criterion(output, label)
            pred = output.argmax(dim=1).squeeze().data.cpu()
            real = label.data.cpu()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss+=loss.cpu().item()
            label_true = torch.cat((label_true,real),dim=0)
            label_pred = torch.cat((label_pred,pred),dim=0)
        train_loss /= len(train_dataloader)
        acc, acc_cls, mean_iu = label_accuracy_score(label_true.numpy(),label_pred.numpy(),numclasses)
        print("\n epoch:{}, train_loss:{:.4f}, acc:{:.4f}, acc_cls:{:.4f}, mean_iu:{:.4f}".format(
            e+1, train_loss, acc, acc_cls, mean_iu))
        with open(log_path, 'a') as f:
            f.write('\n epoch:{}, train_loss:{:.4f}, acc:{:.4f}, acc_cls:{:.4f}, mean_iu:{:.4f}'.format(
                e+1,train_loss,acc, acc_cls, mean_iu))
        net.eval()
        val_loss = 0.0
        val_label_true = torch.LongTensor()
        val_label_pred = torch.LongTensor()
        with torch.no_grad():
            for batch_id, (data, label) in enumerate(val_dataloader):
                data, label = data.to(device), label.to(device)
                output = net(data)
                loss = criterion(output, label)
                pred = output.argmax(dim=1).squeeze().data.cpu()
                real = label.data.cpu()
                val_loss += loss.cpu().item()
                val_label_true = torch.cat((val_label_true, real), dim=0)
                val_label_pred = torch.cat((val_label_pred, pred), dim=0)
            val_loss/=len(val_dataloader)
            val_acc, val_acc_cls, val_mean_iu = label_accuracy_score(val_label_true.numpy(),
                                                                    val_label_pred.numpy(),numclasses)
        print('\n epoch:{}, val_loss:{:.4f}, acc:{:.4f}, acc_cls:{:.4f}, mean_iu:{:.4f}'.format(e+1, val_loss, val_acc, val_acc_cls, val_mean_iu))
        with open(log_path, 'a') as f:
            f.write('\n epoch:{}, val_loss:{:.4f}, acc:{:.4f}, acc_cls:{:.4f}, mean_iu:{:.4f}'.format(
            e+1,val_loss,val_acc, val_acc_cls, val_mean_iu))
        score = (val_acc_cls+val_mean_iu)/2
        if score > best_score:
            best_score = score
            torch.save(net.state_dict(), model_path)
def evaluate():
    import util
    import random
    import matplotlib.pyplot as plt
    net.load_state_dict(torch.load(model_path))
    index = random.randint(0, len(val_data)-1)
    val_image, val_label = val_data[index]
    out = net(val_image.unsqueeze(0).to(device))
    pred = out.argmax(dim=1).squeeze().data.cpu().numpy()
    label = val_label.data.numpy()
    img_pred = val_data.label2img(pred)
    img_label = val_data.label2img(label)
    temp = val_image.numpy()
    temp = (temp-np.min(temp)) / (np.max(temp)-np.min(temp))*255
    fig, ax = plt.subplots(1,3)
    ax[0].imshow(temp.transpose(1,2,0).astype("uint8"))
    ax[1].imshow(img_label)
    ax[2].imshow(img_pred)
    plt.show()
if __name__=="__main__":
    # train_model()
    evaluate()

最终训练结果是:

由于数据比较简单,训练到epoch为5时,mIOU就已经达到0.97了。

最后测试一下效果:

从左到右分别是:原图、真实label、预测label

备注:

其实最开始使用voc数据集训练的,但效果极差,也没发现哪里有问题。换个数据集效果就好了,可能有两个原因:

1. voc数据我在处理数据时出错了,没检查出来

2. 这个数据集比较简单,容易学习,所以效果差不多。

到此这篇关于pytorch通过自己的数据集训练Unet网络架构的文章就介绍到这了,更多相关pytorch Unet内容请搜索编程网以前的文章或继续浏览下面的相关文章希望大家以后多多支持编程网!

阅读原文内容投诉

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

软考中级精品资料免费领

  • 历年真题答案解析
  • 备考技巧名师总结
  • 高频考点精准押题
  • 2024年上半年信息系统项目管理师第二批次真题及答案解析(完整版)

    难度     807人已做
    查看
  • 【考后总结】2024年5月26日信息系统项目管理师第2批次考情分析

    难度     351人已做
    查看
  • 【考后总结】2024年5月25日信息系统项目管理师第1批次考情分析

    难度     314人已做
    查看
  • 2024年上半年软考高项第一、二批次真题考点汇总(完整版)

    难度     433人已做
    查看
  • 2024年上半年系统架构设计师考试综合知识真题

    难度     221人已做
    查看

相关文章

发现更多好内容

猜你喜欢

AI推送时光机
位置:首页-资讯-后端开发
咦!没有更多了?去看看其它编程学习网 内容吧
首页课程
资料下载
问答资讯