Skip to content
Draft
3 changes: 3 additions & 0 deletions src/liger_kernel/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma3 # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma3_text # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4 # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gpt_oss # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_granite # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama4 # noqa: F401
Expand Down Expand Up @@ -89,6 +90,7 @@ def __getattr__(name: str):
"apply_liger_kernel_to_gemma3",
"apply_liger_kernel_to_gemma3_text",
"apply_liger_kernel_to_glm4",
"apply_liger_kernel_to_gpt_oss",
"apply_liger_kernel_to_granite",
"apply_liger_kernel_to_llama",
"apply_liger_kernel_to_llava",
Expand Down Expand Up @@ -148,6 +150,7 @@ def __getattr__(name: str):
"apply_liger_kernel_to_gemma3",
"apply_liger_kernel_to_gemma3_text",
"apply_liger_kernel_to_glm4",
"apply_liger_kernel_to_gpt_oss",
"apply_liger_kernel_to_granite",
"apply_liger_kernel_to_llama",
"apply_liger_kernel_to_llava",
Expand Down
113 changes: 113 additions & 0 deletions src/liger_kernel/transformers/model/gpt_oss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from typing import Optional
from typing import Union

import torch

from transformers.cache_utils import Cache
from transformers.modeling_outputs import MoeCausalLMOutputWithPast
from transformers.modeling_outputs import MoeModelOutputWithPast
from transformers.models.gpt_oss.modeling_gpt_oss import load_balancing_loss_func
from transformers.processing_utils import Unpack
from transformers.utils import TransformersKwargs

from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss


def lce_forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_router_logits: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[TransformersKwargs],
) -> MoeCausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

Example:

