深度学习环境配置结束后第一个程序运行问题记录

深度学习环境配置结束后第一个程序运行问题记录,第1张

深度学习环境配置结束后第一个程序运行问题记录

深度学习环境:

CUDA11.2

CUDnn8.1

Anaconda4.10

python3.8

tensorflow2.6

keras2.6

一、GAN网络的运行
from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import adam_v2

import matplotlib.pyplot as plt

import sys
import os
import numpy as np

class GAN():
    def __init__(self):
        # --------------------------------- #
        #   行28,列28,也就是mnist的shape
        # --------------------------------- #
        self.img_rows = 28
        self.img_cols = 28
        self.channels = 1
        # 28,28,1
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        self.latent_dim = 100
        # adam优化器
        optimizer = adam_v2.Adam(0.0002, 0.5)

        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='binary_crossentropy',
            optimizer=optimizer,
            metrics=['accuracy'])

        self.generator = self.build_generator()
        
        gan_input = Input(shape=(self.latent_dim,))
        img = self.generator(gan_input)
        # 在训练generate的时候不训练discriminator
        self.discriminator.trainable = False
        # 对生成的假图片进行预测
        validity = self.discriminator(img)
        self.combined = Model(gan_input, validity)
        self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)


    def build_generator(self):
        # --------------------------------- #
        #   生成器,输入一串随机数字
        # --------------------------------- #
        model = Sequential()

        model.add(Dense(256, input_dim=self.latent_dim))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))

        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))

        model.add(Dense(1024))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))

        model.add(Dense(np.prod(self.img_shape), activation='tanh'))
        model.add(Reshape(self.img_shape))

        noise = Input(shape=(self.latent_dim,))
        img = model(noise)

        return Model(noise, img)

    def build_discriminator(self):
        # ----------------------------------- #
        #   评价器,对输入进来的图片进行评价
        # ----------------------------------- #
        model = Sequential()
        # 输入一张图片
        model.add(Flatten(input_shape=self.img_shape))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(256))
        model.add(LeakyReLU(alpha=0.2))
        # 判断真伪
        model.add(Dense(1, activation='sigmoid'))

        img = Input(shape=self.img_shape)
        validity = model(img)

        return Model(img, validity)

    def train(self, epochs, batch_size=128, sample_interval=50):
        # 获得数据
        (X_train, _), (_, _) = mnist.load_data()

        # 进行标准化
        X_train = X_train / 127.5 - 1.
        X_train = np.expand_dims(X_train, axis=3)

        # 创建标签
        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))

        for epoch in range(epochs):

            # --------------------------- #
            #   随机选取batch_size个图片
            #   对discriminator进行训练
            # --------------------------- #
            idx = np.random.randint(0, X_train.shape[0], batch_size)
            imgs = X_train[idx]

            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))

            gen_imgs = self.generator.predict(noise)

            d_loss_real = self.discriminator.train_on_batch(imgs, valid)
            d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # --------------------------- #
            #  训练generator
            # --------------------------- #
            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
            g_loss = self.combined.train_on_batch(noise, valid)
            print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))

            if epoch % sample_interval == 0:
                self.sample_images(epoch)

    def sample_images(self, epoch):

        r, c = 5, 5
        noise = np.random.normal(0, 1, (r * c, self.latent_dim))
        gen_imgs = self.generator.predict(noise)

        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("images/%d.png" % epoch)
        plt.close()


if __name__ == '__main__':
    if not os.path.exists("./images"):
        os.makedirs("./images")
    gan = GAN()
    gan.train(epochs=30000, batch_size=256, sample_interval=200)

代码来自Bilibili up主Bubbliiiing提供的github仓库下载的。送你过去:Keras 搭建自己的GAN生成对抗网络平台(Bubbliiiing 深度学习 教程)_哔哩哔哩_bilibili

在problems中会有一些问题提示:

①库没有安装全,添加库就可以解决;

②导入提示错误,可能在不同的Tensorflow和Keras版本之间,有一些差别,比如:

from keras.optimizers import adam_v2

optimizer = adam_v2.Adam(0.0002, 0.5)

from keras.optimizers import Adam

 类似这种问题在网上能够好找到解决方法。

二、运行CycleGAN

       CycleGAN的代码也是在  Bubbliiiing  博主的github里面找到的,博主的B站,CSDN,github是一个名字,很好找的。

1、tensorflow显存不足

        因为CycleGAN网络中包含两个生成器,两个鉴别器,所以网络参数比较多。

 解决方法请查看:tensorflow出现显存不足的提示_楠仔码头的博客-CSDN博客

该博客中提到的方法:①减少batchsize,即减少了GPU内存分配需求

                                    ②如果内存比较大的话,可以切换CPU

                                    ③重置输入图片尺寸,即通过减小图片的大小来减少对显存的消耗

我尝试了第三种,把输入图片尺寸改成了64*64*3,问题能够解决,而且也因为代码中的batchsize已经是1了,没办法再小了。

2、Anaconda中Qt5的路径问题

 解决方法请查看:python出现This application failed to stat could not find or load the Qt platform plugin "windows"_昆兰.沃斯 的博客-CSDN博客

        博客中指出需要添加环境变量,在用户变量中新建:

 ​​​​

 输入变量名QT_QPA_PLATFORM_PLUGIN_PATH

 变量值输入Anaconda安装目录下的.......plugins的路径,按照上面的路径应该能找到自己电脑上的plugins。我的就是按照上面博客中的路径找到的。

欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/zaji/3971896.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-10-21
下一篇 2022-10-21

发表评论

登录后才能评论

评论列表(0条)

保存