Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add domain modules trained End-to-End #149

Merged
merged 9 commits into from
Oct 4, 2024
Merged

Conversation

bdvllrs
Copy link
Collaborator

@bdvllrs bdvllrs commented Sep 19, 2024

Allow for some domain modules to be trained end-to-end with the global workspace.

This brings some breaking changes:

  1. DomainModule.compute_loss and DomainModule.compute_*_loss now require an 3rd
    parameter raw_target: Any that stores the raw domain input (before being encoded).
    This is usefull for unimodal losses that require the actual inputs to compute the loss.
  2. GWLossesBase.step requires a new first argument raw_data: RawDomainGroupsT to
    pass the raw_targets to the domain modules.
  1. needs to be changed in all projects that implement a DomainModule (every project).
  2. has probably less impact as most project won't redefine a Loss module.

@RolandBERTINJOHANNET
Copy link
Collaborator

Once this is merged all we have to do on our personal projects is add that raw input parameter in the function calls and also the first arg for loss.step() ?

Shouldn't be too hard

@bdvllrs
Copy link
Collaborator Author

bdvllrs commented Sep 19, 2024

I don't think loss.step will be a problem, but yes for the domain modules loss function definitions.
I'm still missing to update the shimmer tutorial to reflect this change.

@bdvllrs
Copy link
Collaborator Author

bdvllrs commented Sep 19, 2024

Should be all good now.

Copy link
Collaborator

@RolandBERTINJOHANNET RolandBERTINJOHANNET left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, yes.

@bdvllrs bdvllrs force-pushed the end_to_end_domain_modules branch from 12c1a58 to 9976918 Compare September 19, 2024 10:13
@bdvllrs bdvllrs force-pushed the end_to_end_domain_modules branch from a411a87 to 3de6cfb Compare October 3, 2024 15:54
@bdvllrs
Copy link
Collaborator Author

bdvllrs commented Oct 3, 2024

After some internal discussion, we removed the End2EndDomainModule in favor of the freeze and unfreeze methods.

DomainModule stay frozen by default.

Copy link
Collaborator

@RolandBERTINJOHANNET RolandBERTINJOHANNET left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

goood

@bdvllrs bdvllrs merged commit 4e011e7 into main Oct 4, 2024
2 checks passed
@bdvllrs bdvllrs deleted the end_to_end_domain_modules branch October 4, 2024 13:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants