Skip to content

Commit

Permalink
Merge pull request #148 from VEXLife/master
Browse files Browse the repository at this point in the history
Minor changes
  • Loading branch information
ShangtongZhang authored Oct 28, 2021
2 parents ac4fbce + dedf66e commit fbf020d
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 24 deletions.
34 changes: 19 additions & 15 deletions chapter06/random_walk.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def temporal_difference(values, alpha=0.1, batch=False):
# @batch: whether to update @values
def monte_carlo(values, alpha=0.1, batch=False):
state = 3
trajectory = [3]
trajectory = [state]

# if end up with left terminal state, all returns are 0
# if end up with right terminal state, all returns are 1
Expand Down Expand Up @@ -89,11 +89,11 @@ def compute_state_value():
plt.figure(1)
for i in range(episodes[-1] + 1):
if i in episodes:
plt.plot(current_values, label=str(i) + ' episodes')
plt.plot(("A", "B", "C", "D", "E"), current_values[1:6], label=str(i) + ' episodes')
temporal_difference(current_values)
plt.plot(TRUE_VALUE, label='true values')
plt.xlabel('state')
plt.ylabel('estimated value')
plt.plot(("A", "B", "C", "D", "E"), TRUE_VALUE[1:6], label='true values')
plt.xlabel('State')
plt.ylabel('Estimated Value')
plt.legend()

# Example 6.2 right
Expand Down Expand Up @@ -122,9 +122,9 @@ def rms_error():
monte_carlo(current_values, alpha=alpha)
total_errors += np.asarray(errors)
total_errors /= runs
plt.plot(total_errors, linestyle=linestyle, label=method + ', alpha = %.02f' % (alpha))
plt.xlabel('episodes')
plt.ylabel('RMS')
plt.plot(total_errors, linestyle=linestyle, label=method + ', $\\alpha$ = %.02f' % (alpha))
plt.xlabel('Walks/Episodes')
plt.ylabel('Empirical RMS error, averaged over states')
plt.legend()

# Figure 6.2
Expand All @@ -135,6 +135,7 @@ def batch_updating(method, episodes, alpha=0.001):
total_errors = np.zeros(episodes)
for r in tqdm(range(0, runs)):
current_values = np.copy(VALUES)
current_values[1:6] = -1
errors = []
# track shown trajectories and reward/return sequences
trajectories = []
Expand Down Expand Up @@ -180,13 +181,16 @@ def example_6_2():

def figure_6_2():
episodes = 100 + 1
td_erros = batch_updating('TD', episodes)
mc_erros = batch_updating('MC', episodes)

plt.plot(td_erros, label='TD')
plt.plot(mc_erros, label='MC')
plt.xlabel('episodes')
plt.ylabel('RMS error')
td_errors = batch_updating('TD', episodes)
mc_errors = batch_updating('MC', episodes)

plt.plot(td_errors, label='TD')
plt.plot(mc_errors, label='MC')
plt.title("Batch Training")
plt.xlabel('Walks/Episodes')
plt.ylabel('RMS error, averaged over states')
plt.xlim(0, 100)
plt.ylim(0, 0.25)
plt.legend()

plt.savefig('../images/figure_6_2.png')
Expand Down
14 changes: 5 additions & 9 deletions chapter10/mountain_car.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,17 +194,17 @@ def semi_gradient_n_step_sarsa(value_function, n=1):

if time < T:
# take current action and go to the new state
new_postion, new_velocity, reward = step(current_position, current_velocity, current_action)
new_position, new_velocity, reward = step(current_position, current_velocity, current_action)
# choose new action
new_action = get_action(new_postion, new_velocity, value_function)
new_action = get_action(new_position, new_velocity, value_function)

# track new state and action
positions.append(new_postion)
positions.append(new_position)
velocities.append(new_velocity)
actions.append(new_action)
rewards.append(reward)

if new_postion == POSITION_MAX:
if new_position == POSITION_MAX:
T = time

# get the time of the state to update
Expand All @@ -224,7 +224,7 @@ def semi_gradient_n_step_sarsa(value_function, n=1):
value_function.learn(positions[update_time], velocities[update_time], actions[update_time], returns)
if update_time == T - 1:
break
current_position = new_postion
current_position = new_position
current_velocity = new_velocity
current_action = new_action

Expand Down Expand Up @@ -366,7 +366,3 @@ def figure_10_4():
figure_10_2()
figure_10_3()
figure_10_4()




Binary file modified images/example_6_2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified images/figure_6_2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit fbf020d

Please sign in to comment.