Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Alteration to default summation used to define TD loss of ensemble Q functions #421

Open
joshuaspear opened this issue Sep 24, 2024 · 1 comment
Labels
enhancement New feature or request

Comments

@joshuaspear
Copy link
Contributor

Is your feature request related to a problem? Please describe.
I have a feeling that the reduction used for computing the TD error over ensembles of Q functions should be a mean rather than a sum.
(https://github.com/takuseno/d3rlpy/blob/4b54bdde93d19f3915f3236367a5ec253ef99cee/d3rlpy/models/torch/q_functions/ensemble_q_function.py#L105C9-L105C30)

My understanding of the ensemble Q function is that the critic loss is backpropped to all the constituent Q-functions equally, as a result, I feel that a mean reduction would help control divergent networks and enable larger learning rates to be used. Experimentally, I have found the need to use smaller learning rates with n_critics>1 as, I believe that, when a td loss is dominated by a divergent network, the high magnitude updates caused by the TD loss sum can cause the other networks to also become divergent.

I'd more than happily be pushed back on this - FYI I couldn't find any literature on the use of ensemble Q-networks in this manor.

Describe the solution you'd like

assert target.ndim == 2
td_sum = torch.tensor(
    0.0,
    dtype=torch.float32,
    device=get_device(observations),
)
for forwarder in forwarders:
    loss = forwarder.compute_error(
        observations=observations,
        actions=actions,
        rewards=rewards,
        target=target,
        terminals=terminals,
        gamma=gamma,
        reduction="none",
    )
    td_sum += loss.mean()
td_sum = td_sum/len(forwarders)
return td_sum
@takuseno
Copy link
Owner

Thanks for the issue. First of all, I intentionally use the sum of TD loss over ensemble networks. This is because addition/subtraction doesn't interfere gradients each other as its definition. The advantage of sum instead of mean is that the gradient scale is not affected by the number of critics. If it's mean, the gradient propagated to individual networks becomes smaller proportionally to the number of critics, which requires additional learning rate tuning.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants