Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions python/tvm/contrib/cutlass/attention_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,10 +169,10 @@ def instantiate_flash_attention_template(attrs):
int k_head_stride = ${head_dim};
int v_head_stride = ${head_dim};
int o_head_stride = ${head_dim};
int q_row_stride = q_head_stride * ${num_heads};
int k_row_stride = k_head_stride * ${num_heads};
int v_row_stride = v_head_stride * ${num_heads};
int o_row_stride = o_head_stride * ${num_heads};
int q_row_stride = q_head_stride * ${num_q_heads};
int k_row_stride = k_head_stride * ${num_kv_heads};
int v_row_stride = v_head_stride * ${num_kv_heads};
int o_row_stride = o_head_stride * ${num_q_heads};
int q_batch_stride = q_row_stride * ${num_queries};
int k_batch_stride = k_row_stride * ${num_keys};
int v_batch_stride = v_row_stride * ${num_keys};
Expand All @@ -190,8 +190,8 @@ def instantiate_flash_attention_template(attrs):
${num_batches},
${num_queries},
${num_keys},
${num_heads},
${num_heads},
${num_q_heads},
${num_kv_heads},
${head_dim},
q_batch_stride,
k_batch_stride,
Expand All @@ -215,13 +215,13 @@ def instantiate_flash_attention_template(attrs):
int k_head_stride = ${head_dim};
int v_head_stride = ${head_dim};
int o_head_stride = ${head_dim};
int row_stride = q_head_stride * ${num_heads} +
k_head_stride * ${num_heads} +
v_head_stride * ${num_heads};
int row_stride = q_head_stride * ${num_q_heads} +
k_head_stride * ${num_kv_heads} +
v_head_stride * ${num_kv_heads};
int q_row_stride = row_stride;
int k_row_stride = row_stride;
int v_row_stride = row_stride;
int o_row_stride = o_head_stride * ${num_heads};
int o_row_stride = o_head_stride * ${num_q_heads};

int q_batch_stride = q_row_stride * ${num_queries};
int k_batch_stride = k_row_stride * ${num_keys};
Expand All @@ -234,14 +234,14 @@ def instantiate_flash_attention_template(attrs):

