Skip to content

Add an int64 path for mlp kernels#3614

Merged
danielhanchen merged 3 commits into
unslothai:mainfrom
mmathew23:tiled/contextlen
Nov 20, 2025
Merged

Add an int64 path for mlp kernels#3614
danielhanchen merged 3 commits into
unslothai:mainfrom
mmathew23:tiled/contextlen

Conversation

@mmathew23

Copy link
Copy Markdown
Contributor

The llama mlp kernels produce nans with extremely long context length. This is happens when the num_elements is greater than 2**31. In these cases it's best to calculate offsets with tl.int64 instead of int32. This PR will route to int64 kernels if the num_elements is big enough.

Comment thread unsloth/kernels/geglu.py Outdated
device = gate.device
out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = device)
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
if n_elements <= (2**31) - 1024:

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why -1024? Is it maybe hd?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes I forgot to account for hd. The idea is that I wanted to add a buffer just to be safe.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait actually it is 1024, ie the BLOCK_SIZE.

Comment thread unsloth/kernels/geglu.py Outdated
batch_seq_len, hd = e.shape
n_elements = e.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
if n_elements <= (2**31) - 1024:

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe move (2**31) to a global var

Comment thread unsloth/kernels/swiglu.py Outdated
e,
g,
n_elements,
BLOCK_SIZE: tl.constexpr,

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is actually a way to use 1 kernel only and dispatch, but for now this is fine - we can refactor later

@mmathew23 mmathew23 force-pushed the tiled/contextlen branch 2 times, most recently from c008eca to 262ada3 Compare November 19, 2025 17:24
@mmathew23 mmathew23 marked this pull request as ready for review November 19, 2025 22:16
@mmathew23

Copy link
Copy Markdown
Contributor Author

Why -1024? Is it maybe hd?

So the idea is that offsets cannot be more than 2**31-1 which means n_elements<=2**31. I want to add a buffer before this point and since we are processing in BLOCK_SIZE blocks instead of hidden_dim blocks I figured it would be better. Plus we get the added benefit of the behavior remaining consistent across models.

I've updated the PR to reflect your comments and finalized it. Let me know if there's anything else to address.

@danielhanchen danielhanchen merged commit ac82560 into unslothai:main Nov 20, 2025
1 check passed
abiswas-realadvice pushed a commit to abiswas-realadvice/unsloth that referenced this pull request May 14, 2026
* Add an int64 path for mlp kernels

* move constant expressions to globals

* fix name
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants