-
Notifications
You must be signed in to change notification settings - Fork 645
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update to support Gymnasium #277
Changes from 4 commits
4c39915
a7c42fb
0c0a647
35d01ee
0ddcac7
2495c5f
1d64b5b
d4dcc60
4e8f8b8
59e727c
2ae0be5
ba8983f
ed68e76
7e8f2db
4a05385
f8271fe
ecffa00
28fd178
813192d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -71,7 +71,7 @@ def thunk(): | |
if capture_video: | ||
if idx == 0: | ||
env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") | ||
env.seed(seed) | ||
|
||
env.action_space.seed(seed) | ||
env.observation_space.seed(seed) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think |
||
return env | ||
|
@@ -160,12 +160,12 @@ def forward(self, x): | |
envs.single_observation_space, | ||
envs.single_action_space, | ||
device, | ||
handle_timeout_termination=True, | ||
handle_timeout_termination=False, | ||
) | ||
start_time = time.time() | ||
|
||
# TRY NOT TO MODIFY: start the game | ||
obs = envs.reset() | ||
obs, _ = envs.reset(seed=args.seed) | ||
for global_step in range(args.total_timesteps): | ||
# ALGO LOGIC: put action logic here | ||
if global_step < args.learning_starts: | ||
|
@@ -177,7 +177,7 @@ def forward(self, x): | |
actions = actions.cpu().numpy().clip(envs.single_action_space.low, envs.single_action_space.high) | ||
|
||
# TRY NOT TO MODIFY: execute the game and log data. | ||
next_obs, rewards, dones, infos = envs.step(actions) | ||
next_obs, rewards, terminateds, _, infos = envs.step(actions) | ||
|
||
# TRY NOT TO MODIFY: record rewards for plotting purposes | ||
for info in infos: | ||
|
@@ -187,12 +187,12 @@ def forward(self, x): | |
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step) | ||
break | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Making the assumption that there will be no parrallel env, this could work: if "final_info" in infos:
info = infos["final_info"][0]
print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step) But I have seen that there is a different solution in the DQN file |
||
|
||
# TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation` | ||
# TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation` | ||
real_next_obs = next_obs.copy() | ||
for idx, d in enumerate(dones): | ||
for idx, d in enumerate(terminateds): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should it use With There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, this was a mistake, it should be |
||
if d: | ||
real_next_obs[idx] = infos[idx]["terminal_observation"] | ||
rb.add(obs, real_next_obs, actions, rewards, dones, infos) | ||
real_next_obs[idx] = infos[idx]["final_observation"] | ||
rb.add(obs, real_next_obs, actions, rewards, terminateds, _, infos) | ||
|
||
# TRY NOT TO MODIFY: CRUCIAL step easy to overlook | ||
obs = next_obs | ||
|
@@ -204,7 +204,7 @@ def forward(self, x): | |
next_state_actions = target_actor(data.next_observations) | ||
qf1_next_target = qf1_target(data.next_observations, next_state_actions) | ||
next_q_value = data.rewards.flatten() + (1 - data.dones.flatten()) * args.gamma * (qf1_next_target).view(-1) | ||
|
||
# TODO: to be updated to data.terminateds once SB3 is updated | ||
qf1_a_values = qf1(data.observations, data.actions).view(-1) | ||
qf1_loss = F.mse_loss(qf1_a_values, next_q_value) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perfect!