Skip to content

Commit a54a0eb

Browse files
committed
Squashed commit of the following:
commit 99c2a59 Author: Masahiro Masuda <[email protected]> Date: Wed Sep 27 09:57:21 2023 +0900 Revert "Revert "[Unity] Avoid trivial `var2 = var1` bindings in pattern matcher (apache#15578)"" This reverts commit 0a6a617. commit 9a3ca64 Author: Masahiro Masuda <[email protected]> Date: Wed Sep 27 09:55:02 2023 +0900 wip commit be01900 Author: Masahiro Masuda <[email protected]> Date: Tue Sep 26 19:55:29 2023 +0900 fix test commit a026b65 Author: Masahiro Masuda <[email protected]> Date: Thu Aug 31 22:24:38 2023 +0000 wip commit 233d2d0 Author: Masahiro Masuda <[email protected]> Date: Tue Aug 29 17:42:11 2023 +0000 wip commit 0a6a617 Author: Masahiro Masuda <[email protected]> Date: Tue Aug 29 17:28:25 2023 +0000 Revert "[Unity] Avoid trivial `var2 = var1` bindings in pattern matcher (apache#15578)" This reverts commit 567848e. commit 6c5a435 Author: Masahiro Masuda <[email protected]> Date: Tue Aug 29 17:28:16 2023 +0000 wip commit 7926cbc Author: Masahiro Masuda <[email protected]> Date: Mon Aug 28 06:17:01 2023 +0000 wip commit 9828698 Author: Masahiro Masuda <[email protected]> Date: Mon Aug 28 15:11:47 2023 +0900 wip commit 5d01fd1 Author: Masahiro Masuda <[email protected]> Date: Mon Aug 28 06:05:56 2023 +0000 wip commit ae657b7 Author: Masahiro Masuda <[email protected]> Date: Mon Aug 28 14:49:21 2023 +0900 wip commit ddcab38 Author: Masahiro Masuda <[email protected]> Date: Mon Aug 28 05:42:41 2023 +0000 wip commit ab3572d Author: Masahiro Masuda <[email protected]> Date: Mon Aug 28 10:40:34 2023 +0900 wip commit 690b88e Author: Masahiro Masuda <[email protected]> Date: Mon Aug 28 10:25:33 2023 +0900 update rev
1 parent dfc77eb commit a54a0eb

File tree

8 files changed

+105
-36
lines changed

8 files changed

+105
-36
lines changed

3rdparty/libflash_attn

python/tvm/contrib/cutlass/attention_operation.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -169,10 +169,10 @@ def instantiate_flash_attention_template(attrs):
169169
int k_head_stride = ${head_dim};
170170
int v_head_stride = ${head_dim};
171171
int o_head_stride = ${head_dim};
172-
int q_row_stride = q_head_stride * ${num_heads};
173-
int k_row_stride = k_head_stride * ${num_heads};
174-
int v_row_stride = v_head_stride * ${num_heads};
175-
int o_row_stride = o_head_stride * ${num_heads};
172+
int q_row_stride = q_head_stride * ${num_q_heads};
173+
int k_row_stride = k_head_stride * ${num_kv_heads};
174+
int v_row_stride = v_head_stride * ${num_kv_heads};
175+
int o_row_stride = o_head_stride * ${num_q_heads};
176176
int q_batch_stride = q_row_stride * ${num_queries};
177177
int k_batch_stride = k_row_stride * ${num_keys};
178178
int v_batch_stride = v_row_stride * ${num_keys};
@@ -190,8 +190,8 @@ def instantiate_flash_attention_template(attrs):
190190
${num_batches},
191191
${num_queries},
192192
${num_keys},
193-
${num_heads},
194-
${num_heads},
193+
${num_q_heads},
194+
${num_kv_heads},
195195
${head_dim},
196196
q_batch_stride,
197197
k_batch_stride,
@@ -215,13 +215,13 @@ def instantiate_flash_attention_template(attrs):
215215
int k_head_stride = ${head_dim};
216216
int v_head_stride = ${head_dim};
217217
int o_head_stride = ${head_dim};
218-
int row_stride = q_head_stride * ${num_heads} +
219-
k_head_stride * ${num_heads} +
220-
v_head_stride * ${num_heads};
218+
int row_stride = q_head_stride * ${num_q_heads} +
219+
k_head_stride * ${num_kv_heads} +
220+
v_head_stride * ${num_kv_heads};
221221
int q_row_stride = row_stride;
222222
int k_row_stride = row_stride;
223223
int v_row_stride = row_stride;
224-
int o_row_stride = o_head_stride * ${num_heads};
224+
int o_row_stride = o_head_stride * ${num_q_heads};
225225
226226
int q_batch_stride = q_row_stride * ${num_queries};
227227
int k_batch_stride = k_row_stride * ${num_keys};
@@ -234,14 +234,14 @@ def instantiate_flash_attention_template(attrs):
234234
235235
flash_attn::flash_attention_forward(
236236
static_cast<const cutlass::half_t*>(${qkv}->data),
237-
static_cast<const cutlass::half_t*>(${qkv}->data) + ${head_dim} * ${num_heads},
238-
static_cast<const cutlass::half_t*>(${qkv}->data) + ${head_dim} * ${num_heads} * 2,
237+
static_cast<const cutlass::half_t*>(${qkv}->data) + ${head_dim} * ${num_q_heads},
238+
static_cast<const cutlass::half_t*>(${qkv}->data) + ${head_dim} * (${num_q_heads} + ${num_kv_heads}),
239239
static_cast<cutlass::half_t*>(out0->data),
240240
${num_batches},
241241
${num_queries},
242242
${num_keys},
243-
${num_heads},
244-
${num_heads},
243+
${num_q_heads},
244+
${num_kv_heads},
245245
${head_dim},
246246
q_batch_stride,
247247
k_batch_stride,

python/tvm/contrib/cutlass/build.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -909,8 +909,8 @@ def handle_attention(self, f, op_type):
909909

910910
out_shape = signature["ret_shape"]
911911
out_dtype = signature["ret_dtype"]
912-
num_batches, num_queries, num_heads, head_dim = q_shape
913-
_, num_keys, _, _ = k_shape
912+
num_batches, num_queries, num_q_heads, head_dim = q_shape
913+
_, num_keys, num_kv_heads, _ = k_shape
914914
_, _, _, head_dim_value = v_shape
915915
scale = op_attrs.scale
916916

@@ -931,13 +931,15 @@ def handle_attention(self, f, op_type):
931931
"num_batches": num_batches,
932932
"num_queries": num_queries,
933933
"num_keys": num_keys,
934-
"num_heads": num_heads,
934+
"num_q_heads": num_q_heads,
935+
"num_kv_heads": num_kv_heads,
935936
"head_dim": head_dim,
936937
"head_dim_value": head_dim_value,
937938
"scale": scale,
938939
"arch": self.options["sm"],
939940
"qkv_layout": qkv_layout,
940941
"custom_mask_type": custom_mask_type,
942+
"disable_flash": self.options.get("disable_flash", False),
941943
**arg,
942944
}
943945
)
@@ -982,6 +984,8 @@ def profile_relax_function(functions, options):
982984
"""Tune and annotate CUTLASS composite functions with shape, dtype and generated templates."""
983985
tmp_dir = options.get("tmp_dir", "./tmp")
984986
sm = options.get("sm", 80)
987+
print(options)
988+
985989
conv2d_profiler = CutlassConv2DProfiler(sm, _get_cutlass_path(), tmp_dir)
986990
gemm_profiler = CutlassGemmProfiler(sm, _get_cutlass_path(), tmp_dir)
987991

