diff --git a/src/liger_kernel/ops/cross_entropy.py b/src/liger_kernel/ops/cross_entropy.py index 6ad75f93b..a7b7f9c0f 100644 --- a/src/liger_kernel/ops/cross_entropy.py +++ b/src/liger_kernel/ops/cross_entropy.py @@ -32,6 +32,8 @@ def liger_cross_entropy_kernel( loss_ptr, z_loss_ptr, loss_stride, + token_accuracy_ptr, + token_accuracy_stride, n_cols, n_non_ignore, sum_non_ignore_weight, @@ -42,6 +44,7 @@ def liger_cross_entropy_kernel( reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time softcap, RETURN_Z_LOSS: tl.constexpr, + RETURN_TOKEN_ACCURACY: tl.constexpr, BLOCK_SIZE: tl.constexpr, HAS_WEIGHT: tl.constexpr, HAS_SOFTCAPPING: tl.constexpr, @@ -60,6 +63,8 @@ def liger_cross_entropy_kernel( loss_ptr: Pointer to tensor to store the loss. z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0. loss_stride (int): The stride of the loss tensor. + token_accuracy_ptr: Pointer to tensor to store the per-token accuracy. No operation if RETURN_TOKEN_ACCURACY is 0. + token_accuracy_stride (int): The stride of the token accuracy tensor. n_cols (int): The number of columns in the input tensor. n_non_ignore (float): The number of non-ignored elements in the batch. sum_non_ignore_weight (float): The sum of non-ignored target's weights in the batch. @@ -69,7 +74,8 @@ def liger_cross_entropy_kernel( lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training. reduction (str): The string for the reduction to apply softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap). - RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1. + RETURN_Z_LOSS (int): The boolean value to decide whether to store z loss to z_loss_ptr or not. It must be 0 or 1. + RETURN_TOKEN_ACCURACY (int): The boolean value to decide whether to store per-token accuracy to token_accuracy_ptr or not. It must be 0 or 1. BLOCK_SIZE (int): The block size for Triton operations. HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes. HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not. @@ -92,11 +98,17 @@ def liger_cross_entropy_kernel( for i in range(0, n_cols, BLOCK_SIZE): X_offsets = i + tl.arange(0, BLOCK_SIZE) tl.store(X_ptr + X_offsets, 0.0, mask=X_offsets < n_cols) + # For ignored tokens, set token accuracy to 0 + if RETURN_TOKEN_ACCURACY: + token_accuracy_ptr += program_id * token_accuracy_stride + tl.store(token_accuracy_ptr, 0.0) return loss_ptr += program_id * loss_stride if RETURN_Z_LOSS: z_loss_ptr += program_id * loss_stride + if RETURN_TOKEN_ACCURACY: + token_accuracy_ptr += program_id * token_accuracy_stride if HAS_WEIGHT: weight_y = tl.load(weight_ptr + y).cast(tl.float32) @@ -107,6 +119,7 @@ def liger_cross_entropy_kernel( # 3. [Online softmax] first pass: find max + sum m = float("-inf") # m is the max value. use the notation from the paper d = 0.0 # d is the sum. use the notation from the paper + argmax_idx = 0 # Track the index of the maximum value for token accuracy computation ori_X_y = tl.load(X_ptr + y).cast(tl.float32) # we need to store the original value of X_y for the loss calculation if HAS_SOFTCAPPING: ori_X_y = softcap * tanh(ori_X_y / softcap) @@ -127,6 +140,16 @@ def liger_cross_entropy_kernel( if HAS_SOFTCAPPING: X_block = softcap * tanh(X_block / softcap) block_max = tl.max(X_block) + + # Track argmax for accuracy computation + if RETURN_TOKEN_ACCURACY and block_max > m: + # Find the index of the maximum value in this block + is_max_mask = X_block == block_max + # Mask out invalid indices with a value larger than n_cols + masked_offsets = tl.where(is_max_mask, X_offsets, n_cols) + # Get the first (smallest) index where max occurs + argmax_idx = tl.min(masked_offsets) + if label_smoothing > 0: # scale X beforehand to avoid overflow if HAS_WEIGHT: @@ -256,6 +279,10 @@ def liger_cross_entropy_kernel( tl.store(loss_ptr, loss) if RETURN_Z_LOSS: tl.store(z_loss_ptr, z_loss) + if RETURN_TOKEN_ACCURACY: + # Store 1.0 if prediction is correct, 0.0 otherwise + is_correct = 1.0 if argmax_idx == y else 0.0 + tl.store(token_accuracy_ptr, is_correct) # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19 @@ -274,8 +301,12 @@ def cross_entropy_forward( reduction, softcap, return_z_loss, + return_token_accuracy=False, ): assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}" + assert isinstance(return_token_accuracy, bool), ( + f"return_token_accuracy must be True or False. Got: {return_token_accuracy}" + ) BT, V = _input.shape n_rows = BT @@ -285,6 +316,9 @@ def cross_entropy_forward( # unreduced loss loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) if return_z_loss else None + token_accuracy_1d = ( + torch.zeros(n_rows, dtype=torch.float32, device=_input.device) if return_token_accuracy else None + ) target_mask = target != ignore_index n_non_ignore = target_mask.sum().item() @@ -321,6 +355,10 @@ def cross_entropy_forward( loss_ptr=loss_1d, z_loss_ptr=z_loss_1d, loss_stride=loss_1d.stride(-1), # always 1 + token_accuracy_ptr=token_accuracy_1d, + token_accuracy_stride=token_accuracy_1d.stride(-1) + if return_token_accuracy + else 0, # always 1 if accuracy is enabled n_cols=V, n_non_ignore=n_non_ignore, sum_non_ignore_weight=sum_non_ignore_weight, @@ -331,6 +369,7 @@ def cross_entropy_forward( reduction=reduction, softcap=softcap, RETURN_Z_LOSS=return_z_loss, + RETURN_TOKEN_ACCURACY=return_token_accuracy, BLOCK_SIZE=BLOCK_SIZE, HAS_WEIGHT=True if weight is not None else False, HAS_SOFTCAPPING=True if softcap is not None else False, @@ -343,11 +382,14 @@ def cross_entropy_forward( if reduction == "none": loss = loss_1d z_loss = z_loss_1d if return_z_loss else None + token_accuracy = token_accuracy_1d if return_token_accuracy else None else: loss = torch.sum(loss_1d) z_loss = torch.sum(z_loss_1d) if return_z_loss else None + # For accuracy, we compute the mean across all non-ignored tokens + token_accuracy = torch.sum(token_accuracy_1d) / n_non_ignore if return_token_accuracy else None - return loss, z_loss, _input + return loss, z_loss, token_accuracy, _input def cross_entropy_backward(_input, grad_output): @@ -395,6 +437,7 @@ def forward( reduction: str = "mean", softcap: Optional[float] = None, return_z_loss: bool = False, + return_token_accuracy: bool = False, ): """ The forward pass of the Liger Cross Entropy loss. @@ -409,14 +452,15 @@ def forward( label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. reduction (str): The reduction to apply to the output: "none" | "mean | "sum". softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap). - return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss) instead of (loss, None). Default: `False` + return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss, token_accuracy) instead of (loss, None, None). Default: `False` + return_token_accuracy (bool): When `return_token_accuracy` is `True`, computes and returns per-token accuracy without materializing logits. Default: `False` Returns: - tuple: A tuple with the compouted losses with respect to loss and z loss. The elements are tensors or None. + tuple: A tuple with the computed losses and accuracy: (loss, z_loss, token_accuracy). z_loss and token_accuracy are None if not requested. """ input_requires_grad = _input.requires_grad - loss, z_loss, _input = cross_entropy_forward( + loss, z_loss, token_accuracy, _input = cross_entropy_forward( _input, target, weight, @@ -426,6 +470,7 @@ def forward( reduction, softcap, return_z_loss, + return_token_accuracy, ) # TODO: investigation # If we don't detach the _input tensor, the memory will double @@ -433,23 +478,27 @@ def forward( if input_requires_grad: ctx.save_for_backward(_input.detach()) ctx.return_z_loss = return_z_loss + ctx.return_token_accuracy = return_token_accuracy - return loss, z_loss + return loss, z_loss, token_accuracy @staticmethod - def backward(ctx, grad_output, grad_ouput2): + def backward(ctx, grad_output, grad_output2, grad_output3): """ The backward pass of the Liger Cross Entropy loss. Parameters: ctx : The context object with saved tensors. grad_output (tensor): The tensor containing the gradient of the loss with respect to the output. - grad_output2 (tenosr): No use. + grad_output2 (tensor): No use. Gradient for z_loss (not used as z_loss is only for logging). + grad_output3 (tensor): No use. Gradient for token_accuracy (not used as token_accuracy is only for metrics). Returns: tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None. """ if ctx.return_z_loss: - del grad_ouput2 # z_loss is only for logging + del grad_output2 # z_loss is only for logging + if ctx.return_token_accuracy: + del grad_output3 # token_accuracy is only for metrics (_input,) = ctx.saved_tensors _input = cross_entropy_backward(_input, grad_output) @@ -463,4 +512,5 @@ def backward(ctx, grad_output, grad_ouput2): None, None, None, + None, ) diff --git a/src/liger_kernel/ops/fused_linear_cross_entropy.py b/src/liger_kernel/ops/fused_linear_cross_entropy.py index bc1ab45f7..b4ac94de4 100644 --- a/src/liger_kernel/ops/fused_linear_cross_entropy.py +++ b/src/liger_kernel/ops/fused_linear_cross_entropy.py @@ -27,8 +27,12 @@ def fused_linear_cross_entropy_forward( return_z_loss=False, accum_dtype=None, use_token_scaling=False, + return_token_accuracy=False, ): assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}" + assert isinstance(return_token_accuracy, bool), ( + f"return_token_accuracy must be True or False. Got: {return_token_accuracy}" + ) device = _input.device input_requires_grad = _input.requires_grad @@ -64,6 +68,7 @@ def fused_linear_cross_entropy_forward( loss_1d = torch.zeros(BT, dtype=torch.float32, device=device) z_loss_1d = torch.zeros(BT, dtype=_input.dtype, device=_input.device) if return_z_loss else None + token_accuracy_1d = torch.zeros(BT, dtype=torch.float32, device=device) if return_token_accuracy else None # TODO: evaluate how CUDA synchronization caused by .item() affects the speed target_mask = target != ignore_index @@ -129,6 +134,7 @@ def fused_linear_cross_entropy_forward( # unreduced loss loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size, z_loss_1d_slice = z_loss_1d[start_idx:end_idx] if return_z_loss else None + token_accuracy_1d_slice = token_accuracy_1d[start_idx:end_idx] if return_token_accuracy else None # ensure _input and target are contiguous logits_chunk = logits_chunk.contiguous() @@ -144,6 +150,10 @@ def fused_linear_cross_entropy_forward( loss_ptr=loss_1d_slice, z_loss_ptr=z_loss_1d_slice, loss_stride=loss_1d_slice.stride(-1), # always 1 + token_accuracy_ptr=token_accuracy_1d_slice, + token_accuracy_stride=token_accuracy_1d_slice.stride(-1) + if return_token_accuracy + else 0, # always 1 if accuracy is enabled n_cols=V, n_non_ignore=total_n_non_ignore, sum_non_ignore_weight=total_sum_non_ignore_ce_weight, @@ -154,6 +164,7 @@ def fused_linear_cross_entropy_forward( reduction=reduction, softcap=softcap, RETURN_Z_LOSS=return_z_loss, + RETURN_TOKEN_ACCURACY=return_token_accuracy, HAS_WEIGHT=True if ce_weight is not None else False, HAS_SOFTCAPPING=True if softcap is not None else False, HAS_GRADIENTS=input_requires_grad, @@ -170,6 +181,8 @@ def fused_linear_cross_entropy_forward( loss_1d[start_idx:end_idx] = loss_1d_slice if return_z_loss: z_loss_1d[start_idx:end_idx] = z_loss_1d_slice + if return_token_accuracy: + token_accuracy_1d[start_idx:end_idx] = token_accuracy_1d_slice grad_logits_chunk = logits_chunk # chunk_size x V # Apply token scaling to gradients if requested @@ -201,15 +214,18 @@ def fused_linear_cross_entropy_forward( # Return per-token losses loss = loss_1d z_loss = z_loss_1d if return_z_loss else None + token_accuracy = token_accuracy_1d if return_token_accuracy else None else: loss = torch.sum(loss_1d) z_loss = torch.sum(z_loss_1d) if return_z_loss else None + # For accuracy, we compute the mean across all non-ignored tokens + token_accuracy = torch.sum(token_accuracy_1d) / total_n_non_ignore if return_token_accuracy else None # Cast back to original dtype grad_weight = grad_weight.to(weight.dtype) if grad_weight is not None else None grad_bias = grad_bias.to(bias.dtype) if grad_bias is not None else None - return loss, z_loss, grad_input, grad_weight, grad_bias + return loss, z_loss, token_accuracy, grad_input, grad_weight, grad_bias def fused_linear_cross_entropy_backward(grad_output, grad_input, grad_weight, grad_bias): @@ -277,6 +293,7 @@ def forward( return_z_loss: bool = False, accum_dtype=None, use_token_scaling: bool = False, + return_token_accuracy: bool = False, ): """ Fusing the last linear layer with cross-entropy loss @@ -300,9 +317,10 @@ def forward( use_token_scaling (bool): whether to scale each token's loss by its predicted probability (detached). When True, each token's loss is multiplied by the model's predicted probability for that token's true class. Default: False. + return_token_accuracy (bool): When `return_token_accuracy` is `True`, computes and returns per-token accuracy without materializing logits. Default: `False` """ - loss, z_loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward( + loss, z_loss, token_accuracy, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward( _input=_input, weight=weight, target=target, @@ -316,6 +334,7 @@ def forward( return_z_loss=return_z_loss, accum_dtype=accum_dtype, use_token_scaling=use_token_scaling, + return_token_accuracy=return_token_accuracy, ) # downcast to dtype and store for backward ctx.save_for_backward( @@ -324,13 +343,16 @@ def forward( grad_bias.detach() if bias is not None else None, ) ctx.return_z_loss = return_z_loss - return loss, z_loss + ctx.return_token_accuracy = return_token_accuracy + return loss, z_loss, token_accuracy @staticmethod @amp_custom_bwd - def backward(ctx, grad_output, grad_output2): + def backward(ctx, grad_output, grad_output2, grad_output3): if ctx.return_z_loss: del grad_output2 # z_loss is only for logging + if ctx.return_token_accuracy: + del grad_output3 # token_accuracy is only for metrics (grad_input, grad_weight, grad_bias) = ctx.saved_tensors grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward( grad_output, grad_input, grad_weight, grad_bias @@ -349,4 +371,5 @@ def backward(ctx, grad_output, grad_output2): None, None, None, # use_token_scaling + None, # return_token_accuracy ) diff --git a/src/liger_kernel/transformers/cross_entropy.py b/src/liger_kernel/transformers/cross_entropy.py index f01b1f57e..f9128a609 100644 --- a/src/liger_kernel/transformers/cross_entropy.py +++ b/src/liger_kernel/transformers/cross_entropy.py @@ -3,6 +3,7 @@ import torch from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction +from liger_kernel.transformers.functional import CrossEntropyOutput class LigerCrossEntropyLoss(torch.nn.Module): @@ -15,6 +16,7 @@ def __init__( reduction: str = "mean", softcap: Optional[float] = None, return_z_loss: bool = False, + return_token_accuracy: bool = False, ): super().__init__() assert (label_smoothing >= 0) and (label_smoothing <= 1), ( @@ -33,9 +35,10 @@ def __init__( self.reduction = reduction self.softcap = softcap self.return_z_loss = return_z_loss + self.return_token_accuracy = return_token_accuracy def forward(self, _input: torch.Tensor, target: torch.Tensor): - loss, z_loss = LigerCrossEntropyFunction.apply( + loss, z_loss, token_accuracy = LigerCrossEntropyFunction.apply( _input, target, self.weight, @@ -45,7 +48,9 @@ def forward(self, _input: torch.Tensor, target: torch.Tensor): self.reduction, self.softcap, self.return_z_loss, + self.return_token_accuracy, ) - if not self.return_z_loss: + if not self.return_z_loss and not self.return_token_accuracy: return loss - return loss, z_loss + + return CrossEntropyOutput(loss=loss, z_loss=z_loss, token_accuracy=token_accuracy) diff --git a/src/liger_kernel/transformers/functional.py b/src/liger_kernel/transformers/functional.py index 424537541..39411339c 100644 --- a/src/liger_kernel/transformers/functional.py +++ b/src/liger_kernel/transformers/functional.py @@ -1,5 +1,8 @@ +from dataclasses import dataclass from typing import Optional +import torch + from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction from liger_kernel.ops.dyt import LigerDyTFunction from liger_kernel.ops.fused_add_rms_norm import LigerFusedAddRMSNormFunction @@ -22,6 +25,13 @@ from liger_kernel.ops.tvd import LigerTVDLossFunction +@dataclass +class CrossEntropyOutput: + loss: torch.Tensor + z_loss: Optional[torch.Tensor] = None + token_accuracy: Optional[torch.Tensor] = None + + # conform to the function signature in https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html # `weight` and `size_average` are placeholders and not implemented yet def liger_cross_entropy( @@ -36,8 +46,9 @@ def liger_cross_entropy( lse_square_scale: float = 0.0, softcap: Optional[float] = None, return_z_loss: bool = False, + return_token_accuracy: bool = False, ): - loss, z_loss = LigerCrossEntropyFunction.apply( + loss, z_loss, token_accuracy = LigerCrossEntropyFunction.apply( input, target, weight, @@ -47,10 +58,13 @@ def liger_cross_entropy( reduction, softcap, return_z_loss, + return_token_accuracy, ) - if not return_z_loss: + + if not return_z_loss and not return_token_accuracy: return loss - return loss, z_loss + + return CrossEntropyOutput(loss=loss, z_loss=z_loss, token_accuracy=token_accuracy) def liger_fused_linear_cross_entropy( @@ -67,8 +81,9 @@ def liger_fused_linear_cross_entropy( return_z_loss: bool = False, accum_dtype=None, use_token_scaling: bool = False, + return_token_accuracy: bool = False, ): - loss, z_loss = LigerFusedLinearCrossEntropyFunction.apply( + loss, z_loss, token_accuracy = LigerFusedLinearCrossEntropyFunction.apply( input, weight, target, @@ -82,10 +97,13 @@ def liger_fused_linear_cross_entropy( return_z_loss, accum_dtype, use_token_scaling, + return_token_accuracy, ) - if not return_z_loss: + + if not return_z_loss and not return_token_accuracy: return loss - return loss, z_loss + + return CrossEntropyOutput(loss=loss, z_loss=z_loss, token_accuracy=token_accuracy) def liger_fused_linear_jsd( diff --git a/src/liger_kernel/transformers/fused_linear_cross_entropy.py b/src/liger_kernel/transformers/fused_linear_cross_entropy.py index ebf19bcfb..317c6ce37 100644 --- a/src/liger_kernel/transformers/fused_linear_cross_entropy.py +++ b/src/liger_kernel/transformers/fused_linear_cross_entropy.py @@ -3,6 +3,7 @@ import torch from liger_kernel.ops.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction +from liger_kernel.transformers.functional import CrossEntropyOutput class LigerFusedLinearCrossEntropyLoss(torch.nn.Module): @@ -17,6 +18,7 @@ def __init__( return_z_loss: bool = False, accum_dtype: Optional[torch.dtype] = None, use_token_scaling: bool = False, + return_token_accuracy: bool = False, ): super().__init__() assert (label_smoothing >= 0) and (label_smoothing <= 1), ( @@ -37,9 +39,10 @@ def __init__( self.return_z_loss = return_z_loss self.accum_dtype = accum_dtype self.use_token_scaling = use_token_scaling + self.return_token_accuracy = return_token_accuracy def forward(self, lin_weight, _input, target, bias=None): - loss, z_loss = LigerFusedLinearCrossEntropyFunction.apply( + loss, z_loss, token_accuracy = LigerFusedLinearCrossEntropyFunction.apply( _input, lin_weight, target, @@ -53,7 +56,9 @@ def forward(self, lin_weight, _input, target, bias=None): self.return_z_loss, self.accum_dtype, self.use_token_scaling, + self.return_token_accuracy, ) - if not self.return_z_loss: + if not self.return_z_loss and not self.return_token_accuracy: return loss - return loss, z_loss + + return CrossEntropyOutput(loss=loss, z_loss=z_loss, token_accuracy=token_accuracy) diff --git a/src/liger_kernel/transformers/model/falcon_h1.py b/src/liger_kernel/transformers/model/falcon_h1.py index c91a136c9..23652b020 100644 --- a/src/liger_kernel/transformers/model/falcon_h1.py +++ b/src/liger_kernel/transformers/model/falcon_h1.py @@ -4,12 +4,12 @@ import torch -from transformers.modeling_outputs import CausalLMOutputWithPast - if TYPE_CHECKING: from transformers.models.falcon_h1.modeling_falcon_h1 import FalconHybridMambaAttentionDynamicCache from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss +from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result +from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast def lce_forward( @@ -26,8 +26,9 @@ def lce_forward( cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, skip_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, **kwargs, -) -> Union[tuple, CausalLMOutputWithPast]: +) -> Union[tuple, LigerCausalLMOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., @@ -54,6 +55,7 @@ def lce_forward( output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( @@ -77,6 +79,8 @@ def lce_forward( shift_labels = kwargs.pop("shift_labels", None) logits = None loss = None + token_accuracy = None + # if in training mode, don't materialize logits if skip_logits and labels is None: raise ValueError("skip_logits is True, but labels and shift_labels are None") @@ -85,8 +89,9 @@ def lce_forward( # By default, if in training mode, don't materialize logits skip_logits = self.training and labels is not None + # Compute loss if skip_logits: - loss = LigerForCausalLMLoss( + result = LigerForCausalLMLoss( hidden_states=kept_hidden_states, lm_head_weight=self.lm_head.weight, labels=labels, @@ -94,15 +99,24 @@ def lce_forward( hidden_size=self.config.hidden_size, **kwargs, ) + loss, _, token_accuracy = unpack_cross_entropy_result(result) else: logits = self.lm_head(kept_hidden_states) if labels is not None or shift_labels is not None: loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) - return CausalLMOutputWithPast( + if not return_dict: + output = (logits,) + outputs[1:] + output = ((loss,) + output) if loss is not None else output + output = output + (token_accuracy,) if token_accuracy is not None else output + return output + + # Return custom output class with token_accuracy field + return LigerCausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, + token_accuracy=token_accuracy, ) diff --git a/src/liger_kernel/transformers/model/gemma.py b/src/liger_kernel/transformers/model/gemma.py index 2c1e18433..3cc949181 100644 --- a/src/liger_kernel/transformers/model/gemma.py +++ b/src/liger_kernel/transformers/model/gemma.py @@ -12,6 +12,8 @@ from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss +from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result +from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast def lce_forward_deprecated( @@ -147,7 +149,7 @@ def lce_forward( logits_to_keep: Union[int, torch.Tensor] = 0, skip_logits: Optional[bool] = None, **kwargs, -) -> Union[Tuple, CausalLMOutputWithPast]: +) -> Union[Tuple, LigerCausalLMOutputWithPast]: r""" Args: labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -209,6 +211,7 @@ def lce_forward( shift_labels = kwargs.pop("shift_labels", None) logits = None loss = None + token_accuracy = None if skip_logits and labels is None and shift_labels is None: raise ValueError("skip_logits is True, but labels and shift_labels are None") @@ -217,8 +220,9 @@ def lce_forward( # By default, if in training mode, don't materialize logits skip_logits = self.training and (labels is not None or shift_labels is not None) + # Compute loss if skip_logits: - loss = LigerForCausalLMLoss( + result = LigerForCausalLMLoss( hidden_states=kept_hidden_states, lm_head_weight=self.lm_head.weight, labels=labels, @@ -226,6 +230,7 @@ def lce_forward( hidden_size=self.config.hidden_size, **kwargs, ) + loss, _, token_accuracy = unpack_cross_entropy_result(result) else: logits = self.lm_head(kept_hidden_states) if labels is not None or shift_labels is not None: @@ -238,13 +243,19 @@ def lce_forward( ) if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( + output_tuple = (logits,) + outputs[1:] + if loss is not None: + output_tuple = (loss,) + output_tuple + if token_accuracy is not None: + output_tuple = output_tuple + (token_accuracy,) + return output_tuple + + # Return custom output class with token_accuracy field + return LigerCausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, + token_accuracy=token_accuracy, ) diff --git a/src/liger_kernel/transformers/model/gemma2.py b/src/liger_kernel/transformers/model/gemma2.py index 6b1ecc1e1..fc276456d 100644 --- a/src/liger_kernel/transformers/model/gemma2.py +++ b/src/liger_kernel/transformers/model/gemma2.py @@ -13,6 +13,8 @@ from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss +from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result +from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast logger = logging.getLogger(__name__) @@ -158,7 +160,7 @@ def lce_forward( logits_to_keep: Union[int, torch.Tensor] = 0, skip_logits: Optional[bool] = None, **kwargs, -) -> Union[Tuple, CausalLMOutputWithPast]: +) -> Union[Tuple, LigerCausalLMOutputWithPast]: r""" Args: labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -225,6 +227,7 @@ def lce_forward( shift_labels = kwargs.pop("shift_labels", None) logits = None loss = None + token_accuracy = None if skip_logits and labels is None and shift_labels is None: raise ValueError("skip_logits is True, but labels and shift_labels are None") @@ -233,8 +236,9 @@ def lce_forward( # By default, if in training mode, don't materialize logits skip_logits = self.training and (labels is not None or shift_labels is not None) + # Compute loss if skip_logits: - loss = LigerForCausalLMLoss( + result = LigerForCausalLMLoss( hidden_states=kept_hidden_states, lm_head_weight=self.lm_head.weight, labels=labels, @@ -243,6 +247,7 @@ def lce_forward( final_logit_softcapping=self.config.final_logit_softcapping, **kwargs, ) + loss, _, token_accuracy = unpack_cross_entropy_result(result) else: logits = self.lm_head(kept_hidden_states) @@ -262,13 +267,17 @@ def lce_forward( ) if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output + output_tuple = (logits,) + outputs[1:] + output_tuple = (loss,) + output_tuple if loss is not None else output_tuple + output_tuple = output_tuple + (token_accuracy,) if token_accuracy is not None else output_tuple + return output_tuple - return CausalLMOutputWithPast( + # Return custom output class with token_accuracy field + return LigerCausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, + token_accuracy=token_accuracy, ) diff --git a/src/liger_kernel/transformers/model/gemma3.py b/src/liger_kernel/transformers/model/gemma3.py index 88ed0ef96..1a19f561a 100644 --- a/src/liger_kernel/transformers/model/gemma3.py +++ b/src/liger_kernel/transformers/model/gemma3.py @@ -7,12 +7,13 @@ from transformers.cache_utils import Cache from transformers.cache_utils import HybridCache -from transformers.modeling_outputs import CausalLMOutputWithPast -from transformers.models.gemma3.modeling_gemma3 import Gemma3CausalLMOutputWithPast from transformers.utils import logging from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss +from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result +from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast +from liger_kernel.transformers.model.output_classes import LigerGemma3CausalLMOutputWithPast logger = logging.get_logger(__name__) @@ -33,7 +34,7 @@ def causal_forward( logits_to_keep: Union[int, torch.Tensor] = 0, skip_logits: Optional[bool] = None, **loss_kwargs, -) -> Union[Tuple, CausalLMOutputWithPast]: +) -> Union[Tuple, LigerCausalLMOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., @@ -98,12 +99,14 @@ def causal_forward( shift_labels = loss_kwargs.pop("shift_labels", None) loss = None logits = None + token_accuracy = None if skip_logits is None: skip_logits = self.training and (labels is not None or shift_labels is not None) + # Compute loss if skip_logits: - loss = LigerForCausalLMLoss( + result = LigerForCausalLMLoss( hidden_states=kept_hidden_states, lm_head_weight=self.lm_head.weight, labels=labels, @@ -112,7 +115,7 @@ def causal_forward( final_logit_softcapping=self.config.final_logit_softcapping, **loss_kwargs, ) - + loss, _, token_accuracy = unpack_cross_entropy_result(result) else: logits = self.lm_head(kept_hidden_states) if self.config.final_logit_softcapping is not None: @@ -129,15 +132,19 @@ def causal_forward( ) if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output + output_tuple = (logits,) + outputs[1:] + output_tuple = (loss,) + output_tuple if loss is not None else output_tuple + output_tuple = output_tuple + (token_accuracy,) if token_accuracy is not None else output_tuple + return output_tuple - return CausalLMOutputWithPast( + # Return custom output class with token_accuracy field + return LigerCausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, + token_accuracy=token_accuracy, ) @@ -159,7 +166,7 @@ def multimodal_forward( logits_to_keep: Union[int, torch.Tensor] = 0, skip_logits: Optional[bool] = None, **lm_kwargs, -) -> Union[tuple, Gemma3CausalLMOutputWithPast]: +) -> Union[tuple, LigerGemma3CausalLMOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., @@ -235,6 +242,7 @@ def multimodal_forward( loss = None logits = None + token_accuracy = None if skip_logits and labels is None: raise ValueError("skip_logits is True, but labels is None") @@ -261,7 +269,9 @@ def multimodal_forward( shift_labels = shift_labels.view(-1).to(hidden_device) lce = LigerFusedLinearCrossEntropyLoss() - loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels) + result = lce(self.lm_head.weight, shift_hidden_states, shift_labels) + loss, _, token_accuracy = unpack_cross_entropy_result(result) + else: logits = self.lm_head(kept_hidden_states) if labels is not None: @@ -306,13 +316,16 @@ def multimodal_forward( if not return_dict: output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output + output = (loss,) + output if loss is not None else output + output = output + (token_accuracy,) if token_accuracy is not None else output + return output - return Gemma3CausalLMOutputWithPast( + return LigerGemma3CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, image_hidden_states=outputs.image_hidden_states, + token_accuracy=token_accuracy, ) diff --git a/src/liger_kernel/transformers/model/glm4.py b/src/liger_kernel/transformers/model/glm4.py index 87994ace9..5ee9a0e3d 100644 --- a/src/liger_kernel/transformers/model/glm4.py +++ b/src/liger_kernel/transformers/model/glm4.py @@ -5,10 +5,11 @@ import torch -from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.utils.deprecation import deprecate_kwarg from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss +from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result +from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @@ -28,7 +29,7 @@ def lce_forward( logits_to_keep: Union[int, torch.Tensor] = 0, skip_logits: Optional[bool] = None, **kwargs, -) -> Union[Tuple, CausalLMOutputWithPast]: +) -> Union[Tuple, LigerCausalLMOutputWithPast]: r""" Args: labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -91,6 +92,7 @@ def lce_forward( shift_labels = kwargs.pop("shift_labels", None) logits = None loss = None + token_accuracy = None if skip_logits and labels is None and shift_labels is None: raise ValueError("skip_logits is True, but labels and shift_labels are None") @@ -99,8 +101,9 @@ def lce_forward( # By default, if in training mode, don't materialize logits skip_logits = self.training and (labels is not None or shift_labels is not None) + # Compute loss if skip_logits: - loss = LigerForCausalLMLoss( + result = LigerForCausalLMLoss( hidden_states=kept_hidden_states, lm_head_weight=self.lm_head.weight, labels=labels, @@ -108,6 +111,7 @@ def lce_forward( hidden_size=self.config.hidden_size, **kwargs, ) + loss, _, token_accuracy = unpack_cross_entropy_result(result) else: logits = self.lm_head(kept_hidden_states) @@ -120,10 +124,18 @@ def lce_forward( **kwargs, ) - return CausalLMOutputWithPast( + if not return_dict: + output = (logits,) + outputs[1:] + output = ((loss,) + output) if loss is not None else output + output = output + (token_accuracy,) if token_accuracy is not None else output + return output + + # Return custom output class with token_accuracy field + return LigerCausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, + token_accuracy=token_accuracy, ) diff --git a/src/liger_kernel/transformers/model/glm4v.py b/src/liger_kernel/transformers/model/glm4v.py index 1d7d1a97d..369451e03 100644 --- a/src/liger_kernel/transformers/model/glm4v.py +++ b/src/liger_kernel/transformers/model/glm4v.py @@ -5,10 +5,11 @@ import torch -from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.utils.deprecation import deprecate_kwarg from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss +from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result +from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @@ -28,7 +29,7 @@ def lce_forward( logits_to_keep: Union[int, torch.Tensor] = 0, skip_logits: Optional[bool] = None, **kwargs, -) -> Union[Tuple, CausalLMOutputWithPast]: +) -> Union[Tuple, LigerCausalLMOutputWithPast]: r""" Args: labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -113,6 +114,7 @@ def lce_forward( shift_labels = kwargs.pop("shift_labels", None) logits = None loss = None + token_accuracy = None if skip_logits and labels is None and shift_labels is None: raise ValueError("skip_logits is True, but labels and shift_labels are None") @@ -121,8 +123,9 @@ def lce_forward( # By default, if in training mode, don't materialize logits skip_logits = self.training and (labels is not None or shift_labels is not None) + # Compute loss if skip_logits: - loss = LigerForCausalLMLoss( + result = LigerForCausalLMLoss( hidden_states=kept_hidden_states, lm_head_weight=self.lm_head.weight, labels=labels, @@ -130,6 +133,7 @@ def lce_forward( hidden_size=self.config.hidden_size, **kwargs, ) + loss, _, token_accuracy = unpack_cross_entropy_result(result) else: logits = self.lm_head(kept_hidden_states) @@ -142,10 +146,18 @@ def lce_forward( **kwargs, ) - return CausalLMOutputWithPast( + if not return_dict: + output = (logits,) + outputs[1:] + output = ((loss,) + output) if loss is not None else output + output = output + (token_accuracy,) if token_accuracy is not None else output + return output + + # Return custom output class with token_accuracy field + return LigerCausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, + token_accuracy=token_accuracy, ) diff --git a/src/liger_kernel/transformers/model/glm4v_moe.py b/src/liger_kernel/transformers/model/glm4v_moe.py index 1177f0585..4a0889f5c 100644 --- a/src/liger_kernel/transformers/model/glm4v_moe.py +++ b/src/liger_kernel/transformers/model/glm4v_moe.py @@ -4,10 +4,11 @@ import torch -from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeCausalLMOutputWithPast from transformers.utils.deprecation import deprecate_kwarg from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss +from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result +from liger_kernel.transformers.model.output_classes import LigerGlm4vMoeCausalLMOutputWithPast @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @@ -27,8 +28,9 @@ def lce_forward( cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, skip_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, **kwargs, -) -> Union[Tuple, Glm4vMoeCausalLMOutputWithPast]: +) -> Union[Tuple, LigerGlm4vMoeCausalLMOutputWithPast]: r""" Args: labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -90,6 +92,7 @@ def lce_forward( >>> output_text = processor.decode(generated_ids[0][inputs["input_ids"].shape[1]:], skip_special_tokens=False) ``` """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( @@ -114,6 +117,7 @@ def lce_forward( shift_labels = kwargs.pop("shift_labels", None) logits = None loss = None + token_accuracy = None if skip_logits and labels is None and shift_labels is None: raise ValueError("skip_logits is True, but labels and shift_labels are None") @@ -122,8 +126,9 @@ def lce_forward( # By default, if in training mode, don't materialize logits skip_logits = self.training and (labels is not None or shift_labels is not None) + # Compute loss if skip_logits: - loss = LigerForCausalLMLoss( + result = LigerForCausalLMLoss( hidden_states=kept_hidden_states, lm_head_weight=self.lm_head.weight, labels=labels, @@ -131,6 +136,7 @@ def lce_forward( hidden_size=self.config.hidden_size, **kwargs, ) + loss, _, token_accuracy = unpack_cross_entropy_result(result) else: logits = self.lm_head(kept_hidden_states) @@ -143,11 +149,20 @@ def lce_forward( **kwargs, ) - return Glm4vMoeCausalLMOutputWithPast( + if not return_dict: + output = (logits,) + outputs[1:] + output = ((loss,) + output) if loss is not None else output + output = output + (token_accuracy,) if token_accuracy is not None else output + return output + + # Return GLM4V MoE output with accuracy (using dict syntax to add extra field) + return LigerGlm4vMoeCausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, rope_deltas=outputs.rope_deltas, + aux_loss=outputs.aux_loss, + token_accuracy=token_accuracy, ) diff --git a/src/liger_kernel/transformers/model/internvl.py b/src/liger_kernel/transformers/model/internvl.py index 6e472fbbe..2ddfd733d 100644 --- a/src/liger_kernel/transformers/model/internvl.py +++ b/src/liger_kernel/transformers/model/internvl.py @@ -5,10 +5,11 @@ import torch -from transformers.models.internvl.modeling_internvl import InternVLCausalLMOutputWithPast from transformers.utils import can_return_tuple from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss +from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result +from liger_kernel.transformers.model.output_classes import LigerInternVLCausalLMOutputWithPast # Copied from https://github.com/huggingface/transformers/blob/d888bd435d0c0eaabaabad5b33d52af518c7187c/src/transformers/models/internvl/modeling_internvl.py#L862 @@ -33,7 +34,7 @@ def lce_forward( image_sizes: Optional[torch.Tensor] = None, skip_logits: Optional[bool] = None, # Added argument for liger-kernel **lm_kwargs, # renamed from kwargs -) -> Union[Tuple, InternVLCausalLMOutputWithPast]: +) -> Union[Tuple, LigerInternVLCausalLMOutputWithPast]: r""" Example: @@ -111,6 +112,7 @@ def lce_forward( shift_labels = lm_kwargs.pop("shift_labels", None) logits = None loss = None + token_accuracy = None if skip_logits and labels is None and shift_labels is None: raise ValueError("skip_logits is True, but labels and shift_labels are None") @@ -120,7 +122,7 @@ def lce_forward( skip_logits = self.training and (labels is not None or shift_labels is not None) if skip_logits: - loss = LigerForCausalLMLoss( + result = LigerForCausalLMLoss( hidden_states=kept_hidden_states, lm_head_weight=self.lm_head.weight, labels=labels, @@ -128,6 +130,7 @@ def lce_forward( hidden_size=self.config.text_config.hidden_size, **lm_kwargs, ) + loss, _, token_accuracy = unpack_cross_entropy_result(result) else: logits = self.lm_head(kept_hidden_states) @@ -138,13 +141,17 @@ def lce_forward( if not return_dict: output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output + output = (loss,) + output if loss is not None else output + output = output + (token_accuracy,) if token_accuracy is not None else output + return output - return InternVLCausalLMOutputWithPast( + # Return custom output class with token_accuracy field + return LigerInternVLCausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, image_hidden_states=outputs.image_hidden_states, + token_accuracy=token_accuracy, ) diff --git a/src/liger_kernel/transformers/model/llama.py b/src/liger_kernel/transformers/model/llama.py index 43c0078a3..1a6a2ea5a 100644 --- a/src/liger_kernel/transformers/model/llama.py +++ b/src/liger_kernel/transformers/model/llama.py @@ -15,6 +15,8 @@ from liger_kernel.transformers.fsdp import _FSDPForwardRedirection from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss +from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result +from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast from liger_kernel.utils import PEFT_AVAILABLE if TYPE_CHECKING: @@ -162,7 +164,7 @@ def lce_forward( logits_to_keep: Union[int, torch.Tensor] = 0, skip_logits: Optional[bool] = None, **kwargs, -) -> Union[Tuple, CausalLMOutputWithPast]: +) -> Union[Tuple, LigerCausalLMOutputWithPast]: r""" Args: labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -228,6 +230,8 @@ def lce_forward( shift_labels = kwargs.pop("shift_labels", None) logits = None loss = None + token_accuracy = None + # if in training mode, don't materialize logits if skip_logits and labels is None and shift_labels is None: raise ValueError("skip_logits is True, but labels and shift_labels are None") @@ -236,8 +240,9 @@ def lce_forward( # By default, if in training mode, don't materialize logits skip_logits = self.training and (labels is not None or shift_labels is not None) + # Compute loss if skip_logits: - loss = lce_maybe_trainable_lm_head( + result = lce_maybe_trainable_lm_head( self, hidden_states=kept_hidden_states, hidden_size=self.config.hidden_size, @@ -245,7 +250,7 @@ def lce_forward( shift_labels=shift_labels, **kwargs, ) - + loss, _, token_accuracy = unpack_cross_entropy_result(result) else: logits = self.lm_head(kept_hidden_states) if labels is not None or shift_labels is not None: @@ -259,14 +264,18 @@ def lce_forward( if not return_dict: output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output + output = ((loss,) + output) if loss is not None else output + output = output + (token_accuracy,) if token_accuracy is not None else output + return output - return CausalLMOutputWithPast( + # Return custom output class with token_accuracy field + return LigerCausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, + token_accuracy=token_accuracy, ) diff --git a/src/liger_kernel/transformers/model/llama4.py b/src/liger_kernel/transformers/model/llama4.py index 150f83f5f..2b630ae6e 100644 --- a/src/liger_kernel/transformers/model/llama4.py +++ b/src/liger_kernel/transformers/model/llama4.py @@ -6,9 +6,10 @@ import torch from transformers.cache_utils import Cache -from transformers.modeling_outputs import CausalLMOutputWithPast from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss +from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result +from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast def lce_forward( @@ -26,7 +27,7 @@ def lce_forward( cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs, -) -> Union[Tuple, CausalLMOutputWithPast]: +) -> Union[Tuple, LigerCausalLMOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., @@ -78,9 +79,11 @@ def lce_forward( shift_labels = kwargs.pop("shift_labels", None) logits = None loss = None + token_accuracy = None + # Compute loss if self.training and (labels is not None or shift_labels is not None): - loss = LigerForCausalLMLoss( + result = LigerForCausalLMLoss( hidden_states=kept_hidden_states, lm_head_weight=self.lm_head.weight, labels=labels, @@ -88,6 +91,7 @@ def lce_forward( hidden_size=self.config.hidden_size, **kwargs, ) + loss, _, token_accuracy = unpack_cross_entropy_result(result) else: # if in inference mode materialize logits logits = self.lm_head(kept_hidden_states) @@ -100,10 +104,18 @@ def lce_forward( **kwargs, ) - return CausalLMOutputWithPast( + if not return_dict: + output = (logits,) + outputs[1:] + output = ((loss,) + output) if loss is not None else output + output = output + (token_accuracy,) if token_accuracy is not None else output + return output + + # Return custom output class with token_accuracy field + return LigerCausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, + token_accuracy=token_accuracy, ) diff --git a/src/liger_kernel/transformers/model/llava.py b/src/liger_kernel/transformers/model/llava.py index f477a6243..a4453f3cb 100644 --- a/src/liger_kernel/transformers/model/llava.py +++ b/src/liger_kernel/transformers/model/llava.py @@ -11,6 +11,8 @@ from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss +from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result +from liger_kernel.transformers.model.output_classes import LigerLlavaCausalLMOutputWithPast def lce_forward_deprecated( @@ -215,7 +217,7 @@ def lce_forward( image_sizes: torch.Tensor = None, skip_logits: Optional[bool] = None, **lm_kwargs, -) -> Union[Tuple, LlavaCausalLMOutputWithPast]: +) -> Union[Tuple, LigerLlavaCausalLMOutputWithPast]: r""" Args: labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -293,6 +295,7 @@ def lce_forward( shift_labels = lm_kwargs.pop("shift_labels", None) logits = None loss = None + token_accuracy = None if skip_logits and labels is None and shift_labels is None: raise ValueError("skip_logits is True, but labels and shift_labels are None") @@ -302,7 +305,7 @@ def lce_forward( skip_logits = self.training and (labels is not None or shift_labels is not None) if skip_logits: - loss = LigerForCausalLMLoss( + result = LigerForCausalLMLoss( hidden_states=kept_hidden_states, lm_head_weight=self.lm_head.weight, labels=labels, @@ -310,6 +313,7 @@ def lce_forward( hidden_size=self.config.text_config.hidden_size, **lm_kwargs, ) + loss, _, token_accuracy = unpack_cross_entropy_result(result) else: logits = self.lm_head(kept_hidden_states) @@ -324,13 +328,17 @@ def lce_forward( if not return_dict: output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output + output = (loss,) + output if loss is not None else output + output = output + (token_accuracy,) if token_accuracy is not None else output + return output - return LlavaCausalLMOutputWithPast( + # Return custom output class with token_accuracy field + return LigerLlavaCausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, image_hidden_states=outputs.image_hidden_states, + token_accuracy=token_accuracy, ) diff --git a/src/liger_kernel/transformers/model/loss_utils.py b/src/liger_kernel/transformers/model/loss_utils.py index 0e112ca72..f21294506 100644 --- a/src/liger_kernel/transformers/model/loss_utils.py +++ b/src/liger_kernel/transformers/model/loss_utils.py @@ -1,10 +1,28 @@ from typing import Optional +from typing import Tuple import torch import torch.nn as nn import liger_kernel.transformers.functional as F +from liger_kernel.transformers.functional import CrossEntropyOutput + + +def unpack_cross_entropy_result( + result, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + if isinstance(result, CrossEntropyOutput): + return result.loss, result.z_loss, result.token_accuracy + + if isinstance(result, tuple): + loss = result[0] + z_loss = result[1] if len(result) > 1 else None + token_accuracy = result[2] if len(result) > 2 else None + return loss, z_loss, token_accuracy + + return result, None, None + def fixed_fused_linear_cross_entropy( hidden_states: torch.Tensor, @@ -14,10 +32,11 @@ def fixed_fused_linear_cross_entropy( ignore_index: int = -100, final_logit_softcapping: Optional[float] = None, accum_dtype: Optional[torch.dtype] = None, + return_token_accuracy: bool = False, **kwargs, ): reduction = "sum" if num_items_in_batch is not None else "mean" - loss = F.liger_fused_linear_cross_entropy( + result = F.liger_fused_linear_cross_entropy( hidden_states, lm_head_weight, target, @@ -25,11 +44,18 @@ def fixed_fused_linear_cross_entropy( ignore_index=ignore_index, softcap=final_logit_softcapping, accum_dtype=accum_dtype, + return_token_accuracy=return_token_accuracy, **kwargs, ) + + loss, _, token_accuracy = unpack_cross_entropy_result(result) + if reduction == "sum": loss = loss / num_items_in_batch + if return_token_accuracy: + return CrossEntropyOutput(loss=loss, token_accuracy=token_accuracy) + return loss @@ -42,6 +68,7 @@ def LigerForCausalLMLoss( ignore_index: int = -100, shift_labels: Optional[torch.Tensor] = None, final_logit_softcapping: Optional[float] = None, + return_token_accuracy: bool = False, **kwargs, ): # Skip upcast since intermediate values for the loss are all fp32 in kernel @@ -55,13 +82,14 @@ def LigerForCausalLMLoss( shift_labels = shift_labels.view(-1) # Enable model parallelism shift_labels = shift_labels.to(hidden_states.device) - loss = fixed_fused_linear_cross_entropy( + result = fixed_fused_linear_cross_entropy( hidden_states, lm_head_weight, shift_labels, num_items_in_batch, ignore_index, final_logit_softcapping, + return_token_accuracy=return_token_accuracy, **kwargs, ) - return loss + return result diff --git a/src/liger_kernel/transformers/model/mistral.py b/src/liger_kernel/transformers/model/mistral.py index b4a600dca..a6395da5d 100644 --- a/src/liger_kernel/transformers/model/mistral.py +++ b/src/liger_kernel/transformers/model/mistral.py @@ -6,10 +6,11 @@ import torch from transformers.cache_utils import Cache -from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.utils.deprecation import deprecate_kwarg from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss +from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result +from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @@ -29,7 +30,7 @@ def lce_forward( logits_to_keep: Union[int, torch.Tensor] = 0, skip_logits: Optional[bool] = None, **kwargs, -) -> Union[Tuple, CausalLMOutputWithPast]: +) -> Union[Tuple, LigerCausalLMOutputWithPast]: r""" Copy paste Mistral's forward but replace torch cross entropy with liger fused linear cross entropy @@ -94,6 +95,7 @@ def lce_forward( shift_labels = kwargs.pop("shift_labels", None) loss = None logits = None + token_accuracy = None if skip_logits and labels is None and shift_labels is None: raise ValueError("skip_logits is True, but labels and shift_labels are None") @@ -101,8 +103,9 @@ def lce_forward( if skip_logits is None: skip_logits = self.training and (labels is not None or shift_labels is not None) + # Compute loss if skip_logits: - loss = LigerForCausalLMLoss( + result = LigerForCausalLMLoss( hidden_states=kept_hidden_states, lm_head_weight=self.lm_head.weight, labels=labels, @@ -110,6 +113,7 @@ def lce_forward( hidden_size=self.config.hidden_size, **kwargs, ) + loss, _, token_accuracy = unpack_cross_entropy_result(result) else: logits = self.lm_head(kept_hidden_states) @@ -123,14 +127,19 @@ def lce_forward( vocab_size=self.config.vocab_size, **kwargs, ) + if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output + output_tuple = (logits,) + outputs[1:] + output = (loss,) + output_tuple if loss is not None else output_tuple + output = output + (token_accuracy,) if token_accuracy is not None else output + return output - return CausalLMOutputWithPast( + # Return custom output class with token_accuracy field + return LigerCausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, + token_accuracy=token_accuracy, ) diff --git a/src/liger_kernel/transformers/model/mixtral.py b/src/liger_kernel/transformers/model/mixtral.py index 686fc456f..9240fb36b 100644 --- a/src/liger_kernel/transformers/model/mixtral.py +++ b/src/liger_kernel/transformers/model/mixtral.py @@ -12,6 +12,8 @@ from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss +from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result +from liger_kernel.transformers.model.output_classes import LigerMoeCausalLMOutputWithPast def lce_forward_deprecated( @@ -158,7 +160,7 @@ def lce_forward( logits_to_keep: Union[int, torch.Tensor] = 0, skip_logits: Optional[bool] = None, **kwargs, -) -> Union[Tuple, MoeCausalLMOutputWithPast]: +) -> Union[Tuple, LigerMoeCausalLMOutputWithPast]: r""" Args: labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -226,6 +228,7 @@ def lce_forward( shift_labels = kwargs.pop("shift_labels", None) logits = None loss = None + token_accuracy = None if skip_logits and labels is None and shift_labels is None: raise ValueError("skip_logits is True, but labels and shift_labels are None") @@ -234,8 +237,9 @@ def lce_forward( # By default, if in training mode, don't materialize logits skip_logits = self.training and (labels is not None or shift_labels is not None) + # Compute loss if skip_logits: - loss = LigerForCausalLMLoss( + result = LigerForCausalLMLoss( hidden_states=kept_hidden_states, lm_head_weight=self.lm_head.weight, labels=labels, @@ -243,6 +247,7 @@ def lce_forward( hidden_size=self.config.hidden_size, **kwargs, ) + loss, _, token_accuracy = unpack_cross_entropy_result(result) else: logits = self.lm_head(kept_hidden_states) @@ -268,17 +273,21 @@ def lce_forward( loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device if not return_dict: - output = (logits,) + outputs[1:] + output_tuple = (logits,) + outputs[1:] if output_router_logits: - output = (aux_loss,) + output - return (loss,) + output if loss is not None else output + output_tuple = (aux_loss,) + output_tuple + if token_accuracy is not None: + output_tuple = output_tuple + (token_accuracy,) + return (loss,) + output_tuple if loss is not None else output_tuple - return MoeCausalLMOutputWithPast( + # Return custom output class with token_accuracy field + return LigerMoeCausalLMOutputWithPast( loss=loss, aux_loss=aux_loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, - router_logits=outputs.router_logits, + router_logits=outputs.router_logits if return_dict else outputs[-1], + token_accuracy=token_accuracy, ) diff --git a/src/liger_kernel/transformers/model/mllama.py b/src/liger_kernel/transformers/model/mllama.py index 3a6b99582..3dd4c2f28 100644 --- a/src/liger_kernel/transformers/model/mllama.py +++ b/src/liger_kernel/transformers/model/mllama.py @@ -12,6 +12,8 @@ from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss +from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result +from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast def lce_forward_deprecated( @@ -149,7 +151,7 @@ def lce_forward( logits_to_keep: Union[int, torch.Tensor] = 0, skip_logits: Optional[bool] = None, **kwargs, -) -> Union[Tuple, CausalLMOutputWithPast]: +) -> Union[Tuple, LigerCausalLMOutputWithPast]: r""" Args: labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -219,6 +221,7 @@ def lce_forward( shift_labels = kwargs.pop("shift_labels", None) logits = None loss = None + token_accuracy = None if skip_logits and labels is None and shift_labels is None: raise ValueError("skip_logits is True, but labels and shift_labels are None") @@ -228,7 +231,7 @@ def lce_forward( skip_logits = self.training and (labels is not None or shift_labels is not None) if skip_logits: - loss = LigerForCausalLMLoss( + result = LigerForCausalLMLoss( hidden_states=kept_hidden_states, lm_head_weight=self.lm_head.weight, labels=labels, @@ -236,6 +239,7 @@ def lce_forward( hidden_size=self.config.hidden_size, **kwargs, ) + loss, _, token_accuracy = unpack_cross_entropy_result(result) else: logits = self.lm_head(kept_hidden_states) @@ -250,12 +254,16 @@ def lce_forward( if not return_dict: output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output + output = (loss,) + output if loss is not None else output + output = output + (token_accuracy,) if token_accuracy is not None else output + return output - return CausalLMOutputWithPast( + # Return custom output class with token_accuracy field + return LigerCausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, + token_accuracy=token_accuracy, ) diff --git a/src/liger_kernel/transformers/model/olmo2.py b/src/liger_kernel/transformers/model/olmo2.py index d9705c5f8..fee0d46df 100644 --- a/src/liger_kernel/transformers/model/olmo2.py +++ b/src/liger_kernel/transformers/model/olmo2.py @@ -5,10 +5,11 @@ import torch -from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.utils.deprecation import deprecate_kwarg from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss +from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result +from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @@ -28,7 +29,7 @@ def lce_forward( logits_to_keep: Union[int, torch.Tensor] = 0, skip_logits: Optional[bool] = None, **kwargs, -) -> Union[Tuple, CausalLMOutputWithPast]: +) -> Union[Tuple, LigerCausalLMOutputWithPast]: r""" Args: labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -91,6 +92,7 @@ def lce_forward( shift_labels = kwargs.pop("shift_labels", None) logits = None loss = None + token_accuracy = None if skip_logits and labels is None and shift_labels is None: raise ValueError("skip_logits is True, but labels and shift_labels are None") @@ -99,8 +101,9 @@ def lce_forward( # By default, if in training mode, don't materialize logits skip_logits = self.training and (labels is not None or shift_labels is not None) + # Compute loss if skip_logits: - loss = LigerForCausalLMLoss( + result = LigerForCausalLMLoss( hidden_states=kept_hidden_states, lm_head_weight=self.lm_head.weight, labels=labels, @@ -108,6 +111,7 @@ def lce_forward( hidden_size=self.config.hidden_size, **kwargs, ) + loss, _, token_accuracy = unpack_cross_entropy_result(result) else: logits = self.lm_head(kept_hidden_states) @@ -120,10 +124,18 @@ def lce_forward( **kwargs, ) - return CausalLMOutputWithPast( + if not return_dict: + output = (logits,) + outputs[1:] + output = ((loss,) + output) if loss is not None else output + output = output + (token_accuracy,) if token_accuracy is not None else output + return output + + # Return custom output class with token_accuracy field + return LigerCausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, + token_accuracy=token_accuracy, ) diff --git a/src/liger_kernel/transformers/model/output_classes.py b/src/liger_kernel/transformers/model/output_classes.py new file mode 100644 index 000000000..d65e8ebe7 --- /dev/null +++ b/src/liger_kernel/transformers/model/output_classes.py @@ -0,0 +1,147 @@ +""" +Custom output classes for Liger-Kernel that extend transformers' ModelOutput classes +with optional token accuracy field. +""" + +from dataclasses import dataclass +from typing import Optional + +import torch + +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.modeling_outputs import MoeCausalLMOutputWithPast + +# The following model-specific outputs are optional and depend on the installed +# transformers version. Guard their imports so our module remains importable +# even when those models are not available in the environment. +try: + from transformers.models.gemma3.modeling_gemma3 import Gemma3CausalLMOutputWithPast as _Gemma3CausalLMOutputWithPast +except Exception: + _Gemma3CausalLMOutputWithPast = None + +try: + from transformers.models.glm4v_moe.modeling_glm4v_moe import ( + Glm4vMoeCausalLMOutputWithPast as _Glm4vMoeCausalLMOutputWithPast, + ) +except Exception: + _Glm4vMoeCausalLMOutputWithPast = None + +try: + from transformers.models.internvl.modeling_internvl import ( + InternVLCausalLMOutputWithPast as _InternVLCausalLMOutputWithPast, + ) +except Exception: + _InternVLCausalLMOutputWithPast = None + +try: + from transformers.models.llava.modeling_llava import LlavaCausalLMOutputWithPast as _LlavaCausalLMOutputWithPast +except Exception: + _LlavaCausalLMOutputWithPast = None + +try: + from transformers.models.paligemma.modeling_paligemma import ( + PaliGemmaCausalLMOutputWithPast as _PaliGemmaCausalLMOutputWithPast, + ) +except Exception: + _PaliGemmaCausalLMOutputWithPast = None + +try: + from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( + Qwen2_5_VLCausalLMOutputWithPast as _Qwen2_5_VLCausalLMOutputWithPast, + ) +except Exception: + _Qwen2_5_VLCausalLMOutputWithPast = None + +try: + from transformers.models.qwen2_vl.modeling_qwen2_vl import ( + Qwen2VLCausalLMOutputWithPast as _Qwen2VLCausalLMOutputWithPast, + ) +except Exception: + _Qwen2VLCausalLMOutputWithPast = None + +try: + from transformers.models.qwen3_vl.modeling_qwen3_vl import ( + Qwen3VLCausalLMOutputWithPast as _Qwen3VLCausalLMOutputWithPast, + ) +except Exception: + _Qwen3VLCausalLMOutputWithPast = None + +try: + from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import ( + Qwen3VLMoeCausalLMOutputWithPast as _Qwen3VLMoeCausalLMOutputWithPast, + ) +except Exception: + _Qwen3VLMoeCausalLMOutputWithPast = None + + +@dataclass +class LigerCausalLMOutputWithPast(CausalLMOutputWithPast): + token_accuracy: Optional[torch.FloatTensor] = None + + +@dataclass +class LigerMoeCausalLMOutputWithPast(MoeCausalLMOutputWithPast): + token_accuracy: Optional[torch.FloatTensor] = None + + +if _Gemma3CausalLMOutputWithPast is not None: + + @dataclass + class LigerGemma3CausalLMOutputWithPast(_Gemma3CausalLMOutputWithPast): + token_accuracy: Optional[torch.FloatTensor] = None + + +if _Glm4vMoeCausalLMOutputWithPast is not None: + + @dataclass + class LigerGlm4vMoeCausalLMOutputWithPast(_Glm4vMoeCausalLMOutputWithPast): + token_accuracy: Optional[torch.FloatTensor] = None + + +if _LlavaCausalLMOutputWithPast is not None: + + @dataclass + class LigerLlavaCausalLMOutputWithPast(_LlavaCausalLMOutputWithPast): + token_accuracy: Optional[torch.FloatTensor] = None + + +if _InternVLCausalLMOutputWithPast is not None: + + @dataclass + class LigerInternVLCausalLMOutputWithPast(_InternVLCausalLMOutputWithPast): + token_accuracy: Optional[torch.FloatTensor] = None + + +if _PaliGemmaCausalLMOutputWithPast is not None: + + @dataclass + class LigerPaliGemmaCausalLMOutputWithPast(_PaliGemmaCausalLMOutputWithPast): + token_accuracy: Optional[torch.FloatTensor] = None + + +if _Qwen2_5_VLCausalLMOutputWithPast is not None: + + @dataclass + class LigerQwen2_5_VLCausalLMOutputWithPast(_Qwen2_5_VLCausalLMOutputWithPast): + token_accuracy: Optional[torch.FloatTensor] = None + + +if _Qwen2VLCausalLMOutputWithPast is not None: + + @dataclass + class LigerQwen2VLCausalLMOutputWithPast(_Qwen2VLCausalLMOutputWithPast): + token_accuracy: Optional[torch.FloatTensor] = None + + +if _Qwen3VLCausalLMOutputWithPast is not None: + + @dataclass + class LigerQwen3VLCausalLMOutputWithPast(_Qwen3VLCausalLMOutputWithPast): + token_accuracy: Optional[torch.FloatTensor] = None + + +if _Qwen3VLMoeCausalLMOutputWithPast is not None: + + @dataclass + class LigerQwen3VLMoeCausalLMOutputWithPast(_Qwen3VLMoeCausalLMOutputWithPast): + token_accuracy: Optional[torch.FloatTensor] = None diff --git a/src/liger_kernel/transformers/model/paligemma.py b/src/liger_kernel/transformers/model/paligemma.py index 12c4a7dfe..4f1afed43 100644 --- a/src/liger_kernel/transformers/model/paligemma.py +++ b/src/liger_kernel/transformers/model/paligemma.py @@ -13,6 +13,9 @@ from transformers.utils.deprecation import deprecate_kwarg from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss +from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss +from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result +from liger_kernel.transformers.model.output_classes import LigerPaliGemmaCausalLMOutputWithPast logger = logging.get_logger(__name__) @@ -218,7 +221,7 @@ def lce_forward( logits_to_keep: Union[int, torch.Tensor] = 0, skip_logits: Optional[bool] = None, **lm_kwargs, -) -> Union[Tuple, PaliGemmaCausalLMOutputWithPast]: +) -> Union[Tuple, LigerPaliGemmaCausalLMOutputWithPast]: r""" Args: labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -331,6 +334,7 @@ def lce_forward( loss = None logits = None + token_accuracy = None if skip_logits and labels is None: raise ValueError("skip_logits is True, but labels is None") @@ -358,8 +362,16 @@ def lce_forward( shift_hidden_states = shift_hidden_states.view(-1, self.config.text_config.hidden_size) shift_labels = shift_labels.view(-1).to(hidden_device) - lce = LigerFusedLinearCrossEntropyLoss() - loss = lce(self.language_model.lm_head.weight, shift_hidden_states, shift_labels) + # Use LigerForCausalLMLoss with accuracy support and pass already shifted labels + result = LigerForCausalLMLoss( + hidden_states=shift_hidden_states, + lm_head_weight=self.language_model.lm_head.weight, + labels=None, + shift_labels=shift_labels, + hidden_size=self.config.text_config.hidden_size, + **lm_kwargs, + ) + loss, _, token_accuracy = unpack_cross_entropy_result(result) else: logits = self.language_model.lm_head(hidden_states) if labels is not None: @@ -401,15 +413,20 @@ def lce_forward( flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size) flat_labels = shift_labels.view(-1).to(shift_logits.device) loss = loss_fct(flat_logits, flat_labels) + if not return_dict: output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output + output = (loss,) + output if loss is not None else output + output = output + (token_accuracy,) if token_accuracy is not None else output + return output - return PaliGemmaCausalLMOutputWithPast( + # Return PaliGemma output with token_accuracy field + return LigerPaliGemmaCausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, image_hidden_states=image_features if pixel_values is not None else None, + token_accuracy=token_accuracy, ) diff --git a/src/liger_kernel/transformers/model/phi3.py b/src/liger_kernel/transformers/model/phi3.py index c1f54b382..341f45214 100644 --- a/src/liger_kernel/transformers/model/phi3.py +++ b/src/liger_kernel/transformers/model/phi3.py @@ -6,9 +6,10 @@ import torch from transformers.modeling_outputs import BaseModelOutputWithPast -from transformers.modeling_outputs import CausalLMOutputWithPast from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss +from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result +from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast def lce_forward( @@ -27,7 +28,7 @@ def lce_forward( logits_to_keep: Union[int, torch.Tensor] = 0, skip_logits: Optional[bool] = None, **kwargs, -) -> Union[Tuple, CausalLMOutputWithPast]: +) -> Union[Tuple, LigerCausalLMOutputWithPast]: r""" Example: @@ -71,6 +72,7 @@ def lce_forward( shift_labels = kwargs.pop("shift_labels", None) logits = None loss = None + token_accuracy = None if skip_logits and labels is None and shift_labels is None: raise ValueError("skip_logits is True, but labels and shift_labels are None") @@ -79,8 +81,9 @@ def lce_forward( # By default, if in training mode, don't materialize logits skip_logits = self.training and (labels is not None or shift_labels is not None) + # Compute loss if skip_logits: - loss = LigerForCausalLMLoss( + result = LigerForCausalLMLoss( hidden_states=kept_hidden_states, lm_head_weight=self.lm_head.weight, labels=labels, @@ -88,7 +91,7 @@ def lce_forward( hidden_size=self.config.hidden_size, **kwargs, ) - + loss, _, token_accuracy = unpack_cross_entropy_result(result) else: logits = self.lm_head(kept_hidden_states) if labels is not None or shift_labels is not None: @@ -101,13 +104,17 @@ def lce_forward( ) if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output + output_tuple = (logits,) + outputs[1:] + output = (loss,) + output_tuple if loss is not None else output_tuple + output = output + (token_accuracy,) if token_accuracy is not None else output + return output - return CausalLMOutputWithPast( + # Return custom output class with token_accuracy field + return LigerCausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, + token_accuracy=token_accuracy, ) diff --git a/src/liger_kernel/transformers/model/qwen2.py b/src/liger_kernel/transformers/model/qwen2.py index f55091632..0bf8d8c29 100644 --- a/src/liger_kernel/transformers/model/qwen2.py +++ b/src/liger_kernel/transformers/model/qwen2.py @@ -11,6 +11,8 @@ from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss +from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result +from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast def lce_forward_deprecated( @@ -145,7 +147,7 @@ def lce_forward( logits_to_keep: Union[int, torch.Tensor] = 0, skip_logits: Optional[bool] = None, **kwargs, -) -> Union[Tuple, CausalLMOutputWithPast]: +) -> Union[Tuple, LigerCausalLMOutputWithPast]: r""" Args: labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -208,6 +210,7 @@ def lce_forward( shift_labels = kwargs.pop("shift_labels", None) logits = None loss = None + token_accuracy = None if skip_logits and labels is None and shift_labels is None: raise ValueError("skip_logits is True, but labels and shift_labels are None") @@ -216,8 +219,9 @@ def lce_forward( # By default, if in training mode, don't materialize logits skip_logits = self.training and (labels is not None or shift_labels is not None) + # Compute loss if skip_logits: - loss = LigerForCausalLMLoss( + result = LigerForCausalLMLoss( hidden_states=kept_hidden_states, lm_head_weight=self.lm_head.weight, labels=labels, @@ -225,6 +229,7 @@ def lce_forward( hidden_size=self.config.hidden_size, **kwargs, ) + loss, _, token_accuracy = unpack_cross_entropy_result(result) else: logits = self.lm_head(kept_hidden_states) @@ -237,10 +242,18 @@ def lce_forward( **kwargs, ) - return CausalLMOutputWithPast( + if not return_dict: + output_tuple = (logits,) + outputs[1:] + output = (loss,) + output_tuple if loss is not None else output_tuple + output = output + (token_accuracy,) if token_accuracy is not None else output + return output + + # Return custom output class with token accuracy field + return LigerCausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, + token_accuracy=token_accuracy, ) diff --git a/src/liger_kernel/transformers/model/qwen2_5_vl.py b/src/liger_kernel/transformers/model/qwen2_5_vl.py index 2c91271b7..b0d816ea9 100644 --- a/src/liger_kernel/transformers/model/qwen2_5_vl.py +++ b/src/liger_kernel/transformers/model/qwen2_5_vl.py @@ -5,10 +5,11 @@ import torch -from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLCausalLMOutputWithPast from transformers.utils import can_return_tuple from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss +from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result +from liger_kernel.transformers.model.output_classes import LigerQwen2_5_VLCausalLMOutputWithPast @can_return_tuple @@ -33,7 +34,7 @@ def lce_forward( second_per_grid_ts: Optional[torch.Tensor] = None, skip_logits: Optional[bool] = None, **kwargs, -) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]: +) -> Union[Tuple, LigerQwen2_5_VLCausalLMOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., @@ -113,6 +114,7 @@ def lce_forward( shift_labels = kwargs.pop("shift_labels", None) loss = None logits = None + token_accuracy = None if skip_logits and labels is None and shift_labels is None: raise ValueError("skip_logits is True, but labels and shift_labels are None") @@ -120,8 +122,9 @@ def lce_forward( if skip_logits is None: skip_logits = self.training and (labels is not None or shift_labels is not None) + # Compute loss if skip_logits: - loss = LigerForCausalLMLoss( + result = LigerForCausalLMLoss( hidden_states=hidden_states, lm_head_weight=self.lm_head.weight, labels=labels, @@ -129,6 +132,7 @@ def lce_forward( hidden_size=self.config.hidden_size, **kwargs, ) + loss, _, token_accuracy = unpack_cross_entropy_result(result) else: logits = self.lm_head(hidden_states) @@ -142,14 +146,18 @@ def lce_forward( ) if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output + output_tuple = (logits,) + outputs[1:] + output = (loss,) + output_tuple if loss is not None else output_tuple + output = output + (token_accuracy,) if token_accuracy is not None else output + return output - return Qwen2_5_VLCausalLMOutputWithPast( + # Return Qwen2.5-VL output with token accuracy + return LigerQwen2_5_VLCausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, rope_deltas=outputs.rope_deltas, + token_accuracy=token_accuracy, ) diff --git a/src/liger_kernel/transformers/model/qwen2_vl.py b/src/liger_kernel/transformers/model/qwen2_vl.py index b6b6653c5..b290d349a 100644 --- a/src/liger_kernel/transformers/model/qwen2_vl.py +++ b/src/liger_kernel/transformers/model/qwen2_vl.py @@ -5,10 +5,11 @@ import torch -from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLCausalLMOutputWithPast from transformers.utils import can_return_tuple from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss +from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result +from liger_kernel.transformers.model.output_classes import LigerQwen2VLCausalLMOutputWithPast @can_return_tuple @@ -32,7 +33,7 @@ def lce_forward( cache_position: Optional[torch.LongTensor] = None, skip_logits: Optional[bool] = None, **kwargs, -) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]: +) -> Union[Tuple, LigerQwen2VLCausalLMOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., @@ -109,6 +110,7 @@ def lce_forward( shift_labels = kwargs.pop("shift_labels", None) loss = None logits = None + token_accuracy = None if skip_logits and labels is None and shift_labels is None: raise ValueError("skip_logits is True, but labels and shift_labels are None") @@ -116,8 +118,9 @@ def lce_forward( if skip_logits is None: skip_logits = self.training and (labels is not None or shift_labels is not None) + # Compute loss if skip_logits: - loss = LigerForCausalLMLoss( + result = LigerForCausalLMLoss( hidden_states=hidden_states, lm_head_weight=self.lm_head.weight, labels=labels, @@ -125,6 +128,7 @@ def lce_forward( hidden_size=self.config.hidden_size, **kwargs, ) + loss, _, token_accuracy = unpack_cross_entropy_result(result) else: logits = self.lm_head(hidden_states) @@ -137,11 +141,19 @@ def lce_forward( vocab_size=self.config.vocab_size, ) - return Qwen2VLCausalLMOutputWithPast( + if not return_dict: + output_tuple = (logits,) + outputs[1:] + output = (loss,) + output_tuple if loss is not None else output_tuple + output = output + (token_accuracy,) if token_accuracy is not None else output + return output + + # Return Qwen2VL output with token accuracy + return LigerQwen2VLCausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, rope_deltas=outputs.rope_deltas, + token_accuracy=token_accuracy, ) diff --git a/src/liger_kernel/transformers/model/qwen3.py b/src/liger_kernel/transformers/model/qwen3.py index 348a91380..2a68e159b 100644 --- a/src/liger_kernel/transformers/model/qwen3.py +++ b/src/liger_kernel/transformers/model/qwen3.py @@ -4,9 +4,9 @@ import torch -from transformers.modeling_outputs import CausalLMOutputWithPast - from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss +from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result +from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast def lce_forward( @@ -23,8 +23,9 @@ def lce_forward( cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, skip_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, **kwargs, -) -> CausalLMOutputWithPast: +) -> LigerCausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., @@ -60,6 +61,7 @@ def lce_forward( output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( @@ -83,6 +85,7 @@ def lce_forward( shift_labels = kwargs.pop("shift_labels", None) logits = None loss = None + token_accuracy = None if skip_logits and labels is None and shift_labels is None: raise ValueError("skip_logits is True, but labels and shift_labels are None") @@ -91,8 +94,9 @@ def lce_forward( # By default, if in training mode, don't materialize logits skip_logits = self.training and (labels is not None or shift_labels is not None) + # Compute loss if skip_logits: - loss = LigerForCausalLMLoss( + result = LigerForCausalLMLoss( hidden_states=kept_hidden_states, lm_head_weight=self.lm_head.weight, labels=labels, @@ -100,6 +104,7 @@ def lce_forward( hidden_size=self.config.hidden_size, **kwargs, ) + loss, _, token_accuracy = unpack_cross_entropy_result(result) else: logits = self.lm_head(kept_hidden_states) @@ -112,10 +117,18 @@ def lce_forward( **kwargs, ) - return CausalLMOutputWithPast( + if not return_dict: + output = (logits,) + outputs[1:] + output = ((loss,) + output) if loss is not None else output + output = output + (token_accuracy,) if token_accuracy is not None else output + return output + + # Return custom output class with accuracy field + return LigerCausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, + token_accuracy=token_accuracy, ) diff --git a/src/liger_kernel/transformers/model/qwen3_moe.py b/src/liger_kernel/transformers/model/qwen3_moe.py index 133d65b4d..4ceffa6c5 100644 --- a/src/liger_kernel/transformers/model/qwen3_moe.py +++ b/src/liger_kernel/transformers/model/qwen3_moe.py @@ -4,11 +4,12 @@ import torch -from transformers.modeling_outputs import MoeCausalLMOutputWithPast from transformers.modeling_outputs import MoeModelOutputWithPast from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss +from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result +from liger_kernel.transformers.model.output_classes import LigerMoeCausalLMOutputWithPast def lce_forward( @@ -26,8 +27,9 @@ def lce_forward( cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, skip_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, **kwargs, -) -> MoeCausalLMOutputWithPast: +) -> LigerMoeCausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., @@ -64,10 +66,10 @@ def lce_forward( output_router_logits = ( output_router_logits if output_router_logits is not None else self.config.output_router_logits ) - output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs: MoeModelOutputWithPast = self.model( @@ -92,12 +94,14 @@ def lce_forward( shift_labels = kwargs.pop("shift_labels", None) logits = None loss = None + token_accuracy = None if skip_logits is None: skip_logits = self.training and (labels is not None or shift_labels is not None) + # Compute loss if skip_logits: - loss = LigerForCausalLMLoss( + result = LigerForCausalLMLoss( hidden_states=kept_hidden_states, lm_head_weight=self.lm_head.weight, labels=labels, @@ -105,6 +109,7 @@ def lce_forward( hidden_size=self.config.hidden_size, **kwargs, ) + loss, _, token_accuracy = unpack_cross_entropy_result(result) else: # if in inference model materialize logits logits = self.lm_head(kept_hidden_states) if labels is not None or shift_labels is not None: @@ -127,7 +132,15 @@ def lce_forward( if labels is not None: loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device - return MoeCausalLMOutputWithPast( + if not return_dict: + output = (logits,) + outputs[1:] + output = ((aux_loss,) + output) if aux_loss is not None else output + output = ((loss,) + output) if loss is not None else output + output = output + (token_accuracy,) if token_accuracy is not None else output + return output + + # Return custom output class with accuracy field + return LigerMoeCausalLMOutputWithPast( loss=loss, aux_loss=aux_loss, logits=logits, @@ -135,4 +148,5 @@ def lce_forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, router_logits=outputs.router_logits, + token_accuracy=token_accuracy, ) diff --git a/src/liger_kernel/transformers/model/qwen3_next.py b/src/liger_kernel/transformers/model/qwen3_next.py index 0f20f8478..1d7666703 100644 --- a/src/liger_kernel/transformers/model/qwen3_next.py +++ b/src/liger_kernel/transformers/model/qwen3_next.py @@ -5,13 +5,14 @@ import torch -from transformers.modeling_outputs import MoeCausalLMOutputWithPast from transformers.modeling_outputs import MoeModelOutputWithPast if TYPE_CHECKING: from transformers.models.qwen3_next.modeling_qwen3_next import load_balancing_loss_func from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss +from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result +from liger_kernel.transformers.model.output_classes import LigerMoeCausalLMOutputWithPast def lce_forward( @@ -29,8 +30,9 @@ def lce_forward( cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, skip_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, **kwargs, -) -> MoeCausalLMOutputWithPast: +) -> LigerMoeCausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., @@ -66,10 +68,10 @@ def lce_forward( output_router_logits = ( output_router_logits if output_router_logits is not None else self.config.output_router_logits ) - output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs: MoeModelOutputWithPast = self.model( @@ -94,12 +96,13 @@ def lce_forward( shift_labels = kwargs.pop("shift_labels", None) logits = None loss = None + token_accuracy = None if skip_logits is None: skip_logits = self.training and (labels is not None or shift_labels is not None) if skip_logits: - loss = LigerForCausalLMLoss( + result = LigerForCausalLMLoss( hidden_states=kept_hidden_states, lm_head_weight=self.lm_head.weight, labels=labels, @@ -107,6 +110,7 @@ def lce_forward( hidden_size=self.config.hidden_size, **kwargs, ) + loss, _, token_accuracy = unpack_cross_entropy_result(result) else: # if in inference model materialize logits logits = self.lm_head(kept_hidden_states) if labels is not None or shift_labels is not None: @@ -123,7 +127,14 @@ def lce_forward( if labels is not None: loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device - return MoeCausalLMOutputWithPast( + if not return_dict: + output = (logits,) + outputs[1:] + output = ((aux_loss,) + output) if aux_loss is not None else output + output = ((loss,) + output) if loss is not None else output + output = output + (token_accuracy,) if token_accuracy is not None else output + return output + + return LigerMoeCausalLMOutputWithPast( loss=loss, aux_loss=aux_loss, logits=logits, @@ -131,4 +142,5 @@ def lce_forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, router_logits=outputs.router_logits, + token_accuracy=token_accuracy, ) diff --git a/src/liger_kernel/transformers/model/qwen3_vl.py b/src/liger_kernel/transformers/model/qwen3_vl.py index 62a554555..84eeea2b7 100644 --- a/src/liger_kernel/transformers/model/qwen3_vl.py +++ b/src/liger_kernel/transformers/model/qwen3_vl.py @@ -5,10 +5,11 @@ import torch -from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLCausalLMOutputWithPast from transformers.utils import can_return_tuple from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss +from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result +from liger_kernel.transformers.model.output_classes import LigerQwen3VLCausalLMOutputWithPast @can_return_tuple @@ -33,7 +34,7 @@ def lce_forward( second_per_grid_ts: Optional[torch.Tensor] = None, skip_logits: Optional[bool] = None, **kwargs, -) -> Union[Tuple, Qwen3VLCausalLMOutputWithPast]: +) -> Union[Tuple, LigerQwen3VLCausalLMOutputWithPast]: """ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., @@ -107,6 +108,7 @@ def lce_forward( shift_labels = kwargs.pop("shift_labels", None) loss = None logits = None + token_accuracy = None if skip_logits and labels is None and shift_labels is None: raise ValueError("skip_logits is True, but labels and shift_labels are None") @@ -115,7 +117,7 @@ def lce_forward( skip_logits = self.training and (labels is not None or shift_labels is not None) if skip_logits: - loss = LigerForCausalLMLoss( + result = LigerForCausalLMLoss( hidden_states=hidden_states, lm_head_weight=self.lm_head.weight, labels=labels, @@ -123,6 +125,7 @@ def lce_forward( hidden_size=self.config.text_config.hidden_size, **kwargs, ) + loss, _, token_accuracy = unpack_cross_entropy_result(result) else: logits = self.lm_head(hidden_states) @@ -132,13 +135,16 @@ def lce_forward( if not return_dict: output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output + output = (loss,) + output if loss is not None else output + output = output + (token_accuracy,) if token_accuracy is not None else output + return output - return Qwen3VLCausalLMOutputWithPast( + return LigerQwen3VLCausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, rope_deltas=outputs.rope_deltas, + token_accuracy=token_accuracy, ) diff --git a/src/liger_kernel/transformers/model/qwen3_vl_moe.py b/src/liger_kernel/transformers/model/qwen3_vl_moe.py index 4a82131f6..3e8423d09 100644 --- a/src/liger_kernel/transformers/model/qwen3_vl_moe.py +++ b/src/liger_kernel/transformers/model/qwen3_vl_moe.py @@ -5,11 +5,12 @@ import torch -from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeCausalLMOutputWithPast from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import load_balancing_loss_func from transformers.utils import can_return_tuple from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss +from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result +from liger_kernel.transformers.model.output_classes import LigerQwen3VLMoeCausalLMOutputWithPast @can_return_tuple @@ -34,7 +35,7 @@ def lce_forward( second_per_grid_ts: Optional[torch.Tensor] = None, skip_logits: Optional[bool] = None, **kwargs, -) -> Union[Tuple, Qwen3VLMoeCausalLMOutputWithPast]: +) -> Union[Tuple, LigerQwen3VLMoeCausalLMOutputWithPast]: """ Qwen3-VL-MoE forward with fused linear cross entropy support mirroring Qwen3-VL behaviour. """ @@ -69,6 +70,7 @@ def lce_forward( shift_labels = kwargs.pop("shift_labels", None) loss = None logits = None + token_accuracy = None if skip_logits and labels is None and shift_labels is None: raise ValueError("skip_logits is True, but labels and shift_labels are None") @@ -77,7 +79,7 @@ def lce_forward( skip_logits = self.training and (labels is not None or shift_labels is not None) if skip_logits: - loss = LigerForCausalLMLoss( + result = LigerForCausalLMLoss( hidden_states=hidden_states, lm_head_weight=self.lm_head.weight, labels=labels, @@ -85,6 +87,7 @@ def lce_forward( hidden_size=self.config.text_config.hidden_size, **kwargs, ) + loss, _, token_accuracy = unpack_cross_entropy_result(result) else: logits = self.lm_head(hidden_states) @@ -106,9 +109,12 @@ def lce_forward( if not return_dict: output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output + output = (loss,) + output if loss is not None else output + output = output + (aux_loss,) if aux_loss is not None else output + output = output + (token_accuracy,) if token_accuracy is not None else output + return output - return Qwen3VLMoeCausalLMOutputWithPast( + return LigerQwen3VLMoeCausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, @@ -116,4 +122,5 @@ def lce_forward( attentions=outputs.attentions, rope_deltas=outputs.rope_deltas, aux_loss=aux_loss, + token_accuracy=token_accuracy, ) diff --git a/src/liger_kernel/transformers/model/smollm3.py b/src/liger_kernel/transformers/model/smollm3.py index 94bc63086..8d4dcec5b 100644 --- a/src/liger_kernel/transformers/model/smollm3.py +++ b/src/liger_kernel/transformers/model/smollm3.py @@ -7,11 +7,12 @@ import torch from torch.distributed.fsdp import FullyShardedDataParallel -from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.utils.deprecation import deprecate_kwarg from liger_kernel.transformers.fsdp import _FSDPForwardRedirection from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss +from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result +from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast from liger_kernel.utils import PEFT_AVAILABLE if TYPE_CHECKING: @@ -38,7 +39,7 @@ def lce_forward( logits_to_keep: Union[int, torch.Tensor] = 0, skip_logits: Optional[bool] = None, **kwargs, -) -> Union[Tuple, CausalLMOutputWithPast]: +) -> Union[Tuple, LigerCausalLMOutputWithPast]: r""" Args: labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -101,6 +102,8 @@ def lce_forward( shift_labels = kwargs.pop("shift_labels", None) logits = None loss = None + token_accuracy = None + # if in training mode, don't materialize logits if skip_logits and labels is None and shift_labels is None: raise ValueError("skip_logits is True, but labels and shift_labels are None") @@ -109,8 +112,9 @@ def lce_forward( # By default, if in training mode, don't materialize logits skip_logits = self.training and (labels is not None or shift_labels is not None) + # Compute loss if skip_logits: - loss = lce_maybe_trainable_lm_head( + result = lce_maybe_trainable_lm_head( self, hidden_states=kept_hidden_states, hidden_size=self.config.hidden_size, @@ -118,6 +122,7 @@ def lce_forward( shift_labels=shift_labels, **kwargs, ) + loss, _, token_accuracy = unpack_cross_entropy_result(result) else: logits = self.lm_head(kept_hidden_states) @@ -131,15 +136,19 @@ def lce_forward( ) if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output + output_tuple = (logits,) + outputs[1:] + output = (loss,) + output_tuple if loss is not None else output_tuple + output = output + (token_accuracy,) if token_accuracy is not None else output + return output - return CausalLMOutputWithPast( + # Return custom output class with token_accuracy field + return LigerCausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, + token_accuracy=token_accuracy, ) diff --git a/test/transformers/test_cross_entropy.py b/test/transformers/test_cross_entropy.py index 5a98bb1f1..e07689770 100644 --- a/test/transformers/test_cross_entropy.py +++ b/test/transformers/test_cross_entropy.py @@ -11,6 +11,7 @@ from liger_kernel.ops.cross_entropy import liger_cross_entropy_kernel from liger_kernel.ops.utils import is_hip from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss +from liger_kernel.transformers.functional import CrossEntropyOutput from liger_kernel.transformers.functional import liger_cross_entropy from liger_kernel.utils import infer_device @@ -217,8 +218,12 @@ def _test_correctness_with_z_loss_once( target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) if return_z_loss: output, z_output = torch_ce(_input, target) - output2, z_output2 = target_ce(_input2, target) - + result2 = target_ce(_input2, target) + if isinstance(result2, CrossEntropyOutput): + output2 = result2.loss + z_output2 = result2.z_loss + else: + output2, z_output2 = result2 else: output = torch_ce(_input, target) output2 = target_ce(_input2, target) @@ -274,8 +279,9 @@ def _test_correctness_with_z_loss_with_other_params_once( if return_z_loss: output, z_output = torch_ce(_input, target) - output2, z_output2 = target_ce(_input2, target) - + result2 = target_ce(_input2, target) + output2 = result2.loss + z_output2 = result2.z_loss else: output = torch_ce(_input, target) output2 = target_ce(_input2, target) @@ -493,7 +499,7 @@ def _test_correctness_functional( target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) - y1, y1_z = liger_cross_entropy( + result = liger_cross_entropy( x1, target, None, @@ -504,7 +510,9 @@ def _test_correctness_functional( softcap=30.0, return_z_loss=True, ) - y2, y2_z = LigerCrossEntropyFunction.apply(x2, target, None, 0, 1e-4, 0.1, "mean", 30.0, True) + y1 = result.loss + y1_z = result.z_loss + y2, y2_z, _ = LigerCrossEntropyFunction.apply(x2, target, None, 0, 1e-4, 0.1, "mean", 30.0, True, False) assert torch.allclose(y1, y2, atol=atol, rtol=rtol) assert torch.allclose(y1_z, y2_z, atol=atol, rtol=rtol) @@ -1015,6 +1023,7 @@ def test_float32_internal(): # Run kernel for bfloat16 X_bf16 = X_init.clone() loss_bf16 = torch.zeros(batch_size, dtype=torch.float32, device=device) + token_accuracy_bf16 = torch.zeros(batch_size, dtype=torch.float32, device=device) liger_cross_entropy_kernel[(batch_size,)]( X_ptr=X_bf16, X_stride=X_bf16.stride(-2), @@ -1024,6 +1033,8 @@ def test_float32_internal(): z_loss_ptr=loss_bf16, # dummy ptr, not used loss_ptr=loss_bf16, loss_stride=loss_bf16.stride(-1), + token_accuracy_ptr=token_accuracy_bf16, + token_accuracy_stride=token_accuracy_bf16.stride(-1), n_cols=n_cols, n_non_ignore=n_non_ignore, sum_non_ignore_weight=n_non_ignore, # not used @@ -1034,6 +1045,7 @@ def test_float32_internal(): reduction=reduction, softcap=softcap, RETURN_Z_LOSS=0, # False + RETURN_TOKEN_ACCURACY=0, HAS_WEIGHT=False, HAS_SOFTCAPPING=False, HAS_GRADIENTS=True, @@ -1044,6 +1056,7 @@ def test_float32_internal(): # Run kernel for float32 X_fp32 = X_init.float() loss_fp32 = torch.zeros(batch_size, dtype=torch.float32, device=device) + token_accuracy_fp32 = torch.zeros(batch_size, dtype=torch.float32, device=device) liger_cross_entropy_kernel[(batch_size,)]( X_ptr=X_fp32, X_stride=X_fp32.stride(-2), @@ -1053,6 +1066,8 @@ def test_float32_internal(): loss_ptr=loss_fp32, z_loss_ptr=loss_fp32, # dummy ptr, not used loss_stride=loss_fp32.stride(-1), + token_accuracy_ptr=token_accuracy_fp32, + token_accuracy_stride=token_accuracy_fp32.stride(-1), n_cols=n_cols, n_non_ignore=n_non_ignore, sum_non_ignore_weight=n_non_ignore, # not used @@ -1063,6 +1078,7 @@ def test_float32_internal(): reduction=reduction, softcap=softcap, RETURN_Z_LOSS=0, # False + RETURN_TOKEN_ACCURACY=0, HAS_WEIGHT=False, HAS_SOFTCAPPING=False, HAS_GRADIENTS=True, @@ -1106,3 +1122,62 @@ def test_correctness_with_out_of_bounds_target_once(B, T, V, ignore_index): def test_correctness_with_forward_only(B, T, V, ignore_index, reduction, dtype, scalar, atol, rtol): liger_ce = LigerCrossEntropyLoss(ignore_index=ignore_index, reduction=reduction) _test_correctness_with_forward_only(liger_ce, B, T, V, reduction, dtype, scalar, atol, rtol) + + +@pytest.mark.parametrize( + "return_z_loss, return_token_accuracy", + [ + (False, False), + (True, False), + (False, True), + (True, True), + ], +) +def test_liger_cross_entropy_structured_output(return_z_loss, return_token_accuracy): + logits = torch.tensor( + [[2.0, 0.5, -1.0], [0.1, 1.5, 0.3], [0.7, -0.2, 0.9]], + device=device, + requires_grad=True, + ) + targets = torch.tensor([0, 1, 2], device=device) + + original_logits = logits.detach().clone() + + result = liger_cross_entropy( + logits, + targets, + reduction="mean", + return_z_loss=return_z_loss, + return_token_accuracy=return_token_accuracy, + ) + + if not return_z_loss and not return_token_accuracy: + assert isinstance(result, torch.Tensor) + assert result.shape == () + result.backward() + assert logits.grad is not None + logits.grad.zero_() + return + + assert isinstance(result, CrossEntropyOutput) + assert result.loss.shape == () + + if return_z_loss: + assert result.z_loss is not None + assert isinstance(result.z_loss, torch.Tensor) + else: + assert result.z_loss is None + + if return_token_accuracy: + assert result.token_accuracy is not None + with torch.no_grad(): + predictions = original_logits.argmax(dim=-1) + correct = (predictions == targets).float() + expected_accuracy = correct.mean() + assert torch.allclose(result.token_accuracy, expected_accuracy, atol=1e-6) + else: + assert result.token_accuracy is None + + result.loss.backward() + assert logits.grad is not None + logits.grad.zero_() diff --git a/test/transformers/test_fused_linear_cross_entropy.py b/test/transformers/test_fused_linear_cross_entropy.py index 3d409e745..d9dd41e61 100644 --- a/test/transformers/test_fused_linear_cross_entropy.py +++ b/test/transformers/test_fused_linear_cross_entropy.py @@ -8,6 +8,7 @@ from test.utils import set_seed from liger_kernel.ops.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction +from liger_kernel.transformers.functional import CrossEntropyOutput from liger_kernel.transformers.functional import liger_fused_linear_cross_entropy from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss from liger_kernel.utils import infer_device @@ -200,7 +201,10 @@ def test_correctness( if return_z_loss: output1, z_output1 = torch_lm_head_ce(_input1, target) - output2, z_output2 = liger_lm_head_ce(_input2, target) + result2 = liger_lm_head_ce(_input2, target) + assert isinstance(result2, CrossEntropyOutput) + output2 = result2.loss + z_output2 = result2.z_loss else: output1 = torch_lm_head_ce(_input1, target) output2 = liger_lm_head_ce(_input2, target) @@ -328,7 +332,10 @@ def test_correctness_with_forward_only( with torch.no_grad(): if return_z_loss: output1, z_output1 = torch_lm_head_ce(_input1, target) - output2, z_output2 = liger_lm_head_ce(_input2, target) + result2 = liger_lm_head_ce(_input2, target) + assert isinstance(result2, CrossEntropyOutput) + output2 = result2.loss + z_output2 = result2.z_loss else: output1 = torch_lm_head_ce(_input1, target) output2 = liger_lm_head_ce(_input2, target) @@ -372,7 +379,7 @@ def test_correctness_functional(B, T, H, V, scalar, dtype, bias, ce_weight, atol bias = torch.randn(V, device=device, dtype=dtype) if bias else None ce_weight = torch.randn(V, device=device) if ce_weight else None - y1, z1 = liger_fused_linear_cross_entropy( + result = liger_fused_linear_cross_entropy( input=x1, weight=weight, target=target, @@ -386,8 +393,14 @@ def test_correctness_functional(B, T, H, V, scalar, dtype, bias, ce_weight, atol return_z_loss=True, accum_dtype=torch.float32, ) - y2, z2 = LigerFusedLinearCrossEntropyFunction.apply( - x2, weight, target, bias, ce_weight, -100, 1e-4, 0.1, "mean", 30.0, True, torch.float32 + if isinstance(result, CrossEntropyOutput): + y1 = result.loss + z1 = result.z_loss + else: + y1, z1 = result + + y2, z2, _ = LigerFusedLinearCrossEntropyFunction.apply( + x2, weight, target, bias, ce_weight, -100, 1e-4, 0.1, "mean", 30.0, True, torch.float32, False, False ) assert torch.allclose(y1, y2, atol=atol, rtol=rtol) @@ -401,6 +414,130 @@ def test_correctness_functional(B, T, H, V, scalar, dtype, bias, ce_weight, atol assert torch.allclose(x1.grad, x2.grad, atol=atol, rtol=rtol) +@pytest.mark.parametrize( + "B, T, H, V", + [ + (8, 128, 1024, 4096), + (4, 47, 31, 123), # random shape + ], +) +@pytest.mark.parametrize( + "reduction, scalar, dtype, atol, rtol", + [ + ("mean", 1.0, torch.bfloat16, 5e-3, 5e-2), + ("mean", 1.0, torch.float32, 1e-5, 5e-4), + ], +) +@pytest.mark.parametrize("bias", [True, False]) +@pytest.mark.parametrize("return_token_accuracy", [True, False]) +def test_correctness_with_token_accuracy( + B, + T, + H, + V, + scalar, + dtype, + bias, + return_token_accuracy, + reduction, + atol, + rtol, +): + """Test that return_token_accuracy flag works correctly.""" + torch_lm_head_ce = TorchLMHeadCE( + H=H, + V=V, + bias=bias, + reduction=reduction, + dtype=dtype, + ).to(device) + liger_lm_head_ce = LigerLMHeadCE( + H=H, + V=V, + bias=bias, + reduction=reduction, + dtype=dtype, + ).to(device) + + # init the linear in all CEs with the same weights + torch_lm_head_ce.lin.weight.data = liger_lm_head_ce.lin.weight.data = torch.rand(V, H, device=device, dtype=dtype) + + if bias: + torch_lm_head_ce.lin.bias.data = liger_lm_head_ce.lin.bias.data = torch.rand(V, device=device, dtype=dtype) + + _tensor = torch.randn(B * T, H, device=device, dtype=dtype) * scalar + _input1 = _tensor.detach().clone().requires_grad_(True) + _input2 = _tensor.detach().clone().requires_grad_(True) + + target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) + # Assign some random number of elements as ignore_index + num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() + indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] + target[indices_to_assign] = -100 + + # Compute with torch (baseline - only loss) + output1 = torch_lm_head_ce(_input1, target) + + # Compute with liger using functional API with return_token_accuracy + result = liger_fused_linear_cross_entropy( + input=_input2, + weight=liger_lm_head_ce.lin.weight, + target=target, + bias=liger_lm_head_ce.lin.bias if bias else None, + ignore_index=-100, + reduction=reduction, + return_token_accuracy=return_token_accuracy, + ) + + if return_token_accuracy: + # Should return structured output with token_accuracy populated + assert isinstance(result, CrossEntropyOutput), "Expected CrossEntropyOutput when return_token_accuracy=True" + output2 = result.loss + token_accuracy = result.token_accuracy + assert token_accuracy is not None, "token_accuracy should not be None" + + # Verify token_accuracy is computed correctly + with torch.no_grad(): + # Compute expected accuracy + logits = _input2 @ liger_lm_head_ce.lin.weight.t() + if bias: + logits = logits + liger_lm_head_ce.lin.bias + predictions = torch.argmax(logits, dim=-1) + mask = target != -100 + correct = (predictions == target) & mask + expected_accuracy = correct.sum().float() / mask.sum().float() + + assert_verbose_allclose(token_accuracy, expected_accuracy, atol=atol, rtol=rtol) + else: + # Should return only loss + output2 = result + assert not isinstance(result, CrossEntropyOutput), "Expected scalar loss when return_token_accuracy=False" + + # Loss should match regardless of return_token_accuracy flag + assert_verbose_allclose(output1, output2, atol=atol, rtol=rtol) + + grad_output = torch.ones_like(output1) + output1.backward(gradient=grad_output) + output2.backward(gradient=grad_output) + + assert_verbose_allclose(_input1.grad, _input2.grad, atol=atol, rtol=rtol) + + assert_verbose_allclose( + torch_lm_head_ce.lin.weight.grad, + liger_lm_head_ce.lin.weight.grad, + atol=atol, + rtol=rtol, + ) + + if bias: + assert_verbose_allclose( + torch_lm_head_ce.lin.bias.grad, + liger_lm_head_ce.lin.bias.grad, + atol=atol, + rtol=rtol, + ) + + @pytest.mark.parametrize( "B, T, H, V", [ @@ -693,6 +830,93 @@ def test_correctness_token_scaling_module(): assert torch.allclose(x1.grad, x2.grad, atol=1e-5, rtol=1e-5) +@pytest.mark.parametrize( + "return_z_loss, return_token_accuracy", + [ + (False, False), + (True, False), + (False, True), + (True, True), + ], +) +def test_liger_fused_linear_cross_entropy_structured_output(return_z_loss, return_token_accuracy): + hidden_states = torch.tensor( + [[0.2, -0.1], [1.0, 0.5], [-0.3, 0.7]], + device=device, + dtype=torch.float32, + requires_grad=True, + ) + weight = torch.tensor( + [[0.5, -0.4], [-0.2, 0.3], [0.1, 0.6]], + device=device, + dtype=torch.float32, + ) + bias = torch.tensor([0.1, -0.2, 0.05], device=device, dtype=torch.float32) + targets = torch.tensor([0, 1, 2], device=device) + + result = liger_fused_linear_cross_entropy( + input=hidden_states, + weight=weight, + target=targets, + bias=bias, + return_z_loss=return_z_loss, + return_token_accuracy=return_token_accuracy, + ) + + logits = hidden_states @ weight.t() + bias + expected_loss = torch.nn.functional.cross_entropy(logits, targets) + + if not return_z_loss and not return_token_accuracy: + assert isinstance(result, torch.Tensor) + assert torch.allclose(result, expected_loss, atol=1e-6) + result.backward() + assert hidden_states.grad is not None + hidden_states.grad.zero_() + else: + assert isinstance(result, CrossEntropyOutput) + assert torch.allclose(result.loss, expected_loss, atol=1e-6) + + if return_z_loss: + assert result.z_loss is not None + else: + assert result.z_loss is None + + if return_token_accuracy: + assert result.token_accuracy is not None + with torch.no_grad(): + predictions = logits.argmax(dim=-1) + expected_accuracy = (predictions == targets).float().mean() + assert torch.allclose(result.token_accuracy, expected_accuracy, atol=1e-6) + else: + assert result.token_accuracy is None + + result.loss.backward() + assert hidden_states.grad is not None + hidden_states.grad.zero_() + + module = LigerFusedLinearCrossEntropyLoss( + return_z_loss=return_z_loss, + return_token_accuracy=return_token_accuracy, + ) + + module_result = module(weight, hidden_states, targets, bias) + + if not return_z_loss and not return_token_accuracy: + assert isinstance(module_result, torch.Tensor) + assert torch.allclose(module_result, expected_loss, atol=1e-6) + else: + assert isinstance(module_result, CrossEntropyOutput) + assert torch.allclose(module_result.loss, expected_loss, atol=1e-6) + if return_z_loss: + assert module_result.z_loss is not None + else: + assert module_result.z_loss is None + if return_token_accuracy: + assert module_result.token_accuracy is not None + else: + assert module_result.token_accuracy is None + + def test_token_scaling_with_ignore_index(): """Test token scaling when some targets have ignore_index values.""" B, T, H, V = 2, 4, 8, 1000