diff --git a/src/doom/actions.py b/src/doom/actions.py index 835b4dd..63146e7 100644 --- a/src/doom/actions.py +++ b/src/doom/actions.py @@ -88,8 +88,9 @@ def get_action(self, action): for k in self.available_buttons] return doom_action else: - assert type(action) is int - return self.doom_actions[action] + a = action if type(action) == int else action.item() + assert type(a) is int + return self.doom_actions[a] action_categories_discrete = { diff --git a/src/doom/scenarios/deathmatch.py b/src/doom/scenarios/deathmatch.py index e78fcb5..db3d48a 100644 --- a/src/doom/scenarios/deathmatch.py +++ b/src/doom/scenarios/deathmatch.py @@ -173,7 +173,7 @@ def evaluate_deathmatch(game, network, params, n_train_iter=None): # observe the game state / select the next action game.observe_state(params, last_states) - action = network.next_action(last_states) + action = network.next_action(last_states).tolist() pred_features = network.pred_features # game features diff --git a/src/model/bucketed_embedding.py b/src/model/bucketed_embedding.py index f88a4d8..5145105 100644 --- a/src/model/bucketed_embedding.py +++ b/src/model/bucketed_embedding.py @@ -1,7 +1,6 @@ -import torch.nn as nn +import torch - -class BucketedEmbedding(nn.Embedding): +class BucketedEmbedding(torch.nn.Embedding): def __init__(self, bucket_size, num_embeddings, *args, **kwargs): self.bucket_size = bucket_size @@ -9,4 +8,4 @@ def __init__(self, bucket_size, num_embeddings, *args, **kwargs): super(BucketedEmbedding, self).__init__(real_num_embeddings, *args, **kwargs) def forward(self, indices): - return super(BucketedEmbedding, self).forward(indices.div(self.bucket_size)) + return super(BucketedEmbedding, self).forward(indices.div(self.bucket_size).type(torch.LongTensor)) diff --git a/src/model/dqn/base.py b/src/model/dqn/base.py index 6e1c61c..53d540c 100644 --- a/src/model/dqn/base.py +++ b/src/model/dqn/base.py @@ -78,7 +78,7 @@ def base_forward(self, x_screens, x_variables): # create state input if self.n_variables: - output = torch.cat([conv_output] + embeddings, 1) + output = torch.cat([conv_output] + embeddings, dim=1) else: output = conv_output @@ -185,8 +185,8 @@ def prepare_f_train_args(self, screens, variables, features, return screens, variables, features, actions, rewards, isfinal def register_loss(self, loss_history, loss_sc, loss_gf): - loss_history['dqn_loss'].append(loss_sc.data[0]) - loss_history['gf_loss'].append(loss_gf.data[0] + loss_history['dqn_loss'].append(loss_sc.data) + loss_history['gf_loss'].append(loss_gf.data if self.n_features else 0) def next_action(self, last_states, save_graph=False): @@ -205,7 +205,7 @@ def next_action(self, last_states, save_graph=False): if pred_features is not None: assert pred_features.size() == (1, seq_len, self.module.n_features) pred_features = pred_features[0, -1] - action_id = scores.data.max(0)[1][0] + action_id = scores.data.max(0)[1] self.pred_features = pred_features return action_id diff --git a/src/model/dqn/feedforward.py b/src/model/dqn/feedforward.py index e2af5a6..4294430 100644 --- a/src/model/dqn/feedforward.py +++ b/src/model/dqn/feedforward.py @@ -21,10 +21,15 @@ def forward(self, x_screens, x_variables): """ batch_size = x_screens.size(0) + + for x in x_variables: + x.unsqueeze_(0) + assert x_screens.ndimension() == 4 assert len(x_variables) == self.n_variables - assert all(x.ndimension() == 1 and x.size(0) == batch_size - for x in x_variables) + + #assert all(x.ndimension() == 0 and len(list(x.size())) == batch_size + # for x in x_variables) # state input (screen / depth / labels buffer + variables) state_input, output_gf = self.base_forward(x_screens, x_variables) @@ -45,7 +50,6 @@ class DQNFeedforward(DQN): def f_eval(self, last_states): screens, variables = self.prepare_f_eval_args(last_states) - return self.module( screens.view(1, -1, *self.screen_shape[1:]), [variables[-1, i] for i in range(self.params.n_variables)]