pytorch实现 wgan

pytorch实现 wgan,第1张

在网上找了一个wgan的实现代码,在本地跑了以下,效果还可以,我把它封装成一个函数了,感兴趣的朋友可以用一下

不过这个gan生成的是一维数据,对于图片数据可能需要对代码进行一些改变

import numpy as np
import pandas as pd
import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.optim as optim
torch.manual_seed(1)
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score
import warnings
warnings.filterwarnings("ignore")

def train_model_save_gen(data, ITERS = 600, iter_ctrl=200, use_cuda = False, name_save='', file_name='./save_gen/'):

    if not os.path.exists(file_name):
        os.mkdir(file_name)
    FIXED_GENERATOR = False
    LAMBDA = .1
    CRITIC_ITERS = 5
    CRITIC_ITERG = 1
    BATCH_SIZE = len(data)

    class Generator(nn.Module):

        def __init__(self, shape1):
            super(Generator, self).__init__()

            main = nn.Sequential(
                nn.Linear(shape1, 1024),
                nn.ReLU(True),
                nn.Linear(1024, 512),
                nn.ReLU(True),
                nn.Linear(512, 256),
                nn.ReLU(True),
                nn.Linear(256, 512),
                nn.ReLU(True),
                nn.Linear(512, 1024),
                nn.Tanh(),
                nn.Linear(1024, shape1),
            )
            self.main = main

        def forward(self, noise, real_data):
            if FIXED_GENERATOR:
                return noise + real_data
            else:
                output = self.main(noise)
                return output

    class Discriminator(nn.Module):

        def __init__(self, shape1):
            super(Discriminator, self).__init__()

            self.fc1 = nn.Linear(shape1, 512)
            self.relu1 = nn.LeakyReLU(0.2)
            self.fc2 = nn.Linear(512, 256)
            self.relu2 = nn.LeakyReLU(0.2)
            self.fc3 = nn.Linear(256, 256)
            self.relu3 = nn.LeakyReLU(0.2)
            self.fc4 = nn.Linear(256, 128)
            self.relu4 = nn.LeakyReLU(0.2)
            self.fc5 = nn.Linear(128, 1)

        def forward(self, inputs):
            out = self.fc1(inputs)
            out = self.relu1(out)
            out = self.fc2(out)
            out = self.relu2(out)
            out = self.fc3(out)
            out = self.relu3(out)
            out = self.fc4(out)
            out = self.relu4(out)
            out = self.fc5(out)
            return out.view(-1)

    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Linear') != -1:
            m.weight.data.normal_(0.0, 0.02)
            m.bias.data.fill_(0)
        elif classname.find('BatchNorm') != -1:
            m.weight.data.normal_(1.0, 0.02)
            m.bias.data.fill_(0)

    def calc_gradient_penalty(netD, real_data, fake_data):
        alpha = torch.rand(BATCH_SIZE, 1)
        alpha = alpha.expand(real_data.size())
        alpha = alpha.cuda() if use_cuda else alpha
        interpolates = alpha * real_data + ((1 - alpha) * fake_data)
        if use_cuda:
            interpolates = interpolates.cuda()
        interpolates = autograd.Variable(interpolates, requires_grad=True)
        disc_interpolates = netD(interpolates)
        gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
                                  grad_outputs=torch.ones(disc_interpolates.size()).cuda() if use_cuda else torch.ones(
                                      disc_interpolates.size()), create_graph=True, retain_graph=True,
                                  only_inputs=True)[0]
        gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * LAMBDA
        return gradient_penalty

    netG = Generator(data.shape[1])
    netD = Discriminator(data.shape[1])
    netD.apply(weights_init)
    netG.apply(weights_init)
    if use_cuda:
        netD = netD.cuda()
        netG = netG.cuda()

    optimizerD = optim.Adam(netD.parameters(), lr=1e-4, betas=(0.5, 0.9))
    optimizerG = optim.Adam(netG.parameters(), lr=1e-4, betas=(0.5, 0.9))

    one = torch.tensor(1, dtype=torch.float)  ###torch.FloatTensor([1])
    mone = one * -1
    if use_cuda:
        one = one.cuda()
        mone = mone.cuda()

    ### ### ###
    one_list = np.ones((data.shape[0]))
    zero_list = np.zeros((data.shape[0]))
    opt_diff_accuracy_05 = 0.5
    best_item = 0
    opt_accuracy = 0
    all_result = []

    loss_list = {'D_loss':[], 'G_loss':[]}
    for iteration in range(ITERS):
        sys.stdout.write(f'\r进行:{iteration}/{ITERS}')  # \r 默认表示将输出的内容返回到第一个指针,这样的话,后面的内容会覆盖前面的内容。


