diff --git a/unsloth/kernels/rope_embedding.py b/unsloth/kernels/rope_embedding.py index a032e0f7fc..fcc9cb923b 100644 --- a/unsloth/kernels/rope_embedding.py +++ b/unsloth/kernels/rope_embedding.py @@ -312,8 +312,8 @@ def forward(ctx, Q, K, cos, sin, rope_indices): _, n_heads_K, _, _ = K.shape # Inplace rotary embedding is generally fine - Q_out = Q.clone() if not Q.is_contiguous else Q - K_out = K.clone() if not K.is_contiguous else K + Q_out = Q.clone() if not Q.is_contiguous() else Q + K_out = K.clone() if not K.is_contiguous() else K if has_indices: # TRL's rotary indices are always in int32, so casting is just for safety @@ -383,21 +383,21 @@ def backward(ctx, dQ, dK): else ctx.cos.new_empty(1, dtype = torch.int32) ) + # Inplace rotary embedding is generally fine + dQ_out = dQ.clone() if not dQ.is_contiguous() else dQ + dK_out = dK.clone() if not dK.is_contiguous() else dK + Q_batch_stride, Q_head_stride, Q_seq_stride = ( - dQ.stride(0), - dQ.stride(1), - dQ.stride(2), + dQ_out.stride(0), + dQ_out.stride(1), + dQ_out.stride(2), ) K_batch_stride, K_head_stride, K_seq_stride = ( - dK.stride(0), - dK.stride(1), - dK.stride(2), + dK_out.stride(0), + dK_out.stride(1), + dK_out.stride(2), ) - # Inplace rotary embedding is generally fine - dQ_out = dQ.clone() if not dQ.is_contiguous else dQ - dK_out = dK.clone() if not dK.is_contiguous else dK - with torch_gpu_device(dQ.device): _rope_embedding_QK[(batch * ctx.seq_len, ctx.n_heads_Q)]( dQ_out,