中兴捧月RAW图像去噪训练代码

中兴捧月RAW图像去噪训练代码,第1张

主要使用pytorch进行训练
将图片裁剪为四个
只需要修改训练数据路径,数据里面只保留dng文件

import torch
from  torch.utils.data import Dataset
import torch.utils.data as Data
import numpy as np
import os
import rawpy
from unetTorch import Unet
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

#加载数据集
class facadeData1(Dataset):
    def __init__(self,path='./data/train',data_file='noisy',label_file='ground'):
        self.path=path
        self.data_file=data_file
        self.label_file=label_file
        self.file_list=os.listdir(path+'/'+data_file)
    def __len__(self):
        return len(self.file_list)

    def read_image(self,input_path):
        raw = rawpy.imread(input_path)
        raw_data = raw.raw_image_visible
        height = raw_data.shape[0]
        width = raw_data.shape[1]

        raw_data_expand = np.expand_dims(raw_data, axis=2)
        raw_data_expand_c = np.concatenate((raw_data_expand[0:height:2, 0:width:2, :],
                                            raw_data_expand[0:height:2, 1:width:2, :],
                                            raw_data_expand[1:height:2, 0:width:2, :],
                                            raw_data_expand[1:height:2, 1:width:2, :]), axis=2)
        return raw_data_expand_c, height, width

    def normalization(self,input_data, black_level=1024, white_level=16383):
        output_data = (input_data.astype(float) - black_level) / (white_level - black_level)
        return output_data
    def __getitem__(self, index):
        image_name=self.file_list[index]
        label_name=image_name.replace('noise','gt')
        image_path=self.path+'/'+self.data_file+'/'+image_name
        label_path=self.path+'/'+self.label_file+'/'+label_name



        image,height,width=self.read_image(image_path)
        label,height2,width2=self.read_image(label_path)

        #标准化
        image_norm=self.normalization(image)
        label_norm=self.normalization(label)
        #扩展维度
        image=torch.from_numpy(np.transpose(image_norm.reshape(-1, height//2, width//2, 4), (0, 3, 1, 2))).float()
        image=image[:,:,0:868,0:1156]
        label=torch.from_numpy(np.transpose(label_norm.reshape(-1, height // 2, width//2, 4), (0, 3, 1, 2))).float()
        label=label[:, :, 0:868, 0:1156]
        return image,label
class facadeData2(Dataset):
    def __init__(self,path='./data/train',data_file='noisy',label_file='ground'):
        self.path=path
        self.data_file=data_file
        self.label_file=label_file
        self.file_list=os.listdir(path+'/'+data_file)
    def __len__(self):
        return len(self.file_list)

    def read_image(self,input_path):
        raw = rawpy.imread(input_path)
        raw_data = raw.raw_image_visible
        height = raw_data.shape[0]
        width = raw_data.shape[1]

        raw_data_expand = np.expand_dims(raw_data, axis=2)
        raw_data_expand_c = np.concatenate((raw_data_expand[0:height:2, 0:width:2, :],
                                            raw_data_expand[0:height:2, 1:width:2, :],
                                            raw_data_expand[1:height:2, 0:width:2, :],
                                            raw_data_expand[1:height:2, 1:width:2, :]), axis=2)
        return raw_data_expand_c, height, width

    def normalization(self,input_data, black_level=1024, white_level=16383):
        output_data = (input_data.astype(float) - black_level) / (white_level - black_level)
        return output_data
    def __getitem__(self, index):
        image_name=self.file_list[index]
        label_name=image_name.replace('noise','gt')
        image_path=self.path+'/'+self.data_file+'/'+image_name
        label_path=self.path+'/'+self.label_file+'/'+label_name



        image,height,width=self.read_image(image_path)
        label,height2,width2=self.read_image(label_path)

        #标准化
        image_norm=self.normalization(image)
        label_norm=self.normalization(label)
        #扩展维度
        image=torch.from_numpy(np.transpose(image_norm.reshape(-1, height//2, width//2, 4), (0, 3, 1, 2))).float()
        image=image[:,:,868:1736,0:1156]
        label=torch.from_numpy(np.transpose(label_norm.reshape(-1, height // 2, width//2, 4), (0, 3, 1, 2))).float()
        label=label[:, :, 868:1736, 0:1156]
        return image,label
class facadeData3(Dataset):
    def __init__(self,path='./data/train',data_file='noisy',label_file='ground'):
        self.path=path
        self.data_file=data_file
        self.label_file=label_file
        self.file_list=os.listdir(path+'/'+data_file)
    def __len__(self):
        return len(self.file_list)

    def read_image(self,input_path):
        raw = rawpy.imread(input_path)
        raw_data = raw.raw_image_visible
        height = raw_data.shape[0]
        width = raw_data.shape[1]

        raw_data_expand = np.expand_dims(raw_data, axis=2)
        raw_data_expand_c = np.concatenate((raw_data_expand[0:height:2, 0:width:2, :],
                                            raw_data_expand[0:height:2, 1:width:2, :],
                                            raw_data_expand[1:height:2, 0:width:2, :],
                                            raw_data_expand[1:height:2, 1:width:2, :]), axis=2)
        return raw_data_expand_c, height, width

    def normalization(self,input_data, black_level=1024, white_level=16383):
        output_data = (input_data.astype(float) - black_level) / (white_level - black_level)
        return output_data
    def __getitem__(self, index):
        image_name=self.file_list[index]
        label_name=image_name.replace('noise','gt')
        image_path=self.path+'/'+self.data_file+'/'+image_name
        label_path=self.path+'/'+self.label_file+'/'+label_name



        image,height,width=self.read_image(image_path)
        label,height2,width2=self.read_image(label_path)

        #标准化
        image_norm=self.normalization(image)
        label_norm=self.normalization(label)
        #扩展维度
        image=torch.from_numpy(np.transpose(image_norm.reshape(-1, height//2, width//2, 4), (0, 3, 1, 2))).float()
        image=image[:,:,0:868,1156:2312]
        label=torch.from_numpy(np.transpose(label_norm.reshape(-1, height // 2, width//2, 4), (0, 3, 1, 2))).float()
        label=label[:, :, 0:868, 1156:2312]
        return image,label
class facadeData4(Dataset):
    def __init__(self,path='./data/train',data_file='noisy',label_file='ground'):
        self.path=path
        self.data_file=data_file
        self.label_file=label_file
        self.file_list=os.listdir(path+'/'+data_file)
    def __len__(self):
        return len(self.file_list)

    def read_image(self,input_path):
        raw = rawpy.imread(input_path)
        raw_data = raw.raw_image_visible
        height = raw_data.shape[0]
        width = raw_data.shape[1]

        raw_data_expand = np.expand_dims(raw_data, axis=2)
        raw_data_expand_c = np.concatenate((raw_data_expand[0:height:2, 0:width:2, :],
                                            raw_data_expand[0:height:2, 1:width:2, :],
                                            raw_data_expand[1:height:2, 0:width:2, :],
                                            raw_data_expand[1:height:2, 1:width:2, :]), axis=2)
        return raw_data_expand_c, height, width

    def normalization(self,input_data, black_level=1024, white_level=16383):
        output_data = (input_data.astype(float) - black_level) / (white_level - black_level)
        return output_data
    def __getitem__(self, index):
        image_name=self.file_list[index]
        label_name=image_name.replace('noise','gt')
        image_path=self.path+'/'+self.data_file+'/'+image_name
        label_path=self.path+'/'+self.label_file+'/'+label_name



        image,height,width=self.read_image(image_path)
        label,height2,width2=self.read_image(label_path)

        #标准化
        image_norm=self.normalization(image)
        label_norm=self.normalization(label)
        #扩展维度
        image=torch.from_numpy(np.transpose(image_norm.reshape(-1, height//2, width//2, 4), (0, 3, 1, 2))).float()
        image=image[:,:,868:1736,1156:2312]
        label=torch.from_numpy(np.transpose(label_norm.reshape(-1, height // 2, width//2, 4), (0, 3, 1, 2))).float()
        label=label[:, :, 868:1736, 1156:2312]
        return image,label
class Trainer(object):
    def __init__(self,net,lr=1e-4,batch_size=10,num_epoch=1000,train_data=None):
        self.net=net
        self.net=self.net.to(device)
        self.batch_size=batch_size
        self.lr=lr
        self.num_epoch=num_epoch
        self.train_data=train_data
        self.data_loader=Data.DataLoader(dataset=train_data,batch_size=batch_size,shuffle=True)
        self.loss=torch.nn.MSELoss(reduce=True, size_average=True)
        self.loss = self.loss.to(device)
        
    def train(self):
        optim=torch.optim.Adam(self.net.parameters(),lr=self.lr)
        #self.net.load_state_dict(torch.load('./models2/model30.pth'))
        for epoch in range(self.num_epoch):
            for i,(bx,by) in enumerate(self.data_loader):
                bx=bx.squeeze(dim = 1)
                bx =bx.to(device)
                by=by.squeeze(dim = 1)
                by = by.to(device)
                pre=self.net(bx)
                loss =self.loss(pre, by)  # 求损失  # patch 123#bug
                optim.zero_grad()
                loss.backward()
                optim.step()
                print('i',i,'epoch',epoch,'loss',loss)
            if epoch%10==0:
                torch.save(self.net.state_dict(), f'./models3/model{epoch}.pth')

        return None

if __name__ == '__main__':
    train_data1=facadeData1()
    train_data2=facadeData2()
    train_data3=facadeData3()
    train_data4=facadeData4()
    train_data=torch.utils.data.ConcatDataset([train_data1]+[train_data2]+[train_data3]+[train_data3])

    # # print(train_data)
    # image,label=data.__getitem__(100)
    # print(train_data)
    net=Unet()
    trainer=Trainer(net=net,train_data=train_data)
    trainer.train()

欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/langs/790598.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-05-05
下一篇 2022-05-05

发表评论

登录后才能评论

评论列表(0条)

保存