-
Notifications
You must be signed in to change notification settings - Fork 0
/
boyans_chain.py
108 lines (91 loc) · 4.44 KB
/
boyans_chain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
#!python
# -*- coding: utf-8 -*-
#
# Copyright 2022 Midden Vexu
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# @Author:Midden Vexu
# The Boyan's Chain Environment to test the algorithm
# Reference: https://www.researchgate.net/publication/2621189_Least-Squares_Temporal_Difference_Learning
import sys
sys.path.append(".")
from atd import TDAgent, SVDATDAgent, DiagonalizedSVDATDAgent, PlainATDAgent, Backend
import numpy as np
from tqdm import trange
import matplotlib.pyplot as plt
observations = [Backend.create_matrix_func(np.max(
np.vstack(((1 - np.abs(12 - 4 * np.arange(4) - N) / 4), np.zeros(4))),
axis=0), dtype=Backend.float32) for N in range(13)]
rng = np.random.default_rng()
w_optimal = Backend.arange(-24, 8, 8, dtype=Backend.float32)
def evaluate(w, w_pi):
observation_count = 0
absolute_error = 0
for observation in observations[1:]: # Skip the first terminal state.
absolute_error += abs((w @ observation - w_pi @ observation) / (w_pi @ observation))
observation_count += 1
return absolute_error / observation_count
def play_game(agent, total_timesteps=1000, iterations=100):
records = []
for _ in trange(iterations):
timestep = 0
episode = 0
agent.reinit()
agent.w *= 0
record = []
while timestep <= total_timesteps:
pos = 12
observation = observations[pos]
while timestep <= total_timesteps:
record.append(evaluate(agent.w, w_optimal))
pos -= rng.choice([1, 2]) if pos > 1 else 1
next_observation = observations[pos]
timestep += 1
if pos == 0:
agent.learn(observation, next_observation, -2, 0, timestep)
episode += 1
break
agent.learn(observation, next_observation, -3, 1, timestep)
observation = next_observation
records.append(record)
return Backend.mean(Backend.create_matrix_func(records), 0)
plt.figure(dpi=120, figsize=(8, 6))
plt.plot(play_game(agent=TDAgent(lr=0.1, lambd=0, observation_space_n=4, action_space_n=2),
iterations=10), label="TD(0), $\\alpha=0.1$")
plt.plot(play_game(agent=DiagonalizedSVDATDAgent(k=30, eta=1e-4, lambd=0, observation_space_n=4,
action_space_n=2),
iterations=10),
label="DiagonalizedSVDATD(0), $\\alpha=\\frac{1}{1+t}$, \n$\\eta=1\\times10^{-4}$, $r=30$, Accuracy First")
plt.plot(play_game(agent=DiagonalizedSVDATDAgent(k=30, eta=1e-4, lambd=0, observation_space_n=4,
action_space_n=2, svd_diagonalizing=False,
w_update_emphasizes="complexity"),
iterations=10),
label="DiagonalizedSVDATD(0), $\\alpha=\\frac{1}{1+t}$, \n$\\eta=1\\times10^{-4}$, $r=30$, Complexity First")
plt.plot(play_game(agent=DiagonalizedSVDATDAgent(k=30, eta=1e-4, lambd=0, observation_space_n=4,
action_space_n=2, svd_diagonalizing=True,
w_update_emphasizes="complexity"),
iterations=10),
label="DiagonalizedSVDATD(0), $\\alpha=\\frac{1}{1+t}$, \n$\\eta=1\\times10^{-4}$, $r=30$, Complexity First\
\nUsing SVD to diagonalize")
plt.plot(play_game(agent=SVDATDAgent(eta=1e-4, lambd=0, observation_space_n=4, action_space_n=2),
iterations=10), label="SVDATD(0), $\\alpha=\\frac{1}{1+t}$, $\\eta=1\\times10^{-4}$")
plt.plot(play_game(agent=PlainATDAgent(eta=1e-4, lambd=0, observation_space_n=4, action_space_n=2),
iterations=10), label="PlainATD(0), $\\alpha=\\frac{1}{1+t}$, $\\eta=1\\times10^{-4}$")
plt.legend()
plt.title("Boyan's Chain")
plt.xlabel("Timesteps")
plt.ylabel("Percentage Error")
plt.ylim(0, 1)
plt.xlim(0, 1000)
plt.savefig("./figures/boyans_chain.png", format="png")