Skip to content

Commit

Permalink
Fixed SAC weight loading crash.
Browse files Browse the repository at this point in the history
  • Loading branch information
ViktorM committed Aug 29, 2023
1 parent e7e4251 commit efb2035
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 10 deletions.
4 changes: 4 additions & 0 deletions rl_games/algos_torch/players.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def rescale_actions(low, high, action):


class PpoPlayerContinuous(BasePlayer):

def __init__(self, params):
BasePlayer.__init__(self, params)
self.network = self.config['network']
Expand Down Expand Up @@ -81,7 +82,9 @@ def restore(self, fn):
def reset(self):
self.init_rnn()


class PpoPlayerDiscrete(BasePlayer):

def __init__(self, params):
BasePlayer.__init__(self, params)

Expand Down Expand Up @@ -185,6 +188,7 @@ def reset(self):


class SACPlayer(BasePlayer):

def __init__(self, params):
BasePlayer.__init__(self, params)
self.network = self.config['network']
Expand Down
16 changes: 11 additions & 5 deletions rl_games/algos_torch/sac_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,11 @@ def play_steps(self, random_exploration = False):
critic2_losses = []

obs = self.obs
if isinstance(obs, dict):
obs = self.obs['obs']

next_obs_processed = obs.clone()

for s in range(self.num_steps_per_episode):
self.set_eval()
if random_exploration:
Expand Down Expand Up @@ -480,16 +485,17 @@ def play_steps(self, random_exploration = False):
self.current_rewards = self.current_rewards * not_dones
self.current_lengths = self.current_lengths * not_dones

if isinstance(obs, dict):
obs = obs['obs']
if isinstance(next_obs, dict):
next_obs = next_obs['obs']
next_obs_processed = next_obs['obs']

self.obs = next_obs.clone()

rewards = self.rewards_shaper(rewards)

self.replay_buffer.add(obs, action, torch.unsqueeze(rewards, 1), next_obs, torch.unsqueeze(dones, 1))
self.replay_buffer.add(obs, action, torch.unsqueeze(rewards, 1), next_obs_processed, torch.unsqueeze(dones, 1))

self.obs = obs = next_obs.clone()
if isinstance(obs, dict):
obs = self.obs['obs']

if not random_exploration:
self.set_train()
Expand Down
8 changes: 5 additions & 3 deletions rl_games/torch_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ def __init__(self, algo_observer=None):
#torch.backends.cudnn.deterministic = True
#torch.use_deterministic_algorithms(True)

#breakpoint()

def reset(self):
pass

Expand Down Expand Up @@ -131,12 +133,12 @@ def reset(self):
pass

def run(self, args):
load_path = None

if args['train']:
print('Started to train')
self.run_train(args)

elif args['play']:
print('Started to play')
self.run_play(args)
else:
print('Started to train2')
self.run_train(args)
2 changes: 0 additions & 2 deletions runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,9 @@
except yaml.YAMLError as exc:
print(exc)

#rank = int(os.getenv("LOCAL_RANK", "0"))
global_rank = int(os.getenv("RANK", "0"))
if args["track"] and global_rank == 0:
import wandb

wandb.init(
project=args["wandb_project_name"],
entity=args["wandb_entity"],
Expand Down

0 comments on commit efb2035

Please sign in to comment.