From a54a0ebebee705e496bab4efe5eae9ff22203f11 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 27 Sep 2023 09:58:19 +0900 Subject: [PATCH 1/7] Squashed commit of the following: commit 99c2a59a1226f372c50c347c961d0c1201680a3e Author: Masahiro Masuda Date: Wed Sep 27 09:57:21 2023 +0900 Revert "Revert "[Unity] Avoid trivial `var2 = var1` bindings in pattern matcher (#15578)"" This reverts commit 0a6a617e1315f3bc1550e5dc0e4630495e7fe70d. commit 9a3ca64cfa2152628f5704a68383d36949403900 Author: Masahiro Masuda Date: Wed Sep 27 09:55:02 2023 +0900 wip commit be01900d59db94bbccc3d8142d95c302dade7ca2 Author: Masahiro Masuda Date: Tue Sep 26 19:55:29 2023 +0900 fix test commit a026b650002b07833808d078ede41243796f9a95 Author: Masahiro Masuda Date: Thu Aug 31 22:24:38 2023 +0000 wip commit 233d2d0fa7bb1a981f792645e8394d95e8d31cb4 Author: Masahiro Masuda Date: Tue Aug 29 17:42:11 2023 +0000 wip commit 0a6a617e1315f3bc1550e5dc0e4630495e7fe70d Author: Masahiro Masuda Date: Tue Aug 29 17:28:25 2023 +0000 Revert "[Unity] Avoid trivial `var2 = var1` bindings in pattern matcher (#15578)" This reverts commit 567848e3a08a3bcb1ed69344050bc648a101d9b9. commit 6c5a4355e4cd487434c598909b7338159161624e Author: Masahiro Masuda Date: Tue Aug 29 17:28:16 2023 +0000 wip commit 7926cbc9d890c0c07376176a26144b8603bb9732 Author: Masahiro Masuda Date: Mon Aug 28 06:17:01 2023 +0000 wip commit 9828698ca3d808da8a77a432686c8be5dd4dab38 Author: Masahiro Masuda Date: Mon Aug 28 15:11:47 2023 +0900 wip commit 5d01fd1310fd5df98bf5fd56986056e20352ad3d Author: Masahiro Masuda Date: Mon Aug 28 06:05:56 2023 +0000 wip commit ae657b7aed678fa7f7727aebcc9221940e45de26 Author: Masahiro Masuda Date: Mon Aug 28 14:49:21 2023 +0900 wip commit ddcab3887fef5c851689714f2f3924201165591d Author: Masahiro Masuda Date: Mon Aug 28 05:42:41 2023 +0000 wip commit ab3572d852e21af3d4b349afd999654f491dcee8 Author: Masahiro Masuda Date: Mon Aug 28 10:40:34 2023 +0900 wip commit 690b88ef2380fc3ab5e3e02fed61cdf2936e0811 Author: Masahiro Masuda Date: Mon Aug 28 10:25:33 2023 +0900 update rev --- 3rdparty/libflash_attn | 2 +- .../contrib/cutlass/attention_operation.py | 28 ++++---- python/tvm/contrib/cutlass/build.py | 10 ++- python/tvm/contrib/cutlass/gen_tensor_op.py | 15 +++-- python/tvm/relax/backend/contrib/cutlass.py | 7 +- python/tvm/relax/backend/patterns.py | 11 ++-- src/relax/op/nn/attention.cc | 4 +- tests/python/relax/test_codegen_cutlass.py | 64 +++++++++++++++++-- 8 files changed, 105 insertions(+), 36 deletions(-) diff --git a/3rdparty/libflash_attn b/3rdparty/libflash_attn index 58b343e57571..aaf92298f1a8 160000 --- a/3rdparty/libflash_attn +++ b/3rdparty/libflash_attn @@ -1 +1 @@ -Subproject commit 58b343e57571fe5e0a5b43b5eb721acef8b35dff +Subproject commit aaf92298f1a8d60f33824e059e70b6e72c73e0ff diff --git a/python/tvm/contrib/cutlass/attention_operation.py b/python/tvm/contrib/cutlass/attention_operation.py index 67a68df442f8..e59dbf032e6a 100644 --- a/python/tvm/contrib/cutlass/attention_operation.py +++ b/python/tvm/contrib/cutlass/attention_operation.py @@ -169,10 +169,10 @@ def instantiate_flash_attention_template(attrs): int k_head_stride = ${head_dim}; int v_head_stride = ${head_dim}; int o_head_stride = ${head_dim}; - int q_row_stride = q_head_stride * ${num_heads}; - int k_row_stride = k_head_stride * ${num_heads}; - int v_row_stride = v_head_stride * ${num_heads}; - int o_row_stride = o_head_stride * ${num_heads}; + int q_row_stride = q_head_stride * ${num_q_heads}; + int k_row_stride = k_head_stride * ${num_kv_heads}; + int v_row_stride = v_head_stride * ${num_kv_heads}; + int o_row_stride = o_head_stride * ${num_q_heads}; int q_batch_stride = q_row_stride * ${num_queries}; int k_batch_stride = k_row_stride * ${num_keys}; int v_batch_stride = v_row_stride * ${num_keys}; @@ -190,8 +190,8 @@ def instantiate_flash_attention_template(attrs): ${num_batches}, ${num_queries}, ${num_keys}, - ${num_heads}, - ${num_heads}, + ${num_q_heads}, + ${num_kv_heads}, ${head_dim}, q_batch_stride, k_batch_stride, @@ -215,13 +215,13 @@ def instantiate_flash_attention_template(attrs): int k_head_stride = ${head_dim}; int v_head_stride = ${head_dim}; int o_head_stride = ${head_dim}; - int row_stride = q_head_stride * ${num_heads} + - k_head_stride * ${num_heads} + - v_head_stride * ${num_heads}; + int row_stride = q_head_stride * ${num_q_heads} + + k_head_stride * ${num_kv_heads} + + v_head_stride * ${num_kv_heads}; int q_row_stride = row_stride; int k_row_stride = row_stride; int v_row_stride = row_stride; - int o_row_stride = o_head_stride * ${num_heads}; + int o_row_stride = o_head_stride * ${num_q_heads}; int q_batch_stride = q_row_stride * ${num_queries}; int k_batch_stride = k_row_stride * ${num_keys}; @@ -234,14 +234,14 @@ def instantiate_flash_attention_template(attrs): flash_attn::flash_attention_forward( static_cast(${qkv}->data), - static_cast(${qkv}->data) + ${head_dim} * ${num_heads}, - static_cast(${qkv}->data) + ${head_dim} * ${num_heads} * 2, + static_cast(${qkv}->data) + ${head_dim} * ${num_q_heads}, + static_cast(${qkv}->data) + ${head_dim} * (${num_q_heads} + ${num_kv_heads}), static_cast(out0->data), ${num_batches}, ${num_queries}, ${num_keys}, - ${num_heads}, - ${num_heads}, + ${num_q_heads}, + ${num_kv_heads}, ${head_dim}, q_batch_stride, k_batch_stride, diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py index 0c57c4750e87..9267c98cd29e 100644 --- a/python/tvm/contrib/cutlass/build.py +++ b/python/tvm/contrib/cutlass/build.py @@ -909,8 +909,8 @@ def handle_attention(self, f, op_type): out_shape = signature["ret_shape"] out_dtype = signature["ret_dtype"] - num_batches, num_queries, num_heads, head_dim = q_shape - _, num_keys, _, _ = k_shape + num_batches, num_queries, num_q_heads, head_dim = q_shape + _, num_keys, num_kv_heads, _ = k_shape _, _, _, head_dim_value = v_shape scale = op_attrs.scale @@ -931,13 +931,15 @@ def handle_attention(self, f, op_type): "num_batches": num_batches, "num_queries": num_queries, "num_keys": num_keys, - "num_heads": num_heads, + "num_q_heads": num_q_heads, + "num_kv_heads": num_kv_heads, "head_dim": head_dim, "head_dim_value": head_dim_value, "scale": scale, "arch": self.options["sm"], "qkv_layout": qkv_layout, "custom_mask_type": custom_mask_type, + "disable_flash": self.options.get("disable_flash", False), **arg, } ) @@ -982,6 +984,8 @@ def profile_relax_function(functions, options): """Tune and annotate CUTLASS composite functions with shape, dtype and generated templates.""" tmp_dir = options.get("tmp_dir", "./tmp") sm = options.get("sm", 80) + print(options) + conv2d_profiler = CutlassConv2DProfiler(sm, _get_cutlass_path(), tmp_dir) gemm_profiler = CutlassGemmProfiler(sm, _get_cutlass_path(), tmp_dir) diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py b/python/tvm/contrib/cutlass/gen_tensor_op.py index 58bc91863dcc..727235f78d2e 100644 --- a/python/tvm/contrib/cutlass/gen_tensor_op.py +++ b/python/tvm/contrib/cutlass/gen_tensor_op.py @@ -745,7 +745,6 @@ def get_batch_on_arg(arg_name, arg_shape): attrs["data_type"] = DataTypeTag[data_type] attrs["num_batches"] = b = annotations["num_batches"] - attrs["num_heads"] = n = annotations["num_heads"] attrs["head_dim"] = h = annotations["head_dim"] attrs["head_dim_value"] = h_v = annotations["head_dim_value"] attrs["kMaxK"] = max(int(attrs["head_dim"]), int(attrs["head_dim_value"])) @@ -754,25 +753,29 @@ def get_batch_on_arg(arg_name, arg_shape): ) use_flash = ( - annotations["ret_dtype"] == "float16" + not annotations["disable_flash"] + and annotations["ret_dtype"] == "float16" and "bias" not in attrs and int(attrs["head_dim"]) <= 256 and int(attrs["head_dim"]) % 8 == 0 and int(attrs["head_dim"]) == int(attrs["head_dim_value"]) - # We have not thoroughly validated flash with causal mask yet, so for now we support - # only non-causal cases. - and int(annotations["custom_mask_type"]) == 0 + and int(annotations["custom_mask_type"]) in (0, 2) # Flash v2 is currently not supported for sm < 80 and int(annotations["arch"]) >= 80 ) if use_flash: headers.append("flash.h") - attrs["is_causal"] = int(annotations["custom_mask_type"]) > 0 + attrs["is_causal"] = int(annotations["custom_mask_type"]) == 2 + attrs["num_q_heads"] = annotations["num_q_heads"] + attrs["num_kv_heads"] = annotations["num_kv_heads"] code = instantiate_flash_attention_template(attrs) else: headers.append("kernel_forward.h") + assert annotations["num_q_heads"] == annotations["num_kv_heads"] + attrs["num_heads"] = n = annotations["num_q_heads"] + data_type_size = DataTypeSize[data_type] if (data_type_size * h // 8) % 16 == 0 and (data_type_size * h_v // 8) % 16 == 0: attrs["kIsAligned"] = True diff --git a/python/tvm/relax/backend/contrib/cutlass.py b/python/tvm/relax/backend/contrib/cutlass.py index fef6a1ec03c4..51594d41b6a9 100644 --- a/python/tvm/relax/backend/contrib/cutlass.py +++ b/python/tvm/relax/backend/contrib/cutlass.py @@ -576,7 +576,7 @@ def annotate_workspace(mod, _): return mod -def partition_for_cutlass(mod, annotate_codegen=True): +def partition_for_cutlass(mod, annotate_codegen=True, use_flash_attn=True): """ Partition the input module into CUTLASS-supported subgraphs. @@ -598,8 +598,13 @@ def partition_for_cutlass(mod, annotate_codegen=True): """ for func_name, func in mod.functions.items(): if isinstance(func, Function): + if use_flash_attn: + mqa_pattern, rewriter = make_attention_rewrite_pattern("BSNH", "BSNH", with_bias=False, with_cast=True, with_kv_repeat=True) + func = rewrite_call(mqa_pattern, rewriter, func) + for pattern, rewriter in _REWRITE_PATTERNS: func = rewrite_call(pattern, rewriter, func) + mod[func_name] = func patterns = get_patterns_with_prefix("cutlass") diff --git a/python/tvm/relax/backend/patterns.py b/python/tvm/relax/backend/patterns.py index 24edd0e7c950..b1a701dd91bd 100644 --- a/python/tvm/relax/backend/patterns.py +++ b/python/tvm/relax/backend/patterns.py @@ -318,7 +318,7 @@ def make_rms_norm_pattern(): def make_attention_rewrite_pattern( - qkv_layout: str, out_layout: str, with_bias: bool, with_cast: bool + qkv_layout: str, out_layout: str, with_bias: bool, with_cast: bool, with_kv_repeat: bool = False ): """ Create pattern for implicit fused multi head attention rewriting. @@ -350,7 +350,10 @@ def make_attention_rewrite_pattern( """ # pylint: disable=invalid-name - def handle_input(tensor, layout, transpose): + def handle_input(tensor, layout, transpose, repeat=False): + if repeat: + tensor = is_op("relax.repeat")(tensor) + if layout == "BSNH": permuted = is_op("relax.permute_dims")(tensor) shape = wildcard() @@ -434,8 +437,8 @@ def rewriter(matchings, x): q_raw, k_raw, v_raw = wildcard(), wildcard(), wildcard() q, q_rewriter = handle_input(q_raw, qkv_layout, False) - k, k_rewriter = handle_input(k_raw, qkv_layout, True) - v, v_rewriter = handle_input(v_raw, qkv_layout, False) + k, k_rewriter = handle_input(k_raw, qkv_layout, True, repeat=with_kv_repeat) + v, v_rewriter = handle_input(v_raw, qkv_layout, False, repeat=with_kv_repeat) matmul_1 = is_op("relax.matmul")(q, k) scale = is_const() diff --git a/src/relax/op/nn/attention.cc b/src/relax/op/nn/attention.cc index 4f37e3a33c29..e97e0df25b26 100644 --- a/src/relax/op/nn/attention.cc +++ b/src/relax/op/nn/attention.cc @@ -79,8 +79,8 @@ StructInfo InferStructInfoAttention(const Call& call, const BlockBuilder& ctx) { }; diag_equal(num_batches, k_shape->values[0], "query", "key", "batch size"); diag_equal(num_batches, v_shape->values[0], "query", "value", "batch size"); - diag_equal(num_heads, k_shape->values[2], "query", "key", "number of heads"); - diag_equal(num_heads, v_shape->values[2], "query", "value", "number of heads"); + // diag_equal(num_heads, k_shape->values[2], "query", "key", "number of heads"); + // diag_equal(num_heads, v_shape->values[2], "query", "value", "number of heads"); diag_equal(num_keys, v_shape->values[1], "key", "value", "sequence length"); diag_equal(head_dim, k_shape->values[3], "query", "key", "dimension of heads"); diff --git a/tests/python/relax/test_codegen_cutlass.py b/tests/python/relax/test_codegen_cutlass.py index e8d4e83521b0..d298063732f3 100644 --- a/tests/python/relax/test_codegen_cutlass.py +++ b/tests/python/relax/test_codegen_cutlass.py @@ -113,7 +113,7 @@ def get_result_with_relax_cutlass_offload( if assert_all_bindings_fused: assert len(mod["main"].body.blocks[0].bindings) == num_final_bindings - codegen_pass = relax.transform.RunCodegen({"cutlass": {"sm": 80, "find_first_valid": True}}) + codegen_pass = relax.transform.RunCodegen({"cutlass": {"sm": 80, "find_first_valid": True, "disable_flash": True}}) mod = codegen_pass(mod) return build_and_run(mod, args, "cuda") @@ -746,7 +746,7 @@ def attention_causal(request): def test_attention_causal_offload(attention_causal_size, attention_causal): b, (s, s_kv), n, (h, h_v), bias_shape = attention_causal_size q, k, v, bias, ref = get_numpy_attention_ref( - b, s, s_kv, n, h, h_v, bias_shape, "none", attention_causal, "float32" + b, s, s_kv, n, h, h_v, bias_shape, "none", attention_causal, "float16" ) q_shape = (b, s, n, h) @@ -757,10 +757,11 @@ def test_attention_causal_offload(attention_causal_size, attention_causal): q_shape, k_shape, v_shape, - dtype="float32", + dtype="float16", bias_shape=bias_shape, causal_mask=attention_causal, ) + if bias is None: out = get_result_with_relax_cutlass_offload(mod, q, k, v, num_final_bindings=3) else: @@ -1930,7 +1931,10 @@ def main( ex = relax.build(mod_transform, target="llvm") vm = relax.vm.VirtualMachine(ex, tvm.cpu(0)) - (packed_weight, scales,) = vm[ + ( + packed_weight, + scales, + ) = vm[ transform_func_name ]((tvm.nd.array(y),)) @@ -1945,5 +1949,55 @@ def main( tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) +def test_attention_rewrite_multi_query(): + @I.ir_module + class Module: + @R.function + def main( + q: R.Tensor((4, 16, 32, 16), dtype="float16"), + k_single: R.Tensor((4, 16, 1, 16), dtype="float16"), + v_single: R.Tensor((4, 16, 1, 16), dtype="float16"), + ) -> R.Tensor((4, 16, 32, 8), dtype="float16"): + with R.dataflow(): + k = R.repeat(k_single, 32, axis=2) + v = R.repeat(v_single, 32, axis=2) + + lv = R.permute_dims(q, axes=[0, 2, 1, 3]) + lv1 = R.reshape(lv, R.shape([128, 16, 16])) + lv2 = R.permute_dims(k, axes=[0, 2, 1, 3]) + lv3 = R.reshape(lv2, R.shape([128, 16, 16])) + lv4 = R.permute_dims(v, axes=[0, 2, 1, 3]) + lv5 = R.reshape(lv4, R.shape([128, 16, 16])) + + lv6 = R.permute_dims(lv3, axes=[0, 2, 1]) + lv7 = R.matmul(lv1, lv6, out_dtype="float16") + lv3_1 = R.astype(R.const(0.25, "float32"), "float16") + lv8 = R.multiply(lv7, lv3_1) + lv11 = R.astype(R.nn.softmax(R.astype(lv8, "float32"), axis=2), "float16") + lv12 = R.matmul(lv11, lv5, out_dtype="float16") + lv13 = R.reshape(lv12, R.shape([4, 32, 16, 16])) + lv6_1 = R.permute_dims(lv13, axes=[0, 2, 1, 3]) + R.output(lv6_1) + return lv6_1 + + q_np = np.random.randn(4, 16, 32, 16).astype("float16") + k_np = np.random.randn(4, 16, 1, 16).astype("float16") + v_np = np.random.randn(4, 16, 1, 16).astype("float16") + args = [q_np, k_np, v_np] + ref = build_and_run(Module, args, "llvm", legalize=True) + + mod = partition_for_cutlass(Module, use_flash_attn=True) + codegen_pass = relax.transform.RunCodegen({"cutlass": {"sm": 80, "find_first_valid": True, "disable_flash": False}}) + mod = codegen_pass(mod) + + out = build_and_run(mod, args, "cuda") + + print(np.max(np.abs(out - ref)), np.mean(np.abs(out - ref))) + tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) + + if __name__ == "__main__": - tvm.testing.main() + # tvm.testing.main() + test_attention_rewrite_multi_query() + # test_attention_offload((4, (16, 16), 32, (8, 8)), "float16") + # test_attention_causal_offload((1, (1, 8), 4, (16, 16), "none"), "BottomRight") From cea294e2a05e82af8bb1468e7bca3c88ecf8d961 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 27 Sep 2023 10:01:40 +0900 Subject: [PATCH 2/7] black --- python/tvm/relax/backend/contrib/cutlass.py | 4 +++- python/tvm/relax/backend/patterns.py | 2 +- tests/python/relax/test_codegen_cutlass.py | 8 ++++++-- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/python/tvm/relax/backend/contrib/cutlass.py b/python/tvm/relax/backend/contrib/cutlass.py index 51594d41b6a9..459e3256192a 100644 --- a/python/tvm/relax/backend/contrib/cutlass.py +++ b/python/tvm/relax/backend/contrib/cutlass.py @@ -599,7 +599,9 @@ def partition_for_cutlass(mod, annotate_codegen=True, use_flash_attn=True): for func_name, func in mod.functions.items(): if isinstance(func, Function): if use_flash_attn: - mqa_pattern, rewriter = make_attention_rewrite_pattern("BSNH", "BSNH", with_bias=False, with_cast=True, with_kv_repeat=True) + mqa_pattern, rewriter = make_attention_rewrite_pattern( + "BSNH", "BSNH", with_bias=False, with_cast=True, with_kv_repeat=True + ) func = rewrite_call(mqa_pattern, rewriter, func) for pattern, rewriter in _REWRITE_PATTERNS: diff --git a/python/tvm/relax/backend/patterns.py b/python/tvm/relax/backend/patterns.py index b1a701dd91bd..5fc32b2ba045 100644 --- a/python/tvm/relax/backend/patterns.py +++ b/python/tvm/relax/backend/patterns.py @@ -318,7 +318,7 @@ def make_rms_norm_pattern(): def make_attention_rewrite_pattern( - qkv_layout: str, out_layout: str, with_bias: bool, with_cast: bool, with_kv_repeat: bool = False + qkv_layout: str, out_layout: str, with_bias: bool, with_cast: bool, with_kv_repeat: bool = False ): """ Create pattern for implicit fused multi head attention rewriting. diff --git a/tests/python/relax/test_codegen_cutlass.py b/tests/python/relax/test_codegen_cutlass.py index d298063732f3..1c69d8a652c2 100644 --- a/tests/python/relax/test_codegen_cutlass.py +++ b/tests/python/relax/test_codegen_cutlass.py @@ -113,7 +113,9 @@ def get_result_with_relax_cutlass_offload( if assert_all_bindings_fused: assert len(mod["main"].body.blocks[0].bindings) == num_final_bindings - codegen_pass = relax.transform.RunCodegen({"cutlass": {"sm": 80, "find_first_valid": True, "disable_flash": True}}) + codegen_pass = relax.transform.RunCodegen( + {"cutlass": {"sm": 80, "find_first_valid": True, "disable_flash": True}} + ) mod = codegen_pass(mod) return build_and_run(mod, args, "cuda") @@ -1987,7 +1989,9 @@ def main( ref = build_and_run(Module, args, "llvm", legalize=True) mod = partition_for_cutlass(Module, use_flash_attn=True) - codegen_pass = relax.transform.RunCodegen({"cutlass": {"sm": 80, "find_first_valid": True, "disable_flash": False}}) + codegen_pass = relax.transform.RunCodegen( + {"cutlass": {"sm": 80, "find_first_valid": True, "disable_flash": False}} + ) mod = codegen_pass(mod) out = build_and_run(mod, args, "cuda") From 3f659b73c861b95bd2d705cca75aa6e539985bf1 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 27 Sep 2023 10:17:17 +0900 Subject: [PATCH 3/7] clean --- python/tvm/contrib/cutlass/build.py | 2 -- python/tvm/contrib/cutlass/gen_tensor_op.py | 11 ++++++++--- python/tvm/relax/backend/contrib/cutlass.py | 4 ++-- tests/python/relax/test_codegen_cutlass.py | 13 +++---------- 4 files changed, 13 insertions(+), 17 deletions(-) diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py index 9267c98cd29e..9f5d50a0e30d 100644 --- a/python/tvm/contrib/cutlass/build.py +++ b/python/tvm/contrib/cutlass/build.py @@ -939,7 +939,6 @@ def handle_attention(self, f, op_type): "arch": self.options["sm"], "qkv_layout": qkv_layout, "custom_mask_type": custom_mask_type, - "disable_flash": self.options.get("disable_flash", False), **arg, } ) @@ -984,7 +983,6 @@ def profile_relax_function(functions, options): """Tune and annotate CUTLASS composite functions with shape, dtype and generated templates.""" tmp_dir = options.get("tmp_dir", "./tmp") sm = options.get("sm", 80) - print(options) conv2d_profiler = CutlassConv2DProfiler(sm, _get_cutlass_path(), tmp_dir) gemm_profiler = CutlassGemmProfiler(sm, _get_cutlass_path(), tmp_dir) diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py b/python/tvm/contrib/cutlass/gen_tensor_op.py index 727235f78d2e..2e58d60035be 100644 --- a/python/tvm/contrib/cutlass/gen_tensor_op.py +++ b/python/tvm/contrib/cutlass/gen_tensor_op.py @@ -753,13 +753,18 @@ def get_batch_on_arg(arg_name, arg_shape): ) use_flash = ( - not annotations["disable_flash"] - and annotations["ret_dtype"] == "float16" + annotations["ret_dtype"] == "float16" and "bias" not in attrs and int(attrs["head_dim"]) <= 256 and int(attrs["head_dim"]) % 8 == 0 and int(attrs["head_dim"]) == int(attrs["head_dim_value"]) - and int(annotations["custom_mask_type"]) in (0, 2) + # For the causal case (custom mask = "BottomRight"), only use flash for multi-query + # attention workloads (indicated by the "repeat" op in the pattern). + # Otherwise, CUTLASS fMHA seems faster for causal attention with a single query. + and ( + int(annotations["custom_mask_type"]) == 0 + or (int(annotations["custom_mask_type"]) == 2 and "repeat" in func_name) + ) # Flash v2 is currently not supported for sm < 80 and int(annotations["arch"]) >= 80 ) diff --git a/python/tvm/relax/backend/contrib/cutlass.py b/python/tvm/relax/backend/contrib/cutlass.py index 459e3256192a..b6585d9706bf 100644 --- a/python/tvm/relax/backend/contrib/cutlass.py +++ b/python/tvm/relax/backend/contrib/cutlass.py @@ -576,7 +576,7 @@ def annotate_workspace(mod, _): return mod -def partition_for_cutlass(mod, annotate_codegen=True, use_flash_attn=True): +def partition_for_cutlass(mod, annotate_codegen=True, use_flash_mqa=True): """ Partition the input module into CUTLASS-supported subgraphs. @@ -598,7 +598,7 @@ def partition_for_cutlass(mod, annotate_codegen=True, use_flash_attn=True): """ for func_name, func in mod.functions.items(): if isinstance(func, Function): - if use_flash_attn: + if use_flash_mqa: mqa_pattern, rewriter = make_attention_rewrite_pattern( "BSNH", "BSNH", with_bias=False, with_cast=True, with_kv_repeat=True ) diff --git a/tests/python/relax/test_codegen_cutlass.py b/tests/python/relax/test_codegen_cutlass.py index 1c69d8a652c2..bd40d0738b34 100644 --- a/tests/python/relax/test_codegen_cutlass.py +++ b/tests/python/relax/test_codegen_cutlass.py @@ -113,9 +113,7 @@ def get_result_with_relax_cutlass_offload( if assert_all_bindings_fused: assert len(mod["main"].body.blocks[0].bindings) == num_final_bindings - codegen_pass = relax.transform.RunCodegen( - {"cutlass": {"sm": 80, "find_first_valid": True, "disable_flash": True}} - ) + codegen_pass = relax.transform.RunCodegen({"cutlass": {"sm": 80, "find_first_valid": True}}) mod = codegen_pass(mod) return build_and_run(mod, args, "cuda") @@ -1988,20 +1986,15 @@ def main( args = [q_np, k_np, v_np] ref = build_and_run(Module, args, "llvm", legalize=True) - mod = partition_for_cutlass(Module, use_flash_attn=True) - codegen_pass = relax.transform.RunCodegen( - {"cutlass": {"sm": 80, "find_first_valid": True, "disable_flash": False}} - ) + mod = partition_for_cutlass(Module, use_flash_mqa=True) + codegen_pass = relax.transform.RunCodegen({"cutlass": {"sm": 80}}) mod = codegen_pass(mod) out = build_and_run(mod, args, "cuda") - print(np.max(np.abs(out - ref)), np.mean(np.abs(out - ref))) tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) if __name__ == "__main__": # tvm.testing.main() test_attention_rewrite_multi_query() - # test_attention_offload((4, (16, 16), 32, (8, 8)), "float16") - # test_attention_causal_offload((1, (1, 8), 4, (16, 16), "none"), "BottomRight") From 7f010eab54b8e1e598201be09209b81564cb1920 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 27 Sep 2023 10:29:42 +0900 Subject: [PATCH 4/7] add doc --- python/tvm/contrib/cutlass/build.py | 1 - python/tvm/contrib/cutlass/gen_tensor_op.py | 5 ++++- python/tvm/relax/backend/contrib/cutlass.py | 4 ++++ python/tvm/relax/backend/patterns.py | 4 ++++ src/relax/op/nn/attention.cc | 13 +++++++++++-- 5 files changed, 23 insertions(+), 4 deletions(-) diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py index 9f5d50a0e30d..b97fc20008b4 100644 --- a/python/tvm/contrib/cutlass/build.py +++ b/python/tvm/contrib/cutlass/build.py @@ -983,7 +983,6 @@ def profile_relax_function(functions, options): """Tune and annotate CUTLASS composite functions with shape, dtype and generated templates.""" tmp_dir = options.get("tmp_dir", "./tmp") sm = options.get("sm", 80) - conv2d_profiler = CutlassConv2DProfiler(sm, _get_cutlass_path(), tmp_dir) gemm_profiler = CutlassGemmProfiler(sm, _get_cutlass_path(), tmp_dir) diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py b/python/tvm/contrib/cutlass/gen_tensor_op.py index 2e58d60035be..7f6141b1c839 100644 --- a/python/tvm/contrib/cutlass/gen_tensor_op.py +++ b/python/tvm/contrib/cutlass/gen_tensor_op.py @@ -778,7 +778,10 @@ def get_batch_on_arg(arg_name, arg_shape): else: headers.append("kernel_forward.h") - assert annotations["num_q_heads"] == annotations["num_kv_heads"] + assert ( + annotations["num_q_heads"] == annotations["num_kv_heads"] + ), "The number of query and KV heads need to be the same for CUTLASS fMHA." + attrs["num_heads"] = n = annotations["num_q_heads"] data_type_size = DataTypeSize[data_type] diff --git a/python/tvm/relax/backend/contrib/cutlass.py b/python/tvm/relax/backend/contrib/cutlass.py index b6585d9706bf..9efea3a0dccf 100644 --- a/python/tvm/relax/backend/contrib/cutlass.py +++ b/python/tvm/relax/backend/contrib/cutlass.py @@ -590,6 +590,10 @@ def partition_for_cutlass(mod, annotate_codegen=True, use_flash_mqa=True): body consists only of a call to the composite function. See the doc of FuseOpsByPattern for more detail. + use_flash_mqa: bool + Whether to consider a rewrite pattern for multi-query attention, which is supported by + the Flash Attention kernel. + Returns ------- mod: tvm.IRModule diff --git a/python/tvm/relax/backend/patterns.py b/python/tvm/relax/backend/patterns.py index 5fc32b2ba045..10a075647b5a 100644 --- a/python/tvm/relax/backend/patterns.py +++ b/python/tvm/relax/backend/patterns.py @@ -338,6 +338,10 @@ def make_attention_rewrite_pattern( Whether or not rewriting is intended to be applied to a module after the FP16 conversion pass. + with_kv_repeat: bool + Whether or not to include the Relax repeat op in the pattern, which is typically used + in a Relax module to support multi-query attention. + Returns ------- pattern: DFPattern diff --git a/src/relax/op/nn/attention.cc b/src/relax/op/nn/attention.cc index e97e0df25b26..2b89fb0faadf 100644 --- a/src/relax/op/nn/attention.cc +++ b/src/relax/op/nn/attention.cc @@ -77,10 +77,19 @@ StructInfo InferStructInfoAttention(const Call& call, const BlockBuilder& ctx) { << v1 << " while the " << dim << " of " << m2 << " is " << v2); } }; + auto multiple_of = [&](PrimExpr v1, PrimExpr v2, String m1, String m2, String dim) { + if (analyzer->CanProve(indexmod(v1, v2) != 0)) { + ctx->ReportFatal(Diagnostic::Error(call) + << "The " << m1 << " " << dim << " should be an multiple of " << m2 << " " + << dim << ". However, the " << dim << " of " << m1 << " is " << v1 + << " while the " << dim << " of " << m2 << " is " << v2); + } + }; + diag_equal(num_batches, k_shape->values[0], "query", "key", "batch size"); diag_equal(num_batches, v_shape->values[0], "query", "value", "batch size"); - // diag_equal(num_heads, k_shape->values[2], "query", "key", "number of heads"); - // diag_equal(num_heads, v_shape->values[2], "query", "value", "number of heads"); + multiple_of(num_heads, k_shape->values[2], "query", "key", "number of heads"); + multiple_of(num_heads, v_shape->values[2], "query", "value", "number of heads"); diag_equal(num_keys, v_shape->values[1], "key", "value", "sequence length"); diag_equal(head_dim, k_shape->values[3], "query", "key", "dimension of heads"); From 60d29dc92311ab9dcaa91c88630ebb5fd0fb266b Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 27 Sep 2023 15:08:00 +0900 Subject: [PATCH 5/7] update rev --- 3rdparty/libflash_attn | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/libflash_attn b/3rdparty/libflash_attn index aaf92298f1a8..63cce0ca8fa6 160000 --- a/3rdparty/libflash_attn +++ b/3rdparty/libflash_attn @@ -1 +1 @@ -Subproject commit aaf92298f1a8d60f33824e059e70b6e72c73e0ff +Subproject commit 63cce0ca8fa6bfca1982b342588273641cc5b86b From 91c0e17766f209e03e885535a91a312906a0536d Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 27 Sep 2023 15:09:27 +0900 Subject: [PATCH 6/7] update test --- src/relax/op/nn/attention.cc | 2 +- tests/python/relax/test_codegen_cutlass.py | 8 ++------ 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/src/relax/op/nn/attention.cc b/src/relax/op/nn/attention.cc index 2b89fb0faadf..484137fecc40 100644 --- a/src/relax/op/nn/attention.cc +++ b/src/relax/op/nn/attention.cc @@ -80,7 +80,7 @@ StructInfo InferStructInfoAttention(const Call& call, const BlockBuilder& ctx) { auto multiple_of = [&](PrimExpr v1, PrimExpr v2, String m1, String m2, String dim) { if (analyzer->CanProve(indexmod(v1, v2) != 0)) { ctx->ReportFatal(Diagnostic::Error(call) - << "The " << m1 << " " << dim << " should be an multiple of " << m2 << " " + << "The " << m1 << " " << dim << " should be a multiple of " << m2 << " " << dim << ". However, the " << dim << " of " << m1 << " is " << v1 << " while the " << dim << " of " << m2 << " is " << v2); } diff --git a/tests/python/relax/test_codegen_cutlass.py b/tests/python/relax/test_codegen_cutlass.py index bd40d0738b34..83936ef9c99f 100644 --- a/tests/python/relax/test_codegen_cutlass.py +++ b/tests/python/relax/test_codegen_cutlass.py @@ -1931,10 +1931,7 @@ def main( ex = relax.build(mod_transform, target="llvm") vm = relax.vm.VirtualMachine(ex, tvm.cpu(0)) - ( - packed_weight, - scales, - ) = vm[ + (packed_weight, scales,) = vm[ transform_func_name ]((tvm.nd.array(y),)) @@ -1996,5 +1993,4 @@ def main( if __name__ == "__main__": - # tvm.testing.main() - test_attention_rewrite_multi_query() + tvm.testing.main() From 5baa451f17761ab348cb5113f015d732808c9321 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 27 Sep 2023 19:25:35 +0000 Subject: [PATCH 7/7] fix --- python/tvm/contrib/cutlass/gen_tensor_op.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py b/python/tvm/contrib/cutlass/gen_tensor_op.py index 7f6141b1c839..62e64549c2ae 100644 --- a/python/tvm/contrib/cutlass/gen_tensor_op.py +++ b/python/tvm/contrib/cutlass/gen_tensor_op.py @@ -752,6 +752,8 @@ def get_batch_on_arg(arg_name, arg_shape): float(1 / math.sqrt(h.value)) if annotations["scale"] is None else annotations["scale"] ) + is_mqa = annotations["num_q_heads"] != annotations["num_kv_heads"] + use_flash = ( annotations["ret_dtype"] == "float16" and "bias" not in attrs @@ -759,11 +761,11 @@ def get_batch_on_arg(arg_name, arg_shape): and int(attrs["head_dim"]) % 8 == 0 and int(attrs["head_dim"]) == int(attrs["head_dim_value"]) # For the causal case (custom mask = "BottomRight"), only use flash for multi-query - # attention workloads (indicated by the "repeat" op in the pattern). - # Otherwise, CUTLASS fMHA seems faster for causal attention with a single query. + # attention workloads. Otherwise, CUTLASS fMHA seems faster for causal attention + # with a single query. and ( int(annotations["custom_mask_type"]) == 0 - or (int(annotations["custom_mask_type"]) == 2 and "repeat" in func_name) + or (int(annotations["custom_mask_type"]) == 2 and is_mqa) ) # Flash v2 is currently not supported for sm < 80 and int(annotations["arch"]) >= 80 @@ -779,7 +781,7 @@ def get_batch_on_arg(arg_name, arg_shape): headers.append("kernel_forward.h") assert ( - annotations["num_q_heads"] == annotations["num_kv_heads"] + not is_mqa ), "The number of query and KV heads need to be the same for CUTLASS fMHA." attrs["num_heads"] = n = annotations["num_q_heads"]