Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions unsloth/kernels/rope_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,8 +312,8 @@ def forward(ctx, Q, K, cos, sin, rope_indices):
_, n_heads_K, _, _ = K.shape

# Inplace rotary embedding is generally fine
Q_out = Q.clone() if not Q.is_contiguous else Q
K_out = K.clone() if not K.is_contiguous else K
Q_out = Q.clone() if not Q.is_contiguous() else Q
K_out = K.clone() if not K.is_contiguous() else K

if has_indices:
# TRL's rotary indices are always in int32, so casting is just for safety
Expand Down Expand Up @@ -383,21 +383,21 @@ def backward(ctx, dQ, dK):
else ctx.cos.new_empty(1, dtype = torch.int32)
)

# Inplace rotary embedding is generally fine
dQ_out = dQ.clone() if not dQ.is_contiguous() else dQ
dK_out = dK.clone() if not dK.is_contiguous() else dK

Q_batch_stride, Q_head_stride, Q_seq_stride = (
dQ.stride(0),
dQ.stride(1),
dQ.stride(2),
dQ_out.stride(0),
dQ_out.stride(1),
dQ_out.stride(2),
)
K_batch_stride, K_head_stride, K_seq_stride = (
dK.stride(0),
dK.stride(1),
dK.stride(2),
dK_out.stride(0),
dK_out.stride(1),
dK_out.stride(2),
)

# Inplace rotary embedding is generally fine
dQ_out = dQ.clone() if not dQ.is_contiguous else dQ
dK_out = dK.clone() if not dK.is_contiguous else dK

with torch_gpu_device(dQ.device):
_rope_embedding_QK[(batch * ctx.seq_len, ctx.n_heads_Q)](
dQ_out,
Expand Down