Skip to content

Commit 65ae3c9

Browse files
committed
[Unity][Transform] Implement relax.transform.ReorderTakeAfterMatmul
If `R.matmul(x, R.take(weights, indices))` occurs, with `R.take` selecting along the output feature dimension, it can be rearranged to `R.take(R.matmul(x, weights), indices)`.
1 parent 030ca4a commit 65ae3c9

File tree

4 files changed

+366
-0
lines changed

4 files changed

+366
-0
lines changed

python/tvm/relax/transform/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
RemovePurityChecking,
6464
RemoveUnusedParameters,
6565
RemoveUnusedOutputs,
66+
ReorderTakeAfterMatmul,
6667
RewriteCUDAGraph,
6768
RewriteDataflowReshape,
6869
RunCodegen,

python/tvm/relax/transform/transform.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1302,6 +1302,21 @@ def ExpandMatmulOfSum():
13021302
return _ffi_api.ExpandMatmulOfSum() # type: ignore
13031303

13041304

1305+
def ReorderTakeAfterMatmul():
1306+
"""Reorder `matmul(x, take(weights, indices))` to `take(matmul(x,weights),indices)`
1307+
1308+
Useful for optimizing LoRA computations, where several LoRAs may
1309+
be batched together.
1310+
1311+
Returns
1312+
-------
1313+
ret : tvm.transform.Pass
1314+
The corresponding pass.
1315+
"""
1316+
1317+
return _ffi_api.ReorderTakeAfterMatmul() # type: ignore
1318+
1319+
13051320
def CombineParallelMatmul(check=None):
13061321
"""Combine multiple matmul operators sharing the same LHS matrix into one,
13071322
followed by slicing. When all matmul branches in a tree have the same set of fused ops,
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file tvm/relax/transform/expand_matmul_of_sum.cc
22+
* \brief Expand `matmul(x, A+B)` to `matmul(x, A) + matmul(x,B)`
23+
*/
24+
25+
#include <tvm/relax/analysis.h>
26+
#include <tvm/relax/dataflow_matcher.h>
27+
#include <tvm/relax/expr.h>
28+
#include <tvm/relax/expr_functor.h>
29+
#include <tvm/relax/transform.h>
30+
31+
#include <optional>
32+
#include <unordered_set>
33+
#include <vector>
34+
35+
#include "../op/tensor/index.h"
36+
#include "../op/tensor/linear_algebra.h"
37+
#include "../op/tensor/manipulate.h"
38+
39+
namespace tvm {
40+
namespace relax {
41+
42+
namespace {
43+
std::tuple<DFPattern, TypedPackedFunc<Expr(Expr, Map<DFPattern, Expr>)>> CreatePatterns() {
44+
auto pat_lhs = WildcardPattern();
45+
46+
auto pat_weights = WildcardPattern();
47+
auto pat_indices = WildcardPattern();
48+
auto pat_rhs = IsOp("relax.take")(pat_weights, pat_indices);
49+
50+
auto pat_matmul = IsOp("relax.matmul")(pat_lhs, pat_rhs);
51+
52+
auto rewriter = [=](Expr expr, Map<DFPattern, Expr> matches) -> Expr {
53+
auto lhs = matches[pat_lhs];
54+
auto weights = matches[pat_weights];
55+
auto indices = matches[pat_indices];
56+
57+
const auto* take_call = matches[pat_rhs].as<CallNode>();
58+
ICHECK(take_call) << "InternalError: "
59+
<< "Match of relax.take operator should produce Call, "
60+
<< "but instead produces " << matches[pat_rhs] << " with type "
61+
<< matches[pat_rhs]->GetTypeKey();
62+
const auto* attrs = take_call->attrs.as<TakeAttrs>();
63+
ICHECK(attrs) << "InternalError: "
64+
<< "Attributes for relax.take operator should be TakeAttrs, "
65+
<< "but were instead " << take_call->attrs << " with type "
66+
<< take_call->GetTypeKey();
67+
68+
const auto* lhs_sinfo = lhs->struct_info_.as<TensorStructInfoNode>();
69+
if (!lhs_sinfo) return expr;
70+
71+
const auto* weights_sinfo = weights->struct_info_.as<TensorStructInfoNode>();
72+
if (!weights_sinfo) return expr;
73+
74+
const auto* indices_sinfo = indices->struct_info_.as<TensorStructInfoNode>();
75+
if (!indices_sinfo) return expr;
76+
77+
const auto* matmul_sinfo = expr->struct_info_.as<TensorStructInfoNode>();
78+
if (!matmul_sinfo) return expr;
79+
80+
if (!attrs->axis.defined()) return expr;
81+
auto axis = attrs->axis.value()->value;
82+
83+
if (lhs_sinfo->IsUnknownNdim() || indices_sinfo->IsUnknownNdim() ||
84+
matmul_sinfo->IsUnknownNdim() || weights_sinfo->IsUnknownNdim())
85+
return expr;
86+
87+
if (indices_sinfo->ndim == 1 && axis + 1 == weights_sinfo->ndim) {
88+
// Simpler case. The activations may have batch dimensions, but
89+
// the weights do not.
90+
91+
// lhs.shape = [*batch, infeatures]
92+
// weights.shape = [infeatures, table_size]
93+
// indices.shape = [outfeatures]
94+
95+
// out_table.shape = [*batch, table_size]
96+
auto out_table = matmul(lhs, weights, DataType::Void());
97+
// new_output.shape = [*batch, outfeatures]
98+
auto new_output = take(out_table, indices, Integer(matmul_sinfo->ndim - 1));
99+
100+
return new_output;
101+
} else if (lhs_sinfo->ndim == 3 && weights_sinfo->ndim == 3 && indices_sinfo->ndim == 1 &&
102+
axis == 0 && weights_sinfo->GetShape().defined() &&
103+
lhs_sinfo->GetShape().defined()) {
104+
// More complicated case, used for batched LoRA. The conditions
105+
// on the argument dimensions can probably be relaxed, but would
106+
// probably need to remove the use of the einsum operator.
107+
108+
auto lhs_shape = lhs_sinfo->GetShape().value();
109+
auto weight_shape = weights_sinfo->GetShape().value();
110+
111+
// lhs.shape = [batch1, batch2, infeatures]
112+
// weights.shape = [table_size, infeatures, outfeatures]
113+
// indices.shape = [batch1]
114+
115+
// reordered_weight.shape = [infeatures, table_size, outfeatures]
116+
auto reordered_weight = permute_dims(weights, Array{Integer(1), Integer(0), Integer(2)});
117+
// fused_weight.shape = [infeatures, table_size * outfeatures]
118+
auto fused_weight = reshape(reordered_weight,
119+
ShapeExpr({weight_shape[1], weight_shape[0] * weight_shape[2]}));
120+
// fused_output.shape = [batch1, batch2, table_size * outfeatures]
121+
auto fused_output = matmul(lhs, fused_weight, DataType::Void());
122+
// indexed_output.shape = [batch1, batch2, table_size, outfeatures]
123+
auto indexed_output = reshape(
124+
fused_output, ShapeExpr({lhs_shape[0], lhs_shape[1], weight_shape[0], weight_shape[2]}));
125+
126+
// TODO(Lunderberg): Find a better way to express these last two
127+
// steps. For an output at [i,j,k], the value is
128+
// `indexed_output[i, j, indices[i], k]`, but there doesn't seem
129+
// to be a good way to express that in relax. It could be
130+
// written using `call_te`, but that would prevent later
131+
// optimizations from recognizing the high-level relax
132+
// operations.
133+
134+
// duplicated_output.shape = [batch1, batch2, batch1, outfeatures]
135+
auto duplicated_output = take(indexed_output, indices, Integer(2));
136+
// new_output.shape = [batch1, batch2, outfeatures]
137+
auto new_output = einsum(Tuple({duplicated_output}), "ijik->ijk");
138+
139+
return new_output;
140+
} else {
141+
return expr;
142+
}
143+
};
144+
145+
return {pat_matmul, rewriter};
146+
}
147+
148+
} // namespace
149+
150+
namespace transform {
151+
Pass ReorderTakeAfterMatmul() {
152+
auto pass_func = [=](Function func, IRModule mod, PassContext pc) {
153+
auto [pattern, rewriter] = CreatePatterns();
154+
return RewriteCall(pattern, rewriter, func);
155+
};
156+
return CreateFunctionPass(pass_func, 1, "ReorderTakeAfterMatmul", {});
157+
}
158+
159+
TVM_REGISTER_GLOBAL("relax.transform.ReorderTakeAfterMatmul")
160+
.set_body_typed(ReorderTakeAfterMatmul);
161+
162+
} // namespace transform
163+
} // namespace relax
164+
} // namespace tvm
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
import inspect
19+
20+
import pytest
21+
22+
import tvm.testing
23+
from tvm import relax
24+
from tvm.script import ir as I, relax as R, tir as T
25+
26+
27+
class Base:
28+
def test_compare(self):
29+
transform = relax.transform.ReorderTakeAfterMatmul()
30+
31+
if inspect.isclass(self.Expected) and issubclass(self.Expected, Exception):
32+
with pytest.raises(self.Expected):
33+
transform(self.Before)
34+
else:
35+
after = transform(self.Before)
36+
tvm.ir.assert_structural_equal(self.Expected, after)
37+
38+
39+
class TestSimple(Base):
40+
@I.ir_module
41+
class Before:
42+
@R.function
43+
def main(
44+
x: R.Tensor([1, 16], "float32"),
45+
weight_table: R.Tensor([16, "weight_table_size"], "float32"),
46+
routing_table: R.Tensor([32], "int64"),
47+
) -> R.Tensor([1, 32], "float32"):
48+
weight_table_size = T.int64()
49+
with R.dataflow():
50+
weight: R.Tensor([16, 32], "float32") = R.take(weight_table, routing_table, axis=1)
51+
out: R.Tensor([1, 32], "float32") = R.matmul(x, weight)
52+
R.output(out)
53+
return out
54+
55+
@I.ir_module
56+
class Expected:
57+
@R.function
58+
def main(
59+
x: R.Tensor([1, 16], "float32"),
60+
weight_table: R.Tensor([16, "weight_table_size"], "float32"),
61+
routing_table: R.Tensor([32], "int64"),
62+
) -> R.Tensor([1, 32], "float32"):
63+
weight_table_size = T.int64()
64+
with R.dataflow():
65+
out_table: R.Tensor([1, weight_table_size], "float32") = R.matmul(x, weight_table)
66+
out: R.Tensor([1, 32], "float32") = R.take(out_table, routing_table, axis=1)
67+
R.output(out)
68+
return out
69+
70+
71+
class TestBatchedActivations(Base):
72+
@I.ir_module
73+
class Before:
74+
@R.function
75+
def main(
76+
x: R.Tensor(["batch_size", 1, 16], "float32"),
77+
weight_table: R.Tensor([16, "weight_table_size"], "float32"),
78+
routing_table: R.Tensor([32], "int64"),
79+
) -> R.Tensor(["batch_size", 1, 32], "float32"):
80+
batch_size = T.int64()
81+
weight_table_size = T.int64()
82+
with R.dataflow():
83+
weight: R.Tensor([16, 32], "float32") = R.take(weight_table, routing_table, axis=1)
84+
out: R.Tensor([batch_size, 1, 32], "float32") = R.matmul(x, weight)
85+
R.output(out)
86+
return out
87+
88+
@I.ir_module
89+
class Expected:
90+
@R.function
91+
def main(
92+
x: R.Tensor(["batch_size", 1, 16], "float32"),
93+
weight_table: R.Tensor([16, "weight_table_size"], "float32"),
94+
routing_table: R.Tensor([32], "int64"),
95+
) -> R.Tensor(["batch_size", 1, 32], "float32"):
96+
batch_size = T.int64()
97+
weight_table_size = T.int64()
98+
with R.dataflow():
99+
out_table: R.Tensor([batch_size, 1, weight_table_size], "float32") = R.matmul(
100+
x, weight_table
101+
)
102+
out: R.Tensor([batch_size, 1, 32], "float32") = R.take(
103+
out_table, routing_table, axis=2
104+
)
105+
R.output(out)
106+
return out
107+
108+
109+
class TestStaticBatchedActivationsAndWeights(Base):
110+
@I.ir_module
111+
class Before:
112+
@R.function
113+
def main(
114+
x: R.Tensor([128, 1, 16], "float32"),
115+
weight_table: R.Tensor(["routing_table_size", 16, 32], "float32"),
116+
routing_table: R.Tensor([128], "int64"),
117+
) -> R.Tensor([128, 1, 32], "float32"):
118+
batch_size = T.int64()
119+
routing_table_size = T.int64()
120+
with R.dataflow():
121+
weight = R.take(weight_table, routing_table, axis=0)
122+
out = R.matmul(x, weight)
123+
R.output(out)
124+
return out
125+
126+
@I.ir_module
127+
class Expected:
128+
@R.function
129+
def main(
130+
x: R.Tensor([128, 1, 16], "float32"),
131+
weight_table: R.Tensor(["routing_table_size", 16, 32], "float32"),
132+
routing_table: R.Tensor([128], "int64"),
133+
) -> R.Tensor([128, 1, 32], "float32"):
134+
batch_size = T.int64()
135+
routing_table_size = T.int64()
136+
with R.dataflow():
137+
reordered_weight = R.permute_dims(weight_table, [1, 0, 2])
138+
fused_weight = R.reshape(reordered_weight, [16, routing_table_size * 32])
139+
fused_output = R.matmul(x, fused_weight)
140+
reordered_output = R.reshape(fused_output, [128, 1, routing_table_size, 32])
141+
tabular_output = R.take(reordered_output, routing_table, axis=2)
142+
out = R.einsum([tabular_output], "ijik->ijk")
143+
R.output(out)
144+
return out
145+
146+
147+
class TestDynamicBatchedActivationsAndWeights(Base):
148+
@I.ir_module
149+
class Before:
150+
@R.function
151+
def main(
152+
x: R.Tensor(["batch_size", 1, 16], "float32"),
153+
weight_table: R.Tensor(["routing_table_size", 16, 32], "float32"),
154+
routing_table: R.Tensor(["batch_size"], "int64"),
155+
) -> R.Tensor(["batch_size", 1, 32], "float32"):
156+
batch_size = T.int64()
157+
routing_table_size = T.int64()
158+
with R.dataflow():
159+
weight = R.take(weight_table, routing_table, axis=0)
160+
out = R.matmul(x, weight)
161+
R.output(out)
162+
return out
163+
164+
@I.ir_module
165+
class Expected:
166+
@R.function
167+
def main(
168+
x: R.Tensor(["batch_size", 1, 16], "float32"),
169+
weight_table: R.Tensor(["routing_table_size", 16, 32], "float32"),
170+
routing_table: R.Tensor(["batch_size"], "int64"),
171+
) -> R.Tensor(["batch_size", 1, 32], "float32"):
172+
batch_size = T.int64()
173+
routing_table_size = T.int64()
174+
with R.dataflow():
175+
reordered_weight = R.permute_dims(weight_table, [1, 0, 2])
176+
fused_weight = R.reshape(reordered_weight, [16, routing_table_size * 32])
177+
fused_output = R.matmul(x, fused_weight)
178+
reordered_output = R.reshape(fused_output, [batch_size, 1, routing_table_size, 32])
179+
tabular_output = R.take(reordered_output, routing_table, axis=2)
180+
out = R.einsum([tabular_output], "ijik->ijk")
181+
R.output(out)
182+
return out
183+
184+
185+
if __name__ == "__main__":
186+
tvm.testing.main()

0 commit comments

Comments
 (0)