diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index cdd5b1cf12..d92229314f 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -438,6 +438,7 @@ def fast_cross_entropy_loss( batch, seq_len, d = logits.shape assert labels.shape == (batch, seq_len) + device = logits.device loss = Fast_CrossEntropyLoss.apply( logits.view(batch * seq_len, d), labels.view(-1), @@ -446,6 +447,8 @@ def fast_cross_entropy_loss( ) if n_items is None: n_items = torch.count_nonzero(labels != -100) + if torch.is_tensor(n_items): + n_items = n_items.to(device) return loss.sum() / n_items