python——实现简单的强化学习

python——实现简单的强化学习,第1张

python——实现简单的强化学习

文章目录
      • 强化学习
      • Q-Learning算法
      • 代码实现
          • 算法参数
          • 状态集
          • 动作集
          • 奖励集
          • Q table
          • Q-learning算法实现
          • 更新状态
      • 完整代码

强化学习

强化学习(Reinforcement Learning, RL),又称再励学习、评价学习或增强学习,是机器学习的范式和方法论之一,用于描述和解决智能体(agent)在与环境的交互过程中通过学习策略以达成回报最大化或实现特定目标的问题。

强化学习是智能体(Agent)以“试错”的方式进行学习,通过与环境进行交互获得的奖赏指导行为,目标是使智能体获得最大的奖赏,强化学习不同于连接主义学习中的监督学习,主要表现在强化信号上,强化学习中由环境提供的强化信号是对产生动作的好坏作一种评价(通常为标量信号),而不是告诉强化学习系统RLS(reinforcement learning system)如何去产生正确的动作。由于外部环境提供的信息很少,RLS必须靠自身的经历进行学习。通过这种方式,RLS在行动-评价的环境中获得知识,改进行动方案以适应环境。

Q-Learning算法

伪代码

Q value的更新是根据贝尔曼方程

代码实现 算法参数
epsilon = 0.9  # 贪婪度
alpha = 0.1  # 学习率
gamma = 0.8  # 奖励递减值
状态
states = range(6)  # 状态集
def get_next_state(state, action):
    '''对状态执行动作后,得到下一状态'''
    global states

    # l,r,n = -1,+1,0
    if action == 'right' and state != states[-1]:  # 除非最后一个状态(位置),向右就+1
        next_state = state + 1
    elif action == 'left' and state != states[0]:  # 除非最前一个状态(位置),向左就-1
        next_state = state - 1
    else:
        next_state = state
    return next_state
动作集
actions = ['left', 'right']  # 动作集
def get_valid_actions(state):
    '''取当前状态下的合法动作集合,与reward无关!'''
    global actions  # ['left', 'right']

    valid_actions = set(actions)
    if state == states[-1]:  # 最后一个状态(位置),则
        valid_actions -= set(['right'])  # 不能向右
    if state == states[0]:  # 最前一个状态(位置),则
        valid_actions -= set(['left'])  # 不能向左
    return list(valid_actions)
奖励集
rewards = [0, 0, 0, 0, 0, 1]  # 奖励集
Q table

Q table是一种记录状态-行为值 (Q value) 的表。常见的q-table都是二维的,但是也有3维的Q table。

q_table = pd.Dataframe(data=[[0 for _ in actions] for _ in states], index=states, columns=actions)
Q-learning算法实现
for i in range(13):
    # current_state = random.choice(states)
    current_state = 0

    update_env(current_state)  # 环境相关
    total_steps = 0  # 环境相关

    while current_state != states[-1]:
        if (random.uniform(0, 1) > epsilon) or ((q_table.loc[current_state] == 0).all()):  # 探索
            current_action = random.choice(get_valid_actions(current_state))
        else:
            current_action = q_table.loc[current_state].idxmax()  # 利用(贪婪)

        next_state = get_next_state(current_state, current_action)
        next_state_q_values = q_table.loc[next_state, get_valid_actions(next_state)]
        q_table.loc[current_state, current_action] += alpha * (
                    rewards[next_state] + gamma * next_state_q_values.max() - q_table.loc[current_state, current_action])
        current_state = next_state

        update_env(current_state)  # 环境相关
        total_steps += 1  # 环境相关

    print('rEpisode {}: total_steps = {}'.format(i, total_steps), end='')  # 环境相关
    time.sleep(2)  # 环境相关
    print('r                                ', end='')  # 环境相关

print('nq_table:')
print(q_table)
更新状态
def update_env(state):
    global states

    env = list('-----T')
    if state != states[-1]:
        env[state] = '0'
    print('r{}'.format(''.join(env)), end='')
    time.sleep(0.1)
完整代码
import pandas as pd
import random
import time

#########参数
epsilon = 0.9  # 贪婪度
alpha = 0.1  # 学习率
gamma = 0.8  # 奖励递减值

#####探索者的状态,即可到达的位置
states = range(6)  # 状态集
actions = ['left', 'right']  # 动作集
rewards = [0, 0, 0, 0, 0, 1]  # 奖励集

q_table = pd.Dataframe(data=[[0 for _ in actions] for _ in states], index=states, columns=actions)


def update_env(state):
    global states

    env = list('-----T')
    if state != states[-1]:
        env[state] = '0'
    print('r{}'.format(''.join(env)), end='')
    time.sleep(0.1)


def get_next_state(state, action):
    '''对状态执行动作后,得到下一状态'''
    global states

    # l,r,n = -1,+1,0
    if action == 'right' and state != states[-1]:  # 除非最后一个状态(位置),向右就+1
        next_state = state + 1
    elif action == 'left' and state != states[0]:  # 除非最前一个状态(位置),向左就-1
        next_state = state - 1
    else:
        next_state = state
    return next_state


def get_valid_actions(state):
    '''取当前状态下的合法动作集合,与reward无关!'''
    global actions  # ['left', 'right']

    valid_actions = set(actions)
    if state == states[-1]:  # 最后一个状态(位置),则
        valid_actions -= set(['right'])  # 不能向右
    if state == states[0]:  # 最前一个状态(位置),则
        valid_actions -= set(['left'])  # 不能向左
    return list(valid_actions)


for i in range(13):
    # current_state = random.choice(states)
    current_state = 0

    update_env(current_state)  # 环境相关
    total_steps = 0  # 环境相关

    while current_state != states[-1]:
        if (random.uniform(0, 1) > epsilon) or ((q_table.loc[current_state] == 0).all()):  # 探索
            current_action = random.choice(get_valid_actions(current_state))
        else:
            current_action = q_table.loc[current_state].idxmax()  # 利用(贪婪)

        next_state = get_next_state(current_state, current_action)
        next_state_q_values = q_table.loc[next_state, get_valid_actions(next_state)]
        q_table.loc[current_state, current_action] += alpha * (
                    rewards[next_state] + gamma * next_state_q_values.max() - q_table.loc[current_state, current_action])
        current_state = next_state

        update_env(current_state)  # 环境相关
        total_steps += 1  # 环境相关

    print('rEpisode {}: total_steps = {}'.format(i, total_steps), end='')  # 环境相关
    time.sleep(2)  # 环境相关
    print('r                                ', end='')  # 环境相关

print('nq_table:')
print(q_table)

欢迎留言指出错误

欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/zaji/5659822.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-12-16
下一篇 2022-12-16

发表评论

登录后才能评论

评论列表(0条)

保存