Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 686678443
Change-Id: I9e8300bf13004823f9985fe9390056a02c5bee5a
  • Loading branch information
Brax Team authored and btaba committed Oct 16, 2024
1 parent 6a62109 commit c87dcfc
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 8 deletions.
14 changes: 7 additions & 7 deletions brax/training/agents/ppo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,13 +385,6 @@ def training_epoch_with_timing(
specs.Array(env_state.obs.shape[-1:], jnp.dtype('float32'))),
env_steps=0)

if num_timesteps == 0:
return (
make_policy,
(training_state.normalizer_params, training_state.params),
{},
)

if (
restore_checkpoint_path is not None
and epath.Path(restore_checkpoint_path).exists()
Expand All @@ -406,6 +399,13 @@ def training_epoch_with_timing(
normalizer_params=normalizer_params, params=init_params
)

if num_timesteps == 0:
return (
make_policy,
(training_state.normalizer_params, training_state.params),
{},
)

training_state = jax.device_put_replicated(
training_state,
jax.local_devices()[:local_devices_to_use])
Expand Down
3 changes: 2 additions & 1 deletion docs/release-notes/next-release.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Brax Release Notes

* Add boolean `wrap_env` to all brax `train` functions, which optionally wraps the env for training, or uses the env as is.
* Add boolean `wrap_env` to all brax `train` functions, which optionally wraps the env for training, or uses the env as is.
* Fix bug in PPO train to return loaded checkpoint when `num_timesteps` is 0.

0 comments on commit c87dcfc

Please sign in to comment.