|
34 | 34 | namespace tvm { |
35 | 35 | namespace relax { |
36 | 36 |
|
37 | | -std::vector<size_t> GetUsedArgsIndices(const tir::PrimFunc& fn, size_t num_args) { |
| 37 | +std::vector<size_t> GetUsedTensorArgIndices(const tir::PrimFunc& fn, size_t num_args) { |
38 | 38 | std::vector<size_t> indices; |
39 | 39 | for (size_t i = 0; i < num_args; ++i) { |
40 | | - auto buffer_var = fn->buffer_map[fn->params[i]]->data; |
41 | | - if (tir::UsesVar(fn->body, [=](const tir::VarNode* var) { return var == buffer_var.get(); })) { |
42 | | - indices.push_back(i); |
| 40 | + if (auto buffer = fn->buffer_map.Get(fn->params[i])) { |
| 41 | + auto buffer_var = buffer.value()->data; |
| 42 | + if (tir::UsesVar(fn->body, |
| 43 | + [=](const tir::VarNode* var) { return var == buffer_var.get(); })) { |
| 44 | + indices.push_back(i); |
| 45 | + } |
43 | 46 | } |
44 | 47 | } |
45 | 48 | return indices; |
@@ -83,17 +86,17 @@ class DataflowReshapeRewriter : public ExprMutator { |
83 | 86 |
|
84 | 87 | auto prim_fn = Downcast<tir::PrimFunc>(mod_->Lookup(Downcast<GlobalVar>(call->args[0]))); |
85 | 88 | auto arg_tuple = Downcast<Tuple>(call->args[1])->fields; |
86 | | - auto used_arg_indices = GetUsedArgsIndices(prim_fn, arg_tuple.size()); |
| 89 | + auto used_tensor_arg_indices = GetUsedTensorArgIndices(prim_fn, arg_tuple.size()); |
87 | 90 |
|
88 | 91 | // The number of inputs to call_tir(reshape, (...)) might not be one, since FuseOps |
89 | 92 | // can generate a fused TupleGetItem + reshape function whose input is a tuple. FuseTIR |
90 | 93 | // then flattens the tuple input so that the fused TIR reshape function ends up having |
91 | 94 | // multiple input buffers. But only one of them should be accessed and reshaped. |
92 | | - if (used_arg_indices.size() != 1) { |
| 95 | + if (used_tensor_arg_indices.size() != 1) { |
93 | 96 | return GetRef<Call>(call); |
94 | 97 | } |
95 | 98 |
|
96 | | - auto arg = arg_tuple[used_arg_indices[0]]; |
| 99 | + auto arg = arg_tuple[used_tensor_arg_indices[0]]; |
97 | 100 |
|
98 | 101 | if (!IsCallingTIRReshape(call, arg)) { |
99 | 102 | return GetRef<Call>(call); |
|
0 commit comments