Skip to content

Commit ae657b7

Browse files
committed
wip
1 parent ddcab38 commit ae657b7

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

src/relax/op/nn/attention.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,8 @@ StructInfo InferStructInfoAttention(const Call& call, const BlockBuilder& ctx) {
7979
};
8080
diag_equal(num_batches, k_shape->values[0], "query", "key", "batch size");
8181
diag_equal(num_batches, v_shape->values[0], "query", "value", "batch size");
82-
diag_equal(num_heads, k_shape->values[2], "query", "key", "number of heads");
83-
diag_equal(num_heads, v_shape->values[2], "query", "value", "number of heads");
82+
// diag_equal(num_heads, k_shape->values[2], "query", "key", "number of heads");
83+
// diag_equal(num_heads, v_shape->values[2], "query", "value", "number of heads");
8484
diag_equal(num_keys, v_shape->values[1], "key", "value", "sequence length");
8585
diag_equal(head_dim, k_shape->values[3], "query", "key", "dimension of heads");
8686

tests/python/relax/test_codegen_cutlass.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2006,7 +2006,8 @@ def main(
20062006
return lv6_1
20072007

20082008
Module["main"] = rewrite_attention(Module["main"])
2009-
print(Module)
2009+
mod = partition_for_cutlass(Module)
2010+
print(mod)
20102011

20112012
if __name__ == "__main__":
20122013
# tvm.testing.main()

0 commit comments

Comments
 (0)