Skip to content

Conversation

Comet0322
Copy link
Contributor

@Comet0322 Comet0322 commented Aug 9, 2025

Summary

Add GPT-OSS model support, addressing #848
Completed patching for RoPE, RMSNorm, cross_entropy, and fused_linear_cross_entropy.

Known Issues

  • Gated SwiGLU Patching Support: The current Hugging Face implementation of gated SwiGLU in GptOssExperts makes patching difficult. This will be addressed in a future update.
  • GptOssExperts MXFP4 Format Support: MXFP4 tests are pending due to ongoing changes in the Hugging Face Transformers interface.
  • BF16 Convergence Issue: The BF16 convergence test is failing, while FP32 passes. This issue is under investigation.

Testing Done

FP32 Log
test/convergence/fp32/test_mini_models.py::test_mini_model[mini_gpt_oss-32-1e-05-dtype0-1e-08-1e-05-0.005-1e-05-0.005-1e-05] PASSED                        [100%]
test/convergence/fp32/test_mini_models_with_logits.py::test_mini_model[mini_gpt_oss-32-1e-05-dtype0-1e-08-1e-05-0.005-1e-05-0.005-1e-05] PASSED            [100%]
BF16 Log
pytest --disable-warnings test/convergence/bf16/test_mini_models.py 
====================================================================== test session starts =======================================================================
platform linux -- Python 3.10.18, pytest-8.4.1, pluggy-1.6.0
rootdir: /home/admin/Liger-Kernel
configfile: pyproject.toml
plugins: xdist-3.8.0, rerunfailures-15.1
collected 1 item                                                                                                                                                 

test/convergence/bf16/test_mini_models.py::test_mini_model[mini_gpt_oss-32-1e-05-dtype0-0.01-0.05-0.1-0.01-0.01-0.01] FAILED                               [100%]

============================================================================ FAILURES ============================================================================
___________________________________________ test_mini_model[mini_gpt_oss-32-1e-05-dtype0-0.01-0.05-0.1-0.01-0.01-0.01] ___________________________________________

model_name = 'mini_gpt_oss', num_steps = 32, lr = 1e-05, dtype = torch.bfloat16, loss_atol = 0.01, loss_rtol = 0.05, logprobs_atol = 0.1, logprobs_rtol = 0.01
param_atol = 0.01, param_rtol = 0.01

    @pytest.mark.parametrize(
        "model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logprobs_atol, logprobs_rtol, param_atol, param_rtol",
        [
            pytest.param(
                "mini_gpt_oss",
                32,
                1e-5,
                torch.bfloat16,
                1e-2,
                5e-2,
                1e-1,
                1e-2,
                1e-2,
                1e-2,
                marks=[
                    pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
                    pytest.mark.skipif(
                        not GPT_OSS_AVAILABLE,
                        reason="GPT OSS not available in this version of transformers",
                    ),
                ],
            ),
        ],
    )
    def test_mini_model(
        model_name,
        num_steps,
        lr,
        dtype,
        loss_atol,
        loss_rtol,
        logprobs_atol,
        logprobs_rtol,
        param_atol,
        param_rtol,
    ):
        # Non-liger models should be initialized and tested first to avoid the module being overridden
    
        expected_output = run_mini_model(model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr)
    
        actual_output = run_mini_model(model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr, with_liger=True)
    
        # Compare every step of the loss
>       assert_verbose_allclose(
            torch.tensor([expected_output["loss"]]),
            torch.tensor([actual_output["loss"]]),
            atol=loss_atol,
            rtol=loss_rtol,
            extra_info="[Loss]",
        )

test/convergence/bf16/test_mini_models.py:1395: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

tensor1 = tensor([[10.4809, 10.2822, 10.0886,  9.8527,  9.6104,  9.4217,  9.1856,  8.9703,
          8.7122,  8.5283,  8.2974,  ...  6.1561,  6.0330,
          6.9142,  5.7746,  5.6058,  5.5196,  5.4399,  5.1645,  5.2462,  4.9314,
          5.8588]])
