1. 思路2. 流程3. 代码
1. 思路整个模型的要实现的模型是,我们有一个内容的图片和风格的图片,我们通过迁移学习得到一张新的图片。
读取内容图像和风格图像预处理和后处理抽取图像特征定义损失函数初始化合成图像训练模型 3. 代码
# -*- coding: utf-8 -*- # @Project: zc # @Author: ZhangChu # @File name: style_test # @Create time: 2022/1/7 20:10 # 1. 导入相关数据库 import matplotlib.pyplot as plt import torch import torchvision from torch import nn from d2l import torch as d2l import os # 2. 设置图片 d2l.set_figsize() # path = os.path.join(os.getcwd(), 'img', 'banan.jpg') path = 'D:/zc/img/rainier.jpg' content_img = d2l.Image.open(path) # d2l.plt.imshow(content_img) # plt.show() style_img = d2l.Image.open('D:/zc/img/autumn-oak.jpg') # d2l.plt.imshow(style_img) # plt.show() # 3. 张量均一化处理,rgb_mean :均值,rgb_std:方差 rgb_mean = torch.tensor([0.485, 0.456, 0.406]) rgb_std = torch.tensor([0.229, 0.224, 0.225]) # 4. 预处理,Compose组合 def preprocess(img, image_shape): transforms = torchvision.transforms.Compose([ torchvision.transforms.Resize(image_shape), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize(mean=rgb_mean, std=rgb_std)]) return transforms(img).unsqueeze(0) # 5. 后处理,将张量的值进行反归一化处理 def postprocess(img): # 将图片放到 GPU 上 img = img[0].to(rgb_std.device) # 将张量进行反归一化处理,torch.clamp限制张量的值在[0,1]之间 img = torch.clamp(img.permute(1, 2, 0) * rgb_std + rgb_mean, 0, 1) # 将张量变成PILImage图片 return torchvision.transforms.ToPILImage()(img.permute(2, 0, 1)) # 6.将vgg19模型和对应的权重下载作为预处理模型 pretrained_net = torchvision.models.vgg19(pretrained=True) # 7.特征提取 # 选择每个卷积块的第一个卷积层作为风格层style_layers # 选择第四个卷积块的最后一个卷积层作为内容层 content_layers style_layers, content_layers = [0, 5, 10, 19, 28], [25] # 8.从vgg19里面筛选出对应需要的层,具体为我们选择的层 net = nn.Sequential(*[pretrained_net.features[i] for i in range(max(content_layers + style_layers) + 1)]) # 9. 抽取特征,返回内容层和风格层 def extract_features(X, content_layers, style_layers): contents = [] styles = [] for i in range(len(net)): X = net[i](X) if i in style_layers: styles.append(X) if i in content_layers: contents.append(X) return contents, styles # 10.对内容图像抽取内容特征 def get_contents(image_shape, device): """ :param image_shape: 图像的大小 :param device: GPU :return: """ # 将内容图片进行预处理 content_X = preprocess(content_img, image_shape).to(device) # 抽取特征 content_Y, _ = extract_features(content_X, content_layers, style_layers) return content_X, content_Y # 11. 获取风格 def get_styles(image_shape, device): # 对风格图片进行预处理 style_X = preprocess(style_img, image_shape).to(device) # 对风格图片进行特征抽取 _, styles_Y = extract_features(style_X, content_layers, style_layers) return style_X, styles_Y # 12. 均方误差来计算得到内容损失 def content_loss(Y_hat, Y): return torch.square(Y_hat - Y.detach()).mean() # 13.向量克拉姆矩阵, # 主要用来衡量合成图像与风格图像在风格上的差异 # 通过计算向量X_i,X_j的内积,来判断通道i,通道j上风格特征的相关性 def gram(X): num_channels, n = X.shape[1], X.numel() // X.shape[1] X = X.reshape((num_channels, n)) return torch.matmul(X, X.T) / (num_channels * n) # 14.风格损失函数 # 将计算得到的Y_hat 风格[通过gram函数得到]与标签的风格gram_Y求均方误差得到损失 def style_loss(Y_hat, gram_Y): return torch.square(gram(Y_hat) - gram_Y.detach()).mean() # 15.全变分损失,一种常见的去噪方法 def tv_loss(Y_hat): return 0.5 * (torch.abs(Y_hat[:, :, 1:, :] - Y_hat[:, :, :-1, :]).mean() + torch.abs(Y_hat[:, :, :, 1:] - Y_hat[:, :, :, :-1]).mean()) content_weight, style_weight, tv_weight = 1, 1e3, 10 # 16. 分别计算内容损失,风格损失和全变分损失; # 通过调节这些权重超参数,我们可以权衡合成对象在保留内容, # 迁移风格 以及去噪三个方面的相对重要性 def compute_loss(X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram): """ :param X: 输入的图片张量 X :param contents_Y_hat: 通过net(x)得到的内容Y_hat :param styles_Y_hat: 通过net(x)得到的风格 Y_hat :param contents_Y: 标签 Y :param styles_Y_gram: 风格标签 Y :return: contents_l : 内容损失; styles_l : 风格损失 tv_l :全变分损失 l :总损失 """ # 计算内容损失 contents_l = [content_loss(Y_hat, Y) * content_weight for Y_hat, Y in zip( contents_Y_hat, contents_Y)] # 计算风格损失 styles_l = [style_loss(Y_hat, Y) * style_weight for Y_hat, Y in zip( styles_Y_hat, styles_Y_gram)] # 计算全变分损失来去噪 tv_l = tv_loss(X) * tv_weight # 根据一定的比例将三种损失求和得到一个总的损失 l = sum(10 * styles_l + contents_l + [tv_l]) return contents_l, styles_l, tv_l, l # 17. 初始化合成图像,因为权重需要更新,故设置为nn.Parameter类型 class SynthesizedImage(nn.Module): def __init__(self, img_shape, **kwargs): super(SynthesizedImage, self).__init__(**kwargs) self.weight = nn.Parameter(torch.rand(*img_shape)) def forward(self): return self.weight # 18.创建合成图像类的一个模型实例,并将其初始化为图像X def get_inits(X, device, lr, styles_Y): """ :param X: 输入矩阵 X :param device: GPU :param lr: 学习率 :param styles_Y: 风格 Y :return: """ # 实例化一个合成图像类的一个模型实例 gen_img = SynthesizedImage(X.shape).to(device) # 通过 X.data 初始化实例 gen_img 的权重值 gen_img.weight.data.copy_(X.data) # 定义优化器Adam ,需要更新的参数为实例的 gen_img # 设置学习率 lr trainer = torch.optim.Adam(gen_img.parameters(), lr=lr) styles_Y_gram = [gram(Y) for Y in styles_Y] return gen_img(), styles_Y_gram, trainer # 19. 开始训练模型 def train(X, contents_Y, styles_Y, device, lr, num_epochs, lr_decay_epoch): """ :param X: 输入图像X的张量 :param contents_Y: 内容_Y :param styles_Y: 风格_Y :param device: GPU :param lr: 学习率 :param num_epochs: 训练迭代的次数 :param lr_decay_epoch: 学习率衰减次数 :return: """ # 初始化一个合成实例X,styles_Y_gram :风格Y矩阵,trainer:优化器 X, styles_Y_gram, trainer = get_inits(X, device, lr, styles_Y) # 调整学习率 scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_decay_epoch, 0.8) animator = d2l.Animator(xlabel='epoch', ylabel='loss', xlim=[10, num_epochs], legend=['content', 'style', 'TV'], ncols=2, figsize=(7, 2.5)) # 开始迭代 for epoch in range(num_epochs): # 优化器梯度清零 trainer.zero_grad() # 抽取特征,获得内容 contents_Y_hat, styles_Y_hat = extract_features( X, content_layers, style_layers) # 计算相关损失 contents_l, styles_l, tv_l, l = compute_loss( X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram) # 总损失 l 的梯度回传 l.backward() # 优化器梯度更新 trainer.step() # 学习率衰减优化 scheduler.step() if (epoch + 1) % 10 == 0: animator.axes[1].imshow(postprocess(X)) animator.add(epoch + 1, [float(sum(contents_l)), float(sum(styles_l)), float(tv_l)]) return X # 20. 设置 device = GPU,设置图片的大小 device, image_shape = d2l.try_gpu(), (300, 450) # 21. 将神经网络放到 GPU上 net = net.to(device) # 22.根据图片得到内容特征_X,内容特征_Y content_X, content_Y = get_contents(image_shape, device) # 23.抽取特征风格_Y _, styles_Y = get_styles(image_shape, device) # 24.开始训练模型 output = train(content_X.contens_Y, styles_Y, device, 0.3, 500, 500)
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)