Sarsa更新方式
Sarsa 的决策部分和 Q learning 一样, 使用的是 Q 表的形式决策, 在 Q 表中挑选值较大的动作值施加在环境中来换取奖惩. 但是不同的地方在于 Sarsa 的更新方式是不一样的.
-
Q learning, 在 s2 上选取哪一个动作会带来最大的奖励, 但是在真正要做决定时, 却不一定会选取到那个带来最大奖励的动作,在这一步只是估计了一下接下来的动作值.
-
Sarsa 说到做到, 在 s2 这一步估算的动作也是接下来要做的动作. 所以 Q(s1, a2) 现实的计算值, 去掉maxQ,取而代之的是在 s2 上我们实实在在选取的 a2 的 Q 值. 最后像 Q learning 一样, 求出现实和估计的差距 并更新 Q表里的 Q(s1, a2).
-
QL先用新的状态选择的最大Q值更新q表之后选择动作,sarsa在更新Q表之前就选择好了动作
两者对比
Sarsa 是on-policy, 在线学习, 学着自己在做的事情. 而 Q learning 是说到但并不一定做到, 所以它也叫作 Off-policy, 离线学习. 而因为有了 maxQ, Q-learning 也是一个特别勇敢的算法.
整个算法还是一直不断更新 Q table 里的值, 然后再根据新的值来判断要在某个 state 采取怎样的 action.
不过于 Qlearning 不同之处:
- 两者选择action的方式是一样的,更新Q值的方式不一样,去learning用max(s_)更新,sarsa用下一个动作的Q值更新
- 他在当前 state 已经想好了 state 对应的 action, 而且想好了 下一个 state_ 和下一个 action_(Qlearning 还没有想好下一个 action_)
- 更新 Q(s,a) 的时候基于的是下一个 Q(s_, a_) (Qlearning 是基于 maxQ(s_))
sarsa实例
- RL_brain中不同
class SarsaTable:
# 初始化 (与之前一样)
def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
# 选行为 (与之前一样)
def choose_action(self, observation):
# 学习更新参数 (有改变)
def learn(self, s, a, r, s_):
# 检测 state 是否存在 (与之前一样)
def check_state_exist(self, state):
-
run_this中update方法略不同
RL_brain.py
## SarsaL如何决策
import numpy as np
import pandas as pd
#class RL父 class QLearningTable(RL),SarsaTable(RL)子
class RL(object):
#learn两者是不同的
def __init__(self, action_space, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
self.actions = action_space # a list
self.lr = learning_rate
self.gamma = reward_decay
self.epsilon = e_greedy
self.q_table = pd.DataFrame(columns=self.actions, dtype=np.float64)
def check_state_exist(self, state):#检查Q表中是否有此state,没有加上
if state not in self.q_table.index:
# append new state to q table
self.q_table = self.q_table.append(
pd.Series(
[0] * len(self.actions),
index=self.q_table.columns,
name=state,
)
)
def choose_action(self, observation):
self.check_state_exist(observation)
# action selection
if np.random.rand() < self.epsilon:
# choose best action
state_action = self.q_table.loc[observation, :]
# some actions may have the same value, randomly choose on in these actions
action = np.random.choice(state_action[state_action == np.max(state_action)].index)
else:
# choose random action
action = np.random.choice(self.actions)
return action
def learn(self, *args):
pass
# off-policy
class QLearningTable(RL):
def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
super(QLearningTable, self).__init__(actions, learning_rate, reward_decay, e_greedy)
def learn(self, s, a, r, s_):
self.check_state_exist(s_)
q_predict = self.q_table.loc[s, a]
if s_ != 'terminal':
q_target = r + self.gamma * self.q_table.loc[s_, :].max() # next state is not terminal
else:
q_target = r # next state is terminal
self.q_table.loc[s, a] += self.lr * (q_target - q_predict) # update
# on-policy
class SarsaTable(RL):
def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
super(SarsaTable, self).__init__(actions, learning_rate, reward_decay, e_greedy)
def learn(self, s, a, r, s_, a_):
self.check_state_exist(s_)
q_predict = self.q_table.loc[s, a]
if s_ != 'terminal':
q_target = r + self.gamma * self.q_table.loc[s_, a_] # 与QLearning不同之处直接选下一个状态动作
else:
q_target = r # next state is terminal
self.q_table.loc[s, a] += self.lr * (q_target - q_predict) # update
run_this.py
from maze_env import Maze
from RL_brain import SarsaTable
def update():
for episode in range(100):
# initial observation
observation = env.reset()
# RL choose action based on observation
action = RL.choose_action(str(observation))#QL是在while中选action
while True:
# fresh env
env.render()
# RL采取行动得到的反馈
observation_, reward, done = env.step(action)
# RL choose action based on next observation
action_ = RL.choose_action(str(observation_))
# RL learn from this transition (s, a, r, s, a) ==> Sarsa
RL.learn(str(observation), action, reward, str(observation_), action_)
# sarsal和ql不同处 直接采取行动
observation = observation_
action = action_
# break while loop when end of this episode
if done:
break
# end of game
print('game over')
env.destroy()
if __name__ == "__main__":
env = Maze()
RL = SarsaTable(actions=list(range(env.n_actions)))
env.after(100, update)
env.mainloop()
版权声明:本文为komorebi6原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。