diff --git a/src/liger_kernel/transformers/fused_linear_cross_entropy.py b/src/liger_kernel/transformers/fused_linear_cross_entropy.py index b1d94adfe..0122a0a76 100644 --- a/src/liger_kernel/transformers/fused_linear_cross_entropy.py +++ b/src/liger_kernel/transformers/fused_linear_cross_entropy.py @@ -25,7 +25,8 @@ def __init__( assert reduction in { "mean", "sum", - }, f"reduction must be 'mean' or 'sum'. Got: {reduction}" + "none", + }, f"reduction must be 'mean' or 'sum' or 'none'. Got: {reduction}" assert softcap is None or softcap > 0, f"softcap must greater than 0.0 or None. Got: {softcap}" self.ce_weight = ce_weight self.ignore_index = ignore_index