diff --git a/chapter06/random_walk.py b/chapter06/random_walk.py index 317d572e..e6422c97 100644 --- a/chapter06/random_walk.py +++ b/chapter06/random_walk.py @@ -59,7 +59,7 @@ def temporal_difference(values, alpha=0.1, batch=False): # @batch: whether to update @values def monte_carlo(values, alpha=0.1, batch=False): state = 3 - trajectory = [3] + trajectory = [state] # if end up with left terminal state, all returns are 0 # if end up with right terminal state, all returns are 1 @@ -89,11 +89,11 @@ def compute_state_value(): plt.figure(1) for i in range(episodes[-1] + 1): if i in episodes: - plt.plot(current_values, label=str(i) + ' episodes') + plt.plot(("A", "B", "C", "D", "E"), current_values[1:6], label=str(i) + ' episodes') temporal_difference(current_values) - plt.plot(TRUE_VALUE, label='true values') - plt.xlabel('state') - plt.ylabel('estimated value') + plt.plot(("A", "B", "C", "D", "E"), TRUE_VALUE[1:6], label='true values') + plt.xlabel('State') + plt.ylabel('Estimated Value') plt.legend() # Example 6.2 right @@ -122,9 +122,9 @@ def rms_error(): monte_carlo(current_values, alpha=alpha) total_errors += np.asarray(errors) total_errors /= runs - plt.plot(total_errors, linestyle=linestyle, label=method + ', alpha = %.02f' % (alpha)) - plt.xlabel('episodes') - plt.ylabel('RMS') + plt.plot(total_errors, linestyle=linestyle, label=method + ', $\\alpha$ = %.02f' % (alpha)) + plt.xlabel('Walks/Episodes') + plt.ylabel('Empirical RMS error, averaged over states') plt.legend() # Figure 6.2 @@ -135,6 +135,7 @@ def batch_updating(method, episodes, alpha=0.001): total_errors = np.zeros(episodes) for r in tqdm(range(0, runs)): current_values = np.copy(VALUES) + current_values[1:6] = -1 errors = [] # track shown trajectories and reward/return sequences trajectories = [] @@ -180,13 +181,16 @@ def example_6_2(): def figure_6_2(): episodes = 100 + 1 - td_erros = batch_updating('TD', episodes) - mc_erros = batch_updating('MC', episodes) - - plt.plot(td_erros, label='TD') - plt.plot(mc_erros, label='MC') - plt.xlabel('episodes') - plt.ylabel('RMS error') + td_errors = batch_updating('TD', episodes) + mc_errors = batch_updating('MC', episodes) + + plt.plot(td_errors, label='TD') + plt.plot(mc_errors, label='MC') + plt.title("Batch Training") + plt.xlabel('Walks/Episodes') + plt.ylabel('RMS error, averaged over states') + plt.xlim(0, 100) + plt.ylim(0, 0.25) plt.legend() plt.savefig('../images/figure_6_2.png') diff --git a/chapter10/mountain_car.py b/chapter10/mountain_car.py index 2fa0bd2a..bf614477 100644 --- a/chapter10/mountain_car.py +++ b/chapter10/mountain_car.py @@ -194,17 +194,17 @@ def semi_gradient_n_step_sarsa(value_function, n=1): if time < T: # take current action and go to the new state - new_postion, new_velocity, reward = step(current_position, current_velocity, current_action) + new_position, new_velocity, reward = step(current_position, current_velocity, current_action) # choose new action - new_action = get_action(new_postion, new_velocity, value_function) + new_action = get_action(new_position, new_velocity, value_function) # track new state and action - positions.append(new_postion) + positions.append(new_position) velocities.append(new_velocity) actions.append(new_action) rewards.append(reward) - if new_postion == POSITION_MAX: + if new_position == POSITION_MAX: T = time # get the time of the state to update @@ -224,7 +224,7 @@ def semi_gradient_n_step_sarsa(value_function, n=1): value_function.learn(positions[update_time], velocities[update_time], actions[update_time], returns) if update_time == T - 1: break - current_position = new_postion + current_position = new_position current_velocity = new_velocity current_action = new_action @@ -366,7 +366,3 @@ def figure_10_4(): figure_10_2() figure_10_3() figure_10_4() - - - - diff --git a/images/example_6_2.png b/images/example_6_2.png index 9eb9e58d..fa972c96 100644 Binary files a/images/example_6_2.png and b/images/example_6_2.png differ diff --git a/images/figure_6_2.png b/images/figure_6_2.png index 4346928c..00bc1755 100644 Binary files a/images/figure_6_2.png and b/images/figure_6_2.png differ