Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update to support Gymnasium #277

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions cleanrl/c51.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def thunk():
if capture_video:
if idx == 0:
env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
env.seed(seed)

env.action_space.seed(seed)
env.observation_space.seed(seed)
return env
Expand Down Expand Up @@ -159,12 +159,12 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
envs.single_observation_space,
envs.single_action_space,
device,
handle_timeout_termination=True,
handle_timeout_termination=False,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect!

)
start_time = time.time()

# TRY NOT TO MODIFY: start the game
obs = envs.reset()
obs, _ = envs.reset(seed=args.seed)
for global_step in range(args.total_timesteps):
# ALGO LOGIC: put action logic here
epsilon = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps, global_step)
Expand All @@ -175,7 +175,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
actions = actions.cpu().numpy()

# TRY NOT TO MODIFY: execute the game and log data.
next_obs, rewards, dones, infos = envs.step(actions)
next_obs, rewards, terminateds, _, infos = envs.step(actions)

# TRY NOT TO MODIFY: record rewards for plotting purposes
for info in infos:
Expand All @@ -186,12 +186,12 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
writer.add_scalar("charts/epsilon", epsilon, global_step)
break

# TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation`
# TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
real_next_obs = next_obs.copy()
for idx, d in enumerate(dones):
for idx, d in enumerate(terminateds):
if d:
real_next_obs[idx] = infos[idx]["terminal_observation"]
rb.add(obs, real_next_obs, actions, rewards, dones, infos)
real_next_obs[idx] = infos[idx]["final_observation"]
rb.add(obs, real_next_obs, actions, rewards, terminateds, infos)

# TRY NOT TO MODIFY: CRUCIAL step easy to overlook
obs = next_obs
Expand All @@ -202,6 +202,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
with torch.no_grad():
_, next_pmfs = target_network.get_action(data.next_observations)
next_atoms = data.rewards + args.gamma * target_network.atoms * (1 - data.dones)
# TODO: to be updated to data.terminateds once SB3 is updated
# projection
delta_z = target_network.atoms[1] - target_network.atoms[0]
tz = next_atoms.clamp(args.v_min, args.v_max)
Expand Down
16 changes: 8 additions & 8 deletions cleanrl/c51_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ def thunk():
env = gym.wrappers.ResizeObservation(env, (84, 84))
env = gym.wrappers.GrayScaleObservation(env)
env = gym.wrappers.FrameStack(env, 4)
env.seed(seed)
env.action_space.seed(seed)
env.observation_space.seed(seed)
return env
Expand Down Expand Up @@ -181,12 +180,12 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
envs.single_action_space,
device,
optimize_memory_usage=True,
handle_timeout_termination=True,
handle_timeout_termination=False,
)
start_time = time.time()

# TRY NOT TO MODIFY: start the game
obs = envs.reset()
obs, _ = envs.reset(seed=args.seed)
for global_step in range(args.total_timesteps):
# ALGO LOGIC: put action logic here
epsilon = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps, global_step)
Expand All @@ -197,7 +196,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
actions = actions.cpu().numpy()

# TRY NOT TO MODIFY: execute the game and log data.
next_obs, rewards, dones, infos = envs.step(actions)
next_obs, rewards, terminateds, _, infos = envs.step(actions)

# TRY NOT TO MODIFY: record rewards for plotting purposes
for info in infos:
Expand All @@ -208,12 +207,12 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
writer.add_scalar("charts/epsilon", epsilon, global_step)
break

# TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation`
# TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
real_next_obs = next_obs.copy()
for idx, d in enumerate(dones):
for idx, d in enumerate(terminateds):
if d:
real_next_obs[idx] = infos[idx]["terminal_observation"]
rb.add(obs, real_next_obs, actions, rewards, dones, infos)
real_next_obs[idx] = infos[idx]["final_observation"]
rb.add(obs, real_next_obs, actions, rewards, terminateds, infos)

# TRY NOT TO MODIFY: CRUCIAL step easy to overlook
obs = next_obs
Expand All @@ -224,6 +223,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
with torch.no_grad():
_, next_pmfs = target_network.get_action(data.next_observations)
next_atoms = data.rewards + args.gamma * target_network.atoms * (1 - data.dones)
# TODO: to be updated to data.terminateds once SB3 is updated
# projection
delta_z = target_network.atoms[1] - target_network.atoms[0]
tz = next_atoms.clamp(args.v_min, args.v_max)
Expand Down
18 changes: 9 additions & 9 deletions cleanrl/ddpg_continuous_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def thunk():
if capture_video:
if idx == 0:
env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
env.seed(seed)

env.action_space.seed(seed)
env.observation_space.seed(seed)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think env.observation_space.seed(seed) can be remove

return env
Expand Down Expand Up @@ -160,12 +160,12 @@ def forward(self, x):
envs.single_observation_space,
envs.single_action_space,
device,
handle_timeout_termination=True,
handle_timeout_termination=False,
)
start_time = time.time()

