Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

conditional gfn #188

Open
wants to merge 21 commits into
base: master
Choose a base branch
from
Open

conditional gfn #188

wants to merge 21 commits into from

Conversation

josephdviviano
Copy link
Collaborator

@josephdviviano josephdviviano commented Sep 25, 2024

Supports conditioning on a tensor of shape=[n_trajectories, n_cond_dims]. This is passed by the user during a call to the sampler.

Implemented for all GFlowNets. Note that the current version expects a particular kind of estimator. I can imagine this will lead to future changes - e.g., we should have some Estimators which expect huggingface models, so we can use them to produce conditioning vectors / to initialize the policy (this will obviously be a future PR).

Note that the conditioning is useless in my example, we should have a better use-case envisioned for the demo. The demo currently is not complete for all GFlowNet types.

@josephdviviano josephdviviano added the enhancement New feature or request label Sep 25, 2024
@josephdviviano josephdviviano self-assigned this Sep 25, 2024
@josephdviviano
Copy link
Collaborator Author

Don't worry about the tests - they should be easy to fix.

I can make the chances for DB, Sub-TB, and FM pretty easily if we agree this is a good approach, before a proper review.


or

$s \mapsto (P_B(s' \mid s, c))_{s' \in Parents(s)}$.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might be worth mentioning that this is a s very specific conditioning use-case, where the condition is encoded separately, and embeddings are concatenated.

I don't think we can do a generic one, but this should be enough as an example !

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What other conditioning approaches would be worth including? Cross attention?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general I would think the conditioning should be embedded / encoded separately --- or would the conditioning just need to be concatenated to the state before input? I could add support for that.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think there is an exhaustive list of ways we can process the condition. What you have is great as an example. I suggest you just add a comment or doc that the user might want to write their own module

@@ -68,7 +67,28 @@ def sample_actions(
the sampled actions under the probability distribution of the given
states.
"""
estimator_output = self.estimator(states)
# TODO: Should estimators instead ignore None for the conditioning vector?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wouldn't it be cleaner with fewer if else blocks ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes there's a bit of cruft with all the if-else blocks, but as it stands an estimator can either accept one or two arguments and I think it's good if it fails noisily... what do you think?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok ! makes sense.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added these exception_handlers to reduce the cruft.

@saleml
Copy link
Collaborator

saleml commented Sep 25, 2024

LGTM! Looking forward to test this feature

@josephdviviano josephdviviano marked this pull request as ready for review October 1, 2024 16:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants