diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index 1bb1691776..d92229314f 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -22,7 +22,6 @@ triton_cast, torch_gpu_device, is_cdna, - is_rdna, ) from transformers.models.llama.modeling_llama import logger from unsloth_zoo.utils import Version @@ -365,7 +364,7 @@ def forward( SOFTCAP = logit_softcapping, DO_LOGIT_SCALING = DO_LOGIT_SCALING, LOGIT_SCALE = logit_scaling, - num_warps = 16 if is_cdna() or is_rdna() else 32, + num_warps = 32 if not is_cdna() else 16, ) # logsumexp(chunked_logsumexp) - x # Do the -x separately