diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h index 3daa32fd76b6..bcfe3207bcef 100644 --- a/include/tvm/relax/attrs/nn.h +++ b/include/tvm/relax/attrs/nn.h @@ -295,6 +295,16 @@ struct DropoutAttrs : public tvm::AttrsNode { } }; // struct DropoutAttrs +/*! \brief Attributes used in dropout operator */ +struct AttentionAttrs : public tvm::AttrsNode { + Optional scale; + + TVM_DECLARE_ATTRS(AttentionAttrs, "relax.attrs.AttentionAttrs") { + TVM_ATTR_FIELD(scale).describe( + "The custom scale applied before the softmax. The default value is 1 / sqrt(head_dim)."); + } +}; // struct AttentionAttrs + } // namespace relax } // namespace tvm diff --git a/python/tvm/contrib/cutlass/attention_operation.py b/python/tvm/contrib/cutlass/attention_operation.py index 9093a03dd6ed..f7dee4e3b80a 100644 --- a/python/tvm/contrib/cutlass/attention_operation.py +++ b/python/tvm/contrib/cutlass/attention_operation.py @@ -90,7 +90,7 @@ def instantiate_attention_template(attrs, func_args): p.head_dim_value = ${head_dim_value}; // H' p.num_queries = ${num_queries}; // S p.num_keys = ${num_keys}; // S' - p.scale = 1.0f / sqrt(float(${head_dim})); + p.scale = ${scale}; // stride for N p.q_strideH = p.head_dim; // H @@ -123,12 +123,12 @@ def instantiate_attention_template(attrs, func_args): CHECK(Attention::check_supported(p)); kernel_fn<<>>(p); """ - if attrs["kSupportsBias"]: - template = substitute_template( - template, {"bias_template": bias_template[attrs["bias_layout"]]} - ) - else: - template = substitute_template(template, {"bias_template": ""}) + + template = substitute_template( + template, + {"bias_template": bias_template[attrs["bias_layout"]] if "bias_layout" in attrs else ""}, + ) + for i, arg in enumerate(func_args): attrs["arg{}".format(i)] = arg return substitute_template(template, attrs) diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py index 47bdcaa790b8..a7c20a226e35 100644 --- a/python/tvm/contrib/cutlass/build.py +++ b/python/tvm/contrib/cutlass/build.py @@ -785,7 +785,12 @@ def handle_matmul(self, f, op_type): def handle_attention(self, f, op_type): """Tune and annotate a dense op.""" signature = _extract_relax_function_signature(f) - + if _get_call_node(f.body, "relax.nn.attention") is not None: + op_attrs = _get_call_node(f.body, "relax.nn.attention").attrs + elif _get_call_node(f.body, "relax.nn.attention_bias") is not None: + op_attrs = _get_call_node(f.body, "relax.nn.attention_bias").attrs + else: + raise ValueError(f"Cannot find call node for attention") q_shape = signature["arg0_shape"] k_shape = signature["arg1_shape"] v_shape = signature["arg2_shape"] @@ -797,6 +802,7 @@ def handle_attention(self, f, op_type): num_batches, num_queries, num_heads, head_dim = q_shape _, num_keys, _, _ = k_shape _, _, _, head_dim_value = v_shape + scale = op_attrs.scale bias = {} if "arg3_dtype" in signature: bias["arg3_dtype"] = signature["arg3_dtype"] @@ -820,6 +826,7 @@ def handle_attention(self, f, op_type): "num_heads": num_heads, "head_dim": head_dim, "head_dim_value": head_dim_value, + "scale": scale, "arch": self.options["sm"], **bias, } diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py b/python/tvm/contrib/cutlass/gen_tensor_op.py index 1de193858036..61c88c657f05 100644 --- a/python/tvm/contrib/cutlass/gen_tensor_op.py +++ b/python/tvm/contrib/cutlass/gen_tensor_op.py @@ -17,6 +17,7 @@ # pylint: disable=invalid-name """Common functions and classes for CUTLASS GEMM and Conv2d geneator.""" import logging +import math import multiprocessing import os import re @@ -722,6 +723,12 @@ def get_batch_on_arg(arg_name, arg_shape): attrs["kKeysPerBlock"] = 64 attrs["kSingleValueIteration"] = True attrs["output_size"] = b * s * n * h_v + attrs["scale"] = ( + float(1 / math.sqrt(h.value)) if annotations["scale"] is None else annotations["scale"] + ) + assert ( + attrs["scale"] > 0 or attrs["scale"] < 0 + ), "Cutlass may generate nan occasionally when scale == 0.0" attrs["arch"] = "cutlass::arch::Sm{}".format(annotations["arch"]) attrs["kSupportsDropout"] = False if len(func_args) > 3: @@ -735,7 +742,9 @@ def get_batch_on_arg(arg_name, arg_shape): else: raise NotImplementedError() else: - attrs["kSupportsBias"] = False + # To support negative scale in current Cutlass implementation, + # kSupportsBias should be set true, or there are nan's as result. + attrs["kSupportsBias"] = attrs["scale"] < 0 code = instantiate_attention_template(attrs, func_args) return CodegenResult(code, headers) diff --git a/python/tvm/contrib/cutlass/library.py b/python/tvm/contrib/cutlass/library.py index b72553ef6052..ead5804b59a0 100644 --- a/python/tvm/contrib/cutlass/library.py +++ b/python/tvm/contrib/cutlass/library.py @@ -20,7 +20,7 @@ import enum from enum import auto as enum_auto -from tvm.tir.expr import IntImm +from tvm.tir.expr import IntImm, FloatImm class GeneratorTarget(enum.Enum): @@ -147,6 +147,8 @@ def substitute_template(template, values): for key, value in values.items(): if isinstance(value, (int, IntImm)): value = str(int(value)) + if isinstance(value, (float, FloatImm)): + value = str(float(value)) elif isinstance(value, bool): value = str(value).lower() regex = "\\$\\{%s\\}" % key diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py index e1d41c6cdfd6..02468637e0f9 100644 --- a/python/tvm/relax/op/nn/nn.py +++ b/python/tvm/relax/op/nn/nn.py @@ -18,6 +18,7 @@ from typing import List, Optional, Tuple, Union from tvm import DataType +from tvm.tir import FloatImm from . import _ffi_api from ...expr import Expr @@ -913,7 +914,13 @@ def cross_entropy_with_logits(predictions: Expr, labels: Expr) -> Expr: return _ffi_api.cross_entropy_with_logits(predictions, labels) # type: ignore -def attention(query: Expr, key: Expr, value: Expr, bias: Optional[Expr] = None) -> Expr: +def attention( + query: Expr, + key: Expr, + value: Expr, + bias: Optional[Expr] = None, + scale: Optional[FloatImm] = None, +) -> Expr: r"""Computes fused multi head attention. All input tensors are of 4-D tensors with BSNH layout. @@ -943,10 +950,13 @@ def attention(query: Expr, key: Expr, value: Expr, bias: Optional[Expr] = None) (batch_size, num_head, seq_len, seq_len_kv), (batch_size, seq_len, seq_len_kv) or (batch_size, seq_len_kv). + scale: Optional[FloatImm] + The custom scale applied before the softmax. The default value is 1 / sqrt(head_dim). + Returns ------- result : relax.Expr The computed result. The layout of the output should be (batch_size, seq_len, num_head, head_dim_v). """ - return _ffi_api.attention(query, key, value, bias) # type: ignore + return _ffi_api.attention(query, key, value, bias, scale) # type: ignore diff --git a/python/tvm/relax/transform/legalize_ops/nn.py b/python/tvm/relax/transform/legalize_ops/nn.py index 889e6e09417b..1ce45206354d 100644 --- a/python/tvm/relax/transform/legalize_ops/nn.py +++ b/python/tvm/relax/transform/legalize_ops/nn.py @@ -312,3 +312,59 @@ def _nn_group_norm(bb: BlockBuilder, call: Call) -> Expr: def _nn_dropout(bb: BlockBuilder, call: Call) -> Expr: logging.info("Dropout is handled by frontend translator at this moment and is not legalized.") return call + + +def _te_attention( + q: te.Tensor, k: te.Tensor, v: te.Tensor, bias: te.Tensor, scale: tir.FloatImm +) -> te.Tensor: + batch_size, seq_len, num_head, head_dim = q.shape + _, seq_len_kv, _, head_dim_v = v.shape + q = topi.transpose(q, [0, 2, 1, 3]) + k = topi.transpose(k, [0, 2, 1, 3]) + v = topi.transpose(v, [0, 2, 1, 3]) + q = topi.reshape(q, [batch_size * num_head, seq_len, head_dim]) + k = topi.reshape(k, [batch_size * num_head, seq_len_kv, head_dim]) + v = topi.reshape(v, [batch_size * num_head, seq_len_kv, head_dim_v]) + p = topi.nn.batch_matmul(q, k) + if scale is not None: + p = topi.multiply(p, scale) + else: + p = topi.divide(p, tir.sqrt(tir.Cast(p.dtype, head_dim))) + if bias is not None: + p = topi.reshape(p, [batch_size, num_head, seq_len, seq_len_kv]) + if len(bias.shape) == 2: + bias = topi.reshape(bias, [batch_size, 1, 1, seq_len_kv]) + elif len(bias.shape) == 3: + bias = topi.reshape(bias, [batch_size, 1, seq_len, seq_len_kv]) + p = topi.add(p, bias) + p = topi.reshape(p, [batch_size * num_head, seq_len, seq_len_kv]) + s = topi.nn.softmax(p) + o = topi.nn.batch_matmul(s, v, transpose_b=False) + o = topi.reshape(o, [batch_size, num_head, seq_len, head_dim_v]) + return topi.transpose(o, [0, 2, 1, 3]) + + +@register_legalize("relax.nn.attention") +def _nn_attention(bb: BlockBuilder, call: Call) -> Expr: + return bb.call_te( + _te_attention, + call.args[0], + call.args[1], + call.args[2], + None, + call.attrs.scale, + primfunc_name_hint="attention", + ) + + +@register_legalize("relax.nn.attention_bias") +def _nn_attention_bias(bb: BlockBuilder, call: Call) -> Expr: + return bb.call_te( + _te_attention, + call.args[0], + call.args[1], + call.args[2], + call.args[3], + call.attrs.scale, + primfunc_name_hint="attention_bias", + ) diff --git a/src/relax/op/nn/attention.cc b/src/relax/op/nn/attention.cc index e139aa09d692..c27e8b68d0bc 100644 --- a/src/relax/op/nn/attention.cc +++ b/src/relax/op/nn/attention.cc @@ -26,14 +26,18 @@ namespace tvm { namespace relax { /* relax.nn.attention */ -Expr attention(Expr query, Expr key, Expr value, Optional bias) { +TVM_REGISTER_NODE_TYPE(AttentionAttrs); + +Expr attention(Expr query, Expr key, Expr value, Optional bias, Optional scale) { + ObjectPtr attrs = make_object(); + attrs->scale = scale; if (bias.defined()) { return Call(Op::Get("relax.nn.attention_bias"), - {std::move(query), std::move(key), std::move(value), std::move(bias.value())}, {}, - {}); + {std::move(query), std::move(key), std::move(value), std::move(bias.value())}, + Attrs(attrs), {}); } return Call(Op::Get("relax.nn.attention"), {std::move(query), std::move(key), std::move(value)}, - {}, {}); + Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relax.op.nn.attention").set_body_typed(attention); @@ -105,6 +109,7 @@ StructInfo InferStructInfoAttention(const Call& call, const BlockBuilder& ctx) { } TVM_REGISTER_OP("relax.nn.attention") + .set_attrs_type() .set_num_inputs(3) .add_argument("query", "Tensor", "The input queries tensor.") .add_argument("key", "Tensor", "The input keys tensor.") @@ -112,6 +117,7 @@ TVM_REGISTER_OP("relax.nn.attention") .set_attr("FInferStructInfo", InferStructInfoAttention); TVM_REGISTER_OP("relax.nn.attention_bias") + .set_attrs_type() .set_num_inputs(4) .add_argument("query", "Tensor", "The input queries tensor.") .add_argument("key", "Tensor", "The input keys tensor.") diff --git a/src/relax/op/nn/attention.h b/src/relax/op/nn/attention.h index 662e0b7e7b81..7eda30b40813 100644 --- a/src/relax/op/nn/attention.h +++ b/src/relax/op/nn/attention.h @@ -33,7 +33,7 @@ namespace tvm { namespace relax { /*! \brief fused multi head attention */ -Expr attention(Expr query, Expr key, Expr value, Optional bias); +Expr attention(Expr query, Expr key, Expr value, Optional bias, Optional scale); } // namespace relax } // namespace tvm diff --git a/tests/python/relax/test_codegen_cutlass.py b/tests/python/relax/test_codegen_cutlass.py index 5ea9a9d04017..c8ca44311de5 100644 --- a/tests/python/relax/test_codegen_cutlass.py +++ b/tests/python/relax/test_codegen_cutlass.py @@ -566,11 +566,14 @@ def attention_size(request): return request.param -def get_relax_attention_module(q, k, v, bias=None): +def get_relax_attention_module(q, k, v, bias=None, qk_scale=None): dtype = str(q.dtype) from tvm.script.ir_builder import IRBuilder - from tvm.script.ir_builder import relax as relax_builder + from tvm.script.ir_builder import relax as relax_builder, tir as T + + if qk_scale is not None: + qk_scale = T.FloatImm("float32", qk_scale) with IRBuilder() as builder: with relax_builder.function(): @@ -581,7 +584,7 @@ def get_relax_attention_module(q, k, v, bias=None): if bias is not None: bias = R.arg("bias", R.Tensor(bias.shape, dtype)) with R.dataflow() as frame: - result = R.emit(R.nn.attention(q, k, v, bias)) + result = R.emit(R.nn.attention(q, k, v, bias, qk_scale)) R.output(result) R.func_ret_value(frame.output_vars[0]) @@ -591,22 +594,32 @@ def get_relax_attention_module(q, k, v, bias=None): @memoize("topi.tests.test_codegen_cutlass.test_attention_offload") -def get_numpy_attention_ref(b, s, s_kv, n, h, h_v, dtype): +def get_numpy_attention_ref(b, s, s_kv, n, h, h_v, bias_shape, bias_reshape, qk_scale, dtype): q = np.random.randn(b, s, n, h).astype(dtype) k = np.random.randn(b, s_kv, n, h).astype(dtype) v = np.random.randn(b, s_kv, n, h_v).astype(dtype) qt = q.transpose(0, 2, 1, 3) # b, n, s, h kt = k.transpose(0, 2, 3, 1) # b, n, h, s_kv - score = qt @ kt / np.sqrt(q.shape[-1]) # b, n, s, s_kv + if not qk_scale == "none": + score = qt @ kt * qk_scale # b, n, s, s_kv + else: + score = qt @ kt / np.sqrt(q.shape[-1]) # b, n, s, s_kv + if not bias_shape == "none": + bias = np.random.randn(*bias_shape).astype(dtype) + score = score + bias.reshape(*bias_reshape) # b, n, s, s_kv + else: + bias = None attn = tvm.topi.testing.softmax_python(score, -1) vt = v.transpose(0, 2, 1, 3) # b, n, s_kv, h_v ref = attn @ vt # b, n, s, h_v - return q, k, v, ref.transpose(0, 2, 1, 3) # b, s, n, h_v + return q, k, v, bias, ref.transpose(0, 2, 1, 3) # b, s, n, h_v def test_attention_offload(attention_size, attention_dtype): b, (s, s_kv), n, (h, h_v) = attention_size - q, k, v, ref = get_numpy_attention_ref(b, s, s_kv, n, h, h_v, attention_dtype) + q, k, v, _, ref = get_numpy_attention_ref( + b, s, s_kv, n, h, h_v, "none", "none", "none", attention_dtype + ) mod = get_relax_attention_module(q, k, v) out = get_result_with_relax_cutlass_offload(mod, q, k, v) @@ -614,25 +627,23 @@ def test_attention_offload(attention_size, attention_dtype): tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) -@memoize("topi.tests.test_codegen_cutlass.test_attention_bias_4d_offload") -def get_numpy_attention_bias_4d_ref(b, s, s_kv, n, h, h_v, dtype): - q = np.random.randn(b, s, n, h).astype(dtype) - k = np.random.randn(b, s_kv, n, h).astype(dtype) - v = np.random.randn(b, s_kv, n, h_v).astype(dtype) - bias = np.random.randn(b, n, s, s_kv).astype(dtype) - qt = q.transpose(0, 2, 1, 3) # b, n, s, h - kt = k.transpose(0, 2, 3, 1) # b, n, h, s_kv - score = qt @ kt / np.sqrt(q.shape[-1]) # b, n, s, s_kv - score_bias = score + bias # b, n, s, s_kv - attn = tvm.topi.testing.softmax_python(score_bias, -1) - vt = v.transpose(0, 2, 1, 3) # b, n, s_kv, h_v - ref = attn @ vt # b, n, s, h_v - return q, k, v, bias, ref.transpose(0, 2, 1, 3) # b, s, n, h_v +@pytest.fixture( + params=[ + # B, S, N, H, bias_shape, bias_reshape + (4, (16, 8), 32, (8, 16), (4, 32, 16, 8), (4, 32, 16, 8)), + (4, (16, 8), 32, (8, 16), (4, 16, 8), (4, 1, 16, 8)), + (4, (16, 8), 32, (8, 16), (4, 8), (4, 1, 1, 8)), + ] +) +def attention_bias_size(request): + return request.param -def test_attention_bias_4d_offload(attention_size, attention_dtype): - b, (s, s_kv), n, (h, h_v) = attention_size - q, k, v, bias, ref = get_numpy_attention_bias_4d_ref(b, s, s_kv, n, h, h_v, attention_dtype) +def test_attention_bias_offload(attention_bias_size, attention_dtype): + b, (s, s_kv), n, (h, h_v), bias_shape, bias_reshape = attention_bias_size + q, k, v, bias, ref = get_numpy_attention_ref( + b, s, s_kv, n, h, h_v, bias_shape, bias_reshape, "none", attention_dtype + ) mod = get_relax_attention_module(q, k, v, bias) out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias) @@ -640,55 +651,33 @@ def test_attention_bias_4d_offload(attention_size, attention_dtype): tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) -@memoize("topi.tests.test_codegen_cutlass.test_attention_bias_3d_offload") -def get_numpy_attention_bias_3d_ref(b, s, s_kv, n, h, h_v, dtype): - q = np.random.randn(b, s, n, h).astype(dtype) - k = np.random.randn(b, s_kv, n, h).astype(dtype) - v = np.random.randn(b, s_kv, n, h_v).astype(dtype) - bias = np.random.randn(b, s, s_kv).astype(dtype) - qt = q.transpose(0, 2, 1, 3) # b, n, s, h - kt = k.transpose(0, 2, 3, 1) # b, n, h, s_kv - score = qt @ kt / np.sqrt(q.shape[-1]) # b, n, s, s_kv - score_bias = score + bias.reshape(b, 1, s, s_kv) # b, n, s, s_kv - attn = tvm.topi.testing.softmax_python(score_bias, -1) - vt = v.transpose(0, 2, 1, 3) # b, n, s_kv, h_v - ref = attn @ vt # b, n, s, h_v - return q, k, v, bias, ref.transpose(0, 2, 1, 3) # b, s, n, h_v - - -def test_attention_bias_3d_offload(attention_size, attention_dtype): - b, (s, s_kv), n, (h, h_v) = attention_size - q, k, v, bias, ref = get_numpy_attention_bias_3d_ref(b, s, s_kv, n, h, h_v, attention_dtype) - - mod = get_relax_attention_module(q, k, v, bias) - out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias) - - tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) - +@pytest.fixture( + params=[ + # B, S, N, H, bias_shape, bias_reshape + (4, (16, 8), 32, (8, 16), (4, 32, 16, 8), (4, 32, 16, 8)), + (4, (16, 8), 32, (8, 16), "none", "none"), + ] +) +def attention_scale_size(request): + return request.param -@memoize("topi.tests.test_codegen_cutlass.test_attention_bias_2d_offload") -def get_numpy_attention_bias_2d_ref(b, s, s_kv, n, h, h_v, dtype): - q = np.random.randn(b, s, n, h).astype(dtype) - k = np.random.randn(b, s_kv, n, h).astype(dtype) - v = np.random.randn(b, s_kv, n, h_v).astype(dtype) - bias = np.random.randn(b, s_kv).astype(dtype) - qt = q.transpose(0, 2, 1, 3) # b, n, s, h - kt = k.transpose(0, 2, 3, 1) # b, n, h, s_kv - score = qt @ kt / np.sqrt(q.shape[-1]) # b, n, s, s_kv - score_bias = score + bias.reshape(b, 1, 1, s_kv) # b, n, s, s_kv - attn = tvm.topi.testing.softmax_python(score_bias, -1) - vt = v.transpose(0, 2, 1, 3) # b, n, s_kv, h_v - ref = attn @ vt # b, n, s, h_v - return q, k, v, bias, ref.transpose(0, 2, 1, 3) # b, s, n, h_v +@pytest.fixture(params=[0.01, 1e-8, -0.5, 1.23]) +def attention_scale(request): + return request.param -def test_attention_bias_2d_offload(attention_size, attention_dtype): - b, (s, s_kv), n, (h, h_v) = attention_size - q, k, v, bias, ref = get_numpy_attention_bias_2d_ref(b, s, s_kv, n, h, h_v, attention_dtype) - mod = get_relax_attention_module(q, k, v, bias) - out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias) +def test_attention_scale_offload(attention_scale_size, attention_scale, attention_dtype): + b, (s, s_kv), n, (h, h_v), bias_shape, bias_reshape = attention_scale_size + q, k, v, bias, ref = get_numpy_attention_ref( + b, s, s_kv, n, h, h_v, bias_shape, bias_reshape, attention_scale, attention_dtype + ) + mod = get_relax_attention_module(q, k, v, bias, attention_scale) + if bias is None: + out = get_result_with_relax_cutlass_offload(mod, q, k, v) + else: + out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias) tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py b/tests/python/relax/test_transform_legalize_ops_nn.py index e944b8d76ebe..e807082e3526 100644 --- a/tests/python/relax/test_transform_legalize_ops_nn.py +++ b/tests/python/relax/test_transform_legalize_ops_nn.py @@ -2280,5 +2280,168 @@ def main(s: R.Shape(["c"]), x: R.Tensor(("n", "4 * c", "h", "w"), dtype="float32 tvm.ir.assert_structural_equal(mod, Expected) +def test_attention(): + # fmt: off + @tvm.script.ir_module + class Attention: + @R.function + def main(q: R.Tensor((4, 16, 32, 8), "float32"), k: R.Tensor((4, 8, 32, 8), "float32"), v: R.Tensor((4, 8, 32, 16), "float32"), bias: R.Tensor((4, 32, 16, 8), "float32")): + scale = T.FloatImm("float32", 0.1) + gv: R.Tensor((4, 16, 32, 16), "float32") = R.nn.attention(q, k, v, bias, scale) + return gv + + @tvm.script.ir_module + class Expected: + @T.prim_func + def attention_bias(rxplaceholder: T.Buffer((T.int64(4), T.int64(16), T.int64(32), T.int64(8)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(8), T.int64(32), T.int64(8)), "float32"), rxplaceholder_2: T.Buffer((T.int64(4), T.int64(8), T.int64(32), T.int64(16)), "float32"), rxplaceholder_3: T.Buffer((T.int64(4), T.int64(32), T.int64(16), T.int64(8)), "float32"), T_transpose: T.Buffer((T.int64(4), T.int64(16), T.int64(32), T.int64(16)), "float32")): + T.func_attr({"tir.noalias": True}) + # with T.block("root"): + T_transpose_1 = T.alloc_buffer((T.int64(4), T.int64(32), T.int64(16), T.int64(8))) + T_reshape = T.alloc_buffer((T.int64(128), T.int64(16), T.int64(8))) + T_transpose_2 = T.alloc_buffer((T.int64(4), T.int64(32), T.int64(8), T.int64(8))) + T_reshape_1 = T.alloc_buffer((T.int64(128), T.int64(8), T.int64(8))) + T_batch_matmul_NT = T.alloc_buffer((T.int64(128), T.int64(16), T.int64(8))) + T_multiply = T.alloc_buffer((T.int64(128), T.int64(16), T.int64(8))) + T_reshape_2 = T.alloc_buffer((T.int64(4), T.int64(32), T.int64(16), T.int64(8))) + T_add = T.alloc_buffer((T.int64(4), T.int64(32), T.int64(16), T.int64(8))) + T_reshape_3 = T.alloc_buffer((T.int64(128), T.int64(16), T.int64(8))) + T_softmax_maxelem = T.alloc_buffer((T.int64(128), T.int64(16))) + T_softmax_exp = T.alloc_buffer((T.int64(128), T.int64(16), T.int64(8))) + T_softmax_expsum = T.alloc_buffer((T.int64(128), T.int64(16))) + T_softmax_norm = T.alloc_buffer((T.int64(128), T.int64(16), T.int64(8))) + T_transpose_3 = T.alloc_buffer((T.int64(4), T.int64(32), T.int64(8), T.int64(16))) + T_reshape_4 = T.alloc_buffer((T.int64(128), T.int64(8), T.int64(16))) + T_batch_matmul_NN = T.alloc_buffer((T.int64(128), T.int64(16), T.int64(16))) + T_reshape_5 = T.alloc_buffer((T.int64(4), T.int64(32), T.int64(16), T.int64(16))) + for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(32), T.int64(16), T.int64(8)): + with T.block("T_transpose"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(rxplaceholder[v_ax0, v_ax2, v_ax1, v_ax3]) + T.writes(T_transpose_1[v_ax0, v_ax1, v_ax2, v_ax3]) + T_transpose_1[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax0, v_ax2, v_ax1, v_ax3] + for ax0, ax1, ax2 in T.grid(T.int64(128), T.int64(16), T.int64(8)): + with T.block("T_reshape"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(T_transpose_1[((v_ax2 // T.int64(8) + v_ax1) // T.int64(16) + v_ax0) % T.int64(128) // T.int64(32), ((v_ax2 // T.int64(8) + v_ax1) // T.int64(16) + v_ax0) % T.int64(32), (v_ax2 // T.int64(8) + v_ax1) % T.int64(16), v_ax2 % T.int64(8)]) + T.writes(T_reshape[v_ax0, v_ax1, v_ax2]) + T_reshape[v_ax0, v_ax1, v_ax2] = T_transpose_1[((v_ax2 // T.int64(8) + v_ax1) // T.int64(16) + v_ax0) % T.int64(128) // T.int64(32), ((v_ax2 // T.int64(8) + v_ax1) // T.int64(16) + v_ax0) % T.int64(32), (v_ax2 // T.int64(8) + v_ax1) % T.int64(16), v_ax2 % T.int64(8)] + for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(32), T.int64(8), T.int64(8)): + with T.block("T_transpose_1"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(rxplaceholder_1[v_ax0, v_ax2, v_ax1, v_ax3]) + T.writes(T_transpose_2[v_ax0, v_ax1, v_ax2, v_ax3]) + T_transpose_2[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder_1[v_ax0, v_ax2, v_ax1, v_ax3] + for ax0, ax1, ax2 in T.grid(T.int64(128), T.int64(8), T.int64(8)): + with T.block("T_reshape_1"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(T_transpose_2[((v_ax2 // T.int64(8) + v_ax1) // T.int64(8) + v_ax0) % T.int64(128) // T.int64(32), ((v_ax2 // T.int64(8) + v_ax1) // T.int64(8) + v_ax0) % T.int64(32), (v_ax2 // T.int64(8) + v_ax1) % T.int64(8), v_ax2 % T.int64(8)]) + T.writes(T_reshape_1[v_ax0, v_ax1, v_ax2]) + T_reshape_1[v_ax0, v_ax1, v_ax2] = T_transpose_2[((v_ax2 // T.int64(8) + v_ax1) // T.int64(8) + v_ax0) % T.int64(128) // T.int64(32), ((v_ax2 // T.int64(8) + v_ax1) // T.int64(8) + v_ax0) % T.int64(32), (v_ax2 // T.int64(8) + v_ax1) % T.int64(8), v_ax2 % T.int64(8)] + for b, i, j, k in T.grid(T.int64(128), T.int64(16), T.int64(8), T.int64(8)): + with T.block("T_batch_matmul_NT"): + v_b, v_i, v_j, v_k = T.axis.remap("SSSR", [b, i, j, k]) + T.reads(T_reshape[v_b, v_i, v_k], T_reshape_1[v_b, v_j, v_k]) + T.writes(T_batch_matmul_NT[v_b, v_i, v_j]) + T.block_attr({"layout_free_placeholders": [T_reshape_1]}) + with T.init(): + T_batch_matmul_NT[v_b, v_i, v_j] = T.float32(0) + T_batch_matmul_NT[v_b, v_i, v_j] = T_batch_matmul_NT[v_b, v_i, v_j] + T_reshape[v_b, v_i, v_k] * T_reshape_1[v_b, v_j, v_k] + for ax0, ax1, ax2 in T.grid(T.int64(128), T.int64(16), T.int64(8)): + with T.block("T_multiply"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(T_batch_matmul_NT[v_ax0, v_ax1, v_ax2]) + T.writes(T_multiply[v_ax0, v_ax1, v_ax2]) + T_multiply[v_ax0, v_ax1, v_ax2] = T_batch_matmul_NT[v_ax0, v_ax1, v_ax2] * T.float32(0.10000000000000001) + for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(32), T.int64(16), T.int64(8)): + with T.block("T_reshape_2"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(T_multiply[(v_ax0 * T.int64(32) + (v_ax3 // T.int64(8) + v_ax2) // T.int64(16) + v_ax1) % T.int64(128), (v_ax3 // T.int64(8) + v_ax2) % T.int64(16), v_ax3 % T.int64(8)]) + T.writes(T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3]) + T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3] = T_multiply[(v_ax0 * T.int64(32) + (v_ax3 // T.int64(8) + v_ax2) // T.int64(16) + v_ax1) % T.int64(128), (v_ax3 // T.int64(8) + v_ax2) % T.int64(16), v_ax3 % T.int64(8)] + for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(32), T.int64(16), T.int64(8)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3], rxplaceholder_3[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3]) + T_add[v_ax0, v_ax1, v_ax2, v_ax3] = T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3] + rxplaceholder_3[v_ax0, v_ax1, v_ax2, v_ax3] + for ax0, ax1, ax2 in T.grid(T.int64(128), T.int64(16), T.int64(8)): + with T.block("T_reshape_3"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(T_add[((v_ax2 // T.int64(8) + v_ax1) // T.int64(16) + v_ax0) % T.int64(128) // T.int64(32), ((v_ax2 // T.int64(8) + v_ax1) // T.int64(16) + v_ax0) % T.int64(32), (v_ax2 // T.int64(8) + v_ax1) % T.int64(16), v_ax2 % T.int64(8)]) + T.writes(T_reshape_3[v_ax0, v_ax1, v_ax2]) + T_reshape_3[v_ax0, v_ax1, v_ax2] = T_add[((v_ax2 // T.int64(8) + v_ax1) // T.int64(16) + v_ax0) % T.int64(128) // T.int64(32), ((v_ax2 // T.int64(8) + v_ax1) // T.int64(16) + v_ax0) % T.int64(32), (v_ax2 // T.int64(8) + v_ax1) % T.int64(16), v_ax2 % T.int64(8)] + for i0, i1, k in T.grid(T.int64(128), T.int64(16), T.int64(8)): + with T.block("T_softmax_maxelem"): + v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k]) + T.reads(T_reshape_3[v_i0, v_i1, v_k]) + T.writes(T_softmax_maxelem[v_i0, v_i1]) + with T.init(): + T_softmax_maxelem[v_i0, v_i1] = T.float32(-3.4028234663852886e+38) + T_softmax_maxelem[v_i0, v_i1] = T.max(T_softmax_maxelem[v_i0, v_i1], T_reshape_3[v_i0, v_i1, v_k]) + for i0, i1, i2 in T.grid(T.int64(128), T.int64(16), T.int64(8)): + with T.block("T_softmax_exp"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(T_reshape_3[v_i0, v_i1, v_i2], T_softmax_maxelem[v_i0, v_i1]) + T.writes(T_softmax_exp[v_i0, v_i1, v_i2]) + T_softmax_exp[v_i0, v_i1, v_i2] = T.exp(T_reshape_3[v_i0, v_i1, v_i2] - T_softmax_maxelem[v_i0, v_i1]) + for i0, i1, k in T.grid(T.int64(128), T.int64(16), T.int64(8)): + with T.block("T_softmax_expsum"): + v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k]) + T.reads(T_softmax_exp[v_i0, v_i1, v_k]) + T.writes(T_softmax_expsum[v_i0, v_i1]) + with T.init(): + T_softmax_expsum[v_i0, v_i1] = T.float32(0) + T_softmax_expsum[v_i0, v_i1] = T_softmax_expsum[v_i0, v_i1] + T_softmax_exp[v_i0, v_i1, v_k] + for i0, i1, i2 in T.grid(T.int64(128), T.int64(16), T.int64(8)): + with T.block("T_softmax_norm"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(T_softmax_exp[v_i0, v_i1, v_i2], T_softmax_expsum[v_i0, v_i1]) + T.writes(T_softmax_norm[v_i0, v_i1, v_i2]) + T.block_attr({"axis": 2}) + T_softmax_norm[v_i0, v_i1, v_i2] = T_softmax_exp[v_i0, v_i1, v_i2] / T_softmax_expsum[v_i0, v_i1] + for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(32), T.int64(8), T.int64(16)): + with T.block("T_transpose_2"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(rxplaceholder_2[v_ax0, v_ax2, v_ax1, v_ax3]) + T.writes(T_transpose_3[v_ax0, v_ax1, v_ax2, v_ax3]) + T_transpose_3[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder_2[v_ax0, v_ax2, v_ax1, v_ax3] + for ax0, ax1, ax2 in T.grid(T.int64(128), T.int64(8), T.int64(16)): + with T.block("T_reshape_4"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(T_transpose_3[((v_ax2 // T.int64(16) + v_ax1) // T.int64(8) + v_ax0) % T.int64(128) // T.int64(32), ((v_ax2 // T.int64(16) + v_ax1) // T.int64(8) + v_ax0) % T.int64(32), (v_ax2 // T.int64(16) + v_ax1) % T.int64(8), v_ax2 % T.int64(16)]) + T.writes(T_reshape_4[v_ax0, v_ax1, v_ax2]) + T_reshape_4[v_ax0, v_ax1, v_ax2] = T_transpose_3[((v_ax2 // T.int64(16) + v_ax1) // T.int64(8) + v_ax0) % T.int64(128) // T.int64(32), ((v_ax2 // T.int64(16) + v_ax1) // T.int64(8) + v_ax0) % T.int64(32), (v_ax2 // T.int64(16) + v_ax1) % T.int64(8), v_ax2 % T.int64(16)] + for b, i, j, k in T.grid(T.int64(128), T.int64(16), T.int64(16), T.int64(8)): + with T.block("T_batch_matmul_NN"): + v_b, v_i, v_j, v_k = T.axis.remap("SSSR", [b, i, j, k]) + T.reads(T_softmax_norm[v_b, v_i, v_k], T_reshape_4[v_b, v_k, v_j]) + T.writes(T_batch_matmul_NN[v_b, v_i, v_j]) + T.block_attr({"layout_free_placeholders": [T_reshape_4]}) + with T.init(): + T_batch_matmul_NN[v_b, v_i, v_j] = T.float32(0) + T_batch_matmul_NN[v_b, v_i, v_j] = T_batch_matmul_NN[v_b, v_i, v_j] + T_softmax_norm[v_b, v_i, v_k] * T_reshape_4[v_b, v_k, v_j] + for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(32), T.int64(16), T.int64(16)): + with T.block("T_reshape_5"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(T_batch_matmul_NN[(v_ax0 * T.int64(32) + (v_ax3 // T.int64(16) + v_ax2) // T.int64(16) + v_ax1) % T.int64(128), (v_ax3 // T.int64(16) + v_ax2) % T.int64(16), v_ax3 % T.int64(16)]) + T.writes(T_reshape_5[v_ax0, v_ax1, v_ax2, v_ax3]) + T_reshape_5[v_ax0, v_ax1, v_ax2, v_ax3] = T_batch_matmul_NN[(v_ax0 * T.int64(32) + (v_ax3 // T.int64(16) + v_ax2) // T.int64(16) + v_ax1) % T.int64(128), (v_ax3 // T.int64(16) + v_ax2) % T.int64(16), v_ax3 % T.int64(16)] + for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(16), T.int64(32), T.int64(16)): + with T.block("T_transpose_3"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(T_reshape_5[v_ax0, v_ax2, v_ax1, v_ax3]) + T.writes(T_transpose[v_ax0, v_ax1, v_ax2, v_ax3]) + T_transpose[v_ax0, v_ax1, v_ax2, v_ax3] = T_reshape_5[v_ax0, v_ax2, v_ax1, v_ax3] + + @R.function + def main(q: R.Tensor((4, 16, 32, 8), dtype="float32"), k: R.Tensor((4, 8, 32, 8), dtype="float32"), v: R.Tensor((4, 8, 32, 16), dtype="float32"), bias: R.Tensor((4, 32, 16, 8), dtype="float32")) -> R.Tensor((4, 16, 32, 16), dtype="float32"): + gv = R.call_tir(Expected.attention_bias, (q, k, v, bias), out_sinfo=R.Tensor((4, 16, 32, 16), dtype="float32")) + return gv + # fmt: on + + mod = LegalizeOps()(Attention) + tvm.ir.assert_structural_equal(mod, Expected) + + if __name__ == "__main__": tvm.testing.main()