flash_attn::flash_attention_forward(
static_cast<const cutlass::half_t*>(${qkv}->data),
static_cast<const cutlass::half_t*>(${qkv}->data) + ${head_dim} * ${num_heads},
static_cast<const cutlass::half_t*>(${qkv}->data) + ${head_dim} * ${num_heads} * 2,
static_cast<const cutlass::half_t*>(${qkv}->data) + ${head_dim} * ${num_q_heads},
static_cast<const cutlass::half_t*>(${qkv}->data) + ${head_dim} * (${num_q_heads} + ${num_kv_heads}),
static_cast<cutlass::half_t*>(out0->data),
${num_batches},
${num_queries},
${num_keys},
${num_heads},
${num_heads},
${num_q_heads},
${num_kv_heads},
${head_dim},
q_batch_stride,
k_batch_stride,
Expand Down
7 changes: 4 additions & 3 deletions python/tvm/contrib/cutlass/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -909,8 +909,8 @@ def handle_attention(self, f, op_type):

out_shape = signature["ret_shape"]
out_dtype = signature["ret_dtype"]
num_batches, num_queries, num_heads, head_dim = q_shape
_, num_keys, _, _ = k_shape
num_batches, num_queries, num_q_heads, head_dim = q_shape
_, num_keys, num_kv_heads, _ = k_shape
_, _, _, head_dim_value = v_shape
scale = op_attrs.scale

Expand All @@ -931,7 +931,8 @@ def handle_attention(self, f, op_type):
"num_batches": num_batches,
"num_queries": num_queries,
"num_keys": num_keys,
"num_heads": num_heads,
"num_q_heads": num_q_heads,
"num_kv_heads": num_kv_heads,
"head_dim": head_dim,
"head_dim_value": head_dim_value,
"scale": scale,
Expand Down
23 changes: 18 additions & 5 deletions python/tvm/contrib/cutlass/gen_tensor_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,34 +745,47 @@ def get_batch_on_arg(arg_name, arg_shape):

attrs["data_type"] = DataTypeTag[data_type]
attrs["num_batches"] = b = annotations["num_batches"]
attrs["num_heads"] = n = annotations["num_heads"]
attrs["head_dim"] = h = annotations["head_dim"]
attrs["head_dim_value"] = h_v = annotations["head_dim_value"]
attrs["kMaxK"] = max(int(attrs["head_dim"]), int(attrs["head_dim_value"]))
attrs["scale"] = (
float(1 / math.sqrt(h.value)) if annotations["scale"] is None else annotations["scale"]
)

is_mqa = annotations["num_q_heads"] != annotations["num_kv_heads"]

use_flash = (
annotations["ret_dtype"] == "float16"
and "bias" not in attrs
and int(attrs["head_dim"]) <= 256
and int(attrs["head_dim"]) % 8 == 0
and int(attrs["head_dim"]) == int(attrs["head_dim_value"])
# We have not thoroughly validated flash with causal mask yet, so for now we support
# only non-causal cases.
and int(annotations["custom_mask_type"]) == 0
# For the causal case (custom mask = "BottomRight"), only use flash for multi-query
# attention workloads. Otherwise, CUTLASS fMHA seems faster for causal attention
# with a single query.
and (
int(annotations["custom_mask_type"]) == 0
or (int(annotations["custom_mask_type"]) == 2 and is_mqa)
)
# Flash v2 is currently not supported for sm < 80
and int(annotations["arch"]) >= 80
)

if use_flash:
headers.append("flash.h")
attrs["is_causal"] = int(annotations["custom_mask_type"]) > 0
attrs["is_causal"] = int(annotations["custom_mask_type"]) == 2
attrs["num_q_heads"] = annotations["num_q_heads"]
attrs["num_kv_heads"] = annotations["num_kv_heads"]
code = instantiate_flash_attention_template(attrs)
else:
headers.append("kernel_forward.h")

assert (
not is_mqa
), "The number of query and KV heads need to be the same for CUTLASS fMHA."

attrs["num_heads"] = n = annotations["num_q_heads"]

data_type_size = DataTypeSize[data_type]
if (data_type_size * h // 8) % 16 == 0 and (data_type_size * h_v // 8) % 16 == 0:
attrs["kIsAligned"] = True
Expand Down
13 changes: 12 additions & 1 deletion python/tvm/relax/backend/contrib/cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ def annotate_workspace(mod, _):
return mod


def partition_for_cutlass(mod, annotate_codegen=True):
def partition_for_cutlass(mod, annotate_codegen=True, use_flash_mqa=True):
"""
Partition the input module into CUTLASS-supported subgraphs.

Expand All @@ -590,6 +590,10 @@ def partition_for_cutlass(mod, annotate_codegen=True):
body consists only of a call to the composite function. See the doc of FuseOpsByPattern
for more detail.

use_flash_mqa: bool
Whether to consider a rewrite pattern for multi-query attention, which is supported by
the Flash Attention kernel.

Returns
-------
mod: tvm.IRModule
Expand All @@ -598,8 +602,15 @@ def partition_for_cutlass(mod, annotate_codegen=True):
"""
for func_name, func in mod.functions.items():
if isinstance(func, Function):
if use_flash_mqa:
mqa_pattern, rewriter = make_attention_rewrite_pattern(
"BSNH", "BSNH", with_bias=False, with_cast=True, with_kv_repeat=True
)
func = rewrite_call(mqa_pattern, rewriter, func)

for pattern, rewriter in _REWRITE_PATTERNS:
func = rewrite_call(pattern, rewriter, func)

mod[func_name] = func

patterns = get_patterns_with_prefix("cutlass")
Expand Down
15 changes: 11 additions & 4 deletions python/tvm/relax/backend/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ def make_rms_norm_pattern():


def make_attention_rewrite_pattern(
qkv_layout: str, out_layout: str, with_bias: bool, with_cast: bool
qkv_layout: str, out_layout: str, with_bias: bool, with_cast: bool, with_kv_repeat: bool = False
):
"""
Create pattern for implicit fused multi head attention rewriting.
Expand All @@ -338,6 +338,10 @@ def make_attention_rewrite_pattern(
Whether or not rewriting is intended to be applied to a module after the FP16 conversion
pass.

with_kv_repeat: bool
Whether or not to include the Relax repeat op in the pattern, which is typically used
in a Relax module to support multi-query attention.

Returns
-------
pattern: DFPattern
Expand All @@ -350,7 +354,10 @@ def make_attention_rewrite_pattern(
"""

# pylint: disable=invalid-name
def handle_input(tensor, layout, transpose):
def handle_input(tensor, layout, transpose, repeat=False):
if repeat:
tensor = is_op("relax.repeat")(tensor)

if layout == "BSNH":
permuted = is_op("relax.permute_dims")(tensor)
shape = wildcard()
Expand Down Expand Up @@ -434,8 +441,8 @@ def rewriter(matchings, x):

q_raw, k_raw, v_raw = wildcard(), wildcard(), wildcard()
q, q_rewriter = handle_input(q_raw, qkv_layout, False)
k, k_rewriter = handle_input(k_raw, qkv_layout, True)
v, v_rewriter = handle_input(v_raw, qkv_layout, False)
k, k_rewriter = handle_input(k_raw, qkv_layout, True, repeat=with_kv_repeat)
v, v_rewriter = handle_input(v_raw, qkv_layout, False, repeat=with_kv_repeat)
matmul_1 = is_op("relax.matmul")(q, k)
scale = is_const()

Expand Down
13 changes: 11 additions & 2 deletions src/relax/op/nn/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,19 @@ StructInfo InferStructInfoAttention(const Call& call, const BlockBuilder& ctx) {
<< v1 << " while the " << dim << " of " << m2 << " is " << v2);
}
};
auto multiple_of = [&](PrimExpr v1, PrimExpr v2, String m1, String m2, String dim) {
if (analyzer->CanProve(indexmod(v1, v2) != 0)) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "The " << m1 << " " << dim << " should be a multiple of " << m2 << " "
<< dim << ". However, the " << dim << " of " << m1 << " is " << v1
<< " while the " << dim << " of " << m2 << " is " << v2);
}
};

