Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions include/tvm/relax/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,16 @@ struct DropoutAttrs : public tvm::AttrsNode<DropoutAttrs> {
}
}; // struct DropoutAttrs

/*! \brief Attributes used in dropout operator */
struct AttentionAttrs : public tvm::AttrsNode<AttentionAttrs> {
Optional<FloatImm> scale;

TVM_DECLARE_ATTRS(AttentionAttrs, "relax.attrs.AttentionAttrs") {
TVM_ATTR_FIELD(scale).describe(
"The custom scale applied before the softmax. The default value is 1 / sqrt(head_dim).");
}
}; // struct AttentionAttrs

} // namespace relax
} // namespace tvm

Expand Down
14 changes: 7 additions & 7 deletions python/tvm/contrib/cutlass/attention_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def instantiate_attention_template(attrs, func_args):
p.head_dim_value = ${head_dim_value}; // H'
p.num_queries = ${num_queries}; // S
p.num_keys = ${num_keys}; // S'
p.scale = 1.0f / sqrt(float(${head_dim}));
p.scale = ${scale};

// stride for N
p.q_strideH = p.head_dim; // H
Expand Down Expand Up @@ -123,12 +123,12 @@ def instantiate_attention_template(attrs, func_args):
CHECK(Attention::check_supported(p));
kernel_fn<<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes>>>(p);
"""
if attrs["kSupportsBias"]:
template = substitute_template(
template, {"bias_template": bias_template[attrs["bias_layout"]]}
)
else:
template = substitute_template(template, {"bias_template": ""})

template = substitute_template(
template,
{"bias_template": bias_template[attrs["bias_layout"]] if "bias_layout" in attrs else ""},
)

for i, arg in enumerate(func_args):
attrs["arg{}".format(i)] = arg
return substitute_template(template, attrs)
9 changes: 8 additions & 1 deletion python/tvm/contrib/cutlass/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,7 +785,12 @@ def handle_matmul(self, f, op_type):
def handle_attention(self, f, op_type):
"""Tune and annotate a dense op."""
signature = _extract_relax_function_signature(f)

if _get_call_node(f.body, "relax.nn.attention") is not None:
op_attrs = _get_call_node(f.body, "relax.nn.attention").attrs
elif _get_call_node(f.body, "relax.nn.attention_bias") is not None:
op_attrs = _get_call_node(f.body, "relax.nn.attention_bias").attrs
else:
raise ValueError(f"Cannot find call node for attention")
q_shape = signature["arg0_shape"]
k_shape = signature["arg1_shape"]
v_shape = signature["arg2_shape"]
Expand All @@ -797,6 +802,7 @@ def handle_attention(self, f, op_type):
num_batches, num_queries, num_heads, head_dim = q_shape
_, num_keys, _, _ = k_shape
_, _, _, head_dim_value = v_shape
scale = op_attrs.scale
bias = {}
if "arg3_dtype" in signature:
bias["arg3_dtype"] = signature["arg3_dtype"]
Expand All @@ -820,6 +826,7 @@ def handle_attention(self, f, op_type):
"num_heads": num_heads,
"head_dim": head_dim,
"head_dim_value": head_dim_value,
"scale": scale,
"arch": self.options["sm"],
**bias,
}
Expand Down
11 changes: 10 additions & 1 deletion python/tvm/contrib/cutlass/gen_tensor_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# pylint: disable=invalid-name
"""Common functions and classes for CUTLASS GEMM and Conv2d geneator."""
import logging
import math
import multiprocessing
import os
import re
Expand Down Expand Up @@ -722,6 +723,12 @@ def get_batch_on_arg(arg_name, arg_shape):
attrs["kKeysPerBlock"] = 64
attrs["kSingleValueIteration"] = True
attrs["output_size"] = b * s * n * h_v
attrs["scale"] = (
float(1 / math.sqrt(h.value)) if annotations["scale"] is None else annotations["scale"]
)
assert (
attrs["scale"] > 0 or attrs["scale"] < 0
), "Cutlass may generate nan occasionally when scale == 0.0"
attrs["arch"] = "cutlass::arch::Sm{}".format(annotations["arch"])
attrs["kSupportsDropout"] = False
if len(func_args) > 3:
Expand All @@ -735,7 +742,9 @@ def get_batch_on_arg(arg_name, arg_shape):
else:
raise NotImplementedError()
else:
attrs["kSupportsBias"] = False
# To support negative scale in current Cutlass implementation,
# kSupportsBias should be set true, or there are nan's as result.
attrs["kSupportsBias"] = attrs["scale"] < 0
code = instantiate_attention_template(attrs, func_args)
return CodegenResult(code, headers)

Expand Down
4 changes: 3 additions & 1 deletion python/tvm/contrib/cutlass/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import enum
from enum import auto as enum_auto

from tvm.tir.expr import IntImm
from tvm.tir.expr import IntImm, FloatImm


class GeneratorTarget(enum.Enum):
Expand Down Expand Up @@ -147,6 +147,8 @@ def substitute_template(template, values):
for key, value in values.items():
if isinstance(value, (int, IntImm)):
value = str(int(value))
if isinstance(value, (float, FloatImm)):
value = str(float(value))
elif isinstance(value, bool):
value = str(value).lower()
regex = "\\$\\{%s\\}" % key
Expand Down
14 changes: 12 additions & 2 deletions python/tvm/relax/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import List, Optional, Tuple, Union

from tvm import DataType
from tvm.tir import FloatImm

from . import _ffi_api
from ...expr import Expr
Expand Down Expand Up @@ -913,7 +914,13 @@ def cross_entropy_with_logits(predictions: Expr, labels: Expr) -> Expr:
return _ffi_api.cross_entropy_with_logits(predictions, labels) # type: ignore


def attention(query: Expr, key: Expr, value: Expr, bias: Optional[Expr] = None) -> Expr:
def attention(
query: Expr,
key: Expr,
value: Expr,
bias: Optional[Expr] = None,
scale: Optional[FloatImm] = None,
) -> Expr:
r"""Computes fused multi head attention.

