Skip to content

CUDA: fix padding of GQA to power of 2 in FA#19115

Merged
JohannesGaessler merged 1 commit intoggml-org:masterfrom
JohannesGaessler:cuda-fa-fix-gqa-padding
Jan 26, 2026
Merged

CUDA: fix padding of GQA to power of 2 in FA#19115
JohannesGaessler merged 1 commit intoggml-org:masterfrom
JohannesGaessler:cuda-fa-fix-gqa-padding

Conversation

@JohannesGaessler
Copy link
Contributor

Fixes #19112 , the issue was introduced with #19092 .

The MMA CUDA FlashAttention kernel uses a stream-k decomposition to treat the four-dimensional input tensors as one continuous dimension to split across streaming multiprocessors. However, in conjunction with the GQA-specific optimizations in the MMA kernel this is only correct if the number of Q columns per CUDA block exactly divide n_gqa. Otherwise the wrong Q and K/V heads will be associated and the result will be wrong (if there is only a single K/V head this doesn't matter so it was not detected in testing).

This PR extends the 4D space on master to a 5D space by splitting the "z" dimension with the number of Q heads into one dimension for the number of K/V heads and another dimension for the number of Q heads per K/V head. This then makes it possible to simply pad the Q columns per CUDA block to a power of 2.

I modified one of the test cases in test-backend-ops to check for this fix. On master n_gqa is set to 1, 4, and 16. I chose these values to check for no GQA optimizations, GQA optimizations with a single CUDA block in z direction, and GQA optimizations with >1 CUDA blocks per z direction. By changing the last value from 16 to 12 it will still cover that case while also checking for the correct padding logic.

@github-actions github-actions bot added testing Everything test related Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Jan 26, 2026
@RodriMora
Copy link
Contributor

Confirmed working good now with https://huggingface.co/bartowski/MiniMaxAI_MiniMax-M2.1-GGUF at Q6.
It was outputting giberish with the latest commit 8f80d1b

@JohannesGaessler JohannesGaessler merged commit b0311c1 into ggml-org:master Jan 26, 2026
76 of 78 checks passed
shaofeiqi pushed a commit to qualcomm/llama.cpp that referenced this pull request Feb 6, 2026
alielfilali01 added a commit to alielfilali01/llama.cpp that referenced this pull request Feb 12, 2026
alielfilali01 added a commit to alielfilali01/llama.cpp that referenced this pull request Feb 19, 2026
CISC added a commit that referenced this pull request Feb 19, 2026
* model: add JAIS-2 architecture support

Add support for the JAIS-2 family of Arabic-English bilingual models
from Inception AI (https://huggingface.co/inceptionai/Jais-2-8B-Chat).

Architecture characteristics:
- LayerNorm (not RMSNorm) with biases
- ReLU² (ReLU squared) activation function
- Separate Q/K/V projections with biases
- Simple MLP without gate projection (up -> act -> down)
- RoPE positional embeddings
- GPT-2 BPE tokenizer

Supported model sizes:
- Jais-2-8B (32 layers, 26 heads, 3328 hidden)
- Jais-2-70B (68 layers, 56 heads, 7168 hidden)

Tested with quantizations: BF16, Q8_0, Q6_K, Q5_K_M, Q5_0, Q4_K_M, Q4_0, Q3_K_M, Q2_K

Note: JAIS-2 requires F32 precision accumulators for numerical stability
and uses standard attention (not flash attention) on CUDA backends.

* fix: run convert_hf_to_gguf_update.py for jais-2 tokenizer hash

* fix: use NEOX RoPE type for JAIS2

* fix: remove Q/K permutation (NEOX RoPE doesn't need it)

* fix: enable flash attention for JAIS2 (fixed by #19115)

* fix: add dedicated JAIS2 pre-tokenizer type and control vector support

- Add LLAMA_VOCAB_PRE_TYPE_JAIS2 with cascading whitespace regex
- Include original regex from tokenizer.json as comment
- Add build_cvec call for control vector support

* no longer necessary to override set_vocab

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
liparetejas pushed a commit to liparetejas/llama.cpp that referenced this pull request Feb 23, 2026
* model: add JAIS-2 architecture support

Add support for the JAIS-2 family of Arabic-English bilingual models
from Inception AI (https://huggingface.co/inceptionai/Jais-2-8B-Chat).

Architecture characteristics:
- LayerNorm (not RMSNorm) with biases
- ReLU² (ReLU squared) activation function
- Separate Q/K/V projections with biases
- Simple MLP without gate projection (up -> act -> down)
- RoPE positional embeddings
- GPT-2 BPE tokenizer

Supported model sizes:
- Jais-2-8B (32 layers, 26 heads, 3328 hidden)
- Jais-2-70B (68 layers, 56 heads, 7168 hidden)

Tested with quantizations: BF16, Q8_0, Q6_K, Q5_K_M, Q5_0, Q4_K_M, Q4_0, Q3_K_M, Q2_K

Note: JAIS-2 requires F32 precision accumulators for numerical stability
and uses standard attention (not flash attention) on CUDA backends.

* fix: run convert_hf_to_gguf_update.py for jais-2 tokenizer hash

* fix: use NEOX RoPE type for JAIS2

* fix: remove Q/K permutation (NEOX RoPE doesn't need it)

* fix: enable flash attention for JAIS2 (fixed by ggml-org#19115)

* fix: add dedicated JAIS2 pre-tokenizer type and control vector support

- Add LLAMA_VOCAB_PRE_TYPE_JAIS2 with cascading whitespace regex
- Include original regex from tokenizer.json as comment
- Add build_cvec call for control vector support

* no longer necessary to override set_vocab

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
bartowski1182 pushed a commit to bartowski1182/llama.cpp that referenced this pull request Mar 2, 2026
* model: add JAIS-2 architecture support

Add support for the JAIS-2 family of Arabic-English bilingual models
from Inception AI (https://huggingface.co/inceptionai/Jais-2-8B-Chat).

Architecture characteristics:
- LayerNorm (not RMSNorm) with biases
- ReLU² (ReLU squared) activation function
- Separate Q/K/V projections with biases
- Simple MLP without gate projection (up -> act -> down)
- RoPE positional embeddings
- GPT-2 BPE tokenizer

Supported model sizes:
- Jais-2-8B (32 layers, 26 heads, 3328 hidden)
- Jais-2-70B (68 layers, 56 heads, 7168 hidden)

Tested with quantizations: BF16, Q8_0, Q6_K, Q5_K_M, Q5_0, Q4_K_M, Q4_0, Q3_K_M, Q2_K

Note: JAIS-2 requires F32 precision accumulators for numerical stability
and uses standard attention (not flash attention) on CUDA backends.

* fix: run convert_hf_to_gguf_update.py for jais-2 tokenizer hash

* fix: use NEOX RoPE type for JAIS2

* fix: remove Q/K permutation (NEOX RoPE doesn't need it)

* fix: enable flash attention for JAIS2 (fixed by ggml-org#19115)

* fix: add dedicated JAIS2 pre-tokenizer type and control vector support

- Add LLAMA_VOCAB_PRE_TYPE_JAIS2 with cascading whitespace regex
- Include original regex from tokenizer.json as comment
- Add build_cvec call for control vector support

* no longer necessary to override set_vocab

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
ArberSephirotheca pushed a commit to ArberSephirotheca/llama.cpp that referenced this pull request Mar 3, 2026
* model: add JAIS-2 architecture support

Add support for the JAIS-2 family of Arabic-English bilingual models
from Inception AI (https://huggingface.co/inceptionai/Jais-2-8B-Chat).

Architecture characteristics:
- LayerNorm (not RMSNorm) with biases
- ReLU² (ReLU squared) activation function
- Separate Q/K/V projections with biases
- Simple MLP without gate projection (up -> act -> down)
- RoPE positional embeddings
- GPT-2 BPE tokenizer

Supported model sizes:
- Jais-2-8B (32 layers, 26 heads, 3328 hidden)
- Jais-2-70B (68 layers, 56 heads, 7168 hidden)

Tested with quantizations: BF16, Q8_0, Q6_K, Q5_K_M, Q5_0, Q4_K_M, Q4_0, Q3_K_M, Q2_K

Note: JAIS-2 requires F32 precision accumulators for numerical stability
and uses standard attention (not flash attention) on CUDA backends.

* fix: run convert_hf_to_gguf_update.py for jais-2 tokenizer hash

* fix: use NEOX RoPE type for JAIS2

* fix: remove Q/K permutation (NEOX RoPE doesn't need it)

* fix: enable flash attention for JAIS2 (fixed by ggml-org#19115)

* fix: add dedicated JAIS2 pre-tokenizer type and control vector support

- Add LLAMA_VOCAB_PRE_TYPE_JAIS2 with cascading whitespace regex
- Include original regex from tokenizer.json as comment
- Add build_cvec call for control vector support

* no longer necessary to override set_vocab

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs testing Everything test related

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Eval bug: Minimax M2.1 outputs gibberish after updating to 7836

3 participants