Skip to content

Commit

Permalink
[Performance] Accelerate GAE (#1142)
Browse files Browse the repository at this point in the history
Co-authored-by: vmoens <[email protected]>
  • Loading branch information
Blonck and vmoens committed May 10, 2023
1 parent 5776258 commit 9402664
Show file tree
Hide file tree
Showing 4 changed files with 319 additions and 180 deletions.
37 changes: 37 additions & 0 deletions benchmarks/test_objectives_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,43 @@ def test_values(benchmark, val_fn, has_lmbda, has_state_value):
)


@pytest.mark.parametrize(
"gae_fn,gamma_tensor,batches,timesteps",
[
[generalized_advantage_estimate, False, 1, 512],
[vec_generalized_advantage_estimate, True, 1, 512],
[vec_generalized_advantage_estimate, False, 1, 512],
[vec_generalized_advantage_estimate, True, 32, 512],
[vec_generalized_advantage_estimate, False, 32, 512],
],
)
def test_gae_speed(benchmark, gae_fn, gamma_tensor, batches, timesteps):
size = (batches, timesteps, 1)
print(size)

torch.manual_seed(0)
device = "cuda:0" if torch.cuda.device_count() else "cpu"
values = torch.randn(*size, device=device)
next_values = torch.randn(*size, device=device)
reward = torch.randn(*size, device=device)
done = torch.zeros(*size, dtype=torch.bool, device=device).bernoulli_(0.1)

gamma = 0.99
if gamma_tensor:
gamma = torch.full(size, gamma)
lmbda = 0.95

benchmark(
gae_fn,
gamma=gamma,
lmbda=lmbda,
state_value=values,
next_state_value=next_values,
reward=reward,
done=done,
)


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
50 changes: 49 additions & 1 deletion test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -4076,7 +4076,7 @@ def test_td1_multi(
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("gamma", [0.99, 0.5, 0.1])
@pytest.mark.parametrize("lmbda", [0.99, 0.5, 0.1])
@pytest.mark.parametrize("N", [(3,), (7, 3)])
@pytest.mark.parametrize("N", [(1,), (3,), (7, 3)])
@pytest.mark.parametrize("T", [200, 5, 3])
@pytest.mark.parametrize("dtype", [torch.float, torch.double])
@pytest.mark.parametrize("has_done", [True, False])
Expand All @@ -4098,6 +4098,54 @@ def test_gae(self, device, gamma, lmbda, N, T, dtype, has_done):
)
torch.testing.assert_close(r1, r2, rtol=1e-4, atol=1e-4)

@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("N", [(1,), (8,), (7, 3)])
@pytest.mark.parametrize("dtype", [torch.float, torch.double])
@pytest.mark.parametrize("has_done", [True, False])
@pytest.mark.parametrize(
"gamma_tensor", ["scalar", "tensor", "tensor_single_element"]
)
@pytest.mark.parametrize(
"lmbda_tensor", ["scalar", "tensor", "tensor_single_element"]
)
def test_gae_param_as_tensor(
self, device, N, dtype, has_done, gamma_tensor, lmbda_tensor
):
torch.manual_seed(0)

gamma = 0.95
lmbda = 0.90
T = 200

done = torch.zeros(*N, T, 1, device=device, dtype=torch.bool)
if has_done:
done = done.bernoulli_(0.1)
reward = torch.randn(*N, T, 1, device=device, dtype=dtype)
state_value = torch.randn(*N, T, 1, device=device, dtype=dtype)
next_state_value = torch.randn(*N, T, 1, device=device, dtype=dtype)

if gamma_tensor == "tensor":
gamma_vec = torch.full_like(reward, gamma)
elif gamma_tensor == "tensor_single_element":
gamma_vec = torch.as_tensor([gamma], device=device)
else:
gamma_vec = gamma

if lmbda_tensor == "tensor":
lmbda_vec = torch.full_like(reward, lmbda)
elif gamma_tensor == "tensor_single_element":
lmbda_vec = torch.as_tensor([lmbda], device=device)
else:
lmbda_vec = lmbda

r1 = vec_generalized_advantage_estimate(
gamma_vec, lmbda_vec, state_value, next_state_value, reward, done
)
r2 = generalized_advantage_estimate(
gamma, lmbda, state_value, next_state_value, reward, done
)
torch.testing.assert_close(r1, r2, rtol=1e-4, atol=1e-4)

@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("gamma", [0.99, 0.5, 0.1])
@pytest.mark.parametrize("lmbda", [0.99, 0.5, 0.1])
Expand Down
Loading

0 comments on commit 9402664

Please sign in to comment.