-
Notifications
You must be signed in to change notification settings - Fork 7.2k
Description
🚀 The feature
Some models currently accept normalization strategies as callables (mobilenet_backbone accepts a norm_layer argument for example) but loss functions are currently hardcoded (F.cross_entropy for fastercnn.roi_heads for example).
Following what has been done on normalization strategies loss function could be passed as callables in the modules constructor. This shouldn't break any backward compatibility. Reduction strategies still need to be properly handled.
Motivation, pitch
Currently, trying different loss functions requires to use some dirty model patches. Accepting the losses in the model constructors would provide a much cleaner way to hack around the models.
If any interest I can propose a first PR modifying the Faster-RNN models.
Alternatives
No response
Additional context
No response
cc @datumbox