```python
>>> from transformers import AutoTokenizer, GptOssForCausalLM

>>> model = GptOssForCausalLM.from_pretrained("mistralai/GptOss-8x7B-v0.1")
>>> tokenizer = AutoTokenizer.from_pretrained("mistralai/GptOss-8x7B-v0.1")

>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")

>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""

output_router_logits = (
output_router_logits if output_router_logits is not None else self.config.output_router_logits
)

# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs: MoeModelOutputWithPast = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_router_logits=output_router_logits,
cache_position=cache_position,
**kwargs,
)

hidden_states = outputs.last_hidden_state
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
kept_hidden_states = hidden_states[:, slice_indices, :]

shift_labels = kwargs.pop("shift_labels", None)
logits = None
loss = None

# if in training mode, do not materialize logits
if self.training and (labels is not None or shift_labels is not None):
loss = LigerForCausalLMLoss(
hidden_states=kept_hidden_states,
lm_head_weight=self.lm_head.weight,
labels=labels,
shift_labels=shift_labels,
hidden_size=self.config.hidden_size,
**kwargs,
)
else: # if in inference model materialize logits
logits = self.lm_head(kept_hidden_states)
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)

aux_loss = None
if output_router_logits:
aux_loss = load_balancing_loss_func(
outputs.router_logits,
self.num_experts,
self.num_experts_per_tok,
attention_mask,
)
if labels is not None:
loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device

return MoeCausalLMOutputWithPast(
loss=loss,
aux_loss=aux_loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
router_logits=outputs.router_logits,
)
56 changes: 56 additions & 0 deletions src/liger_kernel/transformers/monkey_patch.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -1839,13 +1839,69 @@ def apply_liger_kernel_to_glm4(
_patch_rms_norm_module(decoder_layer.post_mlp_layernorm, in_place=False)


def apply_liger_kernel_to_gpt_oss(
rope: bool = True,
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = True,
rms_norm: bool = True,
swiglu: bool = False,
model: PreTrainedModel = None,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace GPT OSS models.
"""
if swiglu:
raise NotImplementedError("LigerSwiGLUMLP is not available for GPT-OSS models.")
assert not (cross_entropy and fused_linear_cross_entropy), (
"cross_entropy and fused_linear_cross_entropy cannot both be True."
)

from transformers.models.gpt_oss import modeling_gpt_oss
from transformers.models.gpt_oss.modeling_gpt_oss import GptOssModel
from liger_kernel.transformers.rms_norm import LigerRMSNormForGptOss

from liger_kernel.transformers.model.gpt_oss import lce_forward as gpt_oss_lce_forward

if rope:
modeling_gpt_oss.apply_rotary_pos_emb = liger_rotary_pos_emb

if rms_norm:
modeling_gpt_oss.GptOssRMSNorm = LigerRMSNormForGptOss

if cross_entropy:
from transformers.loss.loss_utils import nn

nn.functional.cross_entropy = liger_cross_entropy

if fused_linear_cross_entropy:
if model is not None:
model.forward = MethodType(gpt_oss_lce_forward, model)
else:
modeling_gpt_oss.GptOssForCausalLM.forward = gpt_oss_lce_forward

if model is not None:
# The model instance already exists, so we need to additionally patch the
# instance variables that reference already-instantiated modules

# get the base model from the model instance
base_model: GptOssModel = getattr(model, model.base_model_prefix, model)

if rms_norm:
_patch_rms_norm_module(base_model.norm)
for decoder_layer in base_model.layers:
if rms_norm:
_patch_rms_norm_module(decoder_layer.input_layernorm)
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)


# Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py
MODEL_TYPE_TO_APPLY_LIGER_FN = {
"gemma": apply_liger_kernel_to_gemma,
"gemma2": apply_liger_kernel_to_gemma2,
"gemma3_text": apply_liger_kernel_to_gemma3_text,
"gemma3": apply_liger_kernel_to_gemma3,
"glm4": apply_liger_kernel_to_glm4,
"gpt_oss": apply_liger_kernel_to_gpt_oss,
"llama": apply_liger_kernel_to_llama,
"llama4_text": apply_liger_kernel_to_llama4,
"llama4": apply_liger_kernel_to_llama4,
Expand Down
6 changes: 6 additions & 0 deletions src/liger_kernel/transformers/rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,9 @@ def __init__(
self, hidden_size, eps=1e-6, offset=0.0, casting_mode="llama", init_fn="ones", in_place=False, row_mode=None
):
super().__init__(hidden_size, eps, offset, casting_mode, init_fn, in_place, row_mode)

class LigerRMSNormForGptOss(LigerRMSNorm):
def __init__(
self, hidden_size, eps=1e-6, offset=0.0, casting_mode="gemma", init_fn="ones", in_place=True, row_mode=None
):
super().__init__(hidden_size, eps, offset, casting_mode, init_fn, in_place, row_mode)
73 changes: 70 additions & 3 deletions test/convergence/bf16/test_mini_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from liger_kernel.transformers import apply_liger_kernel_to_gemma2
from liger_kernel.transformers import apply_liger_kernel_to_gemma3_text
from liger_kernel.transformers import apply_liger_kernel_to_glm4
from liger_kernel.transformers import apply_liger_kernel_to_gpt_oss
from liger_kernel.transformers import apply_liger_kernel_to_granite
from liger_kernel.transformers import apply_liger_kernel_to_llama
from liger_kernel.transformers import apply_liger_kernel_to_llama4
Expand All @@ -46,6 +47,7 @@
from test.utils import revert_liger_kernel_to_gemma2
from test.utils import revert_liger_kernel_to_gemma3_text
from test.utils import revert_liger_kernel_to_glm4
from test.utils import revert_liger_kernel_to_gpt_oss
from test.utils import revert_liger_kernel_to_granite
from test.utils import revert_liger_kernel_to_llama
from test.utils import revert_liger_kernel_to_llama4
Expand Down Expand Up @@ -168,6 +170,15 @@
except ImportError:
SMOLLM3_AVAILABLE = False

try:
# GPT OSS is only available in transformers>=4.55.0
from transformers.models.gpt_oss.configuration_gpt_oss import GptOssConfig
from transformers.models.gpt_oss.modeling_gpt_oss import GptOssForCausalLM

GPT_OSS_AVAILABLE = True
except ImportError:
GPT_OSS_AVAILABLE = False

from liger_kernel.utils import infer_device

device = infer_device()
Expand Down Expand Up @@ -856,6 +867,43 @@
),
)

if GPT_OSS_AVAILABLE:
MINI_MODEL_SETUPS["mini_gpt_oss"] = MiniModelConfig(
liger_kernel_patch_func=apply_liger_kernel_to_gpt_oss,
liger_kernel_patch_revert_func=revert_liger_kernel_to_gpt_oss,
model_class=GptOssForCausalLM,
mini_model_config=GptOssConfig(
num_hidden_layers=4,
num_local_experts=32, # 128,
vocab_size=32000, # 201088,
hidden_size=896, # 2880,
intermediate_size=896, # 2880,
head_dim=64,
num_attention_heads=8, # 16,
num_key_value_heads=2, # 4,
sliding_window=128,
rope_theta=150000.0,
tie_word_embeddings=False,
hidden_act="silu",
initializer_range=0.02,
max_position_embeddings=32768, # 131072,
rms_norm_eps=1e-5,
rope_scaling=dict(
factor=32.0,
beta_fast=32.0,
beta_slow=1.0,
truncate=False,
rope_type="yarn",
),
attention_dropout=0.0,
num_experts_per_tok=2,
router_aux_loss_coef=0.9,
output_router_logits=False,
use_cache=True,
layer_types=None,
),
)


def create_model(model_name="mini_llama4"):
"""
Expand Down Expand Up @@ -900,6 +948,8 @@ def run_mini_model(

if "gemma" in model_name:
kwargs["geglu"] = True
elif "gpt_oss" in model_name:
kwargs["swiglu"] = False
else:
kwargs["swiglu"] = True

Expand Down Expand Up @@ -1281,9 +1331,7 @@ def run_mini_model(
# 1e-2,
# 1e-2,
# 1e-2,
# marks=pytest.mark.skipif(
# not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
# ),
# marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
# ),
pytest.param(
"mini_gemma3_text",
Expand All @@ -1304,6 +1352,25 @@ def run_mini_model(
),
],
),
pytest.param(
"mini_gpt_oss",
32,
1e-5,
torch.bfloat16,
1e-2,
5e-2,
1e-1,
1e-2,
1e-2,
1e-2,
marks=[
pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
pytest.mark.skipif(
not GPT_OSS_AVAILABLE,
reason="GPT OSS not available in this version of transformers",
),
],
),
],
)
def test_mini_model(
Expand Down
Loading
Loading