diag_equal(num_batches, k_shape->values[0], "query", "key", "batch size");
diag_equal(num_batches, v_shape->values[0], "query", "value", "batch size");
diag_equal(num_heads, k_shape->values[2], "query", "key", "number of heads");
diag_equal(num_heads, v_shape->values[2], "query", "value", "number of heads");
multiple_of(num_heads, k_shape->values[2], "query", "key", "number of heads");
multiple_of(num_heads, v_shape->values[2], "query", "value", "number of heads");
diag_equal(num_keys, v_shape->values[1], "key", "value", "sequence length");
diag_equal(head_dim, k_shape->values[3], "query", "key", "dimension of heads");

Expand Down
51 changes: 49 additions & 2 deletions tests/python/relax/test_codegen_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,7 +746,7 @@ def attention_causal(request):
def test_attention_causal_offload(attention_causal_size, attention_causal):
b, (s, s_kv), n, (h, h_v), bias_shape = attention_causal_size
q, k, v, bias, ref = get_numpy_attention_ref(
b, s, s_kv, n, h, h_v, bias_shape, "none", attention_causal, "float32"
b, s, s_kv, n, h, h_v, bias_shape, "none", attention_causal, "float16"
)

q_shape = (b, s, n, h)
Expand All @@ -757,10 +757,11 @@ def test_attention_causal_offload(attention_causal_size, attention_causal):
q_shape,
k_shape,
v_shape,
dtype="float32",
dtype="float16",
bias_shape=bias_shape,
causal_mask=attention_causal,
)

if bias is None:
out = get_result_with_relax_cutlass_offload(mod, q, k, v, num_final_bindings=3)
else:
Expand Down Expand Up @@ -1945,5 +1946,51 @@ def main(
tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)


def test_attention_rewrite_multi_query():
@I.ir_module
class Module:
@R.function
def main(
q: R.Tensor((4, 16, 32, 16), dtype="float16"),
k_single: R.Tensor((4, 16, 1, 16), dtype="float16"),
v_single: R.Tensor((4, 16, 1, 16), dtype="float16"),
) -> R.Tensor((4, 16, 32, 8), dtype="float16"):
with R.dataflow():
k = R.repeat(k_single, 32, axis=2)
v = R.repeat(v_single, 32, axis=2)

lv = R.permute_dims(q, axes=[0, 2, 1, 3])
lv1 = R.reshape(lv, R.shape([128, 16, 16]))
lv2 = R.permute_dims(k, axes=[0, 2, 1, 3])
lv3 = R.reshape(lv2, R.shape([128, 16, 16]))
lv4 = R.permute_dims(v, axes=[0, 2, 1, 3])
lv5 = R.reshape(lv4, R.shape([128, 16, 16]))

lv6 = R.permute_dims(lv3, axes=[0, 2, 1])
lv7 = R.matmul(lv1, lv6, out_dtype="float16")
lv3_1 = R.astype(R.const(0.25, "float32"), "float16")
lv8 = R.multiply(lv7, lv3_1)
lv11 = R.astype(R.nn.softmax(R.astype(lv8, "float32"), axis=2), "float16")
lv12 = R.matmul(lv11, lv5, out_dtype="float16")
lv13 = R.reshape(lv12, R.shape([4, 32, 16, 16]))
lv6_1 = R.permute_dims(lv13, axes=[0, 2, 1, 3])
R.output(lv6_1)
return lv6_1

q_np = np.random.randn(4, 16, 32, 16).astype("float16")
k_np = np.random.randn(4, 16, 1, 16).astype("float16")
v_np = np.random.randn(4, 16, 1, 16).astype("float16")
args = [q_np, k_np, v_np]
ref = build_and_run(Module, args, "llvm", legalize=True)

mod = partition_for_cutlass(Module, use_flash_mqa=True)
codegen_pass = relax.transform.RunCodegen({"cutlass": {"sm": 80}})
mod = codegen_pass(mod)

out = build_and_run(mod, args, "cuda")

tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)


if __name__ == "__main__":
tvm.testing.main()