Skip to content

Commit b581575

Browse files
authored
[Transform] Implement relax.transform.ReorderPermuteDimsAfterConcat (#16596)
* [Transform] Implement relax.transform.ReorderPermuteDimsAfterConcat This commit implements an optional optimization pass `relax.transform.ReorderPermuteDimsAfterConcat`, which reorder expressions of the form `R.concat(R.permute_dims(A), R.permute_dims(B))` into `R.permute_dims(R.concat(A,B))`. This pass is intended to be used alongside `CombineParallelMatmul`. After parallel matmuls are combined, to be lifted out, and optimized `nn.Linear` kernels to find the `R.matmul(x, R.permute_dims(weights))` patterns they are looking for. ```python @R.function def func(x: R.Tensor, weight_query: R.Tensor, weight_key: R.Tensor, weight_value: R.Tensor): """Initial IRModule The `R.permute_dims` followed by `R.matmul` is the relax equivalent of `nn.Linear`, and will frequently have optimized kernels. """ weight_query_T = R.permute_dims(weight_query) query = R.matmul(x, weight_query) weight_key_T = R.permute_dims(weight_key) key = R.matmul(x, weight_key) weight_value_T = R.permute_dims(weight_value) value = R.matmul(x, weight_value) @R.function def func(x: R.Tensor, weight_query: R.Tensor, weight_key: R.Tensor, weight_value: R.Tensor): """After `CombineParallelMatmul` There's now only a single matmul to be performed, which is generally better than performing three small matmuls. However, the optimized kernels for `nn.Linear` can no longer be applied, because the `R.concat` isn't part of the expected pattern. """ weight_query_T = R.permute_dims(weight_query) weight_key_T = R.permute_dims(weight_key) weight_value_T = R.permute_dims(weight_value) fused_weight_T = R.concat([weight_query_T, weight_key_T, weight_value_T], axis=1) fused_qkv = R.matmul(x, fused_weight_T) query, key, value = R.split(fused_qkv) @R.function def func(x: R.Tensor, weight_query: R.Tensor, weight_key: R.Tensor, weight_value: R.Tensor): """After `ReorderPermuteDimsAfterConcat` There's still only a single matmul, and the optimized kernels for `nn.Linear` can be applied again. """ fused_weight = R.concat([weight_query, weight_key, weight_value], axis=0) fused_weight_T = R.permute_dims(fused_weight) fused_qkv = R.matmul(x, fused_weight_T) query, key, value = R.split(fused_qkv) ``` * Expand description of `max_concat` variable as a temporary solution
1 parent 84b3f69 commit b581575

File tree

4 files changed

+472
-0
lines changed

4 files changed

+472
-0
lines changed

python/tvm/relax/transform/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
RemovePurityChecking,
6464
RemoveUnusedParameters,
6565
RemoveUnusedOutputs,
66+
ReorderPermuteDimsAfterConcat,
6667
ReorderTakeAfterMatmul,
6768
RewriteCUDAGraph,
6869
RewriteDataflowReshape,

python/tvm/relax/transform/transform.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1325,6 +1325,26 @@ def ExpandMatmulOfSum():
13251325
return _ffi_api.ExpandMatmulOfSum() # type: ignore
13261326

13271327

1328+
def ReorderPermuteDimsAfterConcat():
1329+
"""Reorder `concat(permute_dims(A), permute_dims(B))` into `permute_dims(concat(A,B))`
1330+
1331+
Useful for optimizing computations after `CombineParallelMatmul`.
1332+
The patterns for optimized `nn.Linear` implementations look for
1333+
`matmul(activations, permute_dims(weights))`. After
1334+
`CombineParallelMatmul`, the `matmul(activations,
1335+
concat(permute_dims(A), permute_dims(B)))` no longer matches this
1336+
pattern. Rearranging into `matmul(activations,
1337+
permute_dims(concat(A,B)))` restores the pattern match.
1338+
1339+
Returns
1340+
-------
1341+
ret : tvm.transform.Pass
1342+
The corresponding pass.
1343+
"""
1344+
1345+
return _ffi_api.ReorderPermuteDimsAfterConcat() # type: ignore
1346+
1347+
13281348
def ReorderTakeAfterMatmul():
13291349
"""Reorder `matmul(x, take(weights, indices))` to `take(matmul(x,weights),indices)`
13301350
Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file tvm/relax/transform/reorder_permute_dims_after_concat.cc
22+
* \brief Reorder concat(permute_dims(A), permute_dims(B)) into permute_dims(concat(A,B))
23+
*/
24+
25+
#include <tvm/relax/analysis.h>
26+
#include <tvm/relax/dataflow_matcher.h>
27+
#include <tvm/relax/expr.h>
28+
#include <tvm/relax/expr_functor.h>
29+
#include <tvm/relax/transform.h>
30+
31+
#include <optional>
32+
#include <unordered_set>
33+
#include <vector>
34+
35+
#include "../op/tensor/index.h"
36+
#include "../op/tensor/linear_algebra.h"
37+
#include "../op/tensor/manipulate.h"
38+
39+
namespace tvm {
40+
namespace relax {
41+
42+
namespace {
43+
std::tuple<DFPattern, TypedPackedFunc<Expr(Expr, Map<DFPattern, Expr>)>> CreatePatterns() {
44+
// TODO(Lunderberg): Allow pattern-matching to handle a flexible
45+
// number of arguments, each of which matches the same type of
46+
// pattern.
47+
//
48+
// Because we instantiate one DFPattern for each value in
49+
// `min_concat <= i <= max_concat`, we don't want to set
50+
// `max_concat` to an extremely high value. The current value of 12
51+
// was chosen to be significantly higher than the highest value
52+
// required so far (3, for query/key/value in attention layers), but
53+
// not so high that it requires an excessive number of `DFPattern`.
54+
//
55+
// This value is deliberately *NOT* exposed, as `max_concat` may be
56+
// increased at any point that it is required, and other use cases
57+
// should not depend on its value. If there is a use case that
58+
// requires more matmuls to be handled, and pattern-matching does
59+
// not yet support a flexible number of `Tuple` elements,
60+
// `max_concat` should be increased.
61+
size_t min_concat = 2;
62+
size_t max_concat = 12;
63+
64+
std::vector<DFPattern> pat_args;
65+
std::vector<DFPattern> pat_permute_dims;
66+
for (size_t i = 0; i < max_concat; i++) {
67+
auto arg = WildcardPattern();
68+
pat_args.push_back(arg);
69+
pat_permute_dims.push_back(IsOp("relax.permute_dims")(arg));
70+
}
71+
72+
auto make_pattern_with_num_concat = [&](size_t num_concat) -> DFPattern {
73+
ICHECK_LT(num_concat, pat_permute_dims.size());
74+
auto concat_tuple = TuplePattern(
75+
Array<DFPattern>(pat_permute_dims.begin(), pat_permute_dims.begin() + num_concat));
76+
return IsOp("relax.concat")(concat_tuple);
77+
};
78+
79+
DFPattern pat_concat = make_pattern_with_num_concat(min_concat);
80+
for (size_t i = min_concat + 1; i < max_concat; i++) {
81+
pat_concat = pat_concat | make_pattern_with_num_concat(i);
82+
}
83+
84+
auto get_permute_dims_optional_axes = [](const Expr& expr) -> Optional<Array<Integer>> {
85+
auto call = expr.as<CallNode>();
86+
ICHECK(call);
87+
auto attrs = call->attrs.as<PermuteDimsAttrs>();
88+
ICHECK(attrs);
89+
90+
return attrs->axes;
91+
};
92+
93+
auto get_permute_dims_axes =
94+
[get_permute_dims_optional_axes](const Expr& expr) -> Array<Integer> {
95+
if (auto opt_axes = get_permute_dims_optional_axes(expr)) {
96+
return opt_axes.value();
97+
} else {
98+
auto call = Downcast<Call>(expr);
99+
Array<Integer> permutation;
100+
auto arg_sinfo = call->args[0]->struct_info_.as<TensorStructInfoNode>();
101+
CHECK(arg_sinfo) << "Expected permute_dims to have a single tensor argument, "
102+
<< "but argument " << call->args[0] << " has struct info "
103+
<< call->args[0]->struct_info_;
104+
CHECK_GE(arg_sinfo->ndim, 0);
105+
size_t ndim = arg_sinfo->ndim;
106+
for (size_t i = 0; i < ndim; i++) {
107+
permutation.push_back(Integer(ndim - i - 1));
108+
}
109+
return permutation;
110+
}
111+
};
112+
113+
auto permute_dims_axes_are_compatible = [&](const Array<Expr>& permute_dims) -> bool {
114+
auto first_axes = get_permute_dims_axes(permute_dims[0]);
115+
for (size_t i_arg = 1; i_arg < permute_dims.size(); i_arg++) {
116+
auto i_axes = get_permute_dims_axes(permute_dims[i_arg]);
117+
if (i_axes.size() != first_axes.size()) {
118+
return false;
119+
}
120+
for (size_t i_axis = 0; i_axis < first_axes.size(); i_axis++) {
121+
if (i_axes[i_axis]->value != first_axes[i_axis]->value) {
122+
return false;
123+
}
124+
}
125+
}
126+
return true;
127+
};
128+
129+
auto rewriter = [=](Expr expr, Map<DFPattern, Expr> matches) -> Expr {
130+
Array<Expr> args;
131+
Array<Expr> all_permute_dims;
132+
for (size_t i = 0; i < max_concat; i++) {
133+
if (auto permute_dim_expr = matches.Get(pat_permute_dims[i])) {
134+
all_permute_dims.push_back(permute_dim_expr.value());
135+
args.push_back(matches[pat_args[i]]);
136+
}
137+
}
138+
139+
ICHECK_GE(all_permute_dims.size(), min_concat)
140+
<< "InternalError: "
141+
<< "Pattern match should return at least " << min_concat << " items, but only found "
142+
<< all_permute_dims.size() << ": " << all_permute_dims;
143+
144+
if (!permute_dims_axes_are_compatible(all_permute_dims)) {
145+
return expr;
146+
}
147+
Optional<Array<Integer>> permute_axes = get_permute_dims_optional_axes(all_permute_dims[0]);
148+
149+
Call concat_call = Downcast<Call>(matches[pat_concat]);
150+
auto concat_attrs = concat_call->attrs.as<ConcatAttrs>();
151+
ICHECK(concat_attrs);
152+
153+
auto old_concat_axis = [&]() -> size_t {
154+
if (concat_attrs->axis.defined()) {
155+
return concat_attrs->axis.value()->value;
156+
} else {
157+
return 0;
158+
}
159+
}();
160+
Integer new_concat_axis = get_permute_dims_axes(all_permute_dims[0])[old_concat_axis];
161+
162+
auto new_concat = concat(Tuple(args), new_concat_axis);
163+
auto new_permute_dims = permute_dims(new_concat, permute_axes);
164+
165+
return new_permute_dims;
166+
};
167+
168+
return {pat_concat, rewriter};
169+
}
170+
171+
} // namespace
172+
173+
namespace transform {
174+
Pass ReorderPermuteDimsAfterConcat() {
175+
auto pass_func = [=](Function func, IRModule mod, PassContext pc) {
176+
auto [pattern, rewriter] = CreatePatterns();
177+
return RewriteCall(pattern, rewriter, func);
178+
};
179+
return CreateFunctionPass(pass_func, 1, "ReorderPermuteDimsAfterConcat", {});
180+
}
181+
182+
TVM_REGISTER_GLOBAL("relax.transform.ReorderPermuteDimsAfterConcat")
183+
.set_body_typed(ReorderPermuteDimsAfterConcat);
184+
185+
} // namespace transform
186+
} // namespace relax
187+
} // namespace tvm

0 commit comments

Comments
 (0)