Skip to content

Commit

Permalink
Add hard check that _fast_vec_gae is faster than original implementation
Browse files Browse the repository at this point in the history
In case gamma and lmbda are scalars, `fast_vec_gae` should be always faster than
`vec_generalized_advantage_estimate` if len(T) is large enough.
  • Loading branch information
Blonck committed May 9, 2023
1 parent 0c6dc55 commit 245e68f
Showing 1 changed file with 62 additions and 0 deletions.
62 changes: 62 additions & 0 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.

import argparse
import timeit
import warnings
from copy import deepcopy

Expand Down Expand Up @@ -97,6 +98,7 @@
)
from torchrl.objectives.value.advantages import GAE, TD1Estimator, TDLambdaEstimator
from torchrl.objectives.value.functional import (
_fast_vec_gae,
generalized_advantage_estimate,
td0_advantage_estimate,
td1_advantage_estimate,
Expand Down Expand Up @@ -4266,6 +4268,66 @@ def test_tdlambda_tensor_gamma(self, device, gamma, lmbda, N, T, has_done):

torch.testing.assert_close(v1, v2, rtol=1e-4, atol=1e-4)

def test_runtime_fast_gae():
"""Tests that the runtime performance of different GAE calculations.
If gamma and lmbda are scalars the _fast_vec_gae version should be faster than `vec_generalized_advantage_estimate."""
torch.manual_seed(0)

B = 32
T = 512
D = 1
gamma = 0.95
lmbda = 0.95

device = "cuda:0" if torch.cuda.device_count() else "cpu"
state_value = torch.randn(B, T, D, device=device)
next_state_value = torch.randn(B, T, D, device=device)
reward = torch.randn(B, T, D, device=device)
done = torch.zeros(B, T, D, dtype=torch.bool, device=device).bernoulli_(0.1)
time_dim = -2
number = 500

globals = {
"vec_generalized_advantage_estimate": vec_generalized_advantage_estimate,
"_fast_vec_gae": _fast_vec_gae,
"reward": reward,
"state_value": state_value,
"next_state_value": next_state_value,
"done": done,
"gamma": gamma,
"lmbda": lmbda,
"time_dim": time_dim,
}

time_fast_gae = timeit.timeit(
"""generalized_advantage_estimate(
gamma=gamma,
lmbda=lmbda,
state_value=state_value,
next_state_value=next_state_value,
reward=reward,
done=done,
time_dim=time_dim)""",
globals=globals,
number=number,
)

time_vec_gae = timeit.timeit(
"""vec_generalized_advantage_estimate(
gamma=gamma,
lmbda=lmbda,
state_value=state_value,
next_state_value=next_state_value,
reward=reward,
done=done,
time_dim=time_dim)""",
globals=globals,
number=number,
)

assert time_fast_gae < time_vec_gae

@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("gamma", [0.5, 0.99, 0.1])
@pytest.mark.parametrize("N", [(3,), (7, 3)])
Expand Down

0 comments on commit 245e68f

Please sign in to comment.