Skip to content

Commit 5d01fd1

Browse files
committed
wip
1 parent ae657b7 commit 5d01fd1

File tree

4 files changed

+28
-19
lines changed

4 files changed

+28
-19
lines changed

python/tvm/contrib/cutlass/attention_operation.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -169,10 +169,10 @@ def instantiate_flash_attention_template(attrs):
169169
int k_head_stride = ${head_dim};
170170
int v_head_stride = ${head_dim};
171171
int o_head_stride = ${head_dim};
172-
int q_row_stride = q_head_stride * ${num_heads};
173-
int k_row_stride = k_head_stride * ${num_heads};
174-
int v_row_stride = v_head_stride * ${num_heads};
175-
int o_row_stride = o_head_stride * ${num_heads};
172+
int q_row_stride = q_head_stride * ${num_q_heads};
173+
int k_row_stride = k_head_stride * ${num_kv_heads};
174+
int v_row_stride = v_head_stride * ${num_kv_heads};
175+
int o_row_stride = o_head_stride * ${num_q_heads};
176176
int q_batch_stride = q_row_stride * ${num_queries};
177177
int k_batch_stride = k_row_stride * ${num_keys};
178178
int v_batch_stride = v_row_stride * ${num_keys};
@@ -190,8 +190,8 @@ def instantiate_flash_attention_template(attrs):
190190
${num_batches},
191191
${num_queries},
192192
${num_keys},
193-
${num_heads},
194-
${num_heads},
193+
${num_q_heads},
194+
${num_kv_heads},
195195
${head_dim},
196196
q_batch_stride,
197197
k_batch_stride,
@@ -215,9 +215,9 @@ def instantiate_flash_attention_template(attrs):
215215
int k_head_stride = ${head_dim};
216216
int v_head_stride = ${head_dim};
217217
int o_head_stride = ${head_dim};
218-
int row_stride = q_head_stride * ${num_heads} +
219-
k_head_stride * ${num_heads} +
220-
v_head_stride * ${num_heads};
218+
int row_stride = q_head_stride * ${num_q_heads} +
219+
k_head_stride * ${num_kv_heads} +
220+
v_head_stride * ${num_kv_heads};
221221
int q_row_stride = row_stride;
222222
int k_row_stride = row_stride;
223223
int v_row_stride = row_stride;
@@ -234,14 +234,14 @@ def instantiate_flash_attention_template(attrs):
234234
235235
flash_attn::flash_attention_forward(
236236
static_cast<const cutlass::half_t*>(${qkv}->data),
237-
static_cast<const cutlass::half_t*>(${qkv}->data) + ${head_dim} * ${num_heads},
238-
static_cast<const cutlass::half_t*>(${qkv}->data) + ${head_dim} * ${num_heads} * 2,
237+
static_cast<const cutlass::half_t*>(${qkv}->data) + ${head_dim} * ${num_q_heads},
238+
static_cast<const cutlass::half_t*>(${qkv}->data) + ${head_dim} * (${num_q_heads} + ${num_kv_heads})
239239
static_cast<cutlass::half_t*>(out0->data),
240240
${num_batches},
241241
${num_queries},
242242
${num_keys},
243-
${num_heads},
244-
${num_heads},
243+
${num_q_heads},
244+
${num_kv_heads},
245245
${head_dim},
246246
q_batch_stride,
247247
k_batch_stride,

python/tvm/contrib/cutlass/build.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -909,8 +909,8 @@ def handle_attention(self, f, op_type):
909909

910910
out_shape = signature["ret_shape"]
911911
out_dtype = signature["ret_dtype"]
912-
num_batches, num_queries, num_heads, head_dim = q_shape
913-
_, num_keys, _, _ = k_shape
912+
num_batches, num_queries, num_q_heads, head_dim = q_shape
913+
_, num_keys, num_kv_heads, _ = k_shape
914914
_, _, _, head_dim_value = v_shape
915915
scale = op_attrs.scale
916916

@@ -931,7 +931,8 @@ def handle_attention(self, f, op_type):
931931
"num_batches": num_batches,
932932
"num_queries": num_queries,
933933
"num_keys": num_keys,
934-
"num_heads": num_heads,
934+
"num_q_heads": num_q_heads,
935+
"num_kv_heads": num_kv_heads,
935936
"head_dim": head_dim,
936937
"head_dim_value": head_dim_value,
937938
"scale": scale,

python/tvm/contrib/cutlass/gen_tensor_op.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -745,7 +745,6 @@ def get_batch_on_arg(arg_name, arg_shape):
745745

746746
attrs["data_type"] = DataTypeTag[data_type]
747747
attrs["num_batches"] = b = annotations["num_batches"]
748-
attrs["num_heads"] = n = annotations["num_heads"]
749748
attrs["head_dim"] = h = annotations["head_dim"]
750749
attrs["head_dim_value"] = h_v = annotations["head_dim_value"]
751750
attrs["kMaxK"] = max(int(attrs["head_dim"]), int(attrs["head_dim_value"]))
@@ -766,13 +765,21 @@ def get_batch_on_arg(arg_name, arg_shape):
766765
and int(annotations["arch"]) >= 80
767766
)
768767

768+
print(int(attrs["head_dim"]) <= 256, int(attrs["head_dim"]) % 8 == 0, int(attrs["head_dim"]) == int(attrs["head_dim_value"]),int(annotations["arch"]) >= 80, annotations["ret_dtype"] == "float16", "bias" not in attrs, int(annotations["arch"]) >= 80)
769+
770+
769771
if use_flash:
770772
headers.append("flash.h")
771773
attrs["is_causal"] = int(annotations["custom_mask_type"]) > 0
774+
attrs["num_q_heads"] = annotations["num_q_heads"]
775+
attrs["num_kv_heads"] = annotations["num_kv_heads"]
772776
code = instantiate_flash_attention_template(attrs)
773777
else:
774778
headers.append("kernel_forward.h")
775779

780+
assert annotations["num_q_heads"] == annotations["num_kv_heads"]
781+
attrs["num_heads"] = n = annotations["num_q_heads"]
782+
776783
data_type_size = DataTypeSize[data_type]
777784
if (data_type_size * h // 8) % 16 == 0 and (data_type_size * h_v // 8) % 16 == 0:
778785
attrs["kIsAligned"] = True

tests/python/relax/test_codegen_cutlass.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1968,7 +1968,7 @@ def rewrite_attention(f):
19681968

19691969
def callback(_, matchings):
19701970
return R.nn.attention(
1971-
matchings[Q], matchings[K], matchings[V], causal_mask="BottomRight"
1971+
matchings[Q], matchings[K], matchings[V],
19721972
)
19731973

19741974
return rewrite_call(pattern, callback, f)
@@ -2007,7 +2007,8 @@ def main(
20072007

20082008
Module["main"] = rewrite_attention(Module["main"])
20092009
mod = partition_for_cutlass(Module)
2010-
print(mod)
2010+
codegen_pass = relax.transform.RunCodegen({"cutlass": {"sm": 80, "find_first_valid": True}})
2011+
print(codegen_pass(mod))
20112012

20122013
if __name__ == "__main__":
20132014
# tvm.testing.main()

0 commit comments

Comments
 (0)