You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
My understanding of the ensemble Q function is that the critic loss is backpropped to all the constituent Q-functions equally, as a result, I feel that a mean reduction would help control divergent networks and enable larger learning rates to be used. Experimentally, I have found the need to use smaller learning rates with n_critics>1 as, I believe that, when a td loss is dominated by a divergent network, the high magnitude updates caused by the TD loss sum can cause the other networks to also become divergent.
I'd more than happily be pushed back on this - FYI I couldn't find any literature on the use of ensemble Q-networks in this manor.
Thanks for the issue. First of all, I intentionally use the sum of TD loss over ensemble networks. This is because addition/subtraction doesn't interfere gradients each other as its definition. The advantage of sum instead of mean is that the gradient scale is not affected by the number of critics. If it's mean, the gradient propagated to individual networks becomes smaller proportionally to the number of critics, which requires additional learning rate tuning.
Is your feature request related to a problem? Please describe.
I have a feeling that the reduction used for computing the TD error over ensembles of Q functions should be a mean rather than a sum.
(https://github.com/takuseno/d3rlpy/blob/4b54bdde93d19f3915f3236367a5ec253ef99cee/d3rlpy/models/torch/q_functions/ensemble_q_function.py#L105C9-L105C30)
My understanding of the ensemble Q function is that the critic loss is backpropped to all the constituent Q-functions equally, as a result, I feel that a mean reduction would help control divergent networks and enable larger learning rates to be used. Experimentally, I have found the need to use smaller learning rates with n_critics>1 as, I believe that, when a td loss is dominated by a divergent network, the high magnitude updates caused by the TD loss sum can cause the other networks to also become divergent.
I'd more than happily be pushed back on this - FYI I couldn't find any literature on the use of ensemble Q-networks in this manor.
Describe the solution you'd like
The text was updated successfully, but these errors were encountered: