File tree Expand file tree Collapse file tree 1 file changed +2
-2
lines changed
python/tvm/contrib/cutlass Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Original file line number Diff line number Diff line change @@ -221,7 +221,7 @@ def instantiate_flash_attention_template(attrs):
221221 int q_row_stride = row_stride;
222222 int k_row_stride = row_stride;
223223 int v_row_stride = row_stride;
224- int o_row_stride = o_head_stride * ${num_heads };
224+ int o_row_stride = o_head_stride * ${num_q_heads };
225225
226226 int q_batch_stride = q_row_stride * ${num_queries};
227227 int k_batch_stride = k_row_stride * ${num_keys};
@@ -235,7 +235,7 @@ def instantiate_flash_attention_template(attrs):
235235 flash_attn::flash_attention_forward(
236236 static_cast<const cutlass::half_t*>(${qkv}->data),
237237 static_cast<const cutlass::half_t*>(${qkv}->data) + ${head_dim} * ${num_q_heads},
238- static_cast<const cutlass::half_t*>(${qkv}->data) + ${head_dim} * (${num_q_heads} + ${num_kv_heads})
238+ static_cast<const cutlass::half_t*>(${qkv}->data) + ${head_dim} * (${num_q_heads} + ${num_kv_heads}),
239239 static_cast<cutlass::half_t*>(out0->data),
240240 ${num_batches},
241241 ${num_queries},
You can’t perform that action at this time.
0 commit comments