Skip to content

[WIP] Fix confusion on Gemma #121

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 27, 2024
Merged

[WIP] Fix confusion on Gemma #121

merged 3 commits into from
Aug 27, 2024

Conversation

yundai424
Copy link
Collaborator

@yundai424 yundai424 commented Aug 27, 2024

Summary

why

  1. From config point of view, Gemma1 is doing exact GeLU [ref] but gemma 1.1 and 2 are doing approximate gelu (gemma2, gemma1)
  2. also, gemma uses hidden_activation config field and hidden_act is ignored (gemma1 code, gemma2 code)

That being said, we should be fine to claim that all of gemma 1, 1.1 and 2 are supported. But for safety i think we can first go with 1.1 and 2 first

what

  1. adjust the monkey patch so it works properly with gemma2 code base too
  2. checkstyle
  3. [TODO] add final logit softcapping for gemma2 https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma2/modeling_gemma2.py#L1054

Testing Done

  • Hardware Type:
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

@yundai424 yundai424 marked this pull request as ready for review August 27, 2024 17:39
@yundai424 yundai424 changed the title Fix confusion on Gemma [WIP] Fix confusion on Gemma Aug 27, 2024
Copy link
Collaborator

@qingquansong qingquansong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for the fix! Wondering if we could relax the checking here so we can also use Gemma1 ?

@lancerts lancerts merged commit 3d3b604 into main Aug 27, 2024
1 check passed
@lancerts lancerts deleted the yudai/gemma branch August 27, 2024 18:04
if fused_linear_cross_entropy:
modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward
modeling_gemma2.Gemma2ForCausalLM.forward = gemma_lce_forward

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like the same recently added lce_forward is being used for both Gemma and Gemma2 (#111).

It appears that there is a slight difference in the forward between Gemma and Gemma2 in the modeling code in transformers, specifically regarding logit softcapping (https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma2/modeling_gemma2.py#L1054-L1057), which doesn't seem to be accounted for in the new lce code as far as I can tell. Would this potentially lead to incompatibility?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah exactly, just figured out we need to add softcapping :/ i'll have a follow up PR to clarify that right now this only works for gamma 1 and 1.1 and we'll work on changing the kernel correspondingly asap

yundai424 added a commit that referenced this pull request Aug 27, 2024
yundai424 added a commit that referenced this pull request Aug 27, 2024
This reverts commit 3d3b604.

## Summary
<!--- This is a required section; please describe the main purpose of
this proposed code change. --->

<!---
## Details
This is an optional section; is there anything specific that reviewers
should be aware of?
--->

## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->

<!-- 
Replace BLANK with your device type. For example, A100-80G-PCIe

Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them. 
-->

- Hardware Type: <BLANK>
- [ ] run `make test` to ensure correctness
- [ ] run `make checkstyle` to ensure code style
- [ ] run `make test-convergence` to ensure convergence
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants