Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Quality of life and helper callback functions #237

Merged
merged 32 commits into from
Jul 1, 2024

Conversation

laserkelvin
Copy link
Collaborator

@laserkelvin laserkelvin commented Jun 7, 2024

This PR introduces and adds a bunch of changes pertaining to informing the user of things happening under the hood, particularly during training.

One of the big philosophical changes is also focusing more on enabling logging to be done with TensorBoardLogger and WandbLoggers by writing functions more tailored to them, rather than before where loggers were treated in the abstract entirely.

Summary

  • Changed the use of coordinates in periodic boundary utilities to use cartesian coordinates, not fractional coordinates. Also included a warning message that looks at the coordinates as part of diagnostics.
  • For model training, task modules now include log_embeddings and log_embeddings_every_n_steps arguments that are saved to hparams, which as the pair suggests, allow you to regularly log embedding vectors for analysis. This will let you ensure oversmoothing doesn't occur, where all of the embedding features become identical.
  • Introduced a TrainingHelperCallback, which is intended to help diagnose some common issues with training, such as unused parameters, missing gradients, tiny gradients, etc. Complimentary to the change above, there is an option to inject a forward hook to any encoder (assuming it produces an Embeddings structure), and uses it to calculate the variance in embeddings.
  • Introduced a ModelAutocorrelation callback, which will perform an autocorrelation analysis on model parameters and gradients over the course of training. Basically this gives you some insight into how the training dynamics appear, i.e. too much correlation = probably not good.

My intention for the TrainingHelperCallback is to be like a guide for best practices: we can refine this as we go and discover new things, and hopefully will be useful for everyone including new users.

@laserkelvin laserkelvin added ux User experience, quality of life changes training Issues related to model training labels Jun 7, 2024
@laserkelvin
Copy link
Collaborator Author

I have somehow broken SAM and need to fix it first before review

@laserkelvin
Copy link
Collaborator Author

I think I have a lead on what the issue is: because of how SAM works, and because of the modifications to "stashing" embeddings in the batch structure, we now end up with two disjoint computational graphs that causes backward to break.

This needs a bit of thought to fix...

@laserkelvin
Copy link
Collaborator Author

laserkelvin commented Jun 10, 2024

Confirming this by changing out the BaseTaskModule.forward:

        if "embeddings" in batch:
            embeddings = batch.get("embeddings")
        else:
            embeddings = self.encoder(batch)
            batch["embeddings"] = embeddings
        outputs = self.process_embedding(embeddings)
        return outputs

Removing the branch, and just running the encoder + processing embeddings works (i.e. don't try and grab cached embeddings).

Ideally there would be a way to check if embeddings originated from the same computational graph, but that take a lot more surgery than this PR warrants. I'll think of an alternative to this.

The reason we are stashing the embeddings is to benefit the multitask case, where we would want to not have to run the encoder X times for X tasks and datasets.

Copy link
Collaborator

@melo-gonzo melo-gonzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some great features here, thanks for doing all of this! I threw in a few comments, I know you're still working on things.

matsciml/lightning/callbacks.py Show resolved Hide resolved
matsciml/lightning/callbacks.py Show resolved Hide resolved
examples/callbacks/autocorrelation.py Show resolved Hide resolved
matsciml/lightning/callbacks.py Show resolved Hide resolved
matsciml/lightning/callbacks.py Outdated Show resolved Hide resolved
Signed-off-by: Lee, Kin Long Kelvin <[email protected]>
Signed-off-by: Lee, Kin Long Kelvin <[email protected]>
That way we don't do a double log as forward might be called multiple times
Copy link
Collaborator

@melo-gonzo melo-gonzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will bring some great utilities and helpful debugging tools! Looks good to merge.

@laserkelvin laserkelvin merged commit 0e3a640 into IntelLabs:main Jul 1, 2024
4 of 5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
training Issues related to model training ux User experience, quality of life changes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants