基于Keras实现CGAN用于手写数字生成

基于Keras实现CGAN用于手写数字生成,第1张

"""条件GAN的实现:
基于MNIST数据集"""
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"]="2"
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Activation,BatchNormalization,Concatenate,Dense,Embedding,Flatten,Input,Multiply,Reshape
from tensorflow.keras.layers import Conv2D,Conv2DTranspose
from tensorflow.keras.models import Model,Sequential
from tensorflow.keras.optimizers import Adam
import numpy as np

#指定图像尺寸形状
img_rows=28
img_cols=28
channels=1
img_shape=(img_rows,img_cols,channels)

#噪声向量维度
z_dim=100
#总类别
num_classes=10


#生成器的构建
def build_generator(z_dim):
    model=Sequential()
    model.add(Dense(256*7*7,input_dim=z_dim))
    model.add(Reshape((7,7,256)))
    model.add(Conv2DTranspose(128,kernel_size=3,strides=2,padding="same"))
    model.add(BatchNormalization())
    model.add(Activation("relu"))

    model.add(Conv2DTranspose(64,kernel_size=3,strides=1,padding="same"))
    model.add(BatchNormalization())
    model.add(Activation("relu"))

    model.add(Conv2DTranspose(1,kernel_size=3,strides=2,padding="same"))
    model.add(Activation("tanh"))
    return model

#构建条件GAN生成器
def build_cgan_generator(z_dim):
    z=Input(shape=(z_dim,))
    label=Input(shape=(1,),dtype="int32")
    #标签变为与z_dim等维度的词向量
    label_embedding=Embedding(num_classes,z_dim,input_length=1)(label)
    label_embedding=Flatten()(label_embedding)
    #标签嵌入z
    joined_representation=Multiply()([z,label_embedding])
    generator=build_generator(z_dim)
    condition_img=generator(joined_representation)
    return Model([z,label],condition_img)


#鉴别器的构建
def build_discriminator(img_shape):
    model=Sequential()
    model.add(Conv2D(64,kernel_size=3,strides=2,input_shape=(img_shape[0],img_shape[1],img_shape[2]+1),padding="same"))
    model.add(Activation("relu"))

    model.add(Conv2D(64,kernel_size=3,strides=2,padding="same"))
    model.add(BatchNormalization())
    model.add(Activation("relu"))

    model.add(Conv2D(128,kernel_size=3,padding="same"))
    model.add(BatchNormalization())
    model.add(Activation("relu"))

    model.add(Flatten())
    model.add(Dense(1,activation="sigmoid"))

    return model


def build_cgan_discriminator(img_shape):
    img=Input(shape=img_shape)
    label=Input(shape=(1,),dtype="int32")
    label_embedding=Embedding(num_classes,np.prod(img_shape),input_length=1)(label)
    label_embedding=Flatten()(label_embedding)

    label_embedding=Reshape(img_shape)(label_embedding)

    concatenated=Concatenate(axis=-1)([img,label_embedding])

    discriminator=build_discriminator(img_shape)
    classification=discriminator(concatenated)
    return Model([img,label],classification)


#整个GAN模型的搭建
def build_cgan(generator,discriminator):
    z=Input(shape=(z_dim,))
    label=Input(shape=(1,))

    img=generator([z,label])
    classification=discriminator([img,label])

    return Model([z,label],classification)

# print(build_cgan(build_cgan_generator(z_dim),build_cgan_discriminator(img_shape)).summary())
#输出样本图像
def sample_images(img_grid_row=2,img_grid_col=5):
    z=np.random.normal(0,1,(img_grid_col*img_grid_row,z_dim))
    labels=np.arange(0,10).reshape(-1,1)
    gen_images=generator.predict([z,labels])
    gen_images=0.5*gen_images+0.5
    fig,axs=plt.subplots(img_grid_row,img_grid_col,figsize=(10,4),sharex=True,sharey=True)
    cnt=0
    for i in range(img_grid_row):
        for j in range(img_grid_col):
            axs[i,j].imshow(gen_images[cnt,:,:,0],cmap="gray")
            axs[i,j].axis("off")
            axs.set_title("Digit: %d"%labels[cnt])
            cnt+=1
    plt.show()



#构建鉴别器
discriminator=build_cgan_discriminator(img_shape)
discriminator.compile(loss="binary_crossentropy",optimizer=Adam(),metrics=["acc"])

generator=build_cgan_generator(z_dim)
discriminator.trainable=False
cgan=build_cgan(generator,discriminator)

cgan.compile(loss="binary_crossentropy",optimizer=Adam())


#模型训练
accuracies=[]
losses=[]

def train(iterations,batch_size,sample_interval):
    (X_train,y_train),(_,_)=mnist.load_data()
    #数据预处理
    X_train=X_train/127.5-1.
    X_train=np.expand_dims(X_train,axis=3)

    real=np.ones((batch_size,1))
    fake=np.zeros((batch_size,1))
    for iteration in range(iterations):
        idx=np.random.randint(0,X_train.shape[0],batch_size)
        imgs,labels=X_train[idx],y_train[idx]

        z=np.random.normal(0,1,(batch_size,z_dim))

        gen_img=generator([z,labels])
        discriminator.trainable=True
        d_loss_real=discriminator.train_on_batch([imgs,labels],real)
        d_loss_fake=discriminator.train_on_batch([gen_img,labels],fake)
        d_loss=0.5*np.add(d_loss_real,d_loss_fake)
        discriminator.trainable=False
        z=np.random.normal(0,1,(batch_size,z_dim))
        labels=np.random.randint(0,num_classes,(batch_size)).reshape(-1,1)
        g_loss=cgan.train_on_batch([z,labels],fake)

        if(iterations+1)%sample_interval==0:
            print("%d [D loss: %f,acc.: %.2f] [G loss: %f]"%(iteration+1,d_loss[0],d_loss[1]*100,g_loss))
            losses.append((d_loss[0],g_loss))
            accuracies.append(d_loss[1]*100)
            sample_images()
iterations=12000
batch_size=32
sample_interval=1000

train(iterations,batch_size,sample_interval)

欢迎分享,转载请注明来源:内存溢出

原文地址: https://outofmemory.cn/langs/570454.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-04-09
下一篇 2022-04-09

发表评论

登录后才能评论

评论列表(0条)

保存