自定义数据集
# -*- I Love Python!!! And You? -*-
# @Time : 2022/3/27 12:25
# @Author : sunao
# @Email : 939419697@qq.com
# @File : img_segData.py
# @Software: PyCharm
import matplotlib.pyplot as plt
import numpy as np
import torch
from PIL import Image
from torch.utils.data import Dataset,DataLoader
from torchvision import transforms
import torch.nn.functional as F
import os
class img_segData(Dataset):
def __init__(self,img_h=256,img_w=256,path="./data/img_seg",data_file="images",label_file="profiles",
preprocess=True):
'''
数据集初始化
:param img_h: resize图像高度
:param img_w: resize图像宽度
:param path: 数据集路径
:param data_file: 数据特征值文件夹名称
:param label_file: 数据标签文件夹名称
:param preprocess: 是否进行数据预处理
'''
super(img_segData, self).__init__()
self.file_list = os.listdir(path+"/"+data_file)
self.data_file = data_file
self.label_files = label_file
self.path = path
self.img_h = img_h
self.img_w = img_w
self.preprocess = preprocess
pass
def __len__(self):
# 返回数据集大小
return len(self.file_list)
def __getitem__(self, item):
# 返回指定索引的数据集
img_name = self.file_list[item]
label_name = img_name.split(".")[0]+"-profile.jpg"
label_path = self.path+"/"+self.label_files+"/"+label_name
img_path = self.path+"/"+self.data_file+"/"+img_name
# 读取数据
img = Image.open(img_path)
label = Image.open(label_path)
# 数据预处理
if self.preprocess:
trans_img = transforms.Compose([
transforms.Resize(size=(self.img_w,self.img_h)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5))
])
img = trans_img(img)
trans_label = transforms.Compose([
transforms.Resize(size=(self.img_w,self.img_h)),
transforms.ToTensor(),
])
label = trans_label(label)
return img,label
if __name__ == '__main__':
trans_data = img_segData()
img,label = trans_data.__getitem__(5)
print(img.size(),label.size())
# plt.imshow(img.data.numpy().transpose([1,2,0]))
# plt.show()
# plt.imshow(label.data.numpy().reshape(256,256))
# plt.show()
label = torch.where(label==1,torch.full_like(label,0),torch.full_like(label,1))
seg = label * img
plt.imshow(seg.data.numpy().transpose([1,2,0]))
plt.show()
模型
# -*- I Love Python!!! And You? -*-
# @Time : 2022/3/27 13:02
# @Author : sunao
# @Email : 939419697@qq.com
# @File : model.py
# @Software: PyCharm
import torch
import torch.nn as nn
import torch.nn.functional as F
class conv_block(nn.Module):
def __init__(self,ch_in,ch_out):
super(conv_block, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels=ch_in, out_channels=ch_out, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(ch_out),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=ch_out,out_channels=ch_out,kernel_size=3,stride=1,padding=1),
nn.BatchNorm2d(ch_out),
nn.ReLU(inplace=True),
)
def forward(self,x):
out = self.conv(x)
return out
class up_block(nn.Module):
def __init__(self,ch_in,ch_out):
super(up_block, self).__init__()
self.conv = nn.Sequential(
nn.Upsample(scale_factor=2),
nn.Conv2d(in_channels=ch_in,out_channels=ch_out,kernel_size=3,stride=1,padding=1),
nn.BatchNorm2d(ch_out),
nn.ReLU(inplace=True),
)
def forward(self,x):
out = self.conv(x)
return out
class U_Net(nn.Module):
def __init__(self,img_ch=3,output_ch=1):
super(U_Net, self).__init__()
self.ndf=64
self.Maxpool = nn.MaxPool2d(2,2)
self.conv1 = conv_block(ch_in=img_ch,ch_out=self.ndf)
self.conv2 = conv_block(ch_in=self.ndf,ch_out=self.ndf * 2)
self.conv3 = conv_block(ch_in=self.ndf*2,ch_out=self.ndf*2*2)
self.conv4 = conv_block(ch_in=self.ndf*2*2,ch_out=self.ndf*2*2*2)
self.conv5 = conv_block(ch_in=self.ndf*2*2*2,ch_out=self.ndf*2*2*2*2)
self.up4 = up_block(ch_in=self.ndf*2*2*2*2,ch_out=self.ndf*2*2*2)
self.up_conv4 = conv_block(ch_in=self.ndf*2*2*2*2,ch_out=self.ndf*2*2*2)
self.up3 = up_block(ch_in=self.ndf*2*2*2,ch_out=self.ndf*2*2)
self.up_conv3 = conv_block(ch_in=self.ndf*2*2*2,ch_out=self.ndf*2*2)
self.up2 = up_block(ch_in=self.ndf*2*2,ch_out=self.ndf*2)
self.up_conv2 = conv_block(ch_in=self.ndf*2*2,ch_out=self.ndf*2)
self.up1 = up_block(ch_in=self.ndf*2,ch_out=self.ndf)
self.up_conv1 = conv_block(ch_in=self.ndf * 2, ch_out=self.ndf)
self.conv1_1 = conv_block(ch_in=self.ndf,ch_out=output_ch)
def forward(self,x):
# x [none,3, 256, 256]
x1 = self.conv1(x) # [none,3,256,256]
x1_ = self.Maxpool(x1) # [none,64,128,128]
x2 = self.conv2(x1_) # [none,128,128,128]
x2_ = self.Maxpool(x2) # [none,128,64,64]
x3 = self.conv3(x2_) # [none,256,64,64]
x3_ = self.Maxpool(x3) # [none,256,32,32]
x4 = self.conv4(x3_) # [none,512,32,32]
x4_ = self.Maxpool(x4) # [none,512,16,16]
x5 = self.conv5(x4_) # [none,1024,16,16]
u4_ = self.up4(x5) # [none,1024,32,32]
u4 = self.up_conv4(torch.cat([x4,u4_],dim=1)) # [none,512,32,32]
u3_ = self.up3(u4) # [none,512,64,64]
u3 = self.up_conv3(torch.cat([x3,u3_],dim=1)) # [none,256,64,64]
u2_ = self.up2(u3) # [none,256,128,128]
u2 = self.up_conv2(torch.cat([x2,u2_],dim=1)) # [none,128,128,128]
u1_ = self.up1(u2) # [none,128,256,256]
u1 = self.up_conv1(torch.cat([x1,u1_],dim=1)) # [none,64,256,256]
out = self.conv1_1(u1) # [none,1,256,256]
out = torch.sigmoid(out)
return out
if __name__ == '__main__':
unet = U_Net()
print(unet)
训练模型
# -*- I Love Python!!! And You? -*-
# @Time : 2022/3/27 15:34
# @Author : sunao
# @Email : 939419697@qq.com
# @File : img2seg.py
# @Software: PyCharm
import numpy as np
import matplotlib.pyplot as plt
import torchvision
from img_segData import img_segData
from model import U_Net
from torch.utils import data
import torch
import os
from torchvision.utils import save_image
class Trainer(object):
def __init__(self,img_ch=3,out_ch=3,lr=0.005,
batch_size=16,num_epoch=60,train_set=None,
model_path="./model"):
"""
训练器初始化
:param img_ch: 输入图片通道
:param out_ch: 输出图片通道
:param lr: 学习率
:param batch_size: 批量大小
:param num_epoch: 迭代周期
:param train_set: 训练数据集
:param model_path: 模型保存路径
"""
self.img_ch = img_ch
self.out_ch = out_ch
self.lr = lr
self.batch_size = batch_size
self.num_epoch = num_epoch
self.model_path = model_path
self.data_loader = data.DataLoader(dataset=train_set,
batch_size=self.batch_size,
shuffle=True,num_workers=0)
# 初始化模型
self.unet = U_Net(self.img_ch,output_ch=self.out_ch)
self.divice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.unet.to(self.divice)
self.loss = torch.nn.BCELoss()
self.optim = torch.optim.Adam(self.unet.parameters(),lr=self.lr,betas=(0.5,0.99))
def train(self):
if os.path.exists(self.model_path):
self.unet.load_state_dict(torch.load(self.model_path+"/Unet.pkl"))
print("模型导入成功",self.model_path+"/Unet.pkl")
best_loss = 1000000
for epoch in range(self.num_epoch):
self.unet.train(True)
epoch_loss = 0
for i,(bx,by) in enumerate(self.data_loader):
bx = bx.to(self.divice)
by = by.to(self.divice)
bx_gen = self.unet(bx)
loss = self.loss(bx_gen,by)
self.optim.zero_grad()
loss.backward()
self.optim.step()
epoch_loss += loss.item()
print("| epoch %d/%d | loss %f |"%(epoch,self.num_epoch,epoch_loss))
self.save_img(save_name="epoch"+str(epoch)+".png")
if best_loss > epoch_loss:
best_loss = epoch_loss
if os.path.exists(self.model_path) is False:
os.makedirs(self.model_path)
torch.save(self.unet.state_dict(),self.model_path+"/Unet.pkl")
def save_img(self,save_path="./saved/Unet",save_name="result.png"):
data_iter = iter(self.data_loader)
img,labels = next(data_iter)
self.unet.eval()
with torch.no_grad():
bx_gen = self.unet(img.to(self.divice))
img = img.data.cpu()[:5]
print("img.shape ===",img.shape)
gen_label = bx_gen.data.cpu()[:5]
labels = labels.data.cpu()[:5]
gen_label = torch.where(gen_label>0.5,torch.full_like(gen_label,0),
torch.full_like(gen_label,1))
print("gen_label.shape ===",gen_label.shape)
labels = torch.where(labels>0.5,torch.full_like(labels,0),
torch.full_like(labels,1))
gen_label = torch.zeros([3,256,256]) + gen_label
seg_img = img * gen_label
# 0黑色,255白色
seg_img = torch.where(seg_img==0,torch.full_like(seg_img,255),seg_img)
seg_img2 = img * labels
seg_img2 = torch.where(seg_img==0,torch.full_like(seg_img2,255),seg_img2)
print(seg_img2.shape)
save_tensor = torch.cat([img,gen_label,seg_img,seg_img2],0)
if os.path.exists(save_path) is False:
os.makedirs(save_path)
save_image(save_tensor,save_path+'/'+save_name,nrow=5)
if __name__ == '__main__':
# 读取数据
torch.cuda.empty_cache()
train_data = img_segData(img_h=256,img_w=256,path="./data/img_seg",data_file="images",
label_file="profiles",preprocess="True")
# 构建模型,训练模型
trainer = Trainer(img_ch=3,out_ch=1,lr=0.01,batch_size=16,num_epoch=50,train_set=train_data)
trainer.train()
版权声明:本文为ALL_BYA原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。