Skip to content
347 changes: 191 additions & 156 deletions unsloth/kernels/fast_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,73 +135,84 @@ def backward(ctx, dY: torch.Tensor):
g = g.view(-1, g.shape[-1])
dtype = X.dtype

gateA, gateB, upA, upB, downA, downB = (
gateA.to(dtype),
gateB.to(dtype),
upA.to(dtype),
upB.to(dtype),
downA.to(dtype),
downB.to(dtype),
)

gateA, gateB, upA, upB, downA, downB = (
gateA.t(),
gateB.t(),
upA.t(),
upB.t(),
downA.t(),
downB.t(),
)

DW = matmul_lora(dY, downW.t(), downW_quant, downB, downA, downS)
DW, e, g = _backward_function(DW, e, g)
h, df, de = DW, e, g

d_downA = torch.empty_like(downA)
d_downB = torch.empty_like(downB)
d_gateA = torch.empty_like(gateA)
d_gateB = torch.empty_like(gateB)
d_upA = torch.empty_like(upA)
d_upB = torch.empty_like(upB)

# Down projection LoRA weights
# d_downA = h.t() @ (dY @ downB.t())
# d_downB = (downA.t() @ h.t()) @ dY
# d_downA *= downS
# d_downB *= downS
d_downA.addmm_(h.t(), dY @ downB.t(), alpha = downS, beta = 0)
d_downB.addmm_(downA.t() @ h.t(), dY, alpha = downS, beta = 0)

# Up projection LoRA weights
# d_upA = X.t() @ (df @ upB.t())
# d_upB = (upA.t() @ X.t()) @ df
# d_upA *= upS
# d_upB *= upS
d_upA.addmm_(X.t(), df @ upB.t(), alpha = upS, beta = 0)
d_upB.addmm_(upA.t() @ X.t(), df, alpha = upS, beta = 0)

# Gate projection LoRA weights
# d_gateA = X.t() @ (de @ gateB.t())
# d_gateB = (gateA.t() @ X.t()) @ de
# d_gateA *= gateS
# d_gateB *= gateS
d_gateA.addmm_(X.t(), de @ gateB.t(), alpha = gateS, beta = 0)
d_gateB.addmm_(gateA.t() @ X.t(), de, alpha = gateS, beta = 0)

# dX = matmul_lora(df, upW.t(), upW_quant, upB, upA, upS)
# dX += matmul_lora(de, gateW.t(), gateW_quant, gateB, gateA, gateS)
upW = fast_dequantize(upW.t(), upW_quant)
dX = torch.matmul(df, upW.t(), out = X if ctx.inplace else None)
del upW
# dX += df @ upB.to(dtype).t() @ (upS * upA.to(dtype).t())
dX.addmm_(df @ upB.t(), upA.t(), alpha = upS)

gateW = fast_dequantize(gateW.t(), gateW_quant)
# dX += de @ gateW.t()
dX.addmm_(de, gateW.t())
del gateW
# dX += de @ gateB.to(dtype).t() @ (gateS * gateA.to(dtype).t())
dX.addmm_(de @ gateB.t(), gateA.t(), alpha = gateS)
# Disable autocast for the entire backward pass.
# @torch_amp_custom_bwd inherits the float16 autocast context from
# TRL's compiled GRPO trainer, which silently downcasts float32 gradient
# tensors to float16 mid-computation, causing addmm_ dtype mismatches.

# Use the tensor's actual device type so this works on CUDA and XPU.
with torch.amp.autocast(X.device.type, enabled = False):
# Cast incoming gradient to the activation dtype.
if dY.dtype != dtype:
dY = dY.to(dtype)
Comment on lines +146 to +147

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Avoid casting grads before quantized matmul dtype is aligned

Casting dY to X.dtype under disabled autocast sends bf16 gradients into matmul_lora, but fast_dequantize still emits weights in quant_state.dtype (often fp16 for bnb-4bit), so DW = matmul_lora(...) can now fail with bf16/fp16 matmul mismatches in the same GRPO+bnb4bit scenario this patch targets. This regression is introduced by the new explicit cast here without a corresponding weight/output cast in matmul_lora.

