from ExperienceReplay import ReplayMemory import numpy as np import torch as T import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import os #create the network class DeepQNetwork(nn.Module): def __init__(self, input_dims, n_actions, lr, name, checkpoint_dir): super(DeepQNetwork, self).__init__() self.input_dims = input_dims self.n_actions = n_actions self.lr = lr #create the 3 convolutional layers self.Conv1 = nn.Conv2d(input_dims[0], 32, 8, stride=4) self.Conv2 = nn.Conv2d(32, 64, 4, stride=2) self.Conv3 = nn.Conv2d(64, 64, 3, stride=1) length = self.calculate_conv_output_dims(self.input_dims) #create the 2 linear layers self.fc1 = nn.Linear(length, 512) self.fc2 = nn.Linear(512, self.n_actions) self.loss = nn.MSELoss() self.optimizer = optim.RMSprop(self.parameters(), lr=lr) self.device = T.device('cuda:0') self.to(self.device) self.checkpoint_dir = checkpoint_dir self.checkpoint_file = os.path.join(self.checkpoint_dir, name) def calculate_conv_output_dims(self, input_dims): __ = T.zeros(1, *input_dims) _ = self.Conv1(__) _ = self.Conv2(_) _ = self.Conv3(_) return int(np.prod(_.size())) #needs a forward function, 3 ReLu activated convnet functions. 1 FC Dense layer def forward(self, state): #needs to take in a torch tensor conv1 = F.relu(self.Conv1(state)) conv2 = F.relu(self.Conv2(conv1)) conv3 = F.relu(self.Conv3(conv2)) #have to flatten flat = conv3.view(conv3.size()[0], -1) layer1 = F.relu(self.fc1(flat)) actions = self.fc2(layer1) return actions def save_checkpoint(self): print('****saving****') T.save(self.state_dict(), self.checkpoint_file) def load_checkpoint(self): print('****loading****') self.load_state_dict(T.load(self.checkpoint_file)) #Agent Class class Agent(): def __init__(self, input_dims, n_actions, lr, gamma, eps, eps_dec, eps_min, replace, mem_size, minibatch_size, checkpoint_dir, algo, env_name): #needs input_dims, n_actions, lr, gamma, eps, eps_dec, eps_min, replace self.input_dims = input_dims self.n_actions = n_actions self.lr = lr self.gamma = gamma self.eps = eps self.eps_dec = eps_dec self.eps_min = eps_min self.replace = replace self.minibatch_size = minibatch_size self.action_space = [i for i in range(self.n_actions)] self.checkpoint_dir = checkpoint_dir self.env_name = env_name self.algo = algo self.memory = ReplayMemory(input_shape=input_dims, n_actions=n_actions, max_mem_len=mem_size) self.learn_step_counter = 0 #needs Q_target and Q_eval self.Q_eval = DeepQNetwork(input_dims=input_dims, n_actions=n_actions, lr=lr, name=self.env_name+'_'+self.algo+'_q_eval', checkpoint_dir=self.checkpoint_dir) self.Q_next = DeepQNetwork(input_dims=input_dims, n_actions=n_actions, lr=lr, name=self.env_name+'_'+self.algo+'_q_next', checkpoint_dir=self.checkpoint_dir) #learn function def learn(self): if self.memory.next_available_memory < self.minibatch_size: return #zero gradient self.Q_eval.optimizer.zero_grad() #see if target network needs replaced self.replace_target_network() #create our mini batch states, actions, rewards, _states, dones = self.get_minibatch() #get q_eval from our minibatch dim_fixer = np.arange(self.minibatch_size) q_eval = self.Q_eval.forward(states)[dim_fixer, actions] q_next = self.Q_eval.forward(_states).max(dim=1)[0] q_next[dones] = 0.0 #if state is terminal, next value doesn't matter #get q_target from r + gamma*q_next(_state) q_target = rewards + self.gamma * q_next #loss (q_target - q_eval)^2 loss = self.Q_eval.loss(q_target, q_eval).to(self.Q_eval.device) loss.backward() #step the optimizer self.Q_eval.optimizer.step() #increment out step counter self.learn_step_counter += 1 self.decrement_epsilon() def store_transition(self, state, action, reward, _state, done): self.memory.store_transition(state, action, reward, _state, done) def get_minibatch(self): state, action, reward, _state, done = self.memory.minibatch(self.minibatch_size) # turn those into torch tensors states = T.tensor(state).to(self.Q_eval.device) actions = T.tensor(action).to(self.Q_eval.device) rewards = T.tensor(reward).to(self.Q_eval.device) _states = T.tensor(_state).to(self.Q_eval.device) dones = T.tensor(done).to(self.Q_eval.device) return states, actions, rewards, _states, dones # function choose action using eps greedy def choose_action(self, state): if np.random.random() > self.eps: state = T.tensor([state], dtype=T.float32).to(self.Q_eval.device) actions = self.Q_eval.forward(state) action = T.argmax(actions).item() else: action = np.random.choice(self.action_space) return action # function decrement epsilon def decrement_epsilon(self): if self.eps > self.eps_min + self.eps_dec: self.eps -= self.eps_dec else: self.eps = self.eps_min # function replace Q_target with Q_eval def replace_target_network(self): if self.learn_step_counter % self.replace == 0: print('replaced') self.Q_next.load_state_dict(self.Q_eval.state_dict()) def load_models(self): self.Q_eval.load_checkpoint() self.Q_next.load_checkpoint() def save_models(self): self.Q_eval.save_checkpoint() self.Q_next.save_checkpoint()