tensor2 = tensor([[10.4806, 10.2860, 10.0742,  9.8525,  9.6143,  9.4222,  9.1932,  8.9775,
          8.7126,  8.5138,  8.2971,  ...  8.8191,  6.0038,
          7.8853,  5.7446,  5.5767,  5.5193,  5.4383,  5.1643,  5.2460,  4.9023,
          4.8585]])
rtol = 0.05, atol = 0.01, max_print = 5, extra_info = '[Loss]'

    def assert_verbose_allclose(tensor1, tensor2, rtol=1e-05, atol=1e-08, max_print=5, extra_info=""):
        """
        Assert that two tensors are element-wise equal within a tolerance, providing detailed information about mismatches.
    
        Parameters:
        tensor1 (torch.Tensor): First tensor to compare.
        tensor2 (torch.Tensor): Second tensor to compare.
        rtol (float): Relative tolerance.
        atol (float): Absolute tolerance.
        max_print (int): Maximum number of mismatched elements to print.
        extra_info (str): Extra information to show at the start of the error message.
    
        Raises:
        AssertionError: If the tensors are not all close within the given tolerance.
        """
        # Check if the shapes of the tensors match
        if tensor1.shape != tensor2.shape:
            raise AssertionError("Input tensors must have the same shape.")
    
        # Calculate the difference between the tensors
        diff = torch.abs(tensor1 - tensor2)
    
        # Determine the tolerance
        tolerance = atol + rtol * torch.abs(tensor2)
    
        # Find tolerance mismatched elements
        tol_mismatched = diff > tolerance
    
        # Find nan mismatched elements
        nan_mismatched = torch.logical_xor(torch.isnan(tensor1), torch.isnan(tensor2))
    
        # Find +inf mismatched elements
        posinf_mismatched = torch.logical_xor(torch.isposinf(tensor1), torch.isposinf(tensor2))
        # Find -inf mismatched elements
        neginf_mismatched = torch.logical_xor(torch.isneginf(tensor1), torch.isneginf(tensor2))
    
        # Find all mismatched elements
        mismatched = torch.logical_or(
            torch.logical_or(tol_mismatched, nan_mismatched),
            torch.logical_or(posinf_mismatched, neginf_mismatched),
        )
    
        mismatched_indices = torch.nonzero(mismatched)
    
        # Count the number of mismatched elements
        num_mismatched = mismatched.sum().item()
    
        # Check if all elements are close
        all_close = num_mismatched == 0
    
        # Raise AssertionError with detailed information if there are mismatches
        if not all_close and num_mismatched >= 1:
            mismatch_details = [f"Number of mismatched elements: {num_mismatched}"]
            print_count = min(max_print, num_mismatched)
            for index in mismatched_indices[:print_count]:
                i = tuple(index.tolist())
                mismatch_details.append(f"Mismatch at index {i}: tensor1[{i}] = {tensor1[i]}, tensor2[{i}] = {tensor2[i]}")
            if num_mismatched > max_print:
                mismatch_details.append(f"... and {num_mismatched - max_print} more mismatched elements.")
    
>           raise AssertionError(extra_info + "\n".join(mismatch_details))
E           AssertionError: [Loss]Number of mismatched elements: 2
E           Mismatch at index (0, 21): tensor1[(0, 21)] = 8.848180770874023, tensor2[(0, 21)] = 6.204418182373047
E           Mismatch at index (0, 22): tensor1[(0, 22)] = 6.156144142150879, tensor2[(0, 22)] = 8.819061279296875