Useful? React with 👍 / 👎.


gateA, gateB, upA, upB, downA, downB = (
gateA.to(dtype),
gateB.to(dtype),
upA.to(dtype),
upB.to(dtype),
downA.to(dtype),
downB.to(dtype),
)

gateA, gateB, upA, upB, downA, downB = (
gateA.t(),
gateB.t(),
upA.t(),
upB.t(),
downA.t(),
downB.t(),
)

DW = matmul_lora(dY, downW.t(), downW_quant, downB, downA, downS)
DW, e, g = _backward_function(DW, e, g)
h, df, de = DW, e, g

d_downA = torch.empty_like(downA)
d_downB = torch.empty_like(downB)
d_gateA = torch.empty_like(gateA)
d_gateB = torch.empty_like(gateB)
d_upA = torch.empty_like(upA)
d_upB = torch.empty_like(upB)

# Down projection LoRA weights
# d_downA = h.t() @ (dY @ downB.t())
# d_downB = (downA.t() @ h.t()) @ dY
# d_downA *= downS
# d_downB *= downS
d_downA.addmm_(h.t(), dY @ downB.t(), alpha = downS, beta = 0)
d_downB.addmm_(downA.t() @ h.t(), dY, alpha = downS, beta = 0)

# Up projection LoRA weights
# d_upA = X.t() @ (df @ upB.t())
# d_upB = (upA.t() @ X.t()) @ df
# d_upA *= upS
# d_upB *= upS
d_upA.addmm_(X.t(), df @ upB.t(), alpha = upS, beta = 0)
d_upB.addmm_(upA.t() @ X.t(), df, alpha = upS, beta = 0)

# Gate projection LoRA weights
# d_gateA = X.t() @ (de @ gateB.t())
# d_gateB = (gateA.t() @ X.t()) @ de
# d_gateA *= gateS
# d_gateB *= gateS
d_gateA.addmm_(X.t(), de @ gateB.t(), alpha = gateS, beta = 0)
d_gateB.addmm_(gateA.t() @ X.t(), de, alpha = gateS, beta = 0)

# dX = matmul_lora(df, upW.t(), upW_quant, upB, upA, upS)
# dX += matmul_lora(de, gateW.t(), gateW_quant, gateB, gateA, gateS)
upW = fast_dequantize(upW.t(), upW_quant)
dX = torch.matmul(df, upW.t(), out = X if ctx.inplace else None)
del upW
# dX += df @ upB.to(dtype).t() @ (upS * upA.to(dtype).t())
dX.addmm_(df @ upB.t(), upA.t(), alpha = upS)

gateW = fast_dequantize(gateW.t(), gateW_quant)
# dX += de @ gateW.t()
dX.addmm_(de, gateW.t())
del gateW
# dX += de @ gateB.to(dtype).t() @ (gateS * gateA.to(dtype).t())
dX.addmm_(de @ gateB.t(), gateA.t(), alpha = gateS)

# gateW, gateW_quant, gateA, gateB, gateS,
# upW, upW_quant, upA, upB, upS,
Expand Down Expand Up @@ -440,73 +451,88 @@ def backward(ctx, dQ, dK, dV):
X = X.view(-1, X.shape[-1])
dtype = X.dtype

QA, QB, KA, KB, VA, VB = (
QA.to(dtype),
QB.to(dtype),
KA.to(dtype),
KB.to(dtype),
VA.to(dtype),
VB.to(dtype),
)

QA, QB, KA, KB, VA, VB = QA.t(), QB.t(), KA.t(), KB.t(), VA.t(), VB.t()

### Weight projection LoRA weights
# See our blogpost for more details.
d_QA = torch.empty_like(QA)
d_QB = torch.empty_like(QB)
d_KA = torch.empty_like(KA)
d_KB = torch.empty_like(KB)
d_VA = torch.empty_like(VA)
d_VB = torch.empty_like(VB)

