From 4b11c16fd8b353e115fe7cded8db7c71f0d2f0eb Mon Sep 17 00:00:00 2001 From: Rishabh Patra Date: Wed, 21 Oct 2020 21:47:57 +0530 Subject: [PATCH] Adding device --- genrl/agents/modelbased/cem/cem.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/genrl/agents/modelbased/cem/cem.py b/genrl/agents/modelbased/cem/cem.py index a90af5e7..5451bb96 100644 --- a/genrl/agents/modelbased/cem/cem.py +++ b/genrl/agents/modelbased/cem/cem.py @@ -60,7 +60,7 @@ def _create_model(self): "V", discrete, action_lim, - ) + ).to(self.device) self.optim = torch.optim.Adam(self.agent.parameters(), lr=self.lr_policy) def plan(self): @@ -136,7 +136,7 @@ def update_params(self): elite_states, elite_actions = self.select_elites( batch_states, batch_actions, batch_rewards ) - action_probs = self.agent.forward(elite_states.float()) + action_probs = self.agent.forward(elite_states.float().to(self.device)) loss = F.cross_entropy( action_probs.view(-1, self.action_dim), elite_actions.long().view(-1),