test/utils.py:131: AssertionError
---------------------------------------------------------------------- Captured stdout call ----------------------------------------------------------------------
Liger kernel patches have been reverted.
Step 0, Loss: 10.480887413024902
Step 1, Loss: 10.282181739807129
Step 2, Loss: 10.088647842407227
Step 3, Loss: 9.852700233459473
Step 4, Loss: 9.610376358032227
Step 5, Loss: 9.421696662902832
Step 6, Loss: 9.185647010803223
Step 7, Loss: 8.970256805419922
Step 8, Loss: 8.712227821350098
Step 9, Loss: 8.528286933898926
Step 10, Loss: 8.297395706176758
Step 11, Loss: 8.022294998168945
Step 12, Loss: 7.8798322677612305
Step 13, Loss: 7.648087501525879
Step 14, Loss: 7.404232501983643
Step 15, Loss: 7.2665910720825195
Step 16, Loss: 9.697592735290527
Step 17, Loss: 9.604401588439941
Step 18, Loss: 9.46231746673584
Step 19, Loss: 6.633795738220215
Step 20, Loss: 9.081354141235352
Step 21, Loss: 8.848180770874023
Step 22, Loss: 6.156144142150879
Step 23, Loss: 6.032957077026367
Step 24, Loss: 5.914245128631592
Step 25, Loss: 5.774555206298828
Step 26, Loss: 5.605828285217285
Step 27, Loss: 5.519618988037109
Step 28, Loss: 5.439865589141846
Step 29, Loss: 5.164504528045654
Step 30, Loss: 5.246169567108154
Step 31, Loss: 4.93139123916626
Eval Loss: 4.858840465545654
Liger kernel patches have been reverted.
Step 0, Loss: 10.480640411376953
Step 1, Loss: 10.286001205444336
Step 2, Loss: 10.074193954467773
Step 3, Loss: 9.852497100830078
Step 4, Loss: 9.614309310913086
Step 5, Loss: 9.422234535217285
Step 6, Loss: 9.19322395324707
Step 7, Loss: 8.977506637573242
Step 8, Loss: 8.712628364562988
Step 9, Loss: 8.51380729675293
Step 10, Loss: 8.297117233276367
Step 11, Loss: 8.0232572555542
Step 12, Loss: 7.879528522491455
Step 13, Loss: 7.649257659912109
Step 14, Loss: 7.418290138244629
Step 15, Loss: 7.2662553787231445
Step 16, Loss: 9.697296142578125
Step 17, Loss: 9.61251449584961
Step 18, Loss: 9.455589294433594
Step 19, Loss: 6.634244918823242
Step 20, Loss: 9.089345932006836
Step 21, Loss: 6.204418182373047
Step 22, Loss: 8.819061279296875
Step 23, Loss: 6.003849029541016
Step 24, Loss: 5.8852691650390625
Step 25, Loss: 5.744577884674072
Step 26, Loss: 5.576700210571289
Step 27, Loss: 5.519281387329102
Step 28, Loss: 5.438257217407227
Step 29, Loss: 5.164332389831543
Step 30, Loss: 5.246014595031738
Step 31, Loss: 4.902283668518066
Eval Loss: 4.858528137207031
Liger kernel patches have been reverted.
==================================================================== short test summary info =====================================================================
FAILED test/convergence/bf16/test_mini_models.py::test_mini_model[mini_gpt_oss-32-1e-05-dtype0-0.01-0.05-0.1-0.01-0.01-0.01] - AssertionError: [Loss]Number of mismatched elements: 2
================================================================= 1 failed, 1 warning in 16.33s ==================================================================

Env: torch 2.8.0, triton 3.4.0, transformers 4.55.0

  • Hardware Type: H200
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

@Comet0322
Copy link
Contributor Author

@Tcc0403

@lancerts
Copy link
Collaborator

@Tcc0403

Many thanks for the contribution.

@PKUWZP
Copy link
Collaborator

PKUWZP commented Aug 11, 2025

Thanks for the contribution! Any benchmarking results for the GPT-OSS models?

Copy link
Collaborator