sys.stdout.flush() for p in netD.parameters(): p.requires_grad = True # data = inf_train_gen('data_GAN') real_data = torch.FloatTensor(data) if use_cuda: real_data = real_data.cuda() false_data = false_data.cuda() real_data_v = autograd.Variable(real_data) false_data_v = autograd.Variable(false_data) noise = torch.randn(BATCH_SIZE, data.shape[1]) if use_cuda: noise = noise.cuda() noisev = autograd.Variable(noise, volatile=True) fake = autograd.Variable(netG(noisev, real_data_v).data) fake_output = fake.data.cpu().numpy() for iter_d in range(CRITIC_ITERS): netD.zero_grad() D_real = netD(real_data_v) D_real = D_real.mean() D_real.backward(mone) ############## noise = torch.randn(BATCH_SIZE, data.shape[1]) if use_cuda: noise = noise.cuda() noisev = autograd.Variable(noise, volatile=True) # volatile=True相当于 requires_grad=False fake = autograd.Variable(netG(noisev, real_data_v).data) inputv = fake D_fake = netD(inputv) D_fake = D_fake.mean() D_fake.backward(one) ################ gradient_penalty = calc_gradient_penalty(netD, real_data_v.data, fake.data) gradient_penalty.backward() ############ D_cost = D_fake - D_real + gradient_penalty Wasserstein_D = D_real - D_fake loss_list['D_loss'].append(D_cost.item()) optimizerD.step() if not FIXED_GENERATOR: for p in netD.parameters(): p.requires_grad = False for iter_g in range(CRITIC_ITERG): netG.zero_grad() real_data = torch.Tensor(data) if use_cuda: real_data = real_data.cuda() real_data_v = autograd.Variable(real_data) noise = torch.randn(BATCH_SIZE, data.shape[1]) if use_cuda: noise = noise.cuda() noisev = autograd.Variable(noise) fake = netG(noisev, real_data_v) G = netD(fake) G = G.mean() G.backward(mone) G_cost = -G loss_list['G_loss'].append(G_cost.item()) optimizerG.step() ###save generated sample features every 200 iteration if iteration % iter_ctrl == 0: if iteration % 10000 == 0: data = shuffle(data) # save_temp = pd.DataFrame(fake_output) # # fake_writer = open(file_name + "/Iteration_" + str(iteration) + ".txt", "w") # save_temp.to_csv(file_name + "/Iteration_" + str(iteration) + ".csv", index=None) print() print(f'循环{iteration}次..') x = np.concatenate((data, fake_output), axis=0) y = np.concatenate((one_list, zero_list), axis=0) kfold = StratifiedKFold(n_splits=5, shuffle=True, random_state=42) real_label = np.zeros((x.shape[0])) pred_label = np.zeros((x.shape[0])) for train_index, test_index in kfold.split(x, y): x_train, x_test = x[train_index], x[test_index] y_train, y_test = y[train_index], y[test_index] knn = KNeighborsClassifier(n_neighbors=1).fit(x_train, y_train) predicted_y = knn.predict(x_test) pred_label[test_index] = predicted_y real_label[test_index] = y_test accuracy = accuracy_score(real_label, pred_label) all_result.append(str(iteration) + "," + str(accuracy)) print(f'计算{iteration}的acc={accuracy}') diff_accuracy_05 = abs(accuracy - 0.5) if diff_accuracy_05 < opt_diff_accuracy_05: opt_diff_accuracy_05 = diff_accuracy_05 best_item = iteration opt_accuracy = accuracy save_temp = pd.DataFrame(fake_output) # fake_writer = open(file_name + "/Iteration_" + str(iteration) + ".txt", "w") save_temp.to_csv(file_name + "/Iteration3_"+ str(name_save) +'_'+ str(iteration) + ".csv",index=None) torch.save(netG.state_dict(), './model_file/netG'+str(iteration)+'.dict') torch.save(netD.state_dict(), './model_file/netD'+str(iteration)+'.dict') save_loss = pd.DataFrame(loss_list['G_loss']) save_loss.to_csv(file_name + "/Gloss_" + str(iteration) + '.csv', index=None) save_loss = pd.DataFrame(loss_list['D_loss']) save_loss.to_csv(file_name + "/Dloss_" + str(iteration) + '.csv', index=None) return best_item,opt_diff_accuracy_05

调用上述函数即可

欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/langs/570251.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-04-09
下一篇 2022-04-09

发表评论

登录后才能评论

评论列表(0条)

保存