from DeepQNetwork import Agent from GymWrapper import make_env import numpy as np import matplotlib.pyplot as plt if __name__ == '__main__': name = 'PongNoFrameskip-v4' env = make_env(name) n_games = int(500) games_played = 0 load_checkpoint = False if load_checkpoint: agent = Agent(input_dims=env.observation_space.shape, n_actions=env.action_space.n, lr=0.0001, gamma=0.99, eps=0.0, eps_dec=0.0, eps_min=0.0, replace=1000, mem_size=20000, minibatch_size=32, checkpoint_dir='tmp/', algo='DQN', env_name=name) print("agent initialized") agent.load_models() else: agent = Agent(input_dims=(env.observation_space.shape), n_actions=env.action_space.n, lr=0.0001, gamma=0.99, eps=1, eps_dec=1e-5, eps_min=0.1, replace=1000, mem_size=30000, minibatch_size=32, checkpoint_dir='tmp/', algo='DQN', env_name=name) print("agent initialized") scores = [] avg_scores = [] epsilon_history = [] for i in range(n_games): score = 0 state = env.reset() done = False while not done: if load_checkpoint: env.render() action = agent.choose_action(state) _state, reward, done, info = env.step(action) if not load_checkpoint: score += reward agent.store_transition(state, action, reward, _state, int(done)) agent.learn() state = _state games_played += 1 scores.append(score) epsilon_history.append(agent.eps) avg_score = np.mean(scores[-100:]) avg_scores.append(avg_score) print('episode %.2f' %games_played, 'score: %.2f' %score, 'average score: %.2f' %avg_score, 'epsilon %.2f' %agent.eps) #save the model every 10 games if i>100 and i%100 == 0: agent.save_models() agent.save_models() x = [i+1 for i in range(games_played)] print("trying to display graphs") fig, ax1 = plt.subplots() color = 'tab:red' ax1.set_ylabel('score', color=color) # we already handled the x-label with ax1 ax1.plot(x, scores, color=color) ax1.tick_params(axis='y', labelcolor=color) color = 'tab:green' ax1.plot(x, avg_scores, color=color) ax2 = ax1.twinx() # instantiate a second axes that shares the same x-axis color = 'tab:blue' ax2.set_ylabel('epsilon', color=color) # we already handled the x-label with ax1 ax2.plot(x, epsilon_history, color=color) ax2.tick_params(axis='y', labelcolor=color) fig.tight_layout() # otherwise the right y-label is slightly clipped plt.show() print("graphs displayed") # save the model