1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
| import matplotlib.pyplot as plt import numpy as np from tqdm import tqdm class CliffWalkingEnv: def __init__(self, ncol, nrow): self.nrow = nrow self.ncol = ncol self.x = 0 self.y = self.nrow - 1
def step(self, action): change = [[0, -1], [0, 1], [-1, 0], [1, 0]] self.x = min(self.ncol - 1, max(0, self.x + change[action][0])) self.y = min(self.nrow - 1, max(0, self.y + change[action][1])) next_state = self.y * self.ncol + self.x reward = -1 done = False if self.y == self.nrow - 1 and self.x > 0: done = True if self.x != self.ncol - 1: reward = -100 return reward, next_state, done
def reset(self): self.x = 0 self.y = self.nrow - 1 return self.y * self.ncol + self.x class Sarsa: """ Sarsa算法 """ def __init__(self, env, epsilon, alpha, gamma, n_action=4): self.env = env self.Q_table = np.zeros([self.env.nrow * self.env.ncol, n_action]) self.n_action = n_action self.alpha = alpha self.gamma = gamma self.epsilon = epsilon
def take_action(self, state): if np.random.random() < self.epsilon: action = np.random.randint(self.n_action) else: action = np.argmax(self.Q_table[state]) return action def show_result(self): """ 打印结果 """ print("策略:") actions = ['^', 'v', '<', '>'] for i in range(self.env.nrow): for j in range(self.env.ncol): if i == self.env.nrow - 1 and j == self.env.ncol - 1: print("goal".center(5), end="") elif i == self.env.nrow - 1 and self.env.ncol-1>j > 0: print("x".center(5), end="") else: qsa=self.Q_table[i*self.env.ncol+j] max_qsa=np.max(qsa) a_str=''.join( 'o' if qsa[i]<max_qsa else actions[i] for i in range(len(qsa)) ) print(a_str.center(5), end="") print('\n') """单步时序差分更新""" def update_Q(self, s0, a0, r, s1, a1): """计算时序差分误差td_error,更新动作价值函数Q(s,a)""" td_error = r + self.gamma * self.Q_table[s1, a1] - self.Q_table[s0, a0] self.Q_table[s0, a0] += self.alpha * td_error def run(self, episodes_num=1000): for episode in tqdm(range(episodes_num)): s0 = self.env.reset() a0 = self.take_action(s0) while True: cnt +=1 r,s1,done = self.env.step(a0) a1 = self.take_action(s1) self.update_Q(s0, a0, r, s1, a1) if done: break s0 = s1 a0 = a1 if __name__ == '__main__': np.random.seed(0) env = CliffWalkingEnv(12, 4) agent = Sarsa(env, epsilon=0.1, alpha=0.1, gamma=0.9) agent.run(500) agent.show_result()
|