# Q Projection
# d_QA = X.t() @ (dQ @ QB.t())
# d_QB = (QA.t() @ X.t()) @ dQ
# d_QA *= QS
# d_QB *= QS
d_QA.addmm_(X.t(), dQ @ QB.t(), alpha = QS, beta = 0)
d_QB.addmm_(QA.t() @ X.t(), dQ, alpha = QS, beta = 0)

# K Projection
# d_KA = X.t() @ (dK @ KB.t())
# d_KB = (KA.t() @ X.t()) @ dK
# d_KA *= KS
# d_KB *= KS
d_KA.addmm_(X.t(), dK @ KB.t(), alpha = KS, beta = 0)
d_KB.addmm_(KA.t() @ X.t(), dK, alpha = KS, beta = 0)

# V Projection
# d_VA = X.t() @ (dV @ VB.t())
# d_VB = (VA.t() @ X.t()) @ dV
# d_VA *= VS
# d_VB *= VS
d_VA.addmm_(X.t(), dV @ VB.t(), alpha = VS, beta = 0)
d_VB.addmm_(VA.t() @ X.t(), dV, alpha = VS, beta = 0)

# Combine derivatives to find dX
# dQ
QW = fast_dequantize(QW.t(), QW_quant)
dX = torch.matmul(dQ, QW.t(), out = X if ctx.inplace else None)
del QW
# dX += (dQ @ QB.to(dtype).t() @ (QS * QA.to(dtype).t()))
dX.addmm_(dQ @ QB.t(), QA.t(), alpha = QS)

# dK
KW = fast_dequantize(KW.t(), KW_quant)
# dX += dK @ KW.t()
dX.addmm_(dK, KW.t())
del KW
# dX += dK @ KB.to(dtype).t() @ (KS * KA.to(dtype).t())
dX.addmm_(dK @ KB.t(), KA.t(), alpha = KS)

# dV
VW = fast_dequantize(VW.t(), VW_quant)
# dX += dV @ VW.t()
dX.addmm_(dV, VW.t())
del VW
# dX += dV @ VB.to(dtype).t() @ (VS * VA.to(dtype).t())
dX.addmm_(dV @ VB.t(), VA.t(), alpha = VS)
# Disable autocast for the entire backward pass.
# @torch_amp_custom_bwd inherits the float16 autocast context from
# TRL's compiled GRPO trainer, which silently downcasts float32 gradient
# tensors to float16 mid-computation, causing addmm_ dtype mismatches.

# Use the tensor's actual device type so this works on CUDA and XPU.
with torch.amp.autocast(X.device.type, enabled = False):
# Cast incoming gradients to the activation dtype.
if dQ.dtype != dtype:
dQ = dQ.to(dtype)
if dK.dtype != dtype:
dK = dK.to(dtype)
if dV.dtype != dtype:
dV = dV.to(dtype)

QA, QB, KA, KB, VA, VB = (
QA.to(dtype),
QB.to(dtype),
KA.to(dtype),
KB.to(dtype),
VA.to(dtype),
VB.to(dtype),
)

QA, QB, KA, KB, VA, VB = QA.t(), QB.t(), KA.t(), KB.t(), VA.t(), VB.t()

### Weight projection LoRA weights
# See our blogpost for more details.
d_QA = torch.empty_like(QA)
d_QB = torch.empty_like(QB)
d_KA = torch.empty_like(KA)
d_KB = torch.empty_like(KB)
d_VA = torch.empty_like(VA)
d_VB = torch.empty_like(VB)

# Q Projection
# d_QA = X.t() @ (dQ @ QB.t())
# d_QB = (QA.t() @ X.t()) @ dQ
# d_QA *= QS
# d_QB *= QS
d_QA.addmm_(X.t(), dQ @ QB.t(), alpha = QS, beta = 0)
d_QB.addmm_(QA.t() @ X.t(), dQ, alpha = QS, beta = 0)

# K Projection
# d_KA = X.t() @ (dK @ KB.t())
# d_KB = (KA.t() @ X.t()) @ dK
# d_KA *= KS
# d_KB *= KS
d_KA.addmm_(X.t(), dK @ KB.t(), alpha = KS, beta = 0)
d_KB.addmm_(KA.t() @ X.t(), dK, alpha = KS, beta = 0)