# TRY NOT TO MODIFY: start the game
obs = envs.reset()
obs, _ = envs.reset(seed=args.seed)
for global_step in range(args.total_timesteps):
# ALGO LOGIC: put action logic here
if global_step < args.learning_starts:
Expand All @@ -177,7 +177,7 @@ def forward(self, x):
actions = actions.cpu().numpy().clip(envs.single_action_space.low, envs.single_action_space.high)

# TRY NOT TO MODIFY: execute the game and log data.
next_obs, rewards, dones, infos = envs.step(actions)
next_obs, rewards, terminateds, _, infos = envs.step(actions)

# TRY NOT TO MODIFY: record rewards for plotting purposes
for info in infos:
Expand All @@ -187,12 +187,12 @@ def forward(self, x):
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
break
Copy link
Contributor

@vcharraut vcharraut Mar 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Making the assumption that there will be no parrallel env, this could work:

        if "final_info" in infos:
            info = infos["final_info"][0]
            print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
            writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
            writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)

But I have seen that there is a different solution in the DQN file


# TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation`
# TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
real_next_obs = next_obs.copy()
for idx, d in enumerate(dones):
for idx, d in enumerate(terminateds):
Copy link
Contributor

@vcharraut vcharraut Mar 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should it use truncated instead of terminated here ?

With truncated, the results are identical with same seeding between the old and new implementation

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this was a mistake, it should be truncated

if d:
real_next_obs[idx] = infos[idx]["terminal_observation"]
rb.add(obs, real_next_obs, actions, rewards, dones, infos)
real_next_obs[idx] = infos[idx]["final_observation"]
rb.add(obs, real_next_obs, actions, rewards, terminateds, _, infos)

# TRY NOT TO MODIFY: CRUCIAL step easy to overlook
obs = next_obs
Expand All @@ -204,7 +204,7 @@ def forward(self, x):
next_state_actions = target_actor(data.next_observations)
qf1_next_target = qf1_target(data.next_observations, next_state_actions)
next_q_value = data.rewards.flatten() + (1 - data.dones.flatten()) * args.gamma * (qf1_next_target).view(-1)

# TODO: to be updated to data.terminateds once SB3 is updated
qf1_a_values = qf1(data.observations, data.actions).view(-1)
qf1_loss = F.mse_loss(qf1_a_values, next_q_value)

Expand Down
22 changes: 11 additions & 11 deletions cleanrl/ddpg_continuous_action_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def thunk():
if capture_video:
if idx == 0:
env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
env.seed(seed)

env.action_space.seed(seed)
env.observation_space.seed(seed)
return env
Expand Down Expand Up @@ -150,11 +150,11 @@ class TrainState(TrainState):
envs.single_observation_space,
envs.single_action_space,
device="cpu",
handle_timeout_termination=True,
handle_timeout_termination=False,
)

# TRY NOT TO MODIFY: start the game
obs = envs.reset()
obs, _ = envs.reset(seed=args.seed)
action_scale = np.array((envs.action_space.high - envs.action_space.low) / 2.0)
action_bias = np.array((envs.action_space.high + envs.action_space.low) / 2.0)
actor = Actor(
Expand Down Expand Up @@ -186,11 +186,11 @@ def update_critic(
actions: np.ndarray,
next_observations: np.ndarray,
rewards: np.ndarray,
dones: np.ndarray,
terminateds: np.ndarray,
):
next_state_actions = (actor.apply(actor_state.target_params, next_observations)).clip(-1, 1) # TODO: proper clip
qf1_next_target = qf1.apply(qf1_state.target_params, next_observations, next_state_actions).reshape(-1)
next_q_value = (rewards + (1 - dones) * args.gamma * (qf1_next_target)).reshape(-1)
next_q_value = (rewards + (1 - terminateds) * args.gamma * (qf1_next_target)).reshape(-1)

def mse_loss(params):
qf1_a_values = qf1.apply(params, observations, actions).squeeze()
Expand Down Expand Up @@ -235,7 +235,7 @@ def actor_loss(params):
)

# TRY NOT TO MODIFY: execute the game and log data.
next_obs, rewards, dones, infos = envs.step(actions)
next_obs, rewards, terminateds, _, infos = envs.step(actions)

# TRY NOT TO MODIFY: record rewards for plotting purposes
for info in infos:
Expand All @@ -245,12 +245,12 @@ def actor_loss(params):
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
break

# TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation`
# TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
real_next_obs = next_obs.copy()
for idx, d in enumerate(dones):
for idx, d in enumerate(terminateds):
if d:
real_next_obs[idx] = infos[idx]["terminal_observation"]
rb.add(obs, real_next_obs, actions, rewards, dones, infos)
real_next_obs[idx] = infos[idx]["final_observation"]
rb.add(obs, real_next_obs, actions, rewards, terminateds, infos)

