-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add domain modules trained End-to-End #149
Conversation
Once this is merged all we have to do on our personal projects is add that raw input parameter in the function calls and also the first arg for loss.step() ? Shouldn't be too hard |
I don't think loss.step will be a problem, but yes for the domain modules loss function definitions. |
Should be all good now. |
86549bb
to
7f66426
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, yes.
12c1a58
to
9976918
Compare
… during GW training
a411a87
to
3de6cfb
Compare
After some internal discussion, we removed the End2EndDomainModule in favor of the
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
goood
Allow for some domain modules to be trained end-to-end with the global workspace.
This brings some breaking changes:
DomainModule.compute_loss
andDomainModule.compute_*_loss
now require an 3rdparameter
raw_target: Any
that stores the raw domain input (before being encoded).This is usefull for unimodal losses that require the actual inputs to compute the loss.
GWLossesBase.step
requires a new first argumentraw_data: RawDomainGroupsT
topass the
raw_targets
to the domain modules.DomainModule
(every project).