主要使用pytorch进行训练
将图片裁剪为四个
只需要修改训练数据路径,数据里面只保留dng文件
import torch
from torch.utils.data import Dataset
import torch.utils.data as Data
import numpy as np
import os
import rawpy
from unetTorch import Unet
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
#加载数据集
class facadeData1(Dataset):
def __init__(self,path='./data/train',data_file='noisy',label_file='ground'):
self.path=path
self.data_file=data_file
self.label_file=label_file
self.file_list=os.listdir(path+'/'+data_file)
def __len__(self):
return len(self.file_list)
def read_image(self,input_path):
raw = rawpy.imread(input_path)
raw_data = raw.raw_image_visible
height = raw_data.shape[0]
width = raw_data.shape[1]
raw_data_expand = np.expand_dims(raw_data, axis=2)
raw_data_expand_c = np.concatenate((raw_data_expand[0:height:2, 0:width:2, :],
raw_data_expand[0:height:2, 1:width:2, :],
raw_data_expand[1:height:2, 0:width:2, :],
raw_data_expand[1:height:2, 1:width:2, :]), axis=2)
return raw_data_expand_c, height, width
def normalization(self,input_data, black_level=1024, white_level=16383):
output_data = (input_data.astype(float) - black_level) / (white_level - black_level)
return output_data
def __getitem__(self, index):
image_name=self.file_list[index]
label_name=image_name.replace('noise','gt')
image_path=self.path+'/'+self.data_file+'/'+image_name
label_path=self.path+'/'+self.label_file+'/'+label_name
image,height,width=self.read_image(image_path)
label,height2,width2=self.read_image(label_path)
#标准化
image_norm=self.normalization(image)
label_norm=self.normalization(label)
#扩展维度
image=torch.from_numpy(np.transpose(image_norm.reshape(-1, height//2, width//2, 4), (0, 3, 1, 2))).float()
image=image[:,:,0:868,0:1156]
label=torch.from_numpy(np.transpose(label_norm.reshape(-1, height // 2, width//2, 4), (0, 3, 1, 2))).float()
label=label[:, :, 0:868, 0:1156]
return image,label
class facadeData2(Dataset):
def __init__(self,path='./data/train',data_file='noisy',label_file='ground'):
self.path=path
self.data_file=data_file
self.label_file=label_file
self.file_list=os.listdir(path+'/'+data_file)
def __len__(self):
return len(self.file_list)
def read_image(self,input_path):
raw = rawpy.imread(input_path)
raw_data = raw.raw_image_visible
height = raw_data.shape[0]
width = raw_data.shape[1]
raw_data_expand = np.expand_dims(raw_data, axis=2)
raw_data_expand_c = np.concatenate((raw_data_expand[0:height:2, 0:width:2, :],
raw_data_expand[0:height:2, 1:width:2, :],
raw_data_expand[1:height:2, 0:width:2, :],
raw_data_expand[1:height:2, 1:width:2, :]), axis=2)
return raw_data_expand_c, height, width
def normalization(self,input_data, black_level=1024, white_level=16383):
output_data = (input_data.astype(float) - black_level) / (white_level - black_level)
return output_data
def __getitem__(self, index):
image_name=self.file_list[index]
label_name=image_name.replace('noise','gt')
image_path=self.path+'/'+self.data_file+'/'+image_name
label_path=self.path+'/'+self.label_file+'/'+label_name
image,height,width=self.read_image(image_path)
label,height2,width2=self.read_image(label_path)
#标准化
image_norm=self.normalization(image)
label_norm=self.normalization(label)
#扩展维度
image=torch.from_numpy(np.transpose(image_norm.reshape(-1, height//2, width//2, 4), (0, 3, 1, 2))).float()
image=image[:,:,868:1736,0:1156]
label=torch.from_numpy(np.transpose(label_norm.reshape(-1, height // 2, width//2, 4), (0, 3, 1, 2))).float()
label=label[:, :, 868:1736, 0:1156]
return image,label
class facadeData3(Dataset):
def __init__(self,path='./data/train',data_file='noisy',label_file='ground'):
self.path=path
self.data_file=data_file
self.label_file=label_file
self.file_list=os.listdir(path+'/'+data_file)
def __len__(self):
return len(self.file_list)
def read_image(self,input_path):
raw = rawpy.imread(input_path)
raw_data = raw.raw_image_visible
height = raw_data.shape[0]
width = raw_data.shape[1]
raw_data_expand = np.expand_dims(raw_data, axis=2)
raw_data_expand_c = np.concatenate((raw_data_expand[0:height:2, 0:width:2, :],
raw_data_expand[0:height:2, 1:width:2, :],
raw_data_expand[1:height:2, 0:width:2, :],
raw_data_expand[1:height:2, 1:width:2, :]), axis=2)
return raw_data_expand_c, height, width
def normalization(self,input_data, black_level=1024, white_level=16383):
output_data = (input_data.astype(float) - black_level) / (white_level - black_level)
return output_data
def __getitem__(self, index):
image_name=self.file_list[index]
label_name=image_name.replace('noise','gt')
image_path=self.path+'/'+self.data_file+'/'+image_name
label_path=self.path+'/'+self.label_file+'/'+label_name
image,height,width=self.read_image(image_path)
label,height2,width2=self.read_image(label_path)
#标准化
image_norm=self.normalization(image)
label_norm=self.normalization(label)
#扩展维度
image=torch.from_numpy(np.transpose(image_norm.reshape(-1, height//2, width//2, 4), (0, 3, 1, 2))).float()
image=image[:,:,0:868,1156:2312]
label=torch.from_numpy(np.transpose(label_norm.reshape(-1, height // 2, width//2, 4), (0, 3, 1, 2))).float()
label=label[:, :, 0:868, 1156:2312]
return image,label
class facadeData4(Dataset):
def __init__(self,path='./data/train',data_file='noisy',label_file='ground'):
self.path=path
self.data_file=data_file
self.label_file=label_file
self.file_list=os.listdir(path+'/'+data_file)
def __len__(self):
return len(self.file_list)
def read_image(self,input_path):
raw = rawpy.imread(input_path)
raw_data = raw.raw_image_visible
height = raw_data.shape[0]
width = raw_data.shape[1]
raw_data_expand = np.expand_dims(raw_data, axis=2)
raw_data_expand_c = np.concatenate((raw_data_expand[0:height:2, 0:width:2, :],
raw_data_expand[0:height:2, 1:width:2, :],
raw_data_expand[1:height:2, 0:width:2, :],
raw_data_expand[1:height:2, 1:width:2, :]), axis=2)
return raw_data_expand_c, height, width
def normalization(self,input_data, black_level=1024, white_level=16383):
output_data = (input_data.astype(float) - black_level) / (white_level - black_level)
return output_data
def __getitem__(self, index):
image_name=self.file_list[index]
label_name=image_name.replace('noise','gt')
image_path=self.path+'/'+self.data_file+'/'+image_name
label_path=self.path+'/'+self.label_file+'/'+label_name
image,height,width=self.read_image(image_path)
label,height2,width2=self.read_image(label_path)
#标准化
image_norm=self.normalization(image)
label_norm=self.normalization(label)
#扩展维度
image=torch.from_numpy(np.transpose(image_norm.reshape(-1, height//2, width//2, 4), (0, 3, 1, 2))).float()
image=image[:,:,868:1736,1156:2312]
label=torch.from_numpy(np.transpose(label_norm.reshape(-1, height // 2, width//2, 4), (0, 3, 1, 2))).float()
label=label[:, :, 868:1736, 1156:2312]
return image,label
class Trainer(object):
def __init__(self,net,lr=1e-4,batch_size=10,num_epoch=1000,train_data=None):
self.net=net
self.net=self.net.to(device)
self.batch_size=batch_size
self.lr=lr
self.num_epoch=num_epoch
self.train_data=train_data
self.data_loader=Data.DataLoader(dataset=train_data,batch_size=batch_size,shuffle=True)
self.loss=torch.nn.MSELoss(reduce=True, size_average=True)
self.loss = self.loss.to(device)
def train(self):
optim=torch.optim.Adam(self.net.parameters(),lr=self.lr)
#self.net.load_state_dict(torch.load('./models2/model30.pth'))
for epoch in range(self.num_epoch):
for i,(bx,by) in enumerate(self.data_loader):
bx=bx.squeeze(dim = 1)
bx =bx.to(device)
by=by.squeeze(dim = 1)
by = by.to(device)
pre=self.net(bx)
loss =self.loss(pre, by) # 求损失 # patch 123#bug
optim.zero_grad()
loss.backward()
optim.step()
print('i',i,'epoch',epoch,'loss',loss)
if epoch%10==0:
torch.save(self.net.state_dict(), f'./models3/model{epoch}.pth')
return None
if __name__ == '__main__':
train_data1=facadeData1()
train_data2=facadeData2()
train_data3=facadeData3()
train_data4=facadeData4()
train_data=torch.utils.data.ConcatDataset([train_data1]+[train_data2]+[train_data3]+[train_data3])
# # print(train_data)
# image,label=data.__getitem__(100)
# print(train_data)
net=Unet()
trainer=Trainer(net=net,train_data=train_data)
trainer.train()
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)