python/tvm/contrib/cutlass/gen_tensor_op.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -745,7 +745,6 @@ def get_batch_on_arg(arg_name, arg_shape):
745745

746746
attrs["data_type"] = DataTypeTag[data_type]
747747
attrs["num_batches"] = b = annotations["num_batches"]
748-
attrs["num_heads"] = n = annotations["num_heads"]
749748
attrs["head_dim"] = h = annotations["head_dim"]
750749
attrs["head_dim_value"] = h_v = annotations["head_dim_value"]
751750
attrs["kMaxK"] = max(int(attrs["head_dim"]), int(attrs["head_dim_value"]))
@@ -754,25 +753,29 @@ def get_batch_on_arg(arg_name, arg_shape):
754753
)
755754

756755
use_flash = (
757-
annotations["ret_dtype"] == "float16"
756+
not annotations["disable_flash"]
757+
and annotations["ret_dtype"] == "float16"
758758
and "bias" not in attrs
759759
and int(attrs["head_dim"]) <= 256
760760
and int(attrs["head_dim"]) % 8 == 0
761761
and int(attrs["head_dim"]) == int(attrs["head_dim_value"])
762-
# We have not thoroughly validated flash with causal mask yet, so for now we support
763-
# only non-causal cases.
764-
and int(annotations["custom_mask_type"]) == 0
762+
and int(annotations["custom_mask_type"]) in (0, 2)
765763
# Flash v2 is currently not supported for sm < 80
766764
and int(annotations["arch"]) >= 80
767765
)
768766

769767
if use_flash:
770768
headers.append("flash.h")
771-
attrs["is_causal"] = int(annotations["custom_mask_type"]) > 0
769+
attrs["is_causal"] = int(annotations["custom_mask_type"]) == 2
770+
attrs["num_q_heads"] = annotations["num_q_heads"]
771+
attrs["num_kv_heads"] = annotations["num_kv_heads"]
772772
code = instantiate_flash_attention_template(attrs)
773773
else:
774774
headers.append("kernel_forward.h")
775775

