Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan1996 committed Jul 15, 2024
2 parents e11cf12 + 83f10f0 commit d2f5ba8
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion lzero/entry/train_unizero.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def train_unizero(
game_buffer_classes[create_cfg.policy.type])

# Set device based on CUDA availability
cfg.policy.device = cfg.policy.model.world_model.device if torch.cuda.is_available() else 'cpu'
cfg.policy.device = cfg.policy.model.world_model_cfg.device if torch.cuda.is_available() else 'cpu'
logging.info(f'cfg.policy.device: {cfg.policy.device}')

# Compile the configuration
Expand Down
2 changes: 1 addition & 1 deletion lzero/model/unizero_world_models/world_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ def refresh_kvs_with_initial_latent_state_for_init_infer(self, latent_state: tor
if current_obs_embeddings is not None:
if max(buffer_action) == -1:
# First step in an episode
self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n,
self.keys_values_wm = self.transformer.generate_empty_keys_values(n=current_obs_embeddings.shape[0],
max_tokens=self.context_length)
# print(f"current_obs_embeddings.device: {current_obs_embeddings.device}")
outputs_wm = self.forward({'obs_embeddings': current_obs_embeddings},
Expand Down
4 changes: 2 additions & 2 deletions zoo/atari/config/atari_unizero_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@
max_blocks=num_unroll_steps,
max_tokens=2 * num_unroll_steps, # NOTE: each timestep has 2 tokens: obs and action
context_length=2 * infer_context_length,
# device='cuda',
device='cpu',
device='cuda',
# device='cpu',
action_space_size=action_space_size,
num_layers=4,
num_heads=8,
Expand Down

0 comments on commit d2f5ba8

Please sign in to comment.