Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose weight initialization bounds for LayerNorm, Projection, Conv2D, Conv3D #926

Open
wants to merge 10 commits into
base: main
Choose a base branch
from

Conversation

antoine-tran
Copy link
Contributor

@antoine-tran antoine-tran commented Dec 19, 2024

What does this PR do? Please describe:

Most of fairseq2 Modules runs the standard (Xavier) initialization function, or uniform / constant weights initialization.

Sometimes users want to experiment with different algorithms too, for example when they want to manually set the boundaries of the weights. This was the case for the JEPA model.

This PR add parameters init_fn to the common Module (Projection, TransfomerEncoderLayer, LayerNorm)

Fixes #{issue number}

Does your PR introduce any breaking changes? If yes, please list them:
List of all backwards-incompatible changes.

Check list:

  • Was the content of this PR discussed and approved via a GitHub issue? (no need for typos or documentation improvements)
  • Did you read the contributor guideline?
  • Did you make sure that your PR does only one thing instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests?
  • Did you verify new and existing tests pass locally with your changes?
  • Did you update the CHANGELOG? (no need for typos, documentation, or minor internal changes)

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 19, 2024
@antoine-tran antoine-tran changed the title add init function to the builders Add init function to the builders Dec 19, 2024
init_module(proj, std=init_std)

with torch.no_grad():
proj.weight.div_(math.sqrt(2.0 * layer_idx))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this accurate? In the reference implementation, I see that the scaling is done with layer_idx + 1 instead of layer_idx. https://github.com/facebookresearch/jepa/blob/main/src/models/vision_transformer.py#L150

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch , yes this was the mistake, thanks @cbalioglu

@@ -570,3 +571,44 @@ def get_module_size(module: Module) -> ModuleSizeInfo:
info.total_size_bytes += size_bytes

return info


def normalize_truncate(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can use PyTorch's trunc_normal_ instead of this function. This is also noted in the reference implementation here: https://github.com/facebookresearch/jepa/blob/main/src/utils/tensors.py#L18-L19

tensor.clamp_(min=a, max=b)


def init_truncated_uniforma_weights_and_bias(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer this function to be within JEPA's factory.py instead of this file which is meant typically for much more generic module helper functions.

init_module(proj, std=init_std)

with torch.no_grad():
proj.weight.div_(math.sqrt(2.0 * layer_idx))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that I performed this change to make sure that secondary reset_parameters() calls result in identical weight initialization.

@@ -373,6 +370,13 @@ def init_projection(proj: Linear) -> None:
dtype=self._dtype,
)

# rescale the last layer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remnant from old commits?

@antoine-tran antoine-tran changed the title Add init function to the builders Expose weight initialization bounds for LayerNorm, Projection, Conv2D, Conv3D Dec 19, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants