Skip to content

Conversation

@yeandy
Copy link

@yeandy yeandy commented Jun 26, 2023

What does this PR do?

Mimics #9006, but for Flax.

We want to match how PyTorch's logic accounts for group_size and num_beam_groups here and here

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@yeandy yeandy changed the title Add flax diverse group search Add Flax diverse group search Jun 26, 2023
@sgugger
Copy link
Collaborator

sgugger commented Jun 27, 2023

cc @sanchit-gandhi

Copy link
Contributor

@sanchit-gandhi sanchit-gandhi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This look promising already @yeandy! Left some comments regarding the design below. In addition, could we add a few tests to confirm that:

  1. Group beam search runs when we call model.generate
  2. That group beam search is jit'able
  3. And that we get equivalence with PyTorch

trace: bool = True,
params: Optional[Dict[str, jnp.ndarray]] = None,
num_return_sequences: Optional[int] = None,
num_beam_groups: Optional[int] = 1,
Copy link
Contributor

@sanchit-gandhi sanchit-gandhi Jun 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In PyTorch we define a separate beam search method for group beam search:

def group_beam_search(

We only trigger this method if num_beam_groups>1:

is_group_beam_gen_mode = (

My opinion is that we should have a separate group beam search method in Flax as well, rather than adding to the existing one. IMO this is cleaner for the reader and more compartmentalised for building on

cc @gante as well for Flax generate design decision

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @sanchit-gandhi!

My first commit was to get a prototype working for num_beam_groups=1. I intend to refactor the beam search logic to make sure it works for other num_beam_groups sizes.

  1. Will do.
  2. My current logic is jittable, as I've been doing some testing from this example. Are there test in the HF repo that explicitly test whether a function is jittable? Or is sufficient to have an E2E test jits the function?
  3. Will do.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My opinion is that we should have a separate group beam search method in Flax as well, rather than adding to the existing one.

+1 :)

(btw, there was a recent bugfix on the PT side, might be relevant here)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, sounds good @yeandy! Excited to see how this pans out!

add_penalty = ~did_topk_just_finished | beams_in_batch_are_full
topk_log_probs += add_penalty * np.array(-1.0e7)

# Add additional logic for diverse beam search
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! My only nit is that we try and avoid lambda functions in transformers - would you be able to re-write these as standard function definitions please?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lovely thanks

@github-actions
Copy link
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@sanchit-gandhi
Copy link
Contributor

Hey @yeandy! This PR is looking in good shape - thanks for your efforts so far! Would you like to go all the way and see it to completion? Happy to help with the remainder of the integration!

@yeandy
Copy link
Author

yeandy commented Aug 3, 2023

Hey @sanchit-gandhi. Due to other commitments, I currently don't have bandwidth to continue this. And the timeline for me to get to this unknown right now. If someone else wants to work on this, I'm ok with that.

@sanchit-gandhi
Copy link
Contributor

Thanks for letting me know @yeandy! Best of luck with your other commitments, I hope they go well 🤗 Opening this one up to the community to complete!

@sanchit-gandhi sanchit-gandhi linked an issue Aug 7, 2023 that may be closed by this pull request
@huggingface huggingface deleted a comment from github-actions bot Sep 1, 2023
@sanchit-gandhi sanchit-gandhi added the WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress label Sep 1, 2023
@sanchit-gandhi sanchit-gandhi changed the title Add Flax diverse group search [WIP] Add Flax diverse group search Sep 1, 2023
@yipkingster
Copy link

For those who wonder what the status is for this PR, it seems all TF/Flax support has been deprecated. So this PR is no longer in scope.

@Rocketknight1
Copy link
Member

Yes, this should have been closed long ago!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add Flax diverse group search

6 participants