-
Notifications
You must be signed in to change notification settings - Fork 399
Add support for GPT-OSS models #850
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Many thanks for the contribution. |
Thanks for the contribution! Any benchmarking results for the GPT-OSS models? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the contribution! What is the discrepancy you find for bf16 convergence tests?
@@ -1673,3 +1682,41 @@ def test_apply_liger_kernel_to_instance_for_smollm3(): | |||
print(dummy_model_instance) | |||
except Exception as e: | |||
pytest.fail(f"An exception occured in extra_expr: {type(e).__name__} - {e}") | |||
|
|||
|
|||
@pytest.mark.skipif(not is_qwen3_available(), reason="gpt oss module not available") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be is_gpt_oss_available?
cross_entropy: bool = False, | ||
fused_linear_cross_entropy: bool = True, | ||
rms_norm: bool = True, | ||
swiglu: bool = True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if swiglu can't be implemented now, let's set to False by default and raise NotImplementedError if set to True
@PKUWZP I'll add the benchmark results as soon as the swiglu implementation is complete. |
@shimizust During the convergence test, the loss values for the two models running in bf16 diverged significantly at certain steps. This is likely related to the issue discussed here: #742. ![]() |
## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> This PR is a follow-up to #830, exposing `accum_dtype` option for monkey patch functions. All bf16 convergence tests related to fused linear cross entropy are also enforced to run with `accum_dtype=torch.float32` for numerical stability. Related: #512, #742, #827, #850 <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: <BLANK> - [ ] run `make test` to ensure correctness - [ ] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence Signed-off-by: Tcc0403 <[email protected]>
For future contributors: When patching RMSNorm, there are 4 init args that are easily overlooked.
Take GptOss for example It requires
Create |
@Comet0322 Thanks for updating. Do you think we can check in what you have and figure out swiglu after? After @Tcc0403 's accum_dtype changes to the bf16 tests, do they pass now? |
@Comet0322 and I have discussed elsewhere. It seems the router topk choice for experts is quite sensitive in bf16, leading the discrepancy in final losses. I suggest just skipping bf16 convergnece test for now. cc @shimizust |
Summary
Add GPT-OSS model support, addressing #848
Completed patching for RoPE, RMSNorm, cross_entropy, and fused_linear_cross_entropy.
Known Issues
Testing Done
FP32 Log
BF16 Log
Env: torch 2.8.0, triton 3.4.0, transformers 4.55.0
make test
to ensure correctnessmake checkstyle
to ensure code stylemake test-convergence
to ensure convergence