diff --git a/unsloth/kernels/fast_lora.py b/unsloth/kernels/fast_lora.py index f1c0e298d9..f041c55cc4 100644 --- a/unsloth/kernels/fast_lora.py +++ b/unsloth/kernels/fast_lora.py @@ -135,73 +135,84 @@ def backward(ctx, dY: torch.Tensor): g = g.view(-1, g.shape[-1]) dtype = X.dtype - gateA, gateB, upA, upB, downA, downB = ( - gateA.to(dtype), - gateB.to(dtype), - upA.to(dtype), - upB.to(dtype), - downA.to(dtype), - downB.to(dtype), - ) - - gateA, gateB, upA, upB, downA, downB = ( - gateA.t(), - gateB.t(), - upA.t(), - upB.t(), - downA.t(), - downB.t(), - ) - - DW = matmul_lora(dY, downW.t(), downW_quant, downB, downA, downS) - DW, e, g = _backward_function(DW, e, g) - h, df, de = DW, e, g - - d_downA = torch.empty_like(downA) - d_downB = torch.empty_like(downB) - d_gateA = torch.empty_like(gateA) - d_gateB = torch.empty_like(gateB) - d_upA = torch.empty_like(upA) - d_upB = torch.empty_like(upB) - - # Down projection LoRA weights - # d_downA = h.t() @ (dY @ downB.t()) - # d_downB = (downA.t() @ h.t()) @ dY - # d_downA *= downS - # d_downB *= downS - d_downA.addmm_(h.t(), dY @ downB.t(), alpha = downS, beta = 0) - d_downB.addmm_(downA.t() @ h.t(), dY, alpha = downS, beta = 0) - - # Up projection LoRA weights - # d_upA = X.t() @ (df @ upB.t()) - # d_upB = (upA.t() @ X.t()) @ df - # d_upA *= upS - # d_upB *= upS - d_upA.addmm_(X.t(), df @ upB.t(), alpha = upS, beta = 0) - d_upB.addmm_(upA.t() @ X.t(), df, alpha = upS, beta = 0) - - # Gate projection LoRA weights - # d_gateA = X.t() @ (de @ gateB.t()) - # d_gateB = (gateA.t() @ X.t()) @ de - # d_gateA *= gateS - # d_gateB *= gateS - d_gateA.addmm_(X.t(), de @ gateB.t(), alpha = gateS, beta = 0) - d_gateB.addmm_(gateA.t() @ X.t(), de, alpha = gateS, beta = 0) - - # dX = matmul_lora(df, upW.t(), upW_quant, upB, upA, upS) - # dX += matmul_lora(de, gateW.t(), gateW_quant, gateB, gateA, gateS) - upW = fast_dequantize(upW.t(), upW_quant) - dX = torch.matmul(df, upW.t(), out = X if ctx.inplace else None) - del upW - # dX += df @ upB.to(dtype).t() @ (upS * upA.to(dtype).t()) - dX.addmm_(df @ upB.t(), upA.t(), alpha = upS) - - gateW = fast_dequantize(gateW.t(), gateW_quant) - # dX += de @ gateW.t() - dX.addmm_(de, gateW.t()) - del gateW - # dX += de @ gateB.to(dtype).t() @ (gateS * gateA.to(dtype).t()) - dX.addmm_(de @ gateB.t(), gateA.t(), alpha = gateS) + # Disable autocast for the entire backward pass. + # @torch_amp_custom_bwd inherits the float16 autocast context from + # TRL's compiled GRPO trainer, which silently downcasts float32 gradient + # tensors to float16 mid-computation, causing addmm_ dtype mismatches. + + # Use the tensor's actual device type so this works on CUDA and XPU. + with torch.amp.autocast(X.device.type, enabled = False): + # Cast incoming gradient to the activation dtype. + if dY.dtype != dtype: + dY = dY.to(dtype) + + gateA, gateB, upA, upB, downA, downB = ( + gateA.to(dtype), + gateB.to(dtype), + upA.to(dtype), + upB.to(dtype), + downA.to(dtype), + downB.to(dtype), + ) + + gateA, gateB, upA, upB, downA, downB = ( + gateA.t(), + gateB.t(), + upA.t(), + upB.t(), + downA.t(), + downB.t(), + ) + + DW = matmul_lora(dY, downW.t(), downW_quant, downB, downA, downS) + DW, e, g = _backward_function(DW, e, g) + h, df, de = DW, e, g + + d_downA = torch.empty_like(downA) + d_downB = torch.empty_like(downB) + d_gateA = torch.empty_like(gateA) + d_gateB = torch.empty_like(gateB) + d_upA = torch.empty_like(upA) + d_upB = torch.empty_like(upB) + + # Down projection LoRA weights + # d_downA = h.t() @ (dY @ downB.t()) + # d_downB = (downA.t() @ h.t()) @ dY + # d_downA *= downS + # d_downB *= downS + d_downA.addmm_(h.t(), dY @ downB.t(), alpha = downS, beta = 0) + d_downB.addmm_(downA.t() @ h.t(), dY, alpha = downS, beta = 0) + + # Up projection LoRA weights + # d_upA = X.t() @ (df @ upB.t()) + # d_upB = (upA.t() @ X.t()) @ df + # d_upA *= upS + # d_upB *= upS + d_upA.addmm_(X.t(), df @ upB.t(), alpha = upS, beta = 0) + d_upB.addmm_(upA.t() @ X.t(), df, alpha = upS, beta = 0) + + # Gate projection LoRA weights + # d_gateA = X.t() @ (de @ gateB.t()) + # d_gateB = (gateA.t() @ X.t()) @ de + # d_gateA *= gateS + # d_gateB *= gateS + d_gateA.addmm_(X.t(), de @ gateB.t(), alpha = gateS, beta = 0) + d_gateB.addmm_(gateA.t() @ X.t(), de, alpha = gateS, beta = 0) + + # dX = matmul_lora(df, upW.t(), upW_quant, upB, upA, upS) + # dX += matmul_lora(de, gateW.t(), gateW_quant, gateB, gateA, gateS) + upW = fast_dequantize(upW.t(), upW_quant) + dX = torch.matmul(df, upW.t(), out = X if ctx.inplace else None) + del upW + # dX += df @ upB.to(dtype).t() @ (upS * upA.to(dtype).t()) + dX.addmm_(df @ upB.t(), upA.t(), alpha = upS) + + gateW = fast_dequantize(gateW.t(), gateW_quant) + # dX += de @ gateW.t() + dX.addmm_(de, gateW.t()) + del gateW + # dX += de @ gateB.to(dtype).t() @ (gateS * gateA.to(dtype).t()) + dX.addmm_(de @ gateB.t(), gateA.t(), alpha = gateS) # gateW, gateW_quant, gateA, gateB, gateS, # upW, upW_quant, upA, upB, upS, @@ -440,73 +451,88 @@ def backward(ctx, dQ, dK, dV): X = X.view(-1, X.shape[-1]) dtype = X.dtype - QA, QB, KA, KB, VA, VB = ( - QA.to(dtype), - QB.to(dtype), - KA.to(dtype), - KB.to(dtype), - VA.to(dtype), - VB.to(dtype), - ) - - QA, QB, KA, KB, VA, VB = QA.t(), QB.t(), KA.t(), KB.t(), VA.t(), VB.t() - - ### Weight projection LoRA weights - # See our blogpost for more details. - d_QA = torch.empty_like(QA) - d_QB = torch.empty_like(QB) - d_KA = torch.empty_like(KA) - d_KB = torch.empty_like(KB) - d_VA = torch.empty_like(VA) - d_VB = torch.empty_like(VB) - - # Q Projection - # d_QA = X.t() @ (dQ @ QB.t()) - # d_QB = (QA.t() @ X.t()) @ dQ - # d_QA *= QS - # d_QB *= QS - d_QA.addmm_(X.t(), dQ @ QB.t(), alpha = QS, beta = 0) - d_QB.addmm_(QA.t() @ X.t(), dQ, alpha = QS, beta = 0) - - # K Projection - # d_KA = X.t() @ (dK @ KB.t()) - # d_KB = (KA.t() @ X.t()) @ dK - # d_KA *= KS - # d_KB *= KS - d_KA.addmm_(X.t(), dK @ KB.t(), alpha = KS, beta = 0) - d_KB.addmm_(KA.t() @ X.t(), dK, alpha = KS, beta = 0) - - # V Projection - # d_VA = X.t() @ (dV @ VB.t()) - # d_VB = (VA.t() @ X.t()) @ dV - # d_VA *= VS - # d_VB *= VS - d_VA.addmm_(X.t(), dV @ VB.t(), alpha = VS, beta = 0) - d_VB.addmm_(VA.t() @ X.t(), dV, alpha = VS, beta = 0) - - # Combine derivatives to find dX - # dQ - QW = fast_dequantize(QW.t(), QW_quant) - dX = torch.matmul(dQ, QW.t(), out = X if ctx.inplace else None) - del QW - # dX += (dQ @ QB.to(dtype).t() @ (QS * QA.to(dtype).t())) - dX.addmm_(dQ @ QB.t(), QA.t(), alpha = QS) - - # dK - KW = fast_dequantize(KW.t(), KW_quant) - # dX += dK @ KW.t() - dX.addmm_(dK, KW.t()) - del KW - # dX += dK @ KB.to(dtype).t() @ (KS * KA.to(dtype).t()) - dX.addmm_(dK @ KB.t(), KA.t(), alpha = KS) - - # dV - VW = fast_dequantize(VW.t(), VW_quant) - # dX += dV @ VW.t() - dX.addmm_(dV, VW.t()) - del VW - # dX += dV @ VB.to(dtype).t() @ (VS * VA.to(dtype).t()) - dX.addmm_(dV @ VB.t(), VA.t(), alpha = VS) + # Disable autocast for the entire backward pass. + # @torch_amp_custom_bwd inherits the float16 autocast context from + # TRL's compiled GRPO trainer, which silently downcasts float32 gradient + # tensors to float16 mid-computation, causing addmm_ dtype mismatches. + + # Use the tensor's actual device type so this works on CUDA and XPU. + with torch.amp.autocast(X.device.type, enabled = False): + # Cast incoming gradients to the activation dtype. + if dQ.dtype != dtype: + dQ = dQ.to(dtype) + if dK.dtype != dtype: + dK = dK.to(dtype) + if dV.dtype != dtype: + dV = dV.to(dtype) + + QA, QB, KA, KB, VA, VB = ( + QA.to(dtype), + QB.to(dtype), + KA.to(dtype), + KB.to(dtype), + VA.to(dtype), + VB.to(dtype), + ) + + QA, QB, KA, KB, VA, VB = QA.t(), QB.t(), KA.t(), KB.t(), VA.t(), VB.t() + + ### Weight projection LoRA weights + # See our blogpost for more details. + d_QA = torch.empty_like(QA) + d_QB = torch.empty_like(QB) + d_KA = torch.empty_like(KA) + d_KB = torch.empty_like(KB) + d_VA = torch.empty_like(VA) + d_VB = torch.empty_like(VB) + + # Q Projection + # d_QA = X.t() @ (dQ @ QB.t()) + # d_QB = (QA.t() @ X.t()) @ dQ + # d_QA *= QS + # d_QB *= QS + d_QA.addmm_(X.t(), dQ @ QB.t(), alpha = QS, beta = 0) + d_QB.addmm_(QA.t() @ X.t(), dQ, alpha = QS, beta = 0) + + # K Projection + # d_KA = X.t() @ (dK @ KB.t()) + # d_KB = (KA.t() @ X.t()) @ dK + # d_KA *= KS + # d_KB *= KS + d_KA.addmm_(X.t(), dK @ KB.t(), alpha = KS, beta = 0) + d_KB.addmm_(KA.t() @ X.t(), dK, alpha = KS, beta = 0) + + # V Projection + # d_VA = X.t() @ (dV @ VB.t()) + # d_VB = (VA.t() @ X.t()) @ dV + # d_VA *= VS + # d_VB *= VS + d_VA.addmm_(X.t(), dV @ VB.t(), alpha = VS, beta = 0) + d_VB.addmm_(VA.t() @ X.t(), dV, alpha = VS, beta = 0) + + # Combine derivatives to find dX + # dQ + QW = fast_dequantize(QW.t(), QW_quant) + dX = torch.matmul(dQ, QW.t(), out = X if ctx.inplace else None) + del QW + # dX += (dQ @ QB.to(dtype).t() @ (QS * QA.to(dtype).t())) + dX.addmm_(dQ @ QB.t(), QA.t(), alpha = QS) + + # dK + KW = fast_dequantize(KW.t(), KW_quant) + # dX += dK @ KW.t() + dX.addmm_(dK, KW.t()) + del KW + # dX += dK @ KB.to(dtype).t() @ (KS * KA.to(dtype).t()) + dX.addmm_(dK @ KB.t(), KA.t(), alpha = KS) + + # dV + VW = fast_dequantize(VW.t(), VW_quant) + # dX += dV @ VW.t() + dX.addmm_(dV, VW.t()) + del VW + # dX += dV @ VB.to(dtype).t() @ (VS * VA.to(dtype).t()) + dX.addmm_(dV @ VB.t(), VA.t(), alpha = VS) # QW, QW_quant, QA, QB, QS, # KW, KW_quant, KA, KB, KS, @@ -611,28 +637,37 @@ def backward(ctx, dY: torch.Tensor): X = X.reshape(-1, X.shape[-1]) # Must be reshape dtype = X.dtype - A, B = A.to(dtype), B.to(dtype) - - A, B = A.t(), B.t() - - d_A = torch.empty_like(A) - d_B = torch.empty_like(B) - - ### Weight projection LoRA weights - # Weight projection - # d_A = X.t() @ (dY @ B.t()) - # d_B = (A.t() @ X.t()) @ dY - # d_A *= S - # d_B *= S - d_A.addmm_(X.t(), dY @ B.t(), alpha = S, beta = 0) - d_B.addmm_(A.t() @ X.t(), dY, alpha = S, beta = 0) - - # Get derivative for dX - W = fast_dequantize(W.t(), W_quant) - dX = dY @ W.t() - del W - # dX += dY @ B.to(dtype).t() @ (S * A.to(dtype).t()) - dX.addmm_(dY @ B.t(), A.t(), alpha = S) + # Disable autocast for the entire backward pass. + # @torch_amp_custom_bwd inherits the float16 autocast context from + # TRL's compiled GRPO trainer, which silently downcasts float32 gradient + # tensors to float16 mid-computation, causing addmm_ dtype mismatches. + # Use the tensor's actual device type so this works on CUDA and XPU. + with torch.amp.autocast(X.device.type, enabled = False): + if dY.dtype != dtype: + dY = dY.to(dtype) + + A, B = A.to(dtype), B.to(dtype) + + A, B = A.t(), B.t() + + d_A = torch.empty_like(A) + d_B = torch.empty_like(B) + + ### Weight projection LoRA weights + # Weight projection + # d_A = X.t() @ (dY @ B.t()) + # d_B = (A.t() @ X.t()) @ dY + # d_A *= S + # d_B *= S + d_A.addmm_(X.t(), dY @ B.t(), alpha = S, beta = 0) + d_B.addmm_(A.t() @ X.t(), dY, alpha = S, beta = 0) + + # Get derivative for dX + W = fast_dequantize(W.t(), W_quant) + dX = dY @ W.t() + del W + # dX += dY @ B.to(dtype).t() @ (S * A.to(dtype).t()) + dX.addmm_(dY @ B.t(), A.t(), alpha = S) # W, W_quant, A, B, S return dX.view(batch, seq_len, hd), None, None, d_A.t(), d_B.t(), None diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index dd5a9cbf0e..90f2d5d238 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -88,26 +88,10 @@ def is_cdna(): @functools.lru_cache(1) def is_rdna(): - """Detect ROCm-supported RDNA consumer/workstation GPUs (RDNA2, RDNA3, RDNA3.5, RDNA4).""" + """Detect ROCm-supported RDNA consumer/workstation GPUs (RDNA3, RDNA4).""" return is_hip() and triton.runtime.driver.active.get_current_target().arch in ( - # RDNA2 (Navi 21-24) - "gfx1030", - "gfx1031", - "gfx1032", - "gfx1033", - "gfx1034", - "gfx1035", - "gfx1036", - # RDNA3 (Navi 31-33) "gfx1100", "gfx1101", - "gfx1102", - "gfx1103", - # RDNA3.5 (Strix Point / Strix Halo) - "gfx1150", - "gfx1151", - "gfx1152", - # RDNA4 (Navi 48-44) "gfx1200", "gfx1201", ) @@ -161,14 +145,8 @@ def torch_gpu_device(device): if DEVICE_TYPE == "xpu": _gpu_getCurrentRawStream = torch._C._xpu_getCurrentRawStream # NVIDIA GPU Default Logic -elif hasattr(torch._C, "_cuda_getCurrentRawStream"): - _gpu_getCurrentRawStream = torch._C._cuda_getCurrentRawStream else: - # CPU-only torch wheel (no compiled CUDA backend). _get_tensor_stream - # is only invoked during real GPU work, so a no-op binding is safe. - def _gpu_getCurrentRawStream(_index = 0): - return 0 - + _gpu_getCurrentRawStream = torch._C._cuda_getCurrentRawStream c_void_p = ctypes.c_void_p @@ -183,49 +161,36 @@ def _get_tensor_stream(tensor: torch_Tensor) -> c_void_p: global WEIGHT_BUFFERS global ABSMAX_BUFFERS -# DEVICE_COUNT == 0 = no visible accelerator (e.g. CPU-only CI runner). -# The consumer functions below only index these arrays during real GPU -# work, so empty containers are safe -- they just need to be defined so -# the module imports cleanly. +# INTEL GPU Specific Logic if DEVICE_TYPE == "xpu": - if DEVICE_COUNT > 0: - _XPU_STREAMS = { - (index := torch.xpu.device(i).idx): ctypes.c_void_p( - torch._C._xpu_getCurrentRawStream(index) - ) - for i in range(DEVICE_COUNT) - } - XPU_STREAMS = [None] * (max(_XPU_STREAMS.keys()) + 1) - WEIGHT_BUFFERS = [None] * (max(_XPU_STREAMS.keys()) + 1) - ABSMAX_BUFFERS = [None] * (max(_XPU_STREAMS.keys()) + 1) - for k, v in _XPU_STREAMS.items(): - XPU_STREAMS[k] = v - XPU_STREAMS = tuple(XPU_STREAMS) - del _XPU_STREAMS - else: - XPU_STREAMS = () - WEIGHT_BUFFERS = [] - ABSMAX_BUFFERS = [] + _XPU_STREAMS = { + (index := torch.xpu.device(i).idx): ctypes.c_void_p( + torch._C._xpu_getCurrentRawStream(index) + ) + for i in range(DEVICE_COUNT) + } + XPU_STREAMS = [None] * (max(_XPU_STREAMS.keys()) + 1) + WEIGHT_BUFFERS = [None] * (max(_XPU_STREAMS.keys()) + 1) + ABSMAX_BUFFERS = [None] * (max(_XPU_STREAMS.keys()) + 1) + for k, v in _XPU_STREAMS.items(): + XPU_STREAMS[k] = v + XPU_STREAMS = tuple(XPU_STREAMS) + del _XPU_STREAMS else: # NVIDIA GPU Default Logic - if DEVICE_COUNT > 0: - _CUDA_STREAMS = { - (index := torch.cuda.device(i).idx): ctypes.c_void_p( - torch._C._cuda_getCurrentRawStream(index) - ) - for i in range(DEVICE_COUNT) - } - CUDA_STREAMS = [None] * (max(_CUDA_STREAMS.keys()) + 1) - WEIGHT_BUFFERS = [None] * (max(_CUDA_STREAMS.keys()) + 1) - ABSMAX_BUFFERS = [None] * (max(_CUDA_STREAMS.keys()) + 1) - for k, v in _CUDA_STREAMS.items(): - CUDA_STREAMS[k] = v - CUDA_STREAMS = tuple(CUDA_STREAMS) - del _CUDA_STREAMS - else: - CUDA_STREAMS = () - WEIGHT_BUFFERS = [] - ABSMAX_BUFFERS = [] + _CUDA_STREAMS = { + (index := torch.cuda.device(i).idx): ctypes.c_void_p( + torch._C._cuda_getCurrentRawStream(index) + ) + for i in range(DEVICE_COUNT) + } + CUDA_STREAMS = [None] * (max(_CUDA_STREAMS.keys()) + 1) + WEIGHT_BUFFERS = [None] * (max(_CUDA_STREAMS.keys()) + 1) + ABSMAX_BUFFERS = [None] * (max(_CUDA_STREAMS.keys()) + 1) + for k, v in _CUDA_STREAMS.items(): + CUDA_STREAMS[k] = v + CUDA_STREAMS = tuple(CUDA_STREAMS) + del _CUDA_STREAMS # Bitsandbytes operations ctypes_c_int = ctypes.c_int