Skip to content

Commit 23146d6

Browse files
authored
[Unity][Op] Expose scale in R.nn.attention and add its legalize op (#14412)
This PR exposes the custom scale in `R.nn.attention` and adds its legalize op.
1 parent cd3e107 commit 23146d6

File tree

11 files changed

+337
-85
lines changed

11 files changed

+337
-85
lines changed

include/tvm/relax/attrs/nn.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,16 @@ struct DropoutAttrs : public tvm::AttrsNode<DropoutAttrs> {
295295
}
296296
}; // struct DropoutAttrs
297297

298+
/*! \brief Attributes used in dropout operator */
299+
struct AttentionAttrs : public tvm::AttrsNode<AttentionAttrs> {
300+
Optional<FloatImm> scale;
301+
302+
TVM_DECLARE_ATTRS(AttentionAttrs, "relax.attrs.AttentionAttrs") {
303+
TVM_ATTR_FIELD(scale).describe(
304+
"The custom scale applied before the softmax. The default value is 1 / sqrt(head_dim).");
305+
}
306+
}; // struct AttentionAttrs
307+
298308
} // namespace relax
299309
} // namespace tvm
300310

python/tvm/contrib/cutlass/attention_operation.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def instantiate_attention_template(attrs, func_args):
9090
p.head_dim_value = ${head_dim_value}; // H'
9191
p.num_queries = ${num_queries}; // S
9292
p.num_keys = ${num_keys}; // S'
93-
p.scale = 1.0f / sqrt(float(${head_dim}));
93+
p.scale = ${scale};
9494
9595
// stride for N
9696
p.q_strideH = p.head_dim; // H
@@ -123,12 +123,12 @@ def instantiate_attention_template(attrs, func_args):
123123
CHECK(Attention::check_supported(p));
124124
kernel_fn<<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes>>>(p);
125125
"""
126-
if attrs["kSupportsBias"]:
127-
template = substitute_template(
128-
template, {"bias_template": bias_template[attrs["bias_layout"]]}
129-
)
130-
else:
131-
template = substitute_template(template, {"bias_template": ""})
126+
127+
template = substitute_template(
128+
template,
129+
{"bias_template": bias_template[attrs["bias_layout"]] if "bias_layout" in attrs else ""},
130+
)
131+
132132
for i, arg in enumerate(func_args):
133133
attrs["arg{}".format(i)] = arg
134134
return substitute_template(template, attrs)

python/tvm/contrib/cutlass/build.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -786,7 +786,12 @@ def handle_matmul(self, f, op_type):
786786
def handle_attention(self, f, op_type):
787787
"""Tune and annotate a dense op."""
788788
signature = _extract_relax_function_signature(f)
789-
789+
if _get_call_node(f.body, "relax.nn.attention") is not None:
790+
op_attrs = _get_call_node(f.body, "relax.nn.attention").attrs
791+
elif _get_call_node(f.body, "relax.nn.attention_bias") is not None:
792+
op_attrs = _get_call_node(f.body, "relax.nn.attention_bias").attrs
793+
else:
794+
raise ValueError(f"Cannot find call node for attention")
790795
q_shape = signature["arg0_shape"]
791796
k_shape = signature["arg1_shape"]
792797
v_shape = signature["arg2_shape"]
@@ -798,6 +803,7 @@ def handle_attention(self, f, op_type):
798803
num_batches, num_queries, num_heads, head_dim = q_shape
799804
_, num_keys, _, _ = k_shape
800805
_, _, _, head_dim_value = v_shape
806+
scale = op_attrs.scale
801807
bias = {}
802808
if "arg3_dtype" in signature:
803809
bias["arg3_dtype"] = signature["arg3_dtype"]
@@ -821,6 +827,7 @@ def handle_attention(self, f, op_type):
821827
"num_heads": num_heads,
822828
"head_dim": head_dim,
823829
"head_dim_value": head_dim_value,
830+
"scale": scale,
824831
"arch": self.options["sm"],
825832
**bias,
826833
}

python/tvm/contrib/cutlass/gen_tensor_op.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
# pylint: disable=invalid-name
1818
"""Common functions and classes for CUTLASS GEMM and Conv2d geneator."""
1919
import logging
20+
import math
2021
import multiprocessing
2122
import os
2223
import re
@@ -722,6 +723,12 @@ def get_batch_on_arg(arg_name, arg_shape):
722723
attrs["kKeysPerBlock"] = 64
723724
attrs["kSingleValueIteration"] = True
724725
attrs["output_size"] = b * s * n * h_v
726+
attrs["scale"] = (
727+
float(1 / math.sqrt(h.value)) if annotations["scale"] is None else annotations["scale"]
728+
)
729+
assert (
730+
attrs["scale"] > 0 or attrs["scale"] < 0
731+
), "Cutlass may generate nan occasionally when scale == 0.0"
725732
attrs["arch"] = "cutlass::arch::Sm{}".format(annotations["arch"])
726733
attrs["kSupportsDropout"] = False
727734
if len(func_args) > 3:
@@ -735,7 +742,9 @@ def get_batch_on_arg(arg_name, arg_shape):
735742
else:
736743
raise NotImplementedError()
737744
else:
738-
attrs["kSupportsBias"] = False
745+
# To support negative scale in current Cutlass implementation,
746+
# kSupportsBias should be set true, or there are nan's as result.
747+
attrs["kSupportsBias"] = attrs["scale"] < 0
739748
code = instantiate_attention_template(attrs, func_args)
740749
return CodegenResult(code, headers)
741750

python/tvm/contrib/cutlass/library.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import enum
2121
from enum import auto as enum_auto
2222

23-
from tvm.tir.expr import IntImm
23+
from tvm.tir.expr import IntImm, FloatImm
2424

2525

2626
class GeneratorTarget(enum.Enum):
@@ -147,6 +147,8 @@ def substitute_template(template, values):
147147
for key, value in values.items():
148148
if isinstance(value, (int, IntImm)):
149149
value = str(int(value))
150+
if isinstance(value, (float, FloatImm)):
151+
value = str(float(value))
150152
elif isinstance(value, bool):
151153
value = str(value).lower()
152154
regex = "\\$\\{%s\\}" % key

python/tvm/relax/op/nn/nn.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from typing import List, Optional, Tuple, Union
1919

2020
from tvm import DataType
21+
from tvm.tir import FloatImm
2122

2223
from . import _ffi_api
2324
from ...expr import Expr
@@ -913,7 +914,13 @@ def cross_entropy_with_logits(predictions: Expr, labels: Expr) -> Expr:
913914
return _ffi_api.cross_entropy_with_logits(predictions, labels) # type: ignore
914915

915916

916-
def attention(query: Expr, key: Expr, value: Expr, bias: Optional[Expr] = None) -> Expr:
917+
def attention(
918+
query: Expr,
919+
key: Expr,
920+
value: Expr,
921+
bias: Optional[Expr] = None,
922+
scale: Optional[FloatImm] = None,
923+
) -> Expr:
917924
r"""Computes fused multi head attention.
918925
919926
All input tensors are of 4-D tensors with BSNH layout.
@@ -943,10 +950,13 @@ def attention(query: Expr, key: Expr, value: Expr, bias: Optional[Expr] = None)
943950
(batch_size, num_head, seq_len, seq_len_kv),
944951
(batch_size, seq_len, seq_len_kv) or (batch_size, seq_len_kv).
945952
953+
scale: Optional[FloatImm]
954+
The custom scale applied before the softmax. The default value is 1 / sqrt(head_dim).
955+
946956
Returns
947957
-------
948958
result : relax.Expr
949959
The computed result. The layout of the output should be
950960
(batch_size, seq_len, num_head, head_dim_v).
951961
"""
952-
return _ffi_api.attention(query, key, value, bias) # type: ignore
962+
return _ffi_api.attention(query, key, value, bias, scale) # type: ignore

python/tvm/relax/transform/legalize_ops/nn.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,3 +312,59 @@ def _nn_group_norm(bb: BlockBuilder, call: Call) -> Expr:
312312
def _nn_dropout(bb: BlockBuilder, call: Call) -> Expr:
313313
logging.info("Dropout is handled by frontend translator at this moment and is not legalized.")
314314
return call
315+
316+
317+
def _te_attention(
318+
q: te.Tensor, k: te.Tensor, v: te.Tensor, bias: te.Tensor, scale: tir.FloatImm
319+
) -> te.Tensor:
320+
batch_size, seq_len, num_head, head_dim = q.shape
321+
_, seq_len_kv, _, head_dim_v = v.shape
322+
q = topi.transpose(q, [0, 2, 1, 3])
323+
k = topi.transpose(k, [0, 2, 1, 3])
324+
v = topi.transpose(v, [0, 2, 1, 3])
325+
q = topi.reshape(q, [batch_size * num_head, seq_len, head_dim])
326+
k = topi.reshape(k, [batch_size * num_head, seq_len_kv, head_dim])
327+
v = topi.reshape(v, [batch_size * num_head, seq_len_kv, head_dim_v])
328+
p = topi.nn.batch_matmul(q, k)
329+
if scale is not None:
330+
p = topi.multiply(p, scale)
331+
else:
332+
p = topi.divide(p, tir.sqrt(tir.Cast(p.dtype, head_dim)))
333+
if bias is not None:
334+
p = topi.reshape(p, [batch_size, num_head, seq_len, seq_len_kv])
335+
if len(bias.shape) == 2:
336+
bias = topi.reshape(bias, [batch_size, 1, 1, seq_len_kv])
337+
elif len(bias.shape) == 3:
338+
bias = topi.reshape(bias, [batch_size, 1, seq_len, seq_len_kv])
339+
p = topi.add(p, bias)
340+
p = topi.reshape(p, [batch_size * num_head, seq_len, seq_len_kv])
341+
s = topi.nn.softmax(p)
342+
o = topi.nn.batch_matmul(s, v, transpose_b=False)
343+
o = topi.reshape(o, [batch_size, num_head, seq_len, head_dim_v])
344+
return topi.transpose(o, [0, 2, 1, 3])
345+
346+
347+
@register_legalize("relax.nn.attention")
348+
def _nn_attention(bb: BlockBuilder, call: Call) -> Expr:
349+
return bb.call_te(
350+
_te_attention,
351+
call.args[0],
352+
call.args[1],
353+
call.args[2],
354+
None,
355+
call.attrs.scale,
356+
primfunc_name_hint="attention",
357+
)
358+
359+
360+
@register_legalize("relax.nn.attention_bias")
361+
def _nn_attention_bias(bb: BlockBuilder, call: Call) -> Expr:
362+
return bb.call_te(
363+
_te_attention,
364+
call.args[0],
365+
call.args[1],
366+
call.args[2],
367+
call.args[3],
368+
call.attrs.scale,
369+
primfunc_name_hint="attention_bias",
370+
)

src/relax/op/nn/attention.cc

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,18 @@ namespace tvm {
2626
namespace relax {
2727

2828
/* relax.nn.attention */
29-
Expr attention(Expr query, Expr key, Expr value, Optional<Expr> bias) {
29+
TVM_REGISTER_NODE_TYPE(AttentionAttrs);
30+
31+
Expr attention(Expr query, Expr key, Expr value, Optional<Expr> bias, Optional<FloatImm> scale) {
32+
ObjectPtr<AttentionAttrs> attrs = make_object<AttentionAttrs>();
33+
attrs->scale = scale;
3034
if (bias.defined()) {
3135
return Call(Op::Get("relax.nn.attention_bias"),
32-
{std::move(query), std::move(key), std::move(value), std::move(bias.value())}, {},
33-
{});
36+
{std::move(query), std::move(key), std::move(value), std::move(bias.value())},
37+
Attrs(attrs), {});
3438
}
3539
return Call(Op::Get("relax.nn.attention"), {std::move(query), std::move(key), std::move(value)},
36-
{}, {});
40+
Attrs(attrs), {});
3741
}
3842

3943
TVM_REGISTER_GLOBAL("relax.op.nn.attention").set_body_typed(attention);
@@ -105,13 +109,15 @@ StructInfo InferStructInfoAttention(const Call& call, const BlockBuilder& ctx) {
105109
}
106110

107111
TVM_REGISTER_OP("relax.nn.attention")
112+
.set_attrs_type<AttentionAttrs>()
108113
.set_num_inputs(3)
109114
.add_argument("query", "Tensor", "The input queries tensor.")
110115
.add_argument("key", "Tensor", "The input keys tensor.")
111116
.add_argument("value", "Tensor", "The input values tensor.")
112117
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoAttention);
113118

114119
TVM_REGISTER_OP("relax.nn.attention_bias")
120+
.set_attrs_type<AttentionAttrs>()
115121
.set_num_inputs(4)
116122
.add_argument("query", "Tensor", "The input queries tensor.")
117123
.add_argument("key", "Tensor", "The input keys tensor.")

src/relax/op/nn/attention.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ namespace tvm {
3333
namespace relax {
3434

3535
/*! \brief fused multi head attention */
36-
Expr attention(Expr query, Expr key, Expr value, Optional<Expr> bias);
36+
Expr attention(Expr query, Expr key, Expr value, Optional<Expr> bias, Optional<FloatImm> scale);
3737

3838
} // namespace relax
3939
} // namespace tvm

0 commit comments

Comments
 (0)