Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Dedicated tests for on policy losses reduction parameter #1974

Merged
merged 21 commits into from
Feb 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 91 additions & 33 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -6201,7 +6201,6 @@ def _create_seq_mock_data_ppo(
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("td_est", list(ValueEstimators) + [None])
@pytest.mark.parametrize("functional", [True, False])
@pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"])
def test_ppo(
self,
loss_class,
Expand All @@ -6210,7 +6209,6 @@ def test_ppo(
advantage,
td_est,
functional,
reduction,
):
torch.manual_seed(self.seed)
td = self._create_seq_mock_data_ppo(device=device)
Expand Down Expand Up @@ -6246,7 +6244,6 @@ def test_ppo(
value,
loss_critic_type="l2",
functional=functional,
reduction=reduction,
)
if advantage is not None:
advantage(td)
Expand All @@ -6259,15 +6256,6 @@ def test_ppo(
kl = loss.pop("kl")
assert (kl != 0).any()

if reduction == "none":

def func(x):
if x.dtype != torch.float:
return
return x.mean()

loss = loss.apply(func, batch_size=[])

loss_critic = loss["loss_critic"]
loss_objective = loss["loss_objective"] + loss.get("loss_entropy", 0.0)
loss_critic.backward(retain_graph=True)
Expand Down Expand Up @@ -6804,6 +6792,41 @@ def test_ppo_notensordict(
assert loss_obj == loss_val_td.get("loss_objective")
assert loss_crit == loss_val_td.get("loss_critic")

@pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss))
@pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"])
def test_ppo_reduction(self, reduction, loss_class):
torch.manual_seed(self.seed)
device = (
torch.device("cpu")
if torch.cuda.device_count() == 0
else torch.device("cuda")
)
td = self._create_seq_mock_data_ppo(device=device)
actor = self._create_mock_actor(device=device)
value = self._create_mock_value(device=device)
advantage = GAE(
gamma=0.9,
lmbda=0.9,
value_network=value,
)
loss_fn = loss_class(
actor,
value,
loss_critic_type="l2",
reduction=reduction,
)
advantage(td)
loss = loss_fn(td)
if reduction == "none":
for key in loss.keys():
if key.startswith("loss_"):
assert loss[key].shape == td.shape
else:
for key in loss.keys():
if not key.startswith("loss_"):
continue
assert loss[key].shape == torch.Size([])


class TestA2C(LossModuleTestBase):
seed = 0
Expand Down Expand Up @@ -6969,8 +6992,7 @@ def _create_seq_mock_data_a2c(
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("td_est", list(ValueEstimators) + [None])
@pytest.mark.parametrize("functional", (True, False))
@pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"])
def test_a2c(self, device, gradient_mode, advantage, td_est, functional, reduction):
def test_a2c(self, device, gradient_mode, advantage, td_est, functional):
torch.manual_seed(self.seed)
td = self._create_seq_mock_data_a2c(device=device)

Expand Down Expand Up @@ -7005,7 +7027,6 @@ def test_a2c(self, device, gradient_mode, advantage, td_est, functional, reducti
value,
loss_critic_type="l2",
functional=functional,
reduction=reduction,
)

# Check error is raised when actions require grads
Expand All @@ -7023,14 +7044,7 @@ def test_a2c(self, device, gradient_mode, advantage, td_est, functional, reducti
elif td_est is not None:
loss_fn.make_value_estimator(td_est)
loss = loss_fn(td)
if reduction == "none":

def func(x):
if x.dtype != torch.float:
return
return x.mean()

loss = loss.apply(func, batch_size=[])
loss_critic = loss["loss_critic"]
loss_objective = loss["loss_objective"] + loss.get("loss_entropy", 0.0)
loss_critic.backward(retain_graph=True)
Expand Down Expand Up @@ -7413,6 +7427,40 @@ def test_a2c_notensordict(
assert loss_objective == loss_val_td["loss_objective"]
assert loss_critic == loss_val_td["loss_critic"]

@pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"])
def test_a2c_reduction(self, reduction):
torch.manual_seed(self.seed)
device = (
torch.device("cpu")
if torch.cuda.device_count() == 0
else torch.device("cuda")
)
td = self._create_seq_mock_data_a2c(device=device)
actor = self._create_mock_actor(device=device)
value = self._create_mock_value(device=device)
advantage = GAE(
gamma=0.9,
lmbda=0.9,
value_network=value,
)
loss_fn = A2CLoss(
actor,
value,
loss_critic_type="l2",
reduction=reduction,
)
advantage(td)
loss = loss_fn(td)
if reduction == "none":
for key in loss.keys():
if key.startswith("loss_"):
assert loss[key].shape == td.shape
else:
for key in loss.keys():
if not key.startswith("loss_"):
continue
assert loss[key].shape == torch.Size([])


class TestReinforce(LossModuleTestBase):
seed = 0
Expand Down Expand Up @@ -7659,26 +7707,16 @@ def _create_mock_common_layer_setup(
return actor, critic, common, td

@pytest.mark.parametrize("separate_losses", [False, True])
@pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"])
def test_reinforce_tensordict_separate_losses(self, separate_losses, reduction):
def test_reinforce_tensordict_separate_losses(self, separate_losses):
torch.manual_seed(self.seed)
actor, critic, common, td = self._create_mock_common_layer_setup()
loss_fn = ReinforceLoss(
actor_network=actor,
critic_network=critic,
separate_losses=separate_losses,
reduction=reduction,
)

loss = loss_fn(td)
if reduction == "none":

def func(x):
if x.dtype != torch.float:
return
return x.mean()

loss = loss.apply(func, batch_size=[])

assert all(
(p.grad is None) or (p.grad == 0).all()
Expand Down Expand Up @@ -7807,6 +7845,26 @@ def test_reinforce_notensordict(
return
assert loss_actor == loss_val_td["loss_actor"]

@pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"])
def test_reinforce_reduction(self, reduction):
torch.manual_seed(self.seed)
actor, critic, common, td = self._create_mock_common_layer_setup()
loss_fn = ReinforceLoss(
actor_network=actor,
critic_network=critic,
reduction=reduction,
)
loss = loss_fn(td)
if reduction == "none":
for key in loss.keys():
if key.startswith("loss_"):
assert loss[key].shape == td.shape
else:
for key in loss.keys():
if not key.startswith("loss_"):
continue
assert loss[key].shape == torch.Size([])


@pytest.mark.parametrize("device", get_default_devices())
class TestDreamer(LossModuleTestBase):
Expand Down
2 changes: 1 addition & 1 deletion torchrl/objectives/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
loss_critic = self.loss_critic(tensordict)
td_out.set("loss_critic", loss_critic)
td_out = td_out.named_apply(
lambda name, value: _reduce(value, reduction=self.reduction)
lambda name, value: _reduce(value, reduction=self.reduction).squeeze(-1)
if name.startswith("loss_")
else value,
batch_size=[],
Expand Down
13 changes: 7 additions & 6 deletions torchrl/objectives/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from __future__ import annotations

import contextlib
import functools

import math
import warnings
Expand Down Expand Up @@ -560,10 +559,12 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
if self.critic_coef:
loss_critic = self.loss_critic(tensordict)
td_out.set("loss_critic", loss_critic)
td_out = td_out.apply(
functools.partial(_reduce, reduction=self.reduction), batch_size=[]
td_out = td_out.named_apply(
lambda name, value: _reduce(value, reduction=self.reduction).squeeze(-1)
if name.startswith("loss_")
else value,
batch_size=[],
)

return td_out

def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams):
Expand Down Expand Up @@ -807,7 +808,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:

td_out.set("ESS", _reduce(ess, self.reduction) / batch)
td_out = td_out.named_apply(
lambda name, value: _reduce(value, reduction=self.reduction)
lambda name, value: _reduce(value, reduction=self.reduction).squeeze(-1)
if name.startswith("loss_")
else value,
batch_size=[],
Expand Down Expand Up @@ -1070,7 +1071,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
loss_critic = self.loss_critic(tensordict_copy)
td_out.set("loss_critic", loss_critic)
td_out = td_out.named_apply(
lambda name, value: _reduce(value, reduction=self.reduction)
lambda name, value: _reduce(value, reduction=self.reduction).squeeze(-1)
if name.startswith("loss_")
else value,
batch_size=[],
Expand Down
2 changes: 1 addition & 1 deletion torchrl/objectives/reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:

td_out.set("loss_value", self.loss_critic(tensordict))
td_out = td_out.named_apply(
lambda name, value: _reduce(value, reduction=self.reduction)
lambda name, value: _reduce(value, reduction=self.reduction).squeeze(-1)
if name.startswith("loss_")
else value,
batch_size=[],
Expand Down
Loading