-
Notifications
You must be signed in to change notification settings - Fork 86
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Expose weight initialization bounds for LayerNorm, Projection, Conv2D, Conv3D #926
base: main
Are you sure you want to change the base?
Conversation
src/fairseq2/models/jepa/factory.py
Outdated
init_module(proj, std=init_std) | ||
|
||
with torch.no_grad(): | ||
proj.weight.div_(math.sqrt(2.0 * layer_idx)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this accurate? In the reference implementation, I see that the scaling is done with layer_idx + 1
instead of layer_idx
. https://github.com/facebookresearch/jepa/blob/main/src/models/vision_transformer.py#L150
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch , yes this was the mistake, thanks @cbalioglu
src/fairseq2/nn/utils/module.py
Outdated
@@ -570,3 +571,44 @@ def get_module_size(module: Module) -> ModuleSizeInfo: | |||
info.total_size_bytes += size_bytes | |||
|
|||
return info | |||
|
|||
|
|||
def normalize_truncate( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can use PyTorch's trunc_normal_ instead of this function. This is also noted in the reference implementation here: https://github.com/facebookresearch/jepa/blob/main/src/utils/tensors.py#L18-L19
src/fairseq2/nn/utils/module.py
Outdated
tensor.clamp_(min=a, max=b) | ||
|
||
|
||
def init_truncated_uniforma_weights_and_bias( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would prefer this function to be within JEPA's factory.py instead of this file which is meant typically for much more generic module helper functions.
src/fairseq2/models/jepa/factory.py
Outdated
init_module(proj, std=init_std) | ||
|
||
with torch.no_grad(): | ||
proj.weight.div_(math.sqrt(2.0 * layer_idx)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that I performed this change to make sure that secondary reset_parameters()
calls result in identical weight initialization.
@@ -373,6 +370,13 @@ def init_projection(proj: Linear) -> None: | |||
dtype=self._dtype, | |||
) | |||
|
|||
# rescale the last layer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remnant from old commits?
What does this PR do? Please describe:
Most of fairseq2 Modules runs the standard (Xavier) initialization function, or uniform / constant weights initialization.
Sometimes users want to experiment with different algorithms too, for example when they want to manually set the boundaries of the weights. This was the case for the JEPA model.
This PR add parameters
init_fn
to the common Module (Projection, TransfomerEncoderLayer, LayerNorm)Fixes #{issue number}
Does your PR introduce any breaking changes? If yes, please list them:
List of all backwards-incompatible changes.
Check list: