Skip to content

Commit 850abb0

Browse files
authored
[TOPI] Add transpose_a/b & dynamic shape support for batch matmul (#8527)
* Add basic support for batch matmul transpose * Update * Lint fix & add tf convert support * Update Lint fix * Bug fix for qnn.batch_matmul * Bug fix for tensorflow test * Add grad support for batch_matmul * Lint fix Re-triggle CI Bug fix Re-triggle CI Re-triggle CI Re-triggle CI
1 parent cb395ff commit 850abb0

File tree

24 files changed

+673
-277
lines changed

24 files changed

+673
-277
lines changed

include/tvm/relay/attrs/nn.h

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1003,16 +1003,26 @@ struct DenseAttrs : public tvm::AttrsNode<DenseAttrs> {
10031003
}
10041004
};
10051005

1006-
/*! \brief Attributes for batch matmul operator */
1006+
/*! \brief Attributes for batch matmul operator. */
10071007
struct BatchMatmulAttrs : public tvm::AttrsNode<BatchMatmulAttrs> {
1008-
tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite
10091008
DataType out_dtype;
1009+
bool transpose_a;
1010+
bool transpose_b;
1011+
tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite
10101012

10111013
TVM_DECLARE_ATTRS(BatchMatmulAttrs, "relay.attrs.BatchMatmulAttrs") {
10121014
// use 0 bits to indicate none.
10131015
TVM_ATTR_FIELD(out_dtype)
10141016
.set_default(NullValue<DataType>())
10151017
.describe("Output data type, set to explicit type under mixed precision setting");
1018+
1019+
TVM_ATTR_FIELD(transpose_a)
1020+
.set_default(false)
1021+
.describe("Whether the first input tensor is in transposed format.");
1022+
1023+
TVM_ATTR_FIELD(transpose_b)
1024+
.set_default(false)
1025+
.describe("Whether the second input tensor is in transposed format.");
10161026
}
10171027
};
10181028

python/tvm/relay/frontend/tensorflow.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,11 @@
5252
# However, please note that `nn.matmul` is in experimental so it may have some performance
5353
# issues.
5454
"use_dense": True,
55+
# By default, TVM converts `tf.batch_matmul` to `transpose(weight) + nn.batch_matmul_NT`.
56+
# Change this flag to False to directly convert to `nn.batch_matmul`.
57+
# Note that `nn.batch_matmul` with format other than NT is in experimental, it may have some
58+
# performance issues.
59+
"use_nt_batch_matmul": True,
5560
}
5661

5762
# compatible operators that do NOT require any conversion.
@@ -1214,7 +1219,7 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
12141219
return func, self._params
12151220

12161221

1217-
def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None, use_dense_op=True):
1222+
def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None, convert_config=None):
12181223
"""Load tensorflow graph which is a python tensorflow graph object into relay.
12191224
The companion parameters will be handled automatically.
12201225
@@ -1232,10 +1237,15 @@ def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None, use_dense_op
12321237
outputs : List of output tensor names (Optional)
12331238
if not specified then the last node is assumed as graph output.
12341239
1235-
use_dense_op : bool (Optional) = True
1236-
Ture to convert `tf.matmul` to `nn.dense`, else to `nn.matmul`.
1237-
The `nn.dense` op requires the data tensor to be non-transposed and weight tensor to be
1238-
transposed, may insert extra `transpose` to the original graph.
1240+
convert_config : Optional[Dict[str, Any]]
1241+
Default config:
1242+
use_dense : bool = True
1243+
Ture to convert `tf.matmul` to `nn.dense`, else to `nn.matmul`.
1244+
The `nn.dense` op requires the data tensor to be non-transposed and weight tensor
1245+
to be transposed, may insert extra `transpose` to the original graph.
1246+
use_nt_batch_matmul : bool = True
1247+
True to convert `tf.batch_matmul` to `nn.batch_matmul` strict to NT format
1248+
(transpose_a=False, transpose_b=True).
12391249
12401250
Returns
12411251
-------
@@ -1246,7 +1256,8 @@ def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None, use_dense_op
12461256
Dict of converted parameters stored in tvm.nd.NDArray format
12471257
"""
12481258
global TF_DEFAULT_CONFIGS
1249-
TF_DEFAULT_CONFIGS["use_dense"] = use_dense_op
1259+
if convert_config is not None:
1260+
TF_DEFAULT_CONFIGS.update(convert_config)
12501261

12511262
g = GraphProto()
12521263
mod, params = g.from_tensorflow(graph, layout, shape, outputs)

python/tvm/relay/frontend/tensorflow_ops.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1149,6 +1149,8 @@ def _impl(inputs, attr, params, mod):
11491149

11501150
def _batch_matmul():
11511151
def _impl(inputs, attr, params, mod):
1152+
from .tensorflow import TF_DEFAULT_CONFIGS
1153+
11521154
input_x = inputs[0]
11531155
input_y = inputs[1]
11541156
orig_shape_x = _infer_shape(input_x, mod)
@@ -1185,9 +1187,16 @@ def _impl(inputs, attr, params, mod):
11851187
input_y = _op.reshape(input_y, (1, orig_shape_y[-2], orig_shape_y[-1]))
11861188
adj_x = attr["adj_x"]
11871189
adj_y = attr["adj_y"]
1188-
input_x = _op.transpose(input_x, axes=[0, 2, 1]) if adj_x else input_x
1189-
input_y = _op.transpose(input_y, axes=[0, 2, 1]) if not adj_y else input_y
1190-
ret = get_relay_op("batch_matmul")(input_x, input_y)
1190+
1191+
if TF_DEFAULT_CONFIGS["use_nt_batch_matmul"]:
1192+
# Strictly convert all batch_matmul to NT format
1193+
input_x = _op.transpose(input_x, axes=[0, 2, 1]) if adj_x else input_x
1194+
input_y = _op.transpose(input_y, axes=[0, 2, 1]) if not adj_y else input_y
1195+
ret = get_relay_op("batch_matmul")(input_x, input_y)
1196+
else:
1197+
ret = get_relay_op("batch_matmul")(
1198+
input_x, input_y, transpose_a=adj_x, transpose_b=adj_y
1199+
)
11911200

11921201
# reshape result back to n-dimensional
11931202
if ndim > 3:

