diff --git a/rl_games/algos_torch/players.py b/rl_games/algos_torch/players.py index 9120bb94..3dfcdadd 100644 --- a/rl_games/algos_torch/players.py +++ b/rl_games/algos_torch/players.py @@ -16,6 +16,7 @@ def rescale_actions(low, high, action): class PpoPlayerContinuous(BasePlayer): + def __init__(self, params): BasePlayer.__init__(self, params) self.network = self.config['network'] @@ -81,7 +82,9 @@ def restore(self, fn): def reset(self): self.init_rnn() + class PpoPlayerDiscrete(BasePlayer): + def __init__(self, params): BasePlayer.__init__(self, params) @@ -185,6 +188,7 @@ def reset(self): class SACPlayer(BasePlayer): + def __init__(self, params): BasePlayer.__init__(self, params) self.network = self.config['network'] diff --git a/rl_games/algos_torch/sac_agent.py b/rl_games/algos_torch/sac_agent.py index f796a79f..90f878d0 100644 --- a/rl_games/algos_torch/sac_agent.py +++ b/rl_games/algos_torch/sac_agent.py @@ -445,6 +445,11 @@ def play_steps(self, random_exploration = False): critic2_losses = [] obs = self.obs + if isinstance(obs, dict): + obs = self.obs['obs'] + + next_obs_processed = obs.clone() + for s in range(self.num_steps_per_episode): self.set_eval() if random_exploration: @@ -480,16 +485,17 @@ def play_steps(self, random_exploration = False): self.current_rewards = self.current_rewards * not_dones self.current_lengths = self.current_lengths * not_dones - if isinstance(obs, dict): - obs = obs['obs'] if isinstance(next_obs, dict): - next_obs = next_obs['obs'] + next_obs_processed = next_obs['obs'] + + self.obs = next_obs.clone() rewards = self.rewards_shaper(rewards) - self.replay_buffer.add(obs, action, torch.unsqueeze(rewards, 1), next_obs, torch.unsqueeze(dones, 1)) + self.replay_buffer.add(obs, action, torch.unsqueeze(rewards, 1), next_obs_processed, torch.unsqueeze(dones, 1)) - self.obs = obs = next_obs.clone() + if isinstance(obs, dict): + obs = self.obs['obs'] if not random_exploration: self.set_train() diff --git a/rl_games/torch_runner.py b/rl_games/torch_runner.py index e8ec70ee..4c4de4dd 100644 --- a/rl_games/torch_runner.py +++ b/rl_games/torch_runner.py @@ -53,6 +53,8 @@ def __init__(self, algo_observer=None): #torch.backends.cudnn.deterministic = True #torch.use_deterministic_algorithms(True) + #breakpoint() + def reset(self): pass @@ -131,12 +133,12 @@ def reset(self): pass def run(self, args): - load_path = None - if args['train']: + print('Started to train') self.run_train(args) - elif args['play']: + print('Started to play') self.run_play(args) else: + print('Started to train2') self.run_train(args) \ No newline at end of file diff --git a/runner.py b/runner.py index 25f79af4..c1188751 100644 --- a/runner.py +++ b/runner.py @@ -50,11 +50,9 @@ except yaml.YAMLError as exc: print(exc) - #rank = int(os.getenv("LOCAL_RANK", "0")) global_rank = int(os.getenv("RANK", "0")) if args["track"] and global_rank == 0: import wandb - wandb.init( project=args["wandb_project_name"], entity=args["wandb_entity"],