Skip to content

Commit 89cc09c

Browse files
authored
[Unity][Transform] Handle dynamic shapes in CombineParallelMatmul (#16591)
* [Unity][Transform] Handle dynamic shapes in CombineParallelMatmul Prior to this commit, if the weight of a matmul a dynamic shape, and that matmul is being combined with the `CombineParallelMatmul` pass, it could cause a segfault when `dim.as<IntImmNode>()` returns a null pointer. This commit adds explicit test cases for these dynamic shapes, and updates `CombineParallelMatmul` to handle the dynamic shapes. * Add Tuple constructor for PR-16589
1 parent 864fd5c commit 89cc09c

File tree

3 files changed

+240
-61
lines changed

3 files changed

+240
-61
lines changed

include/tvm/relax/expr.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,24 @@ class Tuple : public Expr {
320320
*/
321321
TVM_DLL explicit Tuple(tvm::Array<Expr> fields, Span span = Span());
322322

323+
/*!
324+
* \brief Utility constructor to handle conversion to relax::Expr
325+
*
326+
* If the calling scope already has an array of a specific type of
327+
* relax expression (e.g. `Array<relax::Var>`), it must be converted
328+
* into an array of base type. This constructor handles the
329+
* conversion to the base `Array<relax::Expr>`.
330+
*
331+
* \tparam RelaxExpr The type of relax expression passed in as an argument.
332+
*
333+
* \param fields The fields of a tuple.
334+
*
335+
* \param span The source span of the expression.
336+
*/
337+
template <typename RelaxExpr, typename = std::enable_if_t<std::is_base_of_v<Expr, RelaxExpr>>>
338+
TVM_DLL explicit Tuple(tvm::Array<RelaxExpr> fields, Span span = Span())
339+
: Tuple(fields.Map([](const RelaxExpr& expr) -> Expr { return expr; }), span) {}
340+
323341
TVM_DEFINE_OBJECT_REF_METHODS(Tuple, Expr, TupleNode);
324342
TVM_DEFINE_OBJECT_REF_COW_METHOD(TupleNode);
325343
};

src/relax/transform/combine_parallel_matmul.cc

Lines changed: 100 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,16 @@ struct Patterns {
7171
WildcardPattern input;
7272
std::vector<WildcardPattern> rhs;
7373
std::vector<WildcardPattern> bias;
74-
std::vector<CallPattern> matmul, bias_add, activation;
74+
std::vector<CallPattern> matmul;
75+
std::vector<CallPattern> bias_add;
76+
std::vector<CallPattern> activation;
77+
};
78+
79+
struct SplitInfo {
80+
Var rhs;
81+
Optional<Var> bias;
82+
PrimExpr split_size;
83+
DFPattern pattern_to_replace;
7584
};
7685

7786
Patterns CreatePatterns(const BranchInfo& branch_info) {
@@ -140,40 +149,68 @@ runtime::TypedPackedFunc<Map<Var, Expr>(Map<DFPattern, Var>, Map<Var, Expr>)> Ge
140149
for (const auto& [rhs_dim, indices] : GroupShapes(rhs_shapes)) {
141150
if (indices.size() == 1 || !batch_dims_compatible(rhs_dim, indices, rhs_shapes)) continue;
142151

143-
auto inp = matchings[patterns.input];
152+
auto lhs = matchings[patterns.input];
153+
154+
const auto& patterns_to_replace = [&patterns, &branch_info]() {
155+
if (branch_info.activation) return patterns.activation;
156+
if (branch_info.bias_dim) return patterns.bias_add;
157+
return patterns.matmul;
158+
}();
144159

145-
Array<Var> rhs, bias;
146-
for (auto ind : indices) {
147-
rhs.push_back(matchings[patterns.rhs[ind]]);
148-
if (branch_info.bias_dim) {
149-
ICHECK(matchings.count(patterns.bias[ind]));
150-
bias.push_back(matchings[patterns.bias[ind]]);
160+
std::vector<SplitInfo> splits;
161+
for (auto index : indices) {
162+
Var rhs = matchings[patterns.rhs[index]];
163+
Optional<Var> bias = NullOpt;
164+
if (branch_info.bias_dim.has_value()) {
165+
bias = matchings[patterns.bias[index]];
151166
}
167+
PrimExpr split_size = GetTensorSInfo(rhs)->GetShape().value()[rhs_dim - 1];
168+
DFPattern pattern_to_replace = patterns_to_replace[index];
169+
splits.push_back(SplitInfo{rhs, bias, split_size, pattern_to_replace});
170+
}
171+
// At most one dynamic output shape can be part of the combined
172+
// matmul, and it must be the last item in the split. Use
173+
// `std::stable_sort` instead of `std::sort` to maintain a
174+
// consistent order for all static shapes, and to consistently
175+
// select the same dynamic weight to participate.
176+
auto is_dynamic_split = [](const SplitInfo& split) -> bool {
177+
return !split.split_size->IsInstance<IntImmNode>();
178+
};
179+
std::stable_sort(splits.begin(), splits.end(),
180+
[&is_dynamic_split](const auto& a, const auto& b) {
181+
return is_dynamic_split(a) < is_dynamic_split(b);
182+
});
183+
// Remove anything after the first dynamic shape participating
184+
// in the combined matmul.
185+
if (auto it = std::find_if(splits.begin(), splits.end(), is_dynamic_split);
186+
it != splits.end()) {
187+
splits.erase(it + 1, splits.end());
152188
}
153189

154-
if (!check(inp, rhs, bias, bindings)) {
190+
if (splits.size() == 1) {
155191
continue;
156192
}
157193

158-
auto make_tuple = [](const Array<Var>& var_array) {
159-
Array<Expr> exp_array;
160-
for (auto v : var_array) exp_array.push_back(v);
161-
return Tuple(exp_array);
162-
};
194+
Array<Var> rhs;
195+
Array<Var> bias;
196+
for (const auto& split : splits) {
197+
rhs.push_back(split.rhs);
198+
if (split.bias) {
199+
bias.push_back(split.bias.value());
200+
}
201+
}
163202

164-
auto concat_rhs = concat(make_tuple(rhs), Integer(rhs_dim - 1));
165-
auto out_dtype = GetTensorSInfo(matchings[patterns.matmul[indices[0]]])->dtype;
166-
auto matmul_combined = matmul(inp, concat_rhs, out_dtype);
203+
if (!check(lhs, rhs, bias, bindings)) {
204+
continue;
205+
}
167206

168-
const auto& pattern_to_replace = [&patterns, &branch_info]() {
169-
if (branch_info.activation) return patterns.activation;
170-
if (branch_info.bias_dim) return patterns.bias_add;
171-
return patterns.matmul;
172-
}();
207+
auto concat_rhs = concat(Tuple(rhs), Integer(rhs_dim - 1));
208+
auto out_dtype = GetTensorSInfo(matchings[patterns.matmul[indices[0]]])->dtype;
209+
auto matmul_combined = matmul(lhs, concat_rhs, out_dtype);
173210

174211
if (branch_info.bias_dim) {
175212
auto bias_dim = GetTensorSInfo(bias[0])->ndim;
176-
auto concat_bias = concat(make_tuple(bias), Integer(bias_dim - 1));
213+
auto concat_bias = concat(Tuple(bias), Integer(bias_dim - 1));
177214
matmul_combined = add(matmul_combined, concat_bias);
178215
}
179216

@@ -191,20 +228,23 @@ runtime::TypedPackedFunc<Map<Var, Expr>(Map<DFPattern, Var>, Map<Var, Expr>)> Ge
191228
}
192229
}
193230

194-
int ind = 0;
231+
int split_index = 0;
195232
Array<IntImm> sections;
196-
for (int i = 0; i < static_cast<int>(indices.size()) - 1; ++i) {
197-
auto width = GetTensorSInfo(rhs[i])->GetShape().value()[rhs_dim - 1].as<IntImmNode>();
198-
ind += width->value;
199-
sections.push_back(IntImm(DataType::Int(64), ind));
233+
for (size_t i = 0; i + 1 < splits.size(); i++) {
234+
auto width = splits[i].split_size.as<IntImmNode>();
235+
ICHECK(width) << "InternalError: "
236+
<< "All splits except the last one must have a static shape";
237+
split_index += width->value;
238+
sections.push_back(IntImm(DataType::Int(64), split_index));
200239
}
201240

202-
int lhs_dim = GetTensorSInfo(inp)->ndim;
241+
int lhs_dim = GetTensorSInfo(lhs)->ndim;
203242
int split_axis = std::max<int>(lhs_dim, rhs_dim) - 1;
204243
auto chunks = split(matmul_combined, sections, split_axis);
205244

206-
for (size_t i = 0; i < indices.size(); ++i) {
207-
auto bound_var = matchings[pattern_to_replace[indices[i]]];
245+
for (size_t i = 0; i < splits.size(); i++) {
246+
const auto& split = splits[i];
247+
auto bound_var = matchings[split.pattern_to_replace];
208248
replacements.Set(bound_var, TupleGetItem(chunks, i));
209249
}
210250
}
@@ -244,43 +284,43 @@ std::vector<BranchInfo> GetBranchInfo(Function f) {
244284

245285
PostOrderVisit(f, [&](const Expr& e) {
246286
if (!e->IsInstance<CallNode>()) return;
247-
if (auto match = ExtractMatchedExpr(pat, e, bindings)) {
248-
auto matmul_call = Downcast<Call>(match.value()[matmul_pat]);
249-
auto matmul_lhs = Downcast<Var>(matmul_call->args[0]);
250287

251-
auto it = groups.find(matmul_lhs.get());
252-
BranchInfo* branch = it != groups.end() ? &it->second : nullptr;
253-
std::optional<int> bias_dim = std::nullopt;
254-
std::optional<std::string> activation = std::nullopt;
288+
auto match = ExtractMatchedExpr(pat, e, bindings);
289+
if (!match) return;
255290

256-
if (match.value().count(bias_pat)) {
257-
bias_dim = GetTensorSInfo(match.value()[bias_pat])->ndim;
258-
}
291+
auto matmul_call = Downcast<Call>(match.value()[matmul_pat]);
292+
auto matmul_lhs = Downcast<Var>(matmul_call->args[0]);
259293

260-
for (size_t i = 0; i < activations.size(); ++i) {
261-
if (match.value().count(activation_pat[i]) ||
262-
match.value().count(bias_activation_pat[i])) {
263-
activation = activations[i];
264-
}
294+
std::optional<int> bias_dim = std::nullopt;
295+
std::optional<std::string> activation = std::nullopt;
296+
297+
if (match.value().count(bias_pat)) {
298+
bias_dim = GetTensorSInfo(match.value()[bias_pat])->ndim;
299+
}
300+
301+
for (size_t i = 0; i < activations.size(); ++i) {
302+
if (match.value().count(activation_pat[i]) || match.value().count(bias_activation_pat[i])) {
303+
activation = activations[i];
265304
}
305+
}
266306

267-
if (!branch) {
268-
// Create a new subgraph with one matmul
269-
groups[matmul_lhs.get()] = {1, bias_dim, activation};
270-
} else {
271-
// Create a new branch in the existing parallel matmul subtree, and
272-
// invalidate bias and activation information when needed.
273-
branch->num_branches += 1;
307+
if (auto it = groups.find(matmul_lhs.get()); it != groups.end()) {
308+
// Create a new branch in the existing parallel matmul subtree, and
309+
// invalidate bias and activation information when needed.
310+
BranchInfo* branch = &it->second;
311+
312+
branch->num_branches += 1;
274313

275-
if (!bias_dim || (branch->bias_dim && *branch->bias_dim != *bias_dim)) {
276-
branch->bias_dim = std::nullopt;
277-
}
314+
if (!bias_dim || (branch->bias_dim && *branch->bias_dim != *bias_dim)) {
315+
branch->bias_dim = std::nullopt;
316+
}
278317

279-
if (!activation || (branch->activation && *branch->activation != *activation)) {
280-
branch->activation = std::nullopt;
281-
}
318+
if (!activation || (branch->activation && *branch->activation != *activation)) {
319+
branch->activation = std::nullopt;
282320
}
283-
return;
321+
} else {
322+
// Create a new subgraph with one matmul
323+
groups[matmul_lhs.get()] = {1, bias_dim, activation};
284324
}
285325
});
286326

tests/python/relax/test_transform_combine_parallel_matmul.py

Lines changed: 122 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -525,7 +525,16 @@ def expected(
525525
tvm.ir.assert_structural_equal(after, expected)
526526

527527

528-
def test_dynamic_rhs():
528+
def test_combine_matmul_of_static_and_dynamic_shapes():
529+
"""Combine two matmuls, one with dynamic shape
530+
531+
The `R.split` operator must have a static list of integer indices
532+
at which to split the matmul output, because these integer indices
533+
are stored as operator attributes. However, the last output can
534+
still have a dynamic shape.
535+
536+
"""
537+
529538
@R.function(private=True)
530539
def before(
531540
x: R.Tensor((2, 1024, 640), "float32"),
@@ -572,5 +581,117 @@ def expected(
572581
tvm.ir.assert_structural_equal(after, expected)
573582

574583

584+
def test_combine_matmul_of_dynamic_and_static_shapes():
585+
"""Combine two matmuls, one with dynamic shape
586+
587+
Like `test_combine_matmul_of_static_and_dynamic_shapes`, but the
588+
dynamic-shaped matmul is encountered first. Due to the
589+
requirements imposed by `R.split` storing the split indices as
590+
static integers, the static-shaped weights must occur first in the
591+
concatenated weights.
592+
"""
593+
594+
@R.function(private=True)
595+
def before(
596+
x: R.Tensor((2, 1024, 640), "float32"),
597+
w0: R.Tensor((640, "M"), "float32"),
598+
w1: R.Tensor((640, 640), "float32"),
599+
):
600+
M = T.int64()
601+
with R.dataflow():
602+
lv0 = R.matmul(x, w0)
603+
lv1 = R.matmul(x, w1)
604+
out = (lv0, lv1)
605+
R.output(out)
606+
return out
607+
608+
@R.function(private=True)
609+
def expected(
610+
x: R.Tensor((2, 1024, 640), dtype="float32"),
611+
w0: R.Tensor((640, "M"), dtype="float32"),
612+
w1: R.Tensor((640, 640), dtype="float32"),
613+
) -> R.Tuple(
614+
R.Tensor((2, 1024, "M"), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32")
615+
):
616+
M = T.int64()
617+
with R.dataflow():
618+
lv: R.Tensor((640, 640 + M), dtype="float32") = R.concat((w1, w0), axis=1)
619+
lv1: R.Tensor((2, 1024, 640 + M), dtype="float32") = R.matmul(
620+
x, lv, out_dtype="float32"
621+
)
622+
lv2: R.Tuple(
623+
R.Tensor((2, 1024, 640), dtype="float32"),
624+
R.Tensor((2, 1024, M), dtype="float32"),
625+
) = R.split(lv1, indices_or_sections=[640], axis=2)
626+
lv0: R.Tensor((2, 1024, M), dtype="float32") = lv2[1]
627+
lv1_1: R.Tensor((2, 1024, 640), dtype="float32") = lv2[0]
628+
out: R.Tuple(
629+
R.Tensor((2, 1024, M), dtype="float32"),
630+
R.Tensor((2, 1024, 640), dtype="float32"),
631+
) = (lv0, lv1_1)
632+
R.output(out)
633+
return out
634+
635+
after = CombineParallelMatmul()(tvm.IRModule.from_expr(before))["main"]
636+
637+
tvm.ir.assert_structural_equal(after, expected)
638+
639+
640+
def test_limit_one_dynamic_shape_in_combined_matmul():
641+
"""Combine two matmuls, one with dynamic shape
642+
643+
Like `test_combine_matmul_of_static_and_dynamic_shapes`, but with
644+
two dynamic weights that could, in principle, be merged together.
645+
Because `R.split` must have integer indices at which to split,
646+
only one of the dynamic outputs can be part of the combined
647+
matmul.
648+
"""
649+
650+
@R.function(private=True)
651+
def before(
652+
x: R.Tensor((2, 1024, 640), "float32"),
653+
w0: R.Tensor((640, "M"), "float32"),
654+
w1: R.Tensor((640, 640), "float32"),
655+
w2: R.Tensor((640, "N"), "float32"),
656+
):
657+
M = T.int64()
658+
with R.dataflow():
659+
lv0 = R.matmul(x, w0)
660+
lv1 = R.matmul(x, w1)
661+
lv2 = R.matmul(x, w2)
662+
out = (lv0, lv1, lv2)
663+
R.output(out)
664+
return out
665+
666+
@R.function(private=True)
667+
def expected(
668+
x: R.Tensor((2, 1024, 640), dtype="float32"),
669+
w0: R.Tensor((640, "M"), dtype="float32"),
670+
w1: R.Tensor((640, 640), dtype="float32"),
671+
w2: R.Tensor((640, "N"), "float32"),
672+
) -> R.Tuple(
673+
R.Tensor((2, 1024, "M"), dtype="float32"),
674+
R.Tensor((2, 1024, 640), dtype="float32"),
675+
R.Tensor((2, 1024, "N"), dtype="float32"),
676+
):
677+
M = T.int64()
678+
with R.dataflow():
679+
concat_weights = R.concat((w1, w0), axis=1)
680+
concat_output = R.matmul(x, concat_weights, out_dtype="float32")
681+
split_output: R.Tuple(
682+
[R.Tensor([2, 1024, 640], dtype="float32"), R.Tensor([2, 1024, M], dtype="float32")]
683+
) = R.split(concat_output, indices_or_sections=[640], axis=2)
684+
lv0 = split_output[1]
685+
lv1 = split_output[0]
686+
lv2 = R.matmul(x, w2)
687+
out = (lv0, lv1, lv2)
688+
R.output(out)
689+
return out
690+
691+
after = CombineParallelMatmul()(tvm.IRModule.from_expr(before))["main"]
692+
693+
tvm.ir.assert_structural_equal(after, expected)
694+
695+
575696
if __name__ == "__main__":
576697
tvm.testing.main()

0 commit comments

Comments
 (0)