Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 46 additions & 8 deletions unsloth/kernels/geglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@
torch_gpu_device,
)

# signed int32 max is 2**31-1 so num_elements cannot exceed 2**31
NUM_INT32_ELEMENTS = 2**31
SAFE_INT32_BUFFER_MULTIPLIER = 4
BLOCK_SIZE = 1024
INT32_SAFETY_BUFFER = NUM_INT32_ELEMENTS - BLOCK_SIZE * SAFE_INT32_BUFFER_MULTIPLIER


@triton.jit
def _exact_forward_kernel(
Expand All @@ -29,9 +35,16 @@ def _exact_forward_kernel(
h,
n_elements,
BLOCK_SIZE: tl.constexpr,
LONG_INDEXING: tl.constexpr,
):
block_idx = tl.program_id(0)
offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
if LONG_INDEXING:
offsets = block_idx.to(tl.int64) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE).to(
tl.int64
)
n_elements = tl.cast(n_elements, tl.int64)
else:
offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements

# f = 1/2 * e * (1 + erf(1/sqrt(2) * e))
Expand Down Expand Up @@ -59,7 +72,8 @@ def geglu_exact_forward_kernel(gate, up):
up,
out,
n_elements,
BLOCK_SIZE = 1024,
BLOCK_SIZE = BLOCK_SIZE,
LONG_INDEXING = 0 if n_elements <= INT32_SAFETY_BUFFER else 1,
)
return out

Expand All @@ -71,6 +85,7 @@ def _exact_backward_kernel(
g,
n_elements,
BLOCK_SIZE: tl.constexpr,
LONG_INDEXING: tl.constexpr,
):
"""
f = 1/2 * e * (1 + erf(1/sqrt(2) * e))
Expand All @@ -83,7 +98,13 @@ def _exact_backward_kernel(
f = 1/2 * (1 + erf(1/sqrt(2) * e)) * e
"""
block_idx = tl.program_id(0)
offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
if LONG_INDEXING:
offsets = block_idx.to(tl.int64) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE).to(
tl.int64
)
n_elements = tl.cast(n_elements, tl.int64)
else:
offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements

DW_row = tl.load(DW + offsets, mask = mask, other = 0) # .to(tl.float32)
Expand Down Expand Up @@ -126,7 +147,8 @@ def geglu_exact_backward_kernel(DW, e, g):
e,
g,
n_elements,
BLOCK_SIZE = 1024,
BLOCK_SIZE = BLOCK_SIZE,
LONG_INDEXING = 0 if n_elements <= INT32_SAFETY_BUFFER else 1,
)
return DW, e, g

Expand All @@ -138,9 +160,16 @@ def _approx_forward_kernel(
h,
n_elements,
BLOCK_SIZE: tl.constexpr,
LONG_INDEXING: tl.constexpr,
):
block_idx = tl.program_id(0)
offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
if LONG_INDEXING:
offsets = block_idx.to(tl.int64) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE).to(
tl.int64
)
n_elements = tl.cast(n_elements, tl.int64)
else:
offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements

# f = 1/2 * e * (1 + tanh( sqrt(2/pi) * (x + 0.044715 * x^3 ) ))
Expand Down Expand Up @@ -173,7 +202,8 @@ def geglu_approx_forward_kernel(gate, up):
up,
out,
n_elements,
BLOCK_SIZE = 1024,
BLOCK_SIZE = BLOCK_SIZE,
LONG_INDEXING = 0 if n_elements <= INT32_SAFETY_BUFFER else 1,
)
return out

Expand All @@ -185,6 +215,7 @@ def _approx_backward_kernel(
g,
n_elements,
BLOCK_SIZE: tl.constexpr,
LONG_INDEXING: tl.constexpr,
):
"""
f = 1/2 * e * (1 + tanh( sqrt(2/pi) * x * (1 + 0.044715 * x^2 ) ))
Expand All @@ -201,7 +232,13 @@ def _approx_backward_kernel(
See https://www.desmos.com/calculator/nqprfoni6x
"""
block_idx = tl.program_id(0)
offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
if LONG_INDEXING:
offsets = block_idx.to(tl.int64) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE).to(
tl.int64
)
n_elements = tl.cast(n_elements, tl.int64)
else:
offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements

DW_row = tl.load(DW + offsets, mask = mask, other = 0) # .to(tl.float32)
Expand Down Expand Up @@ -247,6 +284,7 @@ def geglu_approx_backward_kernel(DW, e, g):
e,
g,
n_elements,
BLOCK_SIZE = 1024,
BLOCK_SIZE = BLOCK_SIZE,
LONG_INDEXING = 0 if n_elements <= INT32_SAFETY_BUFFER else 1,
)
return DW, e, g
30 changes: 26 additions & 4 deletions unsloth/kernels/swiglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@
import torch
from .utils import calculate_settings, torch_gpu_device

# signed int32 max is 2**31-1 so num_elements cannot exceed 2**31
NUM_INT32_ELEMENTS = 2**31
SAFE_INT32_BUFFER_MULTIPLIER = 4
BLOCK_SIZE = 1024
INT32_SAFETY_BUFFER = NUM_INT32_ELEMENTS - BLOCK_SIZE * SAFE_INT32_BUFFER_MULTIPLIER


@triton.jit
def _fg_kernel(
Expand All @@ -25,9 +31,16 @@ def _fg_kernel(
h,
n_elements,
BLOCK_SIZE: tl.constexpr,
LONG_INDEXING: tl.constexpr,
):
block_idx = tl.program_id(0)
offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
if LONG_INDEXING:
offsets = block_idx.to(tl.int64) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE).to(
tl.int64
)
n_elements = tl.cast(n_elements, tl.int64)
else:
offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements

e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
Expand All @@ -54,7 +67,8 @@ def swiglu_fg_kernel(e, g):
g,
h,
n_elements,
BLOCK_SIZE = 1024,
BLOCK_SIZE = BLOCK_SIZE,
LONG_INDEXING = 0 if n_elements <= INT32_SAFETY_BUFFER else 1,
)
return h

Expand All @@ -66,6 +80,7 @@ def _DWf_DW_dfg_kernel(
g,
n_elements,
BLOCK_SIZE: tl.constexpr,
LONG_INDEXING: tl.constexpr,
):
"""
e = e.float()
Expand All @@ -77,7 +92,13 @@ def _DWf_DW_dfg_kernel(
de = (dg.float() * se * (1.0 + e * (1.0 - se))).to(dtype)
"""
block_idx = tl.program_id(0)
offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
if LONG_INDEXING:
offsets = block_idx.to(tl.int64) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE).to(
tl.int64
)
n_elements = tl.cast(n_elements, tl.int64)
else:
offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements

DW_row = tl.load(DW + offsets, mask = mask, other = 0) # .to(tl.float32)
Expand Down Expand Up @@ -116,6 +137,7 @@ def swiglu_DWf_DW_dfg_kernel(DW, e, g):
e,
g,
n_elements,
BLOCK_SIZE = 1024,
BLOCK_SIZE = BLOCK_SIZE,
LONG_INDEXING = 0 if n_elements <= INT32_SAFETY_BUFFER else 1,
)
return DW, e, g