Add an int64 path for mlp kernels#3614
Conversation
| device = gate.device | ||
| out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = device) | ||
| grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) | ||
| if n_elements <= (2**31) - 1024: |
There was a problem hiding this comment.
Why -1024? Is it maybe hd?
There was a problem hiding this comment.
yes I forgot to account for hd. The idea is that I wanted to add a buffer just to be safe.
There was a problem hiding this comment.
wait actually it is 1024, ie the BLOCK_SIZE.
| batch_seq_len, hd = e.shape | ||
| n_elements = e.numel() | ||
| grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) | ||
| if n_elements <= (2**31) - 1024: |
There was a problem hiding this comment.
Maybe move (2**31) to a global var
| e, | ||
| g, | ||
| n_elements, | ||
| BLOCK_SIZE: tl.constexpr, |
There was a problem hiding this comment.
there is actually a way to use 1 kernel only and dispatch, but for now this is fine - we can refactor later
c008eca to
262ada3
Compare
262ada3 to
833d91f
Compare
So the idea is that offsets cannot be more than I've updated the PR to reflect your comments and finalized it. Let me know if there's anything else to address. |
* Add an int64 path for mlp kernels * move constant expressions to globals * fix name
The llama mlp kernels produce nans with extremely long context length. This is happens when the num_elements is greater than 2**31. In these cases it's best to calculate offsets with tl.int64 instead of int32. This PR will route to int64 kernels if the num_elements is big enough.