python/tvm/relay/op/_tensor_grad.py

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -590,11 +590,59 @@ def batch_matmul_grad(orig, grad):
590590
GRAD_OUT_bij,LHS_bik->GRAD_IN_RHS_bjk
591591
"""
592592
lhs, rhs = orig.args
593+
if (orig.attrs["transpose_a"], orig.attrs["transpose_b"]) == (True, True):
594+
# ki, jk -> ij
595+
# jk, ij -> ki
596+
# ij, ki -> jk
597+
return [
598+
collapse_sum_like(_nn.batch_matmul(rhs, grad, transpose_a=True, transpose_b=True), lhs),
599+
collapse_sum_like(_nn.batch_matmul(grad, lhs, transpose_a=True, transpose_b=True), rhs),
600+
]
601+
if (orig.attrs["transpose_a"], orig.attrs["transpose_b"]) == (True, False):
602+
# ki, kj -> ij
603+
# kj, ij -> ki
604+
# ki, ij -> kj
605+
return [
606+
collapse_sum_like(
607+
_nn.batch_matmul(rhs, grad, transpose_a=False, transpose_b=True), lhs
608+
),
609+
collapse_sum_like(
610+
_nn.batch_matmul(lhs, grad, transpose_a=False, transpose_b=False), rhs
611+
),
612+
]
613+
if (orig.attrs["transpose_a"], orig.attrs["transpose_b"]) == (False, True):
614+
# ik, jk -> ij
615+
# ij, jk -> ik
616+
# ij, ik -> jk
617+
# Keep using NT format batch_matmul here for not involving extra ops
618+
# TODO(jcf94): Merge all to normal batch_matmul when it is finally ready
619+
return [
620+
collapse_sum_like(
621+
_nn.batch_matmul(
622+
grad,
623+
transpose(rhs, [0, 2, 1]),
624+
transpose_a=False,
625+
transpose_b=True,
626+
),
627+
lhs,
628+
),
629+
collapse_sum_like(
630+
_nn.batch_matmul(
631+
transpose(grad, [0, 2, 1]),
632+
transpose(lhs, [0, 2, 1]),
633+
transpose_a=False,
634+
transpose_b=True,
635+
),
636+
rhs,
637+
),
638+
]
639+
# (orig.attrs["transpose_a"], orig.attrs["transpose_b"]) == (False, False)
640+
# ik, kj -> ij
641+
# ij, kj -> ik
642+
# ik, ij -> kj
593643
return [
594-
collapse_sum_like(_nn.batch_matmul(grad, transpose(rhs, [0, 2, 1])), lhs),
595-
collapse_sum_like(
596-
_nn.batch_matmul(transpose(grad, [0, 2, 1]), transpose(lhs, [0, 2, 1])), rhs
597-
),
644+
collapse_sum_like(_nn.batch_matmul(grad, rhs, transpose_a=False, transpose_b=True), lhs),
645+
collapse_sum_like(_nn.batch_matmul(lhs, grad, transpose_a=True, transpose_b=False), rhs),
598646
]
599647

600648

python/tvm/relay/op/nn/_nn.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1276,24 +1276,28 @@ def dense_pack_shape_func(attrs, inputs, _):
12761276

12771277

12781278
@script
1279-
def _batch_matmul_shape_func(data_shape, weight_shape):
1280-
out = output_tensor((data_shape.shape[0],), "int64")
1281-
for i in const_range(out.shape[0] - 1):
1282-
if i == 0:
1283-
out[i] = max(data_shape[i], weight_shape[i])
1284-
else:
1285-
out[i] = data_shape[i]
1286-
out[out.shape[0] - 1] = weight_shape[weight_shape.shape[0] - 2]
1279+
def _batch_matmul_shape_func(tensor_a_shape, tensor_b_shape, transpose_a, transpose_b):
1280+
out = output_tensor((tensor_a_shape.shape[0],), "int64")
1281+
out[0] = max(tensor_a_shape[0], tensor_b_shape[0])
1282+
out[1] = tensor_a_shape[2] if transpose_a else tensor_a_shape[1]
1283+
out[2] = tensor_b_shape[1] if transpose_b else tensor_b_shape[2]
12871284

12881285
return out
12891286

12901287

12911288
@reg.register_shape_func("nn.batch_matmul", False)
12921289
def batch_matmul_shape_func(attrs, inputs, _):
12931290
"""
1294-
Shape function for dense op.
1291+
Shape function for batch matmul op.
12951292
"""
1296-
ret = [_batch_matmul_shape_func(inputs[0], inputs[1])]
1293+
ret = [
1294+
_batch_matmul_shape_func(
1295+
inputs[0],
1296+
inputs[1],
1297+
expr.IntImm("bool", attrs.transpose_a),
1298+
expr.IntImm("bool", attrs.transpose_b),
1299+
)
1300+
]
12971301
return ret
12981302

12991303

python/tvm/relay/op/nn/nn.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2137,32 +2137,40 @@ def group_norm(data, gamma, beta, num_groups, axis=1, epsilon=1e-5, center=True,
21372137
return _make.group_norm(data, gamma, beta, num_groups, axis, epsilon, center, scale)
21382138

21392139

2140-
def batch_matmul(x, y, out_dtype=""):
2140+
def batch_matmul(tensor_a, tensor_b, out_dtype="", transpose_a=False, transpose_b=True):
21412141
r"""
2142-
Computes batch matrix multiplication of `x` and `y` when `x` and `y` are data
2143-
in batch.
2142+
Compute batch matrix multiplication of `tensor_a` and `tensor_b`.
2143+
2144+
Both `tensor_a` and `tensor_b` can be transposed. For legacy reason, we use NT format
2145+
(transpose_a=False, transpose_b=True) by default.
21442146
21452147
.. math::
21462148
2147-
\mbox{batch_matmul}(x, y)[i, :, :] = \mbox{matmul}(x[i, :, :], y[i, :, :]^T)
2149+
\mbox{batch_matmul}(A, B)[i, :, :] = \mbox{matmul}(A[i, :, :], B[i, :, :])
21482150
21492151
Parameters
21502152
----------
2151-
x : tvm.relay.Expr
2153+
tensor_a : tvm.relay.Expr
21522154
The first input.
21532155
2154-
y : tvm.relay.Expr
2156+
tensor_b : tvm.relay.Expr
21552157
The second input.
21562158
2157-
out_dtype : str, optional
2158-
Specifies the output data type for mixed precision batch matmul
2159+
out_dtype : Optional[str]
2160+
Specifies the output data type for mixed precision batch matmul.
2161+
2162+
transpose_a : Optional[bool] = False
2163+
Whether the first tensor is in transposed format.
2164+
2165+
transpose_b : Optional[bool] = True
2166+
Whether the second tensor is in transposed format.
21592167
21602168
Returns
21612169
-------
21622170
result: tvm.relay.Expr
21632171
The computed result.
21642172
"""
2165-
return _make.batch_matmul(x, y, out_dtype)
2173+
return _make.batch_matmul(tensor_a, tensor_b, out_dtype, transpose_a, transpose_b)
21662174

21672175

21682176
# pylint: disable=no-else-return,inconsistent-return-statements

python/tvm/relay/op/op_attrs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,11 @@ class DenseAttrs(Attrs):
7474
"""Attributes for nn.dense"""
7575

7676

77+
@tvm._ffi.register_object("relay.attrs.BatchMatmulAttrs")
78+
class BatchMatmulAttrs(Attrs):
79+
"""Attributes for nn.batch_matmul"""
80+
81+
7782
@tvm._ffi.register_object("relay.attrs.SoftmaxAttrs")
7883
class SoftmaxAttrs(Attrs):
7984
"""Attributes for nn.softmax"""

python/tvm/relay/op/strategy/cuda.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -819,7 +819,13 @@ def batch_matmul_strategy_cuda(attrs, inputs, out_type, target):
819819
"""batch_matmul cuda strategy"""
820820
strategy = _op.OpStrategy()
821821
x, y = inputs
822-
if x.dtype == "int8" and y.dtype == "int8" and out_type.dtype == "int32":
822+
if (
823+
x.dtype == "int8"
824+
and y.dtype == "int8"
825+
and out_type.dtype == "int32"
826+
and not attrs["transpose_a"]
827+
and attrs["transpose_b"]
828+
):
823829
strategy.add_implementation(
824830
wrap_compute_batch_matmul(topi.cuda.batch_matmul_int8, need_out_dtype=True),
825831
wrap_topi_schedule(topi.cuda.schedule_batch_matmul_int8),
@@ -840,7 +846,12 @@ def batch_matmul_strategy_cuda(attrs, inputs, out_type, target):
840846
name="batch_matmul_cublas.cuda",
841847
plevel=15,
842848
)
843-
if target.kind.name == "cuda" and nvcc.have_tensorcore(target=target):
849+
if (
850+
target.kind.name == "cuda"
851+
and nvcc.have_tensorcore(target=target)
852+
and not attrs["transpose_a"]
853+
and attrs["transpose_b"]
854+
):
844855
x, y = inputs
845856
_, M, K = get_const_tuple(x.shape)
846857
_, N, K = get_const_tuple(y.shape)

python/tvm/relay/op/strategy/generic.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -799,10 +799,11 @@ def wrap_compute_batch_matmul(topi_compute, need_auto_scheduler_layout=False, ne
799799

800800
def _compute_batch_matmul(attrs, inputs, out_type):
801801
args = [inputs[0], inputs[1], out_type.shape]
802+
args.append(out_type.dtype if need_out_dtype else None)
803+
args.append(attrs.transpose_a)
804+
args.append(attrs.transpose_b)
802805
if need_auto_scheduler_layout:
803806
args.append(get_auto_scheduler_rewritten_layout(attrs))
804-
if need_out_dtype:
805-
args.append(out_type.dtype)
806807
return [topi_compute(*args)]
807808

808809
return _compute_batch_matmul

0 commit comments

Comments
 (0)