@shimizust shimizust left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the contribution! What is the discrepancy you find for bf16 convergence tests?

@@ -1673,3 +1682,41 @@ def test_apply_liger_kernel_to_instance_for_smollm3():
print(dummy_model_instance)
except Exception as e:
pytest.fail(f"An exception occured in extra_expr: {type(e).__name__} - {e}")


@pytest.mark.skipif(not is_qwen3_available(), reason="gpt oss module not available")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be is_gpt_oss_available?

cross_entropy: bool = False,
fused_linear_cross_entropy: bool = True,
rms_norm: bool = True,
swiglu: bool = True,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if swiglu can't be implemented now, let's set to False by default and raise NotImplementedError if set to True

@Comet0322
Copy link
Contributor Author

@PKUWZP I'll add the benchmark results as soon as the swiglu implementation is complete.

@Comet0322
Copy link
Contributor Author

@shimizust During the convergence test, the loss values for the two models running in bf16 diverged significantly at certain steps. This is likely related to the issue discussed here: #742.

loss_curve_final

shimizust pushed a commit that referenced this pull request Aug 12, 2025
## Summary
<!--- This is a required section; please describe the main purpose of
this proposed code change. --->
This PR is a follow-up to #830, exposing `accum_dtype` option for monkey
patch functions.
All bf16 convergence tests related to fused linear cross entropy are
also enforced to run with `accum_dtype=torch.float32` for numerical
stability.

Related: #512, #742, #827, #850 
<!---
## Details
This is an optional section; is there anything specific that reviewers
should be aware of?
--->

## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->

<!-- 
Replace BLANK with your device type. For example, A100-80G-PCIe

Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them. 
-->

- Hardware Type: <BLANK>
- [ ] run `make test` to ensure correctness
- [ ] run `make checkstyle` to ensure code style
- [ ] run `make test-convergence` to ensure convergence

Signed-off-by: Tcc0403 <[email protected]>
@Tcc0403
Copy link
Collaborator

Tcc0403 commented Aug 14, 2025

For future contributors:

When patching RMSNorm, there are 4 init args that are easily overlooked.

  1. casting_mode (str): "gemma" or "llama" (more detail)
    • llama: downcasting back to original precision "before" multiplying weight
    • gemma: downcasting back to original precision "after" multiplying weight
  2. init_fn (str): "zeros" or "ones", pretty much llama vs gemma impl too
  3. bias (float): default to 0.0
    • 0.0 (llama): no ops to weight before multiplication
    • 1.0 (gemma): adding 1.0 to weight before multiplication
  4. in_place (bool): True or False
    • True: reusing tensor dY(grad_output) to save some memory, but it doesn't work if dY is required elsewhere (e.g. adding residual after rmsnorm)
    • False: to address the above issue. It can always work. Support out-of-place RMSNorm to fix gemma2 #376

Take GptOss for example
https://github.com/huggingface/transformers/blob/v4.55.2/src/transformers/models/gpt_oss/modeling_gpt_oss.py#L42

It requires

  1. casting_mode="gemma" (casting back after multiplying weight)
  2. init_fn="ones" (self.weight = nn.Parameter(torch.ones(hidden_size)))
  3. bias=0.0 (no shifting for weight)
  4. in_place=True (no dY usage elsewhere)

Create LigerRMSNormForXXXModel under liger_kernel/transformers/rms_norm.py to apply these init params

@shimizust
Copy link
Collaborator

@Comet0322 Thanks for updating. Do you think we can check in what you have and figure out swiglu after? After @Tcc0403 's accum_dtype changes to the bf16 tests, do they pass now?

@Tcc0403
Copy link
Collaborator

Tcc0403 commented Sep 3, 2025

@Comet0322 and I have discussed elsewhere. It seems the router topk choice for experts is quite sensitive in bf16, leading the discrepancy in final losses. I suggest just skipping bf16 convergnece test for now.

cc @shimizust

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants