Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
sven1977 committed Jun 7, 2023
1 parent 44d9839 commit eb5194d
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,5 +43,5 @@ vim [some path P ...]/supersuit/lambda_wrappers/observation_lambda.py
# Save your changes and exit vim

# Run the Atari example.
python run_experoment.py -c examples/atari_100k.yaml --env ALE/Pong-v5
python run_experiment.py -c examples/atari_100k.yaml --env ALE/Pong-v5
```
6 changes: 4 additions & 2 deletions run_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,8 @@
batch_length_T = config.get("batch_length_T", 64)
# The number of timesteps we use to "initialize" (burn-in) a dream_trajectory run.
# For this many timesteps, the posterior (actual observation data) will be used
# to compute z, after that, only the prior (dynamics network) will be used.
# to compute z, after that, only the prior (dynamics network) will be used (to compute
# z^).
burn_in_T = config.get("burn_in_T", 5)
horizon_H = config.get("horizon_H", 15)
assert burn_in_T + horizon_H <= batch_length_T, (
Expand Down Expand Up @@ -180,6 +181,7 @@
print(f"Creating initial DreamerModel ...")
model_dimension = config["model_dimension"]
gray_scaled = is_img_space and len(env_runner.env.single_observation_space.shape) == 2

world_model = WorldModel(
model_dimension=model_dimension,
action_space=action_space,
Expand Down Expand Up @@ -367,7 +369,7 @@

# Add ongoing and finished episodes into buffer. The buffer will automatically
# take care of properly concatenating (by episode IDs) the different chunks of
# the same episodes, even if they come in in separate `add()` calls.
# the same episodes, even if they come in via separate `add()` calls.
buffer.add(episodes=done_episodes + ongoing_episodes)

ts_in_buffer = buffer.get_num_timesteps()
Expand Down

0 comments on commit eb5194d

Please sign in to comment.