使用DQN解决cartpole问题(深度强化学习入门)

使用DQN解决cartpole问题(深度强化学习入门),第1张

使用DQN解决cartpole问题(深度强化学习入门) 使用DQN解决cartpole问题(深度强化学习入门)
# -*- coding: utf-8 -*-
"""
Created on Mon Nov 22 11:16:50 2021

@author: wss
"""
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F # 调用relu啥的
import collections
import random
import torch.optim as optim

#放一些参数
Lr = 0.1   #学习率
Buffer_size = 10000 #经验回放的buffer的大小
Eps = 0.1   # eps 贪心算法的随机选择比列
GAMMA = 0.99  # reward的衰减



#用队列存 transition     并定义了采样函数
Transition = collections.namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward'))
# 用一个类来实现经验回放 去除state的相关性和利用经验
class ReplayMemory(object):

    def __init__(self, capacity):
        self.memory = collections.deque([],maxlen=capacity)

    def push(self, *args):
        """Save a transition"""
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)




#定义DQN 的神经网络部分

class Net(nn.Module):
    def __init__(self,n_in,n_hidden,n_out):
        super(Net,self).__init__()
        self.fc1 = nn.Linear(n_in, n_hidden)
        self.fc2 = nn.Linear(n_hidden, n_hidden)
        self.fc3 = nn.Linear(n_hidden, n_out)

    def forward(self,x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        out = self.fc3(x)
        return out

    
# net属性是神经网络的对象
class DQN(object):
    def __init__(self,n_in,n_hidden,n_out):
#        super(DQN,self).__init__()
        self.net = Net(n_in,n_hidden,n_out)
        self.target_net = Net(n_in,n_hidden,n_out)
        self.optimer = optim.Adam(self.net.parameters(),lr = Lr)
        self.loss_func = nn.MSELoss()
        self.target_net.load_state_dict(self.net.state_dict())
#        self.target_net.eval()     # 解决高估问题  不用训练直接加载policy_net的参数
        self.buffer = ReplayMemory(Buffer_size)
        
        
    #根据state选择 action 
    def select_action(self,state): #返回的action是个数字(不是张量)
        threshold = random.random() 
        Q_actions = self.net(torch.Tensor(state)) #返回不同action对应的Q值
        if  threshold 
刚刚接触深度学习以及强化学习,不知道为什么这个DQN并没有随着训练越来越来越好?

欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/zaji/5595130.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-12-15
下一篇 2022-12-15

发表评论

登录后才能评论

评论列表(0条)

保存