Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Non-functional objectives (PPO, A2C, Reinforce) #1804

Merged
merged 7 commits into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions benchmarks/test_objectives_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,7 @@ def test_a2c_speed(
actor(td.clone())
critic(td.clone())

loss = A2CLoss(actor=actor, critic=critic)
loss = A2CLoss(actor_network=actor, critic_network=critic)
advantage = GAE(value_network=critic, gamma=0.99, lmbda=0.95, shifted=True)
advantage(td)
loss(td)
Expand Down Expand Up @@ -605,7 +605,7 @@ def test_ppo_speed(
actor(td.clone())
critic(td.clone())

loss = ClipPPOLoss(actor=actor, critic=critic)
loss = ClipPPOLoss(actor_network=actor, critic_network=critic)
advantage = GAE(value_network=critic, gamma=0.99, lmbda=0.95, shifted=True)
advantage(td)
loss(td)
Expand Down Expand Up @@ -662,7 +662,7 @@ def test_reinforce_speed(
actor(td.clone())
critic(td.clone())

loss = ReinforceLoss(actor=actor, critic=critic)
loss = ReinforceLoss(actor_network=actor, critic_network=critic)
advantage = GAE(value_network=critic, gamma=0.99, lmbda=0.95, shifted=True)
advantage(td)
loss(td)
Expand Down
4 changes: 2 additions & 2 deletions examples/a2c/a2c_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ def main(cfg: "DictConfig"): # noqa: F821
average_gae=True,
)
loss_module = A2CLoss(
actor=actor,
critic=critic,
actor_network=actor,
critic_network=critic,
loss_critic_type=cfg.loss.loss_critic_type,
entropy_coef=cfg.loss.entropy_coef,
critic_coef=cfg.loss.critic_coef,
Expand Down
4 changes: 2 additions & 2 deletions examples/a2c/a2c_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ def main(cfg: "DictConfig"): # noqa: F821
average_gae=False,
)
loss_module = A2CLoss(
actor=actor,
critic=critic,
actor_network=actor,
critic_network=critic,
loss_critic_type=cfg.loss.loss_critic_type,
entropy_coef=cfg.loss.entropy_coef,
critic_coef=cfg.loss.critic_coef,
Expand Down
2 changes: 1 addition & 1 deletion examples/distributed/collectors/multi_nodes/ray_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@
)
loss_module = ClipPPOLoss(
actor=policy_module,
critic=value_module,
critic_network=value_module,
advantage_key="advantage",
clip_epsilon=clip_epsilon,
entropy_bonus=bool(entropy_eps),
Expand Down
4 changes: 2 additions & 2 deletions examples/impala/impala_multi_node_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ def main(cfg: "DictConfig"): # noqa: F821
average_adv=False,
)
loss_module = A2CLoss(
actor=actor,
critic=critic,
actor_network=actor,
critic_network=critic,
loss_critic_type=cfg.loss.loss_critic_type,
entropy_coef=cfg.loss.entropy_coef,
critic_coef=cfg.loss.critic_coef,
Expand Down
4 changes: 2 additions & 2 deletions examples/impala/impala_multi_node_submitit.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ def main(cfg: "DictConfig"): # noqa: F821
average_adv=False,
)
loss_module = A2CLoss(
actor=actor,
critic=critic,
actor_network=actor,
critic_network=critic,
loss_critic_type=cfg.loss.loss_critic_type,
entropy_coef=cfg.loss.entropy_coef,
critic_coef=cfg.loss.critic_coef,
Expand Down
4 changes: 2 additions & 2 deletions examples/impala/impala_single_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ def main(cfg: "DictConfig"): # noqa: F821
average_adv=False,
)
loss_module = A2CLoss(
actor=actor,
critic=critic,
actor_network=actor,
critic_network=critic,
loss_critic_type=cfg.loss.loss_critic_type,
entropy_coef=cfg.loss.entropy_coef,
critic_coef=cfg.loss.critic_coef,
Expand Down
6 changes: 3 additions & 3 deletions examples/multiagent/mappo_ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,8 @@ def train(cfg: "DictConfig"): # noqa: F821

# Loss
loss_module = ClipPPOLoss(
actor=policy,
critic=value_module,
actor_network=policy,
critic_network=value_module,
clip_epsilon=cfg.loss.clip_epsilon,
entropy_coef=cfg.loss.entropy_eps,
normalize_advantage=False,
Expand Down Expand Up @@ -174,7 +174,7 @@ def train(cfg: "DictConfig"): # noqa: F821
with torch.no_grad():
loss_module.value_estimator(
tensordict_data,
params=loss_module.critic_params,
params=loss_module.critic_network_params,
target_params=loss_module.target_critic_params,
)
current_frames = tensordict_data.numel()
Expand Down
4 changes: 2 additions & 2 deletions examples/ppo/ppo_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ def main(cfg: "DictConfig"): # noqa: F821
average_gae=False,
)
loss_module = ClipPPOLoss(
actor=actor,
critic=critic,
actor_network=actor,
critic_network=critic,
clip_epsilon=cfg.loss.clip_epsilon,
loss_critic_type=cfg.loss.loss_critic_type,
entropy_coef=cfg.loss.entropy_coef,
Expand Down
4 changes: 2 additions & 2 deletions examples/ppo/ppo_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ def main(cfg: "DictConfig"): # noqa: F821
)

loss_module = ClipPPOLoss(
actor=actor,
critic=critic,
actor_network=actor,
critic_network=critic,
clip_epsilon=cfg.loss.clip_epsilon,
loss_critic_type=cfg.loss.loss_critic_type,
entropy_coef=cfg.loss.entropy_coef,
Expand Down
49 changes: 30 additions & 19 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -5820,7 +5820,10 @@ def _create_seq_mock_data_ppo(
@pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None))
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("td_est", list(ValueEstimators) + [None])
def test_ppo(self, loss_class, device, gradient_mode, advantage, td_est):
@pytest.mark.parametrize("functional", [True, False])
def test_ppo(
self, loss_class, device, gradient_mode, advantage, td_est, functional
):
torch.manual_seed(self.seed)
td = self._create_seq_mock_data_ppo(device=device)

Expand Down Expand Up @@ -5850,7 +5853,7 @@ def test_ppo(self, loss_class, device, gradient_mode, advantage, td_est):
else:
raise NotImplementedError

loss_fn = loss_class(actor, value, loss_critic_type="l2")
loss_fn = loss_class(actor, value, loss_critic_type="l2", functional=functional)
if advantage is not None:
advantage(td)
else:
Expand Down Expand Up @@ -6328,7 +6331,7 @@ def test_ppo_notensordict(
)
value = self._create_mock_value(observation_key=observation_key)

loss = loss_class(actor=actor, critic=value)
loss = loss_class(actor_network=actor, critic_network=value)
loss.set_keys(
action=action_key,
reward=reward_key,
Expand Down Expand Up @@ -6537,7 +6540,8 @@ def _create_seq_mock_data_a2c(
@pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None))
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("td_est", list(ValueEstimators) + [None])
def test_a2c(self, device, gradient_mode, advantage, td_est):
@pytest.mark.parametrize("functional", (True, False))
def test_a2c(self, device, gradient_mode, advantage, td_est, functional):
torch.manual_seed(self.seed)
td = self._create_seq_mock_data_a2c(device=device)

Expand Down Expand Up @@ -6567,7 +6571,7 @@ def test_a2c(self, device, gradient_mode, advantage, td_est):
else:
raise NotImplementedError

loss_fn = A2CLoss(actor, value, loss_critic_type="l2")
loss_fn = A2CLoss(actor, value, loss_critic_type="l2", functional=functional)

# Check error is raised when actions require grads
td["action"].requires_grad = True
Expand Down Expand Up @@ -6629,7 +6633,9 @@ def test_a2c_state_dict(self, device, gradient_mode):
def test_a2c_separate_losses(self, separate_losses):
torch.manual_seed(self.seed)
actor, critic, common, td = self._create_mock_common_layer_setup()
loss_fn = A2CLoss(actor=actor, critic=critic, separate_losses=separate_losses)
loss_fn = A2CLoss(
actor_network=actor, critic_network=critic, separate_losses=separate_losses
)

# Check error is raised when actions require grads
td["action"].requires_grad = True
Expand Down Expand Up @@ -6966,7 +6972,6 @@ def test_a2c_notensordict(
class TestReinforce(LossModuleTestBase):
seed = 0

@pytest.mark.parametrize("delay_value", [True, False])
@pytest.mark.parametrize("gradient_mode", [True, False])
@pytest.mark.parametrize("advantage", ["gae", "td", "td_lambda", None])
@pytest.mark.parametrize(
Expand All @@ -6979,7 +6984,12 @@ class TestReinforce(LossModuleTestBase):
None,
],
)
def test_reinforce_value_net(self, advantage, gradient_mode, delay_value, td_est):
@pytest.mark.parametrize(
"delay_value,functional", [[False, True], [False, False], [True, True]]
)
def test_reinforce_value_net(
self, advantage, gradient_mode, delay_value, td_est, functional
):
n_obs = 3
n_act = 5
batch = 4
Expand Down Expand Up @@ -7023,8 +7033,9 @@ def test_reinforce_value_net(self, advantage, gradient_mode, delay_value, td_est

loss_fn = ReinforceLoss(
actor_net,
critic=value_net,
critic_network=value_net,
delay_value=delay_value,
functional=functional,
)

td = TensorDict(
Expand All @@ -7049,7 +7060,7 @@ def test_reinforce_value_net(self, advantage, gradient_mode, delay_value, td_est
if advantage is not None:
params = TensorDict.from_module(value_net)
if delay_value:
target_params = loss_fn.target_critic_params
target_params = loss_fn.target_critic_network_params
else:
target_params = None
advantage(td, params=params, target_params=target_params)
Expand Down Expand Up @@ -7108,7 +7119,7 @@ def test_reinforce_tensordict_keys(self, td_est):

loss_fn = ReinforceLoss(
actor_net,
critic=value_net,
critic_network=value_net,
)

default_keys = {
Expand All @@ -7133,7 +7144,7 @@ def test_reinforce_tensordict_keys(self, td_est):

loss_fn = ReinforceLoss(
actor_net,
critic=value_net,
critic_network=value_net,
)

key_mapping = {
Expand Down Expand Up @@ -7207,14 +7218,14 @@ def test_reinforce_tensordict_separate_losses(self, separate_losses):
torch.manual_seed(self.seed)
actor, critic, common, td = self._create_mock_common_layer_setup()
loss_fn = ReinforceLoss(
actor=actor, critic=critic, separate_losses=separate_losses
actor_network=actor, critic_network=critic, separate_losses=separate_losses
)

loss = loss_fn(td)

assert all(
(p.grad is None) or (p.grad == 0).all()
for p in loss_fn.critic_params.values(True, True)
for p in loss_fn.critic_network_params.values(True, True)
)
assert all(
(p.grad is None) or (p.grad == 0).all()
Expand All @@ -7234,14 +7245,14 @@ def test_reinforce_tensordict_separate_losses(self, separate_losses):
for p in loss_fn.actor_network_params.values(True, True)
)
common_layers = itertools.islice(
loss_fn.critic_params.values(True, True),
loss_fn.critic_network_params.values(True, True),
common_layers_no,
)
assert all(
(p.grad is None) or (p.grad == 0).all() for p in common_layers
)
critic_layers = itertools.islice(
loss_fn.critic_params.values(True, True),
loss_fn.critic_network_params.values(True, True),
common_layers_no,
None,
)
Expand All @@ -7250,7 +7261,7 @@ def test_reinforce_tensordict_separate_losses(self, separate_losses):
)
else:
common_layers = itertools.islice(
loss_fn.critic_params.values(True, True),
loss_fn.critic_network_params.values(True, True),
common_layers_no,
)
assert not any(
Expand All @@ -7266,7 +7277,7 @@ def test_reinforce_tensordict_separate_losses(self, separate_losses):
)
assert not any(
(p.grad is None) or (p.grad == 0).all()
for p in loss_fn.critic_params.values(True, True)
for p in loss_fn.critic_network_params.values(True, True)
)

else:
Expand Down Expand Up @@ -7297,7 +7308,7 @@ def test_reinforce_notensordict(
in_keys=["loc", "scale"],
spec=UnboundedContinuousTensorSpec(n_act),
)
loss = ReinforceLoss(actor=actor_net, critic=value_net)
loss = ReinforceLoss(actor_network=actor_net, critic_network=value_net)
loss.set_keys(
reward=reward_key,
done=done_key,
Expand Down
Loading
Loading