-
Notifications
You must be signed in to change notification settings - Fork 25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
conditional gfn #188
base: master
Are you sure you want to change the base?
conditional gfn #188
Conversation
…ionally contains a tensor of conditioning vectors (one per trajectory)
…itioning into PB and PF computation
…ule can now accept raw tensors
…bute of the trajectory
Don't worry about the tests - they should be easy to fix. I can make the chances for DB, Sub-TB, and FM pretty easily if we agree this is a good approach, before a proper review. |
|
||
or | ||
|
||
$s \mapsto (P_B(s' \mid s, c))_{s' \in Parents(s)}$. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
might be worth mentioning that this is a s very specific conditioning use-case, where the condition is encoded separately, and embeddings are concatenated.
I don't think we can do a generic one, but this should be enough as an example !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What other conditioning approaches would be worth including? Cross attention?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general I would think the conditioning should be embedded / encoded separately --- or would the conditioning just need to be concatenated to the state before input? I could add support for that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think there is an exhaustive list of ways we can process the condition. What you have is great as an example. I suggest you just add a comment or doc that the user might want to write their own module
@@ -68,7 +67,28 @@ def sample_actions( | |||
the sampled actions under the probability distribution of the given | |||
states. | |||
""" | |||
estimator_output = self.estimator(states) | |||
# TODO: Should estimators instead ignore None for the conditioning vector? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wouldn't it be cleaner with fewer if else blocks ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes there's a bit of cruft with all the if-else blocks, but as it stands an estimator can either accept one or two arguments and I think it's good if it fails noisily... what do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok ! makes sense.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added these exception_handlers
to reduce the cruft.
LGTM! Looking forward to test this feature |
Supports conditioning on a tensor of
shape=[n_trajectories, n_cond_dims]
. This is passed by the user during a call to the sampler.Implemented for all GFlowNets. Note that the current version expects a particular kind of estimator. I can imagine this will lead to future changes - e.g., we should have some
Estimators
which expect huggingface models, so we can use them to produce conditioning vectors / to initialize the policy (this will obviously be a future PR).Note that the conditioning is useless in my example, we should have a better use-case envisioned for the demo. The demo currently is not complete for all GFlowNet types.