All input tensors are of 4-D tensors with BSNH layout.
Expand Down Expand Up @@ -943,10 +950,13 @@ def attention(query: Expr, key: Expr, value: Expr, bias: Optional[Expr] = None)
(batch_size, num_head, seq_len, seq_len_kv),
(batch_size, seq_len, seq_len_kv) or (batch_size, seq_len_kv).

scale: Optional[FloatImm]
The custom scale applied before the softmax. The default value is 1 / sqrt(head_dim).

Returns
-------
result : relax.Expr
The computed result. The layout of the output should be
(batch_size, seq_len, num_head, head_dim_v).
"""
return _ffi_api.attention(query, key, value, bias) # type: ignore
return _ffi_api.attention(query, key, value, bias, scale) # type: ignore
56 changes: 56 additions & 0 deletions python/tvm/relax/transform/legalize_ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,3 +312,59 @@ def _nn_group_norm(bb: BlockBuilder, call: Call) -> Expr:
def _nn_dropout(bb: BlockBuilder, call: Call) -> Expr:
logging.info("Dropout is handled by frontend translator at this moment and is not legalized.")
return call


def _te_attention(
q: te.Tensor, k: te.Tensor, v: te.Tensor, bias: te.Tensor, scale: tir.FloatImm
) -> te.Tensor:
batch_size, seq_len, num_head, head_dim = q.shape
_, seq_len_kv, _, head_dim_v = v.shape
q = topi.transpose(q, [0, 2, 1, 3])
k = topi.transpose(k, [0, 2, 1, 3])
v = topi.transpose(v, [0, 2, 1, 3])
q = topi.reshape(q, [batch_size * num_head, seq_len, head_dim])
k = topi.reshape(k, [batch_size * num_head, seq_len_kv, head_dim])
v = topi.reshape(v, [batch_size * num_head, seq_len_kv, head_dim_v])
p = topi.nn.batch_matmul(q, k)
if scale is not None:
p = topi.multiply(p, scale)
else:
p = topi.divide(p, tir.sqrt(tir.Cast(p.dtype, head_dim)))
if bias is not None:
p = topi.reshape(p, [batch_size, num_head, seq_len, seq_len_kv])
if len(bias.shape) == 2:
bias = topi.reshape(bias, [batch_size, 1, 1, seq_len_kv])
elif len(bias.shape) == 3:
bias = topi.reshape(bias, [batch_size, 1, seq_len, seq_len_kv])
p = topi.add(p, bias)
p = topi.reshape(p, [batch_size * num_head, seq_len, seq_len_kv])
s = topi.nn.softmax(p)
o = topi.nn.batch_matmul(s, v, transpose_b=False)
o = topi.reshape(o, [batch_size, num_head, seq_len, head_dim_v])
return topi.transpose(o, [0, 2, 1, 3])


@register_legalize("relax.nn.attention")
def _nn_attention(bb: BlockBuilder, call: Call) -> Expr:
return bb.call_te(
_te_attention,
call.args[0],
call.args[1],
call.args[2],
None,
call.attrs.scale,
primfunc_name_hint="attention",
)


@register_legalize("relax.nn.attention_bias")
def _nn_attention_bias(bb: BlockBuilder, call: Call) -> Expr:
return bb.call_te(
_te_attention,
call.args[0],
call.args[1],
call.args[2],
call.args[3],
call.attrs.scale,
primfunc_name_hint="attention_bias",
)
14 changes: 10 additions & 4 deletions src/relax/op/nn/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,18 @@ namespace tvm {
namespace relax {

/* relax.nn.attention */
Expr attention(Expr query, Expr key, Expr value, Optional<Expr> bias) {
TVM_REGISTER_NODE_TYPE(AttentionAttrs);

Expr attention(Expr query, Expr key, Expr value, Optional<Expr> bias, Optional<FloatImm> scale) {
ObjectPtr<AttentionAttrs> attrs = make_object<AttentionAttrs>();
attrs->scale = scale;
if (bias.defined()) {
return Call(Op::Get("relax.nn.attention_bias"),
{std::move(query), std::move(key), std::move(value), std::move(bias.value())}, {},
{});
{std::move(query), std::move(key), std::move(value), std::move(bias.value())},
Attrs(attrs), {});
}
return Call(Op::Get("relax.nn.attention"), {std::move(query), std::move(key), std::move(value)},
{}, {});
Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relax.op.nn.attention").set_body_typed(attention);
Expand Down Expand Up @@ -105,13 +109,15 @@ StructInfo InferStructInfoAttention(const Call& call, const BlockBuilder& ctx) {
}

TVM_REGISTER_OP("relax.nn.attention")
.set_attrs_type<AttentionAttrs>()
.set_num_inputs(3)
.add_argument("query", "Tensor", "The input queries tensor.")
.add_argument("key", "Tensor", "The input keys tensor.")
.add_argument("value", "Tensor", "The input values tensor.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoAttention);

TVM_REGISTER_OP("relax.nn.attention_bias")
.set_attrs_type<AttentionAttrs>()
.set_num_inputs(4)
.add_argument("query", "Tensor", "The input queries tensor.")
.add_argument("key", "Tensor", "The input keys tensor.")
Expand Down
2 changes: 1 addition & 1 deletion src/relax/op/nn/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ namespace tvm {
namespace relax {

/*! \brief fused multi head attention */
Expr attention(Expr query, Expr key, Expr value, Optional<Expr> bias);
Expr attention(Expr query, Expr key, Expr value, Optional<Expr> bias, Optional<FloatImm> scale);

} // namespace relax
} // namespace tvm
Expand Down
Loading