@@ -201,6 +201,11 @@ void main() {
201201    uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
202202    uint32_t k_stride = p.nb11;
203203    uint32_t v_stride = p.nb21;
204+     // When using grouped query attention, all rows use the same mask (stride 0).
205+     // "p.gqa_ratio >> 16" is just a roundabout way of writing zero
206+     // that prevents the compiler from folding the "&" through the select
207+     // and breaking the alignment detection.
208+     uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
204209    // hint to the compiler that strides are aligned for the aligned variant of the shader
205210    if (Clamp != gl_CooperativeMatrixClampModeConstantNV)
206211    {
@@ -209,6 +214,7 @@ void main() {
209214        k_stride &= ~7;
210215        v_stride &= ~7;
211216#endif
217+         m_stride &= ~7;
212218    }
213219    tensorLayoutQ = setTensorLayoutStrideNV(tensorLayoutQ, q_stride, 1);
214220    tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1);
@@ -261,10 +267,7 @@ void main() {
261267        if (p.mask != 0) {
262268            tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
263269            tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
264-             // When using grouped query attention, all rows use the same mask.
265-             if (p.gqa_ratio > 1) {
266-                 tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, 0, 1);
267-             }
270+             tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
268271
269272            coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
270273
0 commit comments