基于深度强化学习的绘画智能体 代码分析(六)

基于深度强化学习的绘画智能体 代码分析(六),第1张

基于深度强化学习的绘画智能体 代码分析(六)
  1. env.py
import sys
import json
import torch
import numpy as np
import argparse
import torchvision.transforms as transforms
import cv2
from DRL.ddpg import decode
from utils.util import *
from PIL import Image
from torchvision import transforms, utils
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

aug = transforms.Compose( #一般用Compose把多个步骤整合到一起
            [transforms.ToPILImage(), #a tensor 转换为PIL image
             transforms.RandomHorizontalFlip(), #以0.5的概率水平翻转给定的PIL图像
             ])

width = 128
convas_area = width * width

img_train = []
img_test = []
train_num = 0
test_num = 0

class Paint:
    def __init__(self, batch_size, max_step):
        self.batch_size = batch_size
        self.max_step = max_step
        self.action_space = (13)
        self.observation_space = (self.batch_size, width, width, 7)
        self.test = False
        
    def load_data(self):
        # CelebA
        global train_num, test_num
        for i in range(200000):            
            img_id = '%06d' % (i + 1) #6d 是指占6位,因为200000是六位
            try:
                img = cv2.imread('../data/img_align_celeba/' + img_id + '.jpg', cv2.IMREAD_UNCHANGED )
                
                img = cv2.resize(img, (width, width))
                if i > 2000:                
                    train_num += 1
                    img_train.append(img)
                else:
                    test_num += 1
                    img_test.append(img)
            finally:
                if (i + 1) % 10000 == 0:
                    print('loaded {} images'.format(i + 1)) #每读了10000张显示一下读了多少
        print('finish loading data, {} training images, {} testing images'.format(str(train_num), str(test_num)))

cv2.imread(filename, flags)

  • filepath:读入imge的完整路径
  • flags:标志位,{cv2.IMREAD_COLOR,cv2.IMREAD_GRAYSCALE,cv2.IMREAD_UNCHANGED}

cv2.IMREAD_COLOR:默认参数,读入一副彩色图片,忽略alpha通道,可用1作为实参替代
cv2.IMREAD_GRAYSCALE:读入灰度图片,可用0作为实参替代
cv2.IMREAD_UNCHANGED:顾名思义,读入完整图片,包括alpha通道,可用-1作为实参替代

alpha通道,又称A通道,是一个8位的灰度通道,该通道用256级灰度来记录图像中的透明度复信息,定义透明、不透明和半透明区域,其中黑表示全透明,白表示不透明,灰表示半透明

    def pre_data(self, id, test):
        if test:
            img = img_test[id]
        else:
            img = img_train[id]
        if not test:
            img = aug(img)
        img = np.asarray(img) #输入数据,可以转换为数组的任何形式。 这包括列表,元组列表,元组,元组元组,列表元组和ndarray
        return np.transpose(img, (2, 0, 1)) #转置
    
    def reset(self, test=False, begin_num=False):
        self.test = test
        self.imgid = [0] * self.batch_size
        self.gt = torch.zeros([self.batch_size, 3, width, width], dtype=torch.uint8).to(device)
        for i in range(self.batch_size):
            if test:
                id = (i + begin_num)  % test_num
            else:
                id = np.random.randint(train_num) #randint(a, b) 随机生成整数:[a-b]区间的整数(包含两端),0~train_num
            self.imgid[i] = id
            self.gt[i] = torch.tensor(self.pre_data(id, test))
        self.tot_reward = ((self.gt.float() / 255) ** 2).mean(1).mean(1).mean(1)
        self.stepnum = 0
        self.canvas = torch.zeros([self.batch_size, 3, width, width], dtype=torch.uint8).to(device)
        self.lastdis = self.ini_dis = self.cal_dis()
        return self.observation()
    
    def observation(self):
        # canvas B * 3 * width * width
        # gt B * 3 * width * width
        # T B * 1 * width * width
        ob = []
        T = torch.ones([self.batch_size, 1, width, width], dtype=torch.uint8) * self.stepnum
        return torch.cat((self.canvas, self.gt, T.to(device)), 1) # canvas, img, T

    def cal_trans(self, s, t):
        return (s.transpose(0, 3) * t).transpose(0, 3) #转置
    
    def step(self, action):
        self.canvas = (decode(action, self.canvas.float() / 255) * 255).byte()
        self.stepnum += 1
        ob = self.observation()
        done = (self.stepnum == self.max_step)
        reward = self.cal_reward() # np.array([0.] * self.batch_size)
        return ob.detach(), reward, np.array([done] * self.batch_size), None

    def cal_dis(self):
        return (((self.canvas.float() - self.gt.float()) / 255) ** 2).mean(1).mean(1).mean(1)
    
    def cal_reward(self):
        dis = self.cal_dis()
        reward = (self.lastdis - dis) / (self.ini_dis + 1e-8)
        self.lastdis = dis
        return to_numpy(reward)

欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/zaji/4655070.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-11-06
下一篇 2022-11-06

发表评论

登录后才能评论

评论列表(0条)

保存