-
Notifications
You must be signed in to change notification settings - Fork 31.7k
⚠️ [CLAP] Fix dtype of logit scales in init #25682
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| text_config = config.text_config | ||
| audio_config = config.audio_config | ||
|
|
||
| self.logit_scale_a = nn.Parameter(torch.tensor(np.log(config.logit_scale_init_value))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The aforementioned behaviour is a result of the np.log operation defaulting to float64
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given the original code we might need to init in float64 then cast to float if it makes a difference. No idea if the actual value save is in float64!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The parameters are initialised in float64 but are stored in float32 in the state dict
younesbelkada
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
ArthurZucker
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As mentioned offline, never used in the original repo. Is a bit breaking but it is a bug fix. Let's just add one
|
The documentation is not available anymore as the PR was closed or merged. |
|
Note that in the original repo, the model is always cast to float16 for all training / inference. Thus, they likely never used the model in it's default dtype, and always relied on explicitly casting to float16 |
[CLAP] Fix dtype of logit scales
What does this PR do?
The dtype of the CLAP logit scale parameters was always float64 by default (even if the rest of the model was initialised in float32). This PR fixes the logit scales, such that they respect the default dtype of the model.