-
Notifications
You must be signed in to change notification settings - Fork 728
/
environment.py
142 lines (121 loc) · 4.69 KB
/
environment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import time
import numpy as np
import tkinter as tk
from PIL import ImageTk, Image
np.random.seed(1)
PhotoImage = ImageTk.PhotoImage
UNIT = 100 # pixels
HEIGHT = 5 # grid height
WIDTH = 5 # grid width
class Env(tk.Tk):
def __init__(self):
super(Env, self).__init__()
self.action_space = ['u', 'd', 'l', 'r']
self.n_actions = len(self.action_space)
self.title('SARSA')
self.geometry('{0}x{1}'.format(HEIGHT * UNIT, HEIGHT * UNIT))
self.shapes = self.load_images()
self.canvas = self._build_canvas()
self.texts = []
def _build_canvas(self):
canvas = tk.Canvas(self, bg='white',
height=HEIGHT * UNIT,
width=WIDTH * UNIT)
# create grids
for c in range(0, WIDTH * UNIT, UNIT): # 0~400 by 80
x0, y0, x1, y1 = c, 0, c, HEIGHT * UNIT
canvas.create_line(x0, y0, x1, y1)
for r in range(0, HEIGHT * UNIT, UNIT): # 0~400 by 80
x0, y0, x1, y1 = 0, r, HEIGHT * UNIT, r
canvas.create_line(x0, y0, x1, y1)
# add img to canvas
self.rectangle = canvas.create_image(50, 50, image=self.shapes[0])
self.triangle1 = canvas.create_image(250, 150, image=self.shapes[1])
self.triangle2 = canvas.create_image(150, 250, image=self.shapes[1])
self.circle = canvas.create_image(250, 250, image=self.shapes[2])
# pack all
canvas.pack()
return canvas
def load_images(self):
rectangle = PhotoImage(
Image.open("../img/rectangle.png").resize((65, 65)))
triangle = PhotoImage(
Image.open("../img/triangle.png").resize((65, 65)))
circle = PhotoImage(
Image.open("../img/circle.png").resize((65, 65)))
return rectangle, triangle, circle
def text_value(self, row, col, contents, action, font='Helvetica', size=10,
style='normal', anchor="nw"):
if action == 0:
origin_x, origin_y = 7, 42
elif action == 1:
origin_x, origin_y = 85, 42
elif action == 2:
origin_x, origin_y = 42, 5
else:
origin_x, origin_y = 42, 77
x, y = origin_y + (UNIT * col), origin_x + (UNIT * row)
font = (font, str(size), style)
text = self.canvas.create_text(x, y, fill="black", text=contents,
font=font, anchor=anchor)
return self.texts.append(text)
def print_value_all(self, q_table):
for i in self.texts:
self.canvas.delete(i)
self.texts.clear()
for x in range(HEIGHT):
for y in range(WIDTH):
for action in range(0, 4):
state = [x, y]
if str(state) in q_table.keys():
temp = q_table[str(state)][action]
self.text_value(y, x, round(temp, 2), action)
def coords_to_state(self, coords):
x = int((coords[0] - 50) / 100)
y = int((coords[1] - 50) / 100)
return [x, y]
def reset(self):
self.update()
time.sleep(0.5)
x, y = self.canvas.coords(self.rectangle)
self.canvas.move(self.rectangle, UNIT / 2 - x, UNIT / 2 - y)
self.render()
# return observation
return self.coords_to_state(self.canvas.coords(self.rectangle))
def step(self, action):
state = self.canvas.coords(self.rectangle)
base_action = np.array([0, 0])
self.render()
if action == 0: # up
if state[1] > UNIT:
base_action[1] -= UNIT
elif action == 1: # down
if state[1] < (HEIGHT - 1) * UNIT:
base_action[1] += UNIT
elif action == 2: # left
if state[0] > UNIT:
base_action[0] -= UNIT
elif action == 3: # right
if state[0] < (WIDTH - 1) * UNIT:
base_action[0] += UNIT
# move agent
self.canvas.move(self.rectangle, base_action[0], base_action[1])
# move rectangle to top level of canvas
self.canvas.tag_raise(self.rectangle)
next_state = self.canvas.coords(self.rectangle)
# reward function
if next_state == self.canvas.coords(self.circle):
reward = 100
done = True
elif next_state in [self.canvas.coords(self.triangle1),
self.canvas.coords(self.triangle2)]:
reward = -100
done = True
else:
reward = 0
done = False
next_state = self.coords_to_state(next_state)
return next_state, reward, done
def render(self):
time.sleep(0.03)
self.update()