diff --git a/unsloth/kernels/geglu.py b/unsloth/kernels/geglu.py index 67c36dd8ad..50b4e521d3 100644 --- a/unsloth/kernels/geglu.py +++ b/unsloth/kernels/geglu.py @@ -21,6 +21,12 @@ torch_gpu_device, ) +# signed int32 max is 2**31-1 so num_elements cannot exceed 2**31 +NUM_INT32_ELEMENTS = 2**31 +SAFE_INT32_BUFFER_MULTIPLIER = 4 +BLOCK_SIZE = 1024 +INT32_SAFETY_BUFFER = NUM_INT32_ELEMENTS - BLOCK_SIZE * SAFE_INT32_BUFFER_MULTIPLIER + @triton.jit def _exact_forward_kernel( @@ -29,9 +35,16 @@ def _exact_forward_kernel( h, n_elements, BLOCK_SIZE: tl.constexpr, + LONG_INDEXING: tl.constexpr, ): block_idx = tl.program_id(0) - offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + if LONG_INDEXING: + offsets = block_idx.to(tl.int64) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE).to( + tl.int64 + ) + n_elements = tl.cast(n_elements, tl.int64) + else: + offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = offsets < n_elements # f = 1/2 * e * (1 + erf(1/sqrt(2) * e)) @@ -59,7 +72,8 @@ def geglu_exact_forward_kernel(gate, up): up, out, n_elements, - BLOCK_SIZE = 1024, + BLOCK_SIZE = BLOCK_SIZE, + LONG_INDEXING = 0 if n_elements <= INT32_SAFETY_BUFFER else 1, ) return out @@ -71,6 +85,7 @@ def _exact_backward_kernel( g, n_elements, BLOCK_SIZE: tl.constexpr, + LONG_INDEXING: tl.constexpr, ): """ f = 1/2 * e * (1 + erf(1/sqrt(2) * e)) @@ -83,7 +98,13 @@ def _exact_backward_kernel( f = 1/2 * (1 + erf(1/sqrt(2) * e)) * e """ block_idx = tl.program_id(0) - offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + if LONG_INDEXING: + offsets = block_idx.to(tl.int64) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE).to( + tl.int64 + ) + n_elements = tl.cast(n_elements, tl.int64) + else: + offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = offsets < n_elements DW_row = tl.load(DW + offsets, mask = mask, other = 0) # .to(tl.float32) @@ -126,7 +147,8 @@ def geglu_exact_backward_kernel(DW, e, g): e, g, n_elements, - BLOCK_SIZE = 1024, + BLOCK_SIZE = BLOCK_SIZE, + LONG_INDEXING = 0 if n_elements <= INT32_SAFETY_BUFFER else 1, ) return DW, e, g @@ -138,9 +160,16 @@ def _approx_forward_kernel( h, n_elements, BLOCK_SIZE: tl.constexpr, + LONG_INDEXING: tl.constexpr, ): block_idx = tl.program_id(0) - offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + if LONG_INDEXING: + offsets = block_idx.to(tl.int64) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE).to( + tl.int64 + ) + n_elements = tl.cast(n_elements, tl.int64) + else: + offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = offsets < n_elements # f = 1/2 * e * (1 + tanh( sqrt(2/pi) * (x + 0.044715 * x^3 ) )) @@ -173,7 +202,8 @@ def geglu_approx_forward_kernel(gate, up): up, out, n_elements, - BLOCK_SIZE = 1024, + BLOCK_SIZE = BLOCK_SIZE, + LONG_INDEXING = 0 if n_elements <= INT32_SAFETY_BUFFER else 1, ) return out @@ -185,6 +215,7 @@ def _approx_backward_kernel( g, n_elements, BLOCK_SIZE: tl.constexpr, + LONG_INDEXING: tl.constexpr, ): """ f = 1/2 * e * (1 + tanh( sqrt(2/pi) * x * (1 + 0.044715 * x^2 ) )) @@ -201,7 +232,13 @@ def _approx_backward_kernel( See https://www.desmos.com/calculator/nqprfoni6x """ block_idx = tl.program_id(0) - offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + if LONG_INDEXING: + offsets = block_idx.to(tl.int64) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE).to( + tl.int64 + ) + n_elements = tl.cast(n_elements, tl.int64) + else: + offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = offsets < n_elements DW_row = tl.load(DW + offsets, mask = mask, other = 0) # .to(tl.float32) @@ -247,6 +284,7 @@ def geglu_approx_backward_kernel(DW, e, g): e, g, n_elements, - BLOCK_SIZE = 1024, + BLOCK_SIZE = BLOCK_SIZE, + LONG_INDEXING = 0 if n_elements <= INT32_SAFETY_BUFFER else 1, ) return DW, e, g diff --git a/unsloth/kernels/swiglu.py b/unsloth/kernels/swiglu.py index 81e62660bb..b321f5179e 100644 --- a/unsloth/kernels/swiglu.py +++ b/unsloth/kernels/swiglu.py @@ -17,6 +17,12 @@ import torch from .utils import calculate_settings, torch_gpu_device +# signed int32 max is 2**31-1 so num_elements cannot exceed 2**31 +NUM_INT32_ELEMENTS = 2**31 +SAFE_INT32_BUFFER_MULTIPLIER = 4 +BLOCK_SIZE = 1024 +INT32_SAFETY_BUFFER = NUM_INT32_ELEMENTS - BLOCK_SIZE * SAFE_INT32_BUFFER_MULTIPLIER + @triton.jit def _fg_kernel( @@ -25,9 +31,16 @@ def _fg_kernel( h, n_elements, BLOCK_SIZE: tl.constexpr, + LONG_INDEXING: tl.constexpr, ): block_idx = tl.program_id(0) - offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + if LONG_INDEXING: + offsets = block_idx.to(tl.int64) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE).to( + tl.int64 + ) + n_elements = tl.cast(n_elements, tl.int64) + else: + offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = offsets < n_elements e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32) @@ -54,7 +67,8 @@ def swiglu_fg_kernel(e, g): g, h, n_elements, - BLOCK_SIZE = 1024, + BLOCK_SIZE = BLOCK_SIZE, + LONG_INDEXING = 0 if n_elements <= INT32_SAFETY_BUFFER else 1, ) return h @@ -66,6 +80,7 @@ def _DWf_DW_dfg_kernel( g, n_elements, BLOCK_SIZE: tl.constexpr, + LONG_INDEXING: tl.constexpr, ): """ e = e.float() @@ -77,7 +92,13 @@ def _DWf_DW_dfg_kernel( de = (dg.float() * se * (1.0 + e * (1.0 - se))).to(dtype) """ block_idx = tl.program_id(0) - offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + if LONG_INDEXING: + offsets = block_idx.to(tl.int64) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE).to( + tl.int64 + ) + n_elements = tl.cast(n_elements, tl.int64) + else: + offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = offsets < n_elements DW_row = tl.load(DW + offsets, mask = mask, other = 0) # .to(tl.float32) @@ -116,6 +137,7 @@ def swiglu_DWf_DW_dfg_kernel(DW, e, g): e, g, n_elements, - BLOCK_SIZE = 1024, + BLOCK_SIZE = BLOCK_SIZE, + LONG_INDEXING = 0 if n_elements <= INT32_SAFETY_BUFFER else 1, ) return DW, e, g