@@ -71,7 +71,16 @@ struct Patterns {
7171 WildcardPattern input;
7272 std::vector<WildcardPattern> rhs;
7373 std::vector<WildcardPattern> bias;
74- std::vector<CallPattern> matmul, bias_add, activation;
74+ std::vector<CallPattern> matmul;
75+ std::vector<CallPattern> bias_add;
76+ std::vector<CallPattern> activation;
77+ };
78+
79+ struct SplitInfo {
80+ Var rhs;
81+ Optional<Var> bias;
82+ PrimExpr split_size;
83+ DFPattern pattern_to_replace;
7584};
7685
7786Patterns CreatePatterns (const BranchInfo& branch_info) {
@@ -140,40 +149,68 @@ runtime::TypedPackedFunc<Map<Var, Expr>(Map<DFPattern, Var>, Map<Var, Expr>)> Ge
140149 for (const auto & [rhs_dim, indices] : GroupShapes (rhs_shapes)) {
141150 if (indices.size () == 1 || !batch_dims_compatible (rhs_dim, indices, rhs_shapes)) continue ;
142151
143- auto inp = matchings[patterns.input ];
152+ auto lhs = matchings[patterns.input ];
153+
154+ const auto & patterns_to_replace = [&patterns, &branch_info]() {
155+ if (branch_info.activation ) return patterns.activation ;
156+ if (branch_info.bias_dim ) return patterns.bias_add ;
157+ return patterns.matmul ;
158+ }();
144159
145- Array<Var> rhs, bias ;
146- for (auto ind : indices) {
147- rhs. push_back ( matchings[patterns.rhs [ind]]) ;
148- if (branch_info. bias_dim ) {
149- ICHECK (matchings. count (patterns. bias [ind]));
150- bias. push_back ( matchings[patterns.bias [ind]]) ;
160+ std::vector<SplitInfo> splits ;
161+ for (auto index : indices) {
162+ Var rhs = matchings[patterns.rhs [index]] ;
163+ Optional<Var> bias = NullOpt;
164+ if (branch_info. bias_dim . has_value ()) {
165+ bias = matchings[patterns.bias [index]] ;
151166 }
167+ PrimExpr split_size = GetTensorSInfo (rhs)->GetShape ().value ()[rhs_dim - 1 ];
168+ DFPattern pattern_to_replace = patterns_to_replace[index];
169+ splits.push_back (SplitInfo{rhs, bias, split_size, pattern_to_replace});
170+ }
171+ // At most one dynamic output shape can be part of the combined
172+ // matmul, and it must be the last item in the split. Use
173+ // `std::stable_sort` instead of `std::sort` to maintain a
174+ // consistent order for all static shapes, and to consistently
175+ // select the same dynamic weight to participate.
176+ auto is_dynamic_split = [](const SplitInfo& split) -> bool {
177+ return !split.split_size ->IsInstance <IntImmNode>();
178+ };
179+ std::stable_sort (splits.begin (), splits.end (),
180+ [&is_dynamic_split](const auto & a, const auto & b) {
181+ return is_dynamic_split (a) < is_dynamic_split (b);
182+ });
183+ // Remove anything after the first dynamic shape participating
184+ // in the combined matmul.
185+ if (auto it = std::find_if (splits.begin (), splits.end (), is_dynamic_split);
186+ it != splits.end ()) {
187+ splits.erase (it + 1 , splits.end ());
152188 }
153189
154- if (! check (inp, rhs, bias, bindings) ) {
190+ if (splits. size () == 1 ) {
155191 continue ;
156192 }
157193
158- auto make_tuple = [](const Array<Var>& var_array) {
159- Array<Expr> exp_array;
160- for (auto v : var_array) exp_array.push_back (v);
161- return Tuple (exp_array);
162- };
194+ Array<Var> rhs;
195+ Array<Var> bias;
196+ for (const auto & split : splits) {
197+ rhs.push_back (split.rhs );
198+ if (split.bias ) {
199+ bias.push_back (split.bias .value ());
200+ }
201+ }
163202
164- auto concat_rhs = concat ( make_tuple ( rhs), Integer (rhs_dim - 1 ));
165- auto out_dtype = GetTensorSInfo (matchings[patterns. matmul [indices[ 0 ]]])-> dtype ;
166- auto matmul_combined = matmul (inp, concat_rhs, out_dtype);
203+ if (! check (lhs, rhs, bias, bindings)) {
204+ continue ;
205+ }
167206
168- const auto & pattern_to_replace = [&patterns, &branch_info]() {
169- if (branch_info.activation ) return patterns.activation ;
170- if (branch_info.bias_dim ) return patterns.bias_add ;
171- return patterns.matmul ;
172- }();
207+ auto concat_rhs = concat (Tuple (rhs), Integer (rhs_dim - 1 ));
208+ auto out_dtype = GetTensorSInfo (matchings[patterns.matmul [indices[0 ]]])->dtype ;
209+ auto matmul_combined = matmul (lhs, concat_rhs, out_dtype);
173210
174211 if (branch_info.bias_dim ) {
175212 auto bias_dim = GetTensorSInfo (bias[0 ])->ndim ;
176- auto concat_bias = concat (make_tuple (bias), Integer (bias_dim - 1 ));
213+ auto concat_bias = concat (Tuple (bias), Integer (bias_dim - 1 ));
177214 matmul_combined = add (matmul_combined, concat_bias);
178215 }
179216
@@ -191,20 +228,23 @@ runtime::TypedPackedFunc<Map<Var, Expr>(Map<DFPattern, Var>, Map<Var, Expr>)> Ge
191228 }
192229 }
193230
194- int ind = 0 ;
231+ int split_index = 0 ;
195232 Array<IntImm> sections;
196- for (int i = 0 ; i < static_cast <int >(indices.size ()) - 1 ; ++i) {
197- auto width = GetTensorSInfo (rhs[i])->GetShape ().value ()[rhs_dim - 1 ].as <IntImmNode>();
198- ind += width->value ;
199- sections.push_back (IntImm (DataType::Int (64 ), ind));
233+ for (size_t i = 0 ; i + 1 < splits.size (); i++) {
234+ auto width = splits[i].split_size .as <IntImmNode>();
235+ ICHECK (width) << " InternalError: "
236+ << " All splits except the last one must have a static shape" ;
237+ split_index += width->value ;
238+ sections.push_back (IntImm (DataType::Int (64 ), split_index));
200239 }
201240
202- int lhs_dim = GetTensorSInfo (inp )->ndim ;
241+ int lhs_dim = GetTensorSInfo (lhs )->ndim ;
203242 int split_axis = std::max<int >(lhs_dim, rhs_dim) - 1 ;
204243 auto chunks = split (matmul_combined, sections, split_axis);
205244
206- for (size_t i = 0 ; i < indices.size (); ++i) {
207- auto bound_var = matchings[pattern_to_replace[indices[i]]];
245+ for (size_t i = 0 ; i < splits.size (); i++) {
246+ const auto & split = splits[i];
247+ auto bound_var = matchings[split.pattern_to_replace ];
208248 replacements.Set (bound_var, TupleGetItem (chunks, i));
209249 }
210250 }
@@ -244,43 +284,43 @@ std::vector<BranchInfo> GetBranchInfo(Function f) {
244284
245285 PostOrderVisit (f, [&](const Expr& e) {
246286 if (!e->IsInstance <CallNode>()) return ;
247- if (auto match = ExtractMatchedExpr (pat, e, bindings)) {
248- auto matmul_call = Downcast<Call>(match.value ()[matmul_pat]);
249- auto matmul_lhs = Downcast<Var>(matmul_call->args [0 ]);
250287
251- auto it = groups.find (matmul_lhs.get ());
252- BranchInfo* branch = it != groups.end () ? &it->second : nullptr ;
253- std::optional<int > bias_dim = std::nullopt ;
254- std::optional<std::string> activation = std::nullopt ;
288+ auto match = ExtractMatchedExpr (pat, e, bindings);
289+ if (!match) return ;
255290
256- if (match.value ().count (bias_pat)) {
257- bias_dim = GetTensorSInfo (match.value ()[bias_pat])->ndim ;
258- }
291+ auto matmul_call = Downcast<Call>(match.value ()[matmul_pat]);
292+ auto matmul_lhs = Downcast<Var>(matmul_call->args [0 ]);
259293
260- for (size_t i = 0 ; i < activations.size (); ++i) {
261- if (match.value ().count (activation_pat[i]) ||
262- match.value ().count (bias_activation_pat[i])) {
263- activation = activations[i];
264- }
294+ std::optional<int > bias_dim = std::nullopt ;
295+ std::optional<std::string> activation = std::nullopt ;
296+
297+ if (match.value ().count (bias_pat)) {
298+ bias_dim = GetTensorSInfo (match.value ()[bias_pat])->ndim ;
299+ }
300+
301+ for (size_t i = 0 ; i < activations.size (); ++i) {
302+ if (match.value ().count (activation_pat[i]) || match.value ().count (bias_activation_pat[i])) {
303+ activation = activations[i];
265304 }
305+ }
266306
267- if (!branch) {
268- // Create a new subgraph with one matmul
269- groups[matmul_lhs.get ()] = {1 , bias_dim, activation};
270- } else {
271- // Create a new branch in the existing parallel matmul subtree, and
272- // invalidate bias and activation information when needed.
273- branch->num_branches += 1 ;
307+ if (auto it = groups.find (matmul_lhs.get ()); it != groups.end ()) {
308+ // Create a new branch in the existing parallel matmul subtree, and
309+ // invalidate bias and activation information when needed.
310+ BranchInfo* branch = &it->second ;
311+
312+ branch->num_branches += 1 ;
274313
275- if (!bias_dim || (branch->bias_dim && *branch->bias_dim != *bias_dim)) {
276- branch->bias_dim = std::nullopt ;
277- }
314+ if (!bias_dim || (branch->bias_dim && *branch->bias_dim != *bias_dim)) {
315+ branch->bias_dim = std::nullopt ;
316+ }
278317
279- if (!activation || (branch->activation && *branch->activation != *activation)) {
280- branch->activation = std::nullopt ;
281- }
318+ if (!activation || (branch->activation && *branch->activation != *activation)) {
319+ branch->activation = std::nullopt ;
282320 }
283- return ;
321+ } else {
322+ // Create a new subgraph with one matmul
323+ groups[matmul_lhs.get ()] = {1 , bias_dim, activation};
284324 }
285325 });
286326
0 commit comments