参加jittor比赛,热身赛中使用的GAN模型,想起自己还没真正使用过GAN,希望通过这个机会学习下
引入
想要真正地理解世界,就应该能够生成世界的种种组成。因此出现了生成模型;
生成模型是说,我们随机地生成一些图片(以图片任务为例),使得这些图片能够尽可能地描绘真实世界;
然而这个标准很难量化去衡量(怎么样才算真实?),因此提出生成对抗网络GAN,同时设计生成器和判别器两个部分,生成器的任务是努力生成能够以假乱真的图片,而判别器的任务是尽可能区分生成的图片和真实图片;这样评判标准就清晰了,即是真实的训练数据还是生成的数据,同时,生成器和判别器两部分的博弈使得整个系统稳步运行;
然而仅仅生成一些随机的真实图片可能没办法满足我们的需求,我们可能还需要生成有条件的真实图片,比如,我想生成猫的真实图像,那么条件就是猫,则提出Conditional GAN;
这些条件可以是类别标签(上面猫的例子),也可以是用于图像修复的部分数据(图像缺失,那么保留下的那部分数据就可以作为条件),甚至是其他模态的数据(给定一个文本查询,生成文本相关的图片,那么文本查询就可以是条件)
方法整个网络可以通过一个公式来说明,在没有condition的情况下(也就是普通的生成网络),公式是这样的(其中
E
E
E代表期望):
m
i
n
G
m
a
x
D
V
(
D
,
G
)
=
E
x
∼
p
d
a
t
a
(
x
)
[
l
o
g
D
(
x
)
]
+
E
z
∼
p
z
(
z
)
[
l
o
g
(
1
−
D
(
G
(
z
)
)
)
]
min_Gmax_DV(D,G)=E_{x \sim p_{data}(x)}[logD(x)]+E_{z \sim p_z(z)}[log(1-D(G(z)))]
minGmaxDV(D,G)=Ex∼pdata(x)[logD(x)]+Ez∼pz(z)[log(1−D(G(z)))],记为公式 (1)
加上条件之后,公式变为:
m
i
n
G
m
a
x
D
V
(
D
,
G
)
=
E
x
∼
p
d
a
t
a
(
x
∣
y
)
[
l
o
g
D
(
x
)
]
+
E
z
∼
p
z
(
z
)
[
l
o
g
(
1
−
D
(
G
(
z
∣
y
)
)
)
]
min_Gmax_DV(D,G)=E_{x \sim p_{data}(x|y)}[logD(x)]+E_{z \sim p_z(z)}[log(1-D(G(z|y)))]
minGmaxDV(D,G)=Ex∼pdata(x∣y)[logD(x)]+Ez∼pz(z)[log(1−D(G(z∣y)))],记为公式 (2)
对这两个公式的理解(只有条件部分不同,就放在一起说了),可以参考李宏毅的视频对抗生成网络(GAN),这里简单总结一下:
这里的两个目标
m
i
n
G
min_G
minG和
m
a
x
D
max_D
maxD,可以看作是fix
D
/
G
D/G
D/G的参数,更新
G
/
D
G/D
G/D的参数,比如
m
a
x
D
max_D
maxD就代表固定
G
G
G,更新
D
D
D,那么公式对应的详细算法就如下所示:
- 学习判别器 D D D的时候,固定生成器 G G G,先利用生成器 G G G对噪声生成 x ~ \tilde x x~,目标就变为: m a x D V = E x ∼ p d a t a ( x ) [ l o g D ( x ) ] + E z ∼ p z ( z ) [ l o g ( 1 − D ( x ~ ) ) ] max_DV=E_{x \sim p_{data}(x)}[logD(x)]+E_{z \sim p_z(z)}[log(1-D(\tilde x))] maxDV=Ex∼pdata(x)[logD(x)]+Ez∼pz(z)[log(1−D(x~))],那么需要完成以下任务:
- 判别器 D D D执行假图判别任务, D ( G ( z ) ) D(G(z)) D(G(z)):
向量 | 含义 | 维度 | |
---|---|---|---|
输入 | gen_imgs | 假图 | b s z , 32 , 32 bsz, 32, 32 bsz,32,32 |
gen_labels | 假图的label | b s z , 1 bsz, 1 bsz,1 | |
输出 | validity_fake | 假图的真实性 | b s z , 1 bsz, 1 bsz,1 |
目标 | m a x l o g ( 1 − D ( x ~ ) ) maxlog(1-D(\tilde x)) maxlog(1−D(x~)) | 尽可能将假图判定为假 | 等价于 m i n D ( x ~ ) minD(\tilde x) minD(x~) |
- 判别器 D D D执行真图判别任务, D ( x ) D(x) D(x):
向量 | 含义 | 维度 | |
---|---|---|---|
输入 | real_imgs | 真图 | b s z , 32 , 32 bsz, 32, 32 bsz,32,32 |
labels | 真图的label | b s z , 1 bsz, 1 bsz,1 | |
输出 | validity_real | 真图的真实性 | b s z , 1 bsz, 1 bsz,1 |
目标 | m a x l o g D ( x ) maxlogD(x) maxlogD(x) | 尽可能将真图判定为真 | 等价于 m a x D ( x ) maxD(x) maxD(x) |
- 学习生成器 G G G的时候,固定判别器 D D D,目标就变为: m i n G V = E z ∼ p z ( z ) [ l o g ( 1 − D ( x ~ ) ) ] min_GV=E_{z \sim p_z(z)}[log(1-D(\tilde x))] minGV=Ez∼pz(z)[log(1−D(x~))],也就是等价于 m a x G V = E z ∼ p z ( z ) [ l o g D ( x ~ ) ] max_GV=E_{z \sim p_z(z)}[logD(\tilde x)] maxGV=Ez∼pz(z)[logD(x~)](后者的max是上面那张算法图中的写法,我们还是跟随公式使用前者min,因为当时因为写成max导致好久没看懂,无语…),那么需要完成以下任务:
- 生成器 G G G执行假图生成任务, G ( z ) G(z) G(z),然后判别器 D D D执行假图判断任务, D ( G ( z ) ) D(G(z)) D(G(z)):
生成器 | 向量 | 含义 | 维度 |
---|---|---|---|
输入 | z z z | 随机向量 | b s z , 100 bsz, 100 bsz,100 |
gen_labels | 要生成假图的label,相当于condition | b s z , 1 bsz, 1 bsz,1 | |
输出 | gen_imgs | 生成的假图 | b s z , 32 , 32 bsz, 32, 32 bsz,32,32 |
判别器 | 向量 | 含义 | 维度 |
输入 | gen_imgs | 生成的假图 | b s z , 32 , 32 bsz, 32, 32 bsz,32,32 |
gen_labels | 假图的label | b s z , 1 bsz, 1 bsz,1 | |
输出 | validity | 假图的真实性 | b s z , 1 bsz, 1 bsz,1 |
目标 | m i n l o g ( 1 − D ( G ( z ) ) ) minlog(1-D(G(z))) minlog(1−D(G(z))) | 尽可能令生成器生成能够以假乱真的图片,从而让判别器将生成的图片判别为真实 | 等价于 m a x D ( G ( z ) ) maxD(G(z)) maxD(G(z)) |
import jittor as jt
from jittor import init
import argparse
import os
import numpy as np
import math
from jittor import nn
if jt.has_cuda:
jt.flags.use_cuda = 1
parser = argparse.ArgumentParser()
parser.add_argument('--n_epochs', type=int, default=100, help='number of epochs of training')
parser.add_argument('--batch_size', type=int, default=64, help='size of the batches')
parser.add_argument('--lr', type=float, default=0.0002, help='adam: learning rate')
parser.add_argument('--b1', type=float, default=0.5, help='adam: decay of first order momentum of gradient')
parser.add_argument('--b2', type=float, default=0.999, help='adam: decay of first order momentum of gradient')
parser.add_argument('--n_cpu', type=int, default=8, help='number of cpu threads to use during batch generation')
parser.add_argument('--latent_dim', type=int, default=100, help='dimensionality of the latent space')
parser.add_argument('--n_classes', type=int, default=10, help='number of classes for dataset')
parser.add_argument('--img_size', type=int, default=32, help='size of each image dimension')
parser.add_argument('--channels', type=int, default=1, help='number of image channels')
# sample_interval 暂不知用途
parser.add_argument('--sample_interval', type=int, default=1000, help='interval between image sampling')
opt = parser.parse_args()
print(opt)
# img_shape: 1, 32, 32
img_shape = (opt.channels, opt.img_size, opt.img_size)
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
# nn.Embedding创建了一个可查询的embedding字典
# 参数1是num,指embedding字典的大小,参数2是dim,embedding变量的大小
# 也就是说会创建一个10*10的查询表
self.label_emb = nn.Embedding(opt.n_classes, opt.n_classes)
# nn.Linear(in_dim, out_dim)表示全连接层
# in_dim:输入向量维度
# out_dim:输出向量维度
# block定义了一个层,fc+bn+LeakyReLU,以list的形式存在
def block(in_feat, out_feat, normalize=True):
layers = [nn.Linear(in_feat, out_feat)]
if normalize:
layers.append(nn.BatchNorm1d(out_feat, 0.8))
# nn.LeakyReLU代表泄露的ReLU(负值部分泄露,0.2表示泄露的斜率)
layers.append(nn.LeakyReLU(0.2))
return layers
# *用在实参前,相当于对tuple的解构(对应的**是对dict的解构),长度需要跟函数需要的参数对应
# *用在形参前,可以表示一个可变长度的tuple(对应的**表示可变长度的dict,同时有*和**时,*在前,**在后)
# 这里是指把block作为一个tuple传递给nn.Sequential函数
# nn.Sequential需要的参数为*args,也就是说可以传递任意长度的参数
# block(100+10, 128, normalize=False)
# model的维度不断增加,看来应该是个生成器(reshape的 *** 作在后面)
# nn.Linear(1024, 1*32*32=1024)
self.model = nn.Sequential(*block((opt.latent_dim + opt.n_classes), 128, normalize=False),
*block(128, 256),
*block(256, 512),
*block(512, 1024),
nn.Linear(1024, int(np.prod(img_shape))),
nn.Tanh())
# 相当于pytorch中的forward函数,表示网络向量的传递过程
def execute(self, noise, labels):
# label是条件,noise是生成的随机变量
# label: , noise:
gen_input = jt.contrib.concat((self.label_emb(labels), noise), dim=1)
img = self.model(gen_input)
# 将img从1024维向量变为32*32矩阵
img = img.view((img.shape[0], *img_shape))
return img
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.label_embedding = nn.Embedding(opt.n_classes, opt.n_classes)
# nn.Linear(10+1*32*32=1034, 512)
self.model = nn.Sequential(nn.Linear((opt.n_classes + int(np.prod(img_shape))), 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 512),
nn.Dropout(0.4),
nn.LeakyReLU(0.2),
nn.Linear(512, 512),
nn.Dropout(0.4),
nn.LeakyReLU(0.2),
# TODO: 添加最后一个线性层,最终输出为一个实数
nn.Linear(512, 1)
)
def execute(self, img, labels):
# img: bsz, 1*32*32=1024; label: bsz, 10
d_in = jt.contrib.concat((img.view((img.shape[0], (- 1))), self.label_embedding(labels)), dim=1)
# TODO: 将d_in输入到模型中并返回计算结果
validity = self.model(d_in)
return validity
# 损失函数:平方误差
# 调用方法:adversarial_loss(网络输出A, 分类标签B)
# 计算结果:(A-B)^2
adversarial_loss = nn.MSELoss()
generator = Generator()
discriminator = Discriminator()
# 导入MNIST数据集
from jittor.dataset.mnist import MNIST
import jittor.transform as transform
# 以链式的方式组合多个transform
transform = transform.Compose([
transform.Resize(opt.img_size),
# transform.Gray()只在jittor中有的一个函数
# 将任意形式的PIL图像(RGB, HSV, LAB等)转化为灰度图
# 应该有参数啊,为什么没给参数?
transform.Gray(),
# transform.ImageNormalize只在jittor中有的一个函数
transform.ImageNormalize(mean=[0.5], std=[0.5]),
])
# set_attrs只在jittor中有的一个函数
# 顾名思义,为数据集设置一些属性
# 这个函数的确很方便,方便取用一些重要的参数
dataloader = MNIST(train=True, transform=transform).set_attrs(batch_size=opt.batch_size, shuffle=True)
optimizer_G = nn.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = nn.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
from PIL import Image
def save_image(img, path, nrow=10, padding=5):
N,C,W,H = img.shape
if (N%nrow!=0):
print("N%nrow!=0")
return
# 共N个图片,把他们分成nrow行ncol列
ncol=int(N/nrow)
img_all = []
for i in range(ncol):
# 每一列都建立一个新的img_ list
img_ = []
for j in range(nrow):
# 将第i*nrow+j个图片存入,其维度为C, W, H
img_.append(img[i*nrow+j])
# 生成一个C, W, padding的0矩阵加入(干嘛用的,猜想是用来放label的,所以padding应该为10)
img_.append(np.zeros((C,W,padding)))
# np.concatenate可以一次完成多个数组的拼接
# 对img_ 这个list在第2维度上进行拼接
# 则生成C, W, (H+padding)*nrow+的矩阵,加入img_all list
img_all.append(np.concatenate(img_, 2))
# 生成C, padding, (H+padding)*nrow的矩阵,也加入img_all list(干嘛用的)
img_all.append(np.zeros((C,padding,img_all[0].shape[2])))
# 在img_all的第1维度上拼接
# img: C, (W+padding)*ncol, (H+padding)*nrow
img = np.concatenate(img_all, 1)
# 生成C, padding, (H+padding)*nrow的0矩阵
# 再与img在第1维度上拼接,得到C, (W+padding)*ncol+padding, (H+padding)*nrow
img = np.concatenate([np.zeros((C,padding,img.shape[2])), img], 1)
# 生成C, (W+padding)*ncol+padding, padding的0矩阵
# 再与img在第2维度上拼接,得到C, (W+padding)*ncol+padding, (H+padding)*nrow+padding
img = np.concatenate([np.zeros((C,img.shape[1],padding)), img], 2)
min_=img.min()
max_=img.max()
# 标准化到0~255
img=(img-min_)/(max_-min_)*255
# (W+padding)*ncol+padding, (H+padding)*nrow+padding, C
img=img.transpose((1,2,0))
if C==3:
# 把C的维度倒序输出?why
img = img[:,:,::-1]
elif C==1:
img = img[:,:,0]
Image.fromarray(np.uint8(img)).save(path)
def sample_image(n_row, batches_done):
# 随机采样输入并保存生成的图片
# 从正态分布中生成随机变量,维度n_row^2, latent_dim=100
# 推测含义是生成nrow^2个向量,每个向量维度是100
# 所以为什么要停止梯度的计算?(又不是叶子节点)
z = jt.array(np.random.normal(0, 1, (n_row ** 2, opt.latent_dim))).float32().stop_grad()
# 生成一个array,nrow*nrow的长度
labels = jt.array(np.array([num for _ in range(n_row) for num in range(n_row)])).float32().stop_grad()
# 生成器生成图片32*32
gen_imgs = generator(z, labels)
save_image(gen_imgs.numpy(), "%d.png" % batches_done, nrow=n_row)
# ----------
# 模型训练
# ----------
for epoch in range(opt.n_epochs):
# imgs: 64, 1, 32, 32
# labels: 64
for i, (imgs, labels) in enumerate(dataloader):
batch_size = imgs.shape[0]
# 数据标签,valid=1表示真实的图片,fake=0表示生成的图片
# 表示有batch_size个真实图片和batch_size个生成图片
valid = jt.ones([batch_size, 1]).float32().stop_grad()
fake = jt.zeros([batch_size, 1]).float32().stop_grad()
# 真实图片及其类别
real_imgs = jt.array(imgs)
labels = jt.array(labels)
# -----------------
# 训练生成器
# -----------------
# 采样随机噪声和数字类别作为生成器输入
# z: batch_size, 100
z = jt.array(np.random.normal(0, 1, (batch_size, opt.latent_dim))).float32()
# gen_labels: batch_size(从0~10生成随机整数)
gen_labels = jt.array(np.random.randint(0, opt.n_classes, batch_size)).float32()
# 生成一组图片
# 放到生成器中的是生成的label,不是真正的label
# 生成器会返回64, 32, 32的图片
# 此处完成了公式(2)中的G(z|y),在生成标签的条件下,为噪声z生成图片
gen_imgs = generator(z, gen_labels)
# 损失函数衡量生成器欺骗判别器的能力,即希望判别器将生成图片分类为valid
# 判别器对生成的图片和生成的label进行判断,输出一个实数
# validity: 64, 1
# 此处是完成了公式(2)中的D(G(z|y)),获得了生成标签下生成图片的真实性validity
validity = discriminator(gen_imgs, gen_labels)
# 令生成的图片与真实的图片标签进行均方误差
# 此处是完成了公式(2)中的目标:maxD,也就是令D(G(z|y))逼近于1
g_loss = adversarial_loss(validity, valid)
g_loss.sync()
optimizer_G.step(g_loss)
# ---------------------
# 训练判别器
# ---------------------
# 判别器输出真实图片的实数
# validity_real: 64, 1
# 此处是完成公式(2)中的D(x|y),判断真实标签下真实图片的真实性
validity_real = discriminator(real_imgs, labels)
# d_real_loss = adversarial_loss("""TODO: 计算真实类别的损失函数""")
# 此处是完成公式(2)中的
d_real_loss = adversarial_loss(validity_real, valid)
validity_fake = discriminator(gen_imgs.stop_grad(), gen_labels)
# d_fake_loss = adversarial_loss("""TODO: 计算虚假类别的损失函数""")
# 这里的目标是令对生成图片的预测努力接近生成标签
d_fake_loss = adversarial_loss(validity_fake, fake)
# 总的判别器损失
d_loss = (d_real_loss + d_fake_loss) / 2
d_loss.sync()
optimizer_D.step(d_loss)
if i % 50 == 0:
print(
"[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
% (epoch, opt.n_epochs, i, len(dataloader), d_loss.data, g_loss.data)
)
batches_done = epoch * len(dataloader) + i
if batches_done % opt.sample_interval == 0:
sample_image(n_row=10, batches_done=batches_done)
if epoch % 10 == 0:
generator.save("generator_last.pkl")
discriminator.save("discriminator_last.pkl")
generator.eval()
discriminator.eval()
generator.load('generator_last.pkl')
discriminator.load('discriminator_last.pkl')
number = 12312345678#TODO: 写入你注册时绑定的手机号(字符串类型)
n_row = len(number)
z = jt.array(np.random.normal(0, 1, (n_row, opt.latent_dim))).float32().stop_grad()
labels = jt.array(np.array([int(number[num]) for num in range(n_row)])).float32().stop_grad()
gen_imgs = generator(z,labels)
img_array = gen_imgs.data.transpose((1,2,0,3))[0].reshape((gen_imgs.shape[2], -1))
min_=img_array.min()
max_=img_array.max()
img_array=(img_array-min_)/(max_-min_)*255
Image.fromarray(np.uint8(img_array)).save("result.png")
结束语
感觉GAN应该是深度学习中最难理解的网络了(我比较孤陋寡闻),不过最后最算是理解了,开心。姑且就把这篇文章定义为中级文章吧hh(本人的第一篇中级)
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)