# V Projection
# d_VA = X.t() @ (dV @ VB.t())
# d_VB = (VA.t() @ X.t()) @ dV
# d_VA *= VS
# d_VB *= VS
d_VA.addmm_(X.t(), dV @ VB.t(), alpha = VS, beta = 0)
d_VB.addmm_(VA.t() @ X.t(), dV, alpha = VS, beta = 0)

# Combine derivatives to find dX
# dQ
QW = fast_dequantize(QW.t(), QW_quant)
dX = torch.matmul(dQ, QW.t(), out = X if ctx.inplace else None)
del QW
# dX += (dQ @ QB.to(dtype).t() @ (QS * QA.to(dtype).t()))
dX.addmm_(dQ @ QB.t(), QA.t(), alpha = QS)

# dK
KW = fast_dequantize(KW.t(), KW_quant)
# dX += dK @ KW.t()
dX.addmm_(dK, KW.t())
del KW
# dX += dK @ KB.to(dtype).t() @ (KS * KA.to(dtype).t())
dX.addmm_(dK @ KB.t(), KA.t(), alpha = KS)

# dV
VW = fast_dequantize(VW.t(), VW_quant)
# dX += dV @ VW.t()
dX.addmm_(dV, VW.t())
del VW
# dX += dV @ VB.to(dtype).t() @ (VS * VA.to(dtype).t())
dX.addmm_(dV @ VB.t(), VA.t(), alpha = VS)

# QW, QW_quant, QA, QB, QS,
# KW, KW_quant, KA, KB, KS,
Expand Down Expand Up @@ -611,28 +637,37 @@ def backward(ctx, dY: torch.Tensor):
X = X.reshape(-1, X.shape[-1]) # Must be reshape
dtype = X.dtype

A, B = A.to(dtype), B.to(dtype)

A, B = A.t(), B.t()

d_A = torch.empty_like(A)
d_B = torch.empty_like(B)

### Weight projection LoRA weights
# Weight projection
# d_A = X.t() @ (dY @ B.t())
# d_B = (A.t() @ X.t()) @ dY
# d_A *= S
# d_B *= S
d_A.addmm_(X.t(), dY @ B.t(), alpha = S, beta = 0)
d_B.addmm_(A.t() @ X.t(), dY, alpha = S, beta = 0)

# Get derivative for dX
W = fast_dequantize(W.t(), W_quant)
dX = dY @ W.t()
del W
# dX += dY @ B.to(dtype).t() @ (S * A.to(dtype).t())
dX.addmm_(dY @ B.t(), A.t(), alpha = S)
# Disable autocast for the entire backward pass.
# @torch_amp_custom_bwd inherits the float16 autocast context from
# TRL's compiled GRPO trainer, which silently downcasts float32 gradient
# tensors to float16 mid-computation, causing addmm_ dtype mismatches.
# Use the tensor's actual device type so this works on CUDA and XPU.
with torch.amp.autocast(X.device.type, enabled = False):
if dY.dtype != dtype:
dY = dY.to(dtype)

A, B = A.to(dtype), B.to(dtype)

A, B = A.t(), B.t()

d_A = torch.empty_like(A)
d_B = torch.empty_like(B)

### Weight projection LoRA weights
# Weight projection
# d_A = X.t() @ (dY @ B.t())
# d_B = (A.t() @ X.t()) @ dY
# d_A *= S
# d_B *= S
d_A.addmm_(X.t(), dY @ B.t(), alpha = S, beta = 0)
d_B.addmm_(A.t() @ X.t(), dY, alpha = S, beta = 0)

# Get derivative for dX
W = fast_dequantize(W.t(), W_quant)
dX = dY @ W.t()
del W
# dX += dY @ B.to(dtype).t() @ (S * A.to(dtype).t())
dX.addmm_(dY @ B.t(), A.t(), alpha = S)

# W, W_quant, A, B, S
return dX.view(batch, seq_len, hd), None, None, d_A.t(), d_B.t(), None
Expand Down
Loading