776+
assert annotations["num_q_heads"] == annotations["num_kv_heads"]
777+
attrs["num_heads"] = n = annotations["num_q_heads"]
778+
776779
data_type_size = DataTypeSize[data_type]
777780
if (data_type_size * h // 8) % 16 == 0 and (data_type_size * h_v // 8) % 16 == 0:
778781
attrs["kIsAligned"] = True

python/tvm/relax/backend/contrib/cutlass.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -576,7 +576,7 @@ def annotate_workspace(mod, _):
576576
return mod
577577

578578

579-
def partition_for_cutlass(mod, annotate_codegen=True):
579+
def partition_for_cutlass(mod, annotate_codegen=True, use_flash_attn=True):
580580
"""
581581
Partition the input module into CUTLASS-supported subgraphs.
582582
@@ -598,8 +598,13 @@ def partition_for_cutlass(mod, annotate_codegen=True):
598598
"""
599599
for func_name, func in mod.functions.items():
600600
if isinstance(func, Function):
601+
if use_flash_attn:
602+
mqa_pattern, rewriter = make_attention_rewrite_pattern("BSNH", "BSNH", with_bias=False, with_cast=True, with_kv_repeat=True)
603+
func = rewrite_call(mqa_pattern, rewriter, func)
604+
601605
for pattern, rewriter in _REWRITE_PATTERNS:
602606
func = rewrite_call(pattern, rewriter, func)
607+
603608
mod[func_name] = func
604609

605610
patterns = get_patterns_with_prefix("cutlass")

python/tvm/relax/backend/patterns.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ def make_rms_norm_pattern():
318318

319319

320320
def make_attention_rewrite_pattern(
321-
qkv_layout: str, out_layout: str, with_bias: bool, with_cast: bool
321+
qkv_layout: str, out_layout: str, with_bias: bool, with_cast: bool, with_kv_repeat: bool = False
322322
):
323323
"""
324324
Create pattern for implicit fused multi head attention rewriting.
@@ -350,7 +350,10 @@ def make_attention_rewrite_pattern(
350350
"""
351351

352352
# pylint: disable=invalid-name
353-
def handle_input(tensor, layout, transpose):
353+
def handle_input(tensor, layout, transpose, repeat=False):
354+
if repeat:
355+
tensor = is_op("relax.repeat")(tensor)
356+
354357
if layout == "BSNH":
355358
permuted = is_op("relax.permute_dims")(tensor)
356359
shape = wildcard()
@@ -434,8 +437,8 @@ def rewriter(matchings, x):
434437

435438
q_raw, k_raw, v_raw = wildcard(), wildcard(), wildcard()
436439
q, q_rewriter = handle_input(q_raw, qkv_layout, False)
437-
k, k_rewriter = handle_input(k_raw, qkv_layout, True)
438-
v, v_rewriter = handle_input(v_raw, qkv_layout, False)
440+
k, k_rewriter = handle_input(k_raw, qkv_layout, True, repeat=with_kv_repeat)
441+
v, v_rewriter = handle_input(v_raw, qkv_layout, False, repeat=with_kv_repeat)
439442
matmul_1 = is_op("relax.matmul")(q, k)
440443
scale = is_const()
441444

src/relax/op/nn/attention.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,8 @@ StructInfo InferStructInfoAttention(const Call& call, const BlockBuilder& ctx) {
7979
};
8080
diag_equal(num_batches, k_shape->values[0], "query", "key", "batch size");
8181
diag_equal(num_batches, v_shape->values[0], "query", "value", "batch size");
82-
diag_equal(num_heads, k_shape->values[2], "query", "key", "number of heads");
83-
diag_equal(num_heads, v_shape->values[2], "query", "value", "number of heads");
82+
// diag_equal(num_heads, k_shape->values[2], "query", "key", "number of heads");
83+
// diag_equal(num_heads, v_shape->values[2], "query", "value", "number of heads");
8484
diag_equal(num_keys, v_shape->values[1], "key", "value", "sequence length");
8585
diag_equal(head_dim, k_shape->values[3], "query", "key", "dimension of heads");
8686

tests/python/relax/test_codegen_cutlass.py

Lines changed: 59 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def get_result_with_relax_cutlass_offload(
113113
if assert_all_bindings_fused:
114114
assert len(mod["main"].body.blocks[0].bindings) == num_final_bindings
115115

116-
codegen_pass = relax.transform.RunCodegen({"cutlass": {"sm": 80, "find_first_valid": True}})
116+
codegen_pass = relax.transform.RunCodegen({"cutlass": {"sm": 80, "find_first_valid": True, "disable_flash": True}})
117117
mod = codegen_pass(mod)
118118

119119
return build_and_run(mod, args, "cuda")
@@ -746,7 +746,7 @@ def attention_causal(request):
746746
def test_attention_causal_offload(attention_causal_size, attention_causal):
747747
b, (s, s_kv), n, (h, h_v), bias_shape = attention_causal_size
748748
q, k, v, bias, ref = get_numpy_attention_ref(
749-
b, s, s_kv, n, h, h_v, bias_shape, "none", attention_causal, "float32"
749+
b, s, s_kv, n, h, h_v, bias_shape, "none", attention_causal, "float16"
750750
)
751751

752752
q_shape = (b, s, n, h)
@@ -757,10 +757,11 @@ def test_attention_causal_offload(attention_causal_size, attention_causal):
757757
q_shape,
758758
k_shape,
759759
v_shape,
760-
dtype="float32",
760+
dtype="float16",
761761
bias_shape=bias_shape,
762762
causal_mask=attention_causal,
763763
)
764+
764765
if bias is None:
765766
out = get_result_with_relax_cutlass_offload(mod, q, k, v, num_final_bindings=3)
766767
else:
@@ -1930,7 +1931,10 @@ def main(
19301931
ex = relax.build(mod_transform, target="llvm")
19311932
vm = relax.vm.VirtualMachine(ex, tvm.cpu(0))
19321933

1933-
(packed_weight, scales,) = vm[
1934+
(
1935+
packed_weight,
1936+
scales,
1937+
) = vm[
19341938
transform_func_name
19351939
]((tvm.nd.array(y),))
19361940

@@ -1945,5 +1949,55 @@ def main(
19451949
tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
19461950

19471951

1952+
def test_attention_rewrite_multi_query():
1953+
@I.ir_module
1954+
class Module:
1955+
@R.function
1956+
def main(
1957+
q: R.Tensor((4, 16, 32, 16), dtype="float16"),
1958+
k_single: R.Tensor((4, 16, 1, 16), dtype="float16"),
1959+
v_single: R.Tensor((4, 16, 1, 16), dtype="float16"),
1960+
) -> R.Tensor((4, 16, 32, 8), dtype="float16"):
1961+
with R.dataflow():
1962+
k = R.repeat(k_single, 32, axis=2)
1963+
v = R.repeat(v_single, 32, axis=2)
1964+
1965+
lv = R.permute_dims(q, axes=[0, 2, 1, 3])
1966+
lv1 = R.reshape(lv, R.shape([128, 16, 16]))
1967+
lv2 = R.permute_dims(k, axes=[0, 2, 1, 3])
1968+
lv3 = R.reshape(lv2, R.shape([128, 16, 16]))
1969+
lv4 = R.permute_dims(v, axes=[0, 2, 1, 3])
1970+
lv5 = R.reshape(lv4, R.shape([128, 16, 16]))
1971+
1972+
lv6 = R.permute_dims(lv3, axes=[0, 2, 1])
1973+
lv7 = R.matmul(lv1, lv6, out_dtype="float16")
1974+
lv3_1 = R.astype(R.const(0.25, "float32"), "float16")
1975+
lv8 = R.multiply(lv7, lv3_1)
1976+
lv11 = R.astype(R.nn.softmax(R.astype(lv8, "float32"), axis=2), "float16")
1977+
lv12 = R.matmul(lv11, lv5, out_dtype="float16")
1978+
lv13 = R.reshape(lv12, R.shape([4, 32, 16, 16]))
1979+
lv6_1 = R.permute_dims(lv13, axes=[0, 2, 1, 3])
1980+
R.output(lv6_1)
1981+
return lv6_1
1982+
1983+
q_np = np.random.randn(4, 16, 32, 16).astype("float16")
1984+
k_np = np.random.randn(4, 16, 1, 16).astype("float16")
1985+
v_np = np.random.randn(4, 16, 1, 16).astype("float16")
1986+
args = [q_np, k_np, v_np]
1987+
ref = build_and_run(Module, args, "llvm", legalize=True)
1988+
1989+
mod = partition_for_cutlass(Module, use_flash_attn=True)
1990+
codegen_pass = relax.transform.RunCodegen({"cutlass": {"sm": 80, "find_first_valid": True, "disable_flash": False}})
1991+
mod = codegen_pass(mod)
1992+
1993+
out = build_and_run(mod, args, "cuda")
1994+
1995+
print(np.max(np.abs(out - ref)), np.mean(np.abs(out - ref)))
1996+
tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
1997+
1998+
19481999
if __name__ == "__main__":
1949-
tvm.testing.main()
2000+
# tvm.testing.main()
2001+
test_attention_rewrite_multi_query()
2002+
# test_attention_offload((4, (16, 16), 32, (8, 8)), "float16")
2003+
# test_attention_causal_offload((1, (1, 8), 4, (16, 16), "none"), "BottomRight")

0 commit comments

Comments
 (0)