# TRY NOT TO MODIFY: CRUCIAL step easy to overlook
obs = next_obs
Expand All @@ -265,7 +265,7 @@ def actor_loss(params):
data.actions.numpy(),
data.next_observations.numpy(),
data.rewards.flatten().numpy(),
data.dones.flatten().numpy(),
data.dones.flatten().numpy(), # TODO: to be updated to data.terminateds once SB3 is updated
)
if global_step % args.policy_frequency == 0:
actor_state, qf1_state, actor_loss_value = update_actor(
Expand Down
16 changes: 8 additions & 8 deletions cleanrl/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ def thunk():
if capture_video:
if idx == 0:
env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
env.seed(seed)
env.action_space.seed(seed)
env.observation_space.seed(seed)
return env
Expand Down Expand Up @@ -144,12 +143,12 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
envs.single_observation_space,
envs.single_action_space,
device,
handle_timeout_termination=True,
handle_timeout_termination=False,
)
start_time = time.time()

# TRY NOT TO MODIFY: start the game
obs = envs.reset()
obs = envs.reset(seed=args.seed)
for global_step in range(args.total_timesteps):
# ALGO LOGIC: put action logic here
epsilon = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps, global_step)
Expand All @@ -160,7 +159,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
actions = torch.argmax(q_values, dim=1).cpu().numpy()

# TRY NOT TO MODIFY: execute the game and log data.
next_obs, rewards, dones, infos = envs.step(actions)
next_obs, rewards, terminateds, _, infos = envs.step(actions)

# TRY NOT TO MODIFY: record rewards for plotting purposes
for info in infos:
Expand All @@ -171,12 +170,12 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
writer.add_scalar("charts/epsilon", epsilon, global_step)
break

# TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation`
# TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
real_next_obs = next_obs.copy()
for idx, d in enumerate(dones):
for idx, d in enumerate(terminateds):
if d:
real_next_obs[idx] = infos[idx]["terminal_observation"]
rb.add(obs, real_next_obs, actions, rewards, dones, infos)
real_next_obs[idx] = infos[idx]["final_observation"]
rb.add(obs, real_next_obs, actions, rewards, terminateds, infos)

# TRY NOT TO MODIFY: CRUCIAL step easy to overlook
obs = next_obs
Expand All @@ -187,6 +186,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
with torch.no_grad():
target_max, _ = target_network(data.next_observations).max(dim=1)
td_target = data.rewards.flatten() + args.gamma * target_max * (1 - data.dones.flatten())
# TODO: to be updated to data.terminateds once SB3 is updated
old_val = q_network(data.observations).gather(1, data.actions).squeeze()
loss = F.mse_loss(td_target, old_val)

Expand Down
17 changes: 9 additions & 8 deletions cleanrl/dqn_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def thunk():
env = gym.wrappers.ResizeObservation(env, (84, 84))
env = gym.wrappers.GrayScaleObservation(env)
env = gym.wrappers.FrameStack(env, 4)
env.seed(seed)

env.action_space.seed(seed)
env.observation_space.seed(seed)
return env
Expand Down Expand Up @@ -166,12 +166,12 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
envs.single_action_space,
device,
optimize_memory_usage=True,
handle_timeout_termination=True,
handle_timeout_termination=False,
)
start_time = time.time()

# TRY NOT TO MODIFY: start the game
obs = envs.reset()
obs, _ = envs.reset(seed=args.seed)
for global_step in range(args.total_timesteps):
# ALGO LOGIC: put action logic here
epsilon = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps, global_step)
Expand All @@ -182,7 +182,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
actions = torch.argmax(q_values, dim=1).cpu().numpy()

# TRY NOT TO MODIFY: execute the game and log data.
next_obs, rewards, dones, infos = envs.step(actions)
next_obs, rewards, terminateds, _, infos = envs.step(actions)

# TRY NOT TO MODIFY: record rewards for plotting purposes
for info in infos:
Expand All @@ -193,12 +193,12 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
writer.add_scalar("charts/epsilon", epsilon, global_step)
break

# TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation`
# TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
real_next_obs = next_obs.copy()
for idx, d in enumerate(dones):
for idx, d in enumerate(terminateds):
if d:
real_next_obs[idx] = infos[idx]["terminal_observation"]
rb.add(obs, real_next_obs, actions, rewards, dones, infos)
real_next_obs[idx] = infos[idx]["final_observation"]
rb.add(obs, real_next_obs, actions, rewards, terminateds, infos)

# TRY NOT TO MODIFY: CRUCIAL step easy to overlook
obs = next_obs
Expand All @@ -209,6 +209,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
with torch.no_grad():
target_max, _ = target_network(data.next_observations).max(dim=1)
td_target = data.rewards.flatten() + args.gamma * target_max * (1 - data.dones.flatten())
# TODO: to be updated to data.terminateds once SB3 is updated
old_val = q_network(data.observations).gather(1, data.actions).squeeze()
loss = F.mse_loss(td_target, old_val)

Expand Down
Loading