Skip to content

Commit 233d2d0

Browse files
committed
wip
1 parent 0a6a617 commit 233d2d0

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

python/tvm/contrib/cutlass/attention_operation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ def instantiate_flash_attention_template(attrs):
221221
int q_row_stride = row_stride;
222222
int k_row_stride = row_stride;
223223
int v_row_stride = row_stride;
224-
int o_row_stride = o_head_stride * ${num_heads};
224+
int o_row_stride = o_head_stride * ${num_q_heads};
225225
226226
int q_batch_stride = q_row_stride * ${num_queries};
227227
int k_batch_stride = k_row_stride * ${num_keys};
@@ -235,7 +235,7 @@ def instantiate_flash_attention_template(attrs):
235235
flash_attn::flash_attention_forward(
236236
static_cast<const cutlass::half_t*>(${qkv}->data),
237237
static_cast<const cutlass::half_t*>(${qkv}->data) + ${head_dim} * ${num_q_heads},
238-
static_cast<const cutlass::half_t*>(${qkv}->data) + ${head_dim} * (${num_q_heads} + ${num_kv_heads})
238+
static_cast<const cutlass::half_t*>(${qkv}->data) + ${head_dim} * (${num_q_heads} + ${num_kv_heads}),
239239
static_cast<cutlass::half_t*>(out0->data),
240240
${num_batches},
241241
${num_queries},

0 commit comments

Comments
 (0)