Skip to content

Commit b463cba

Browse files
committed
Fix breakage in unit tests
1 parent 9a8d239 commit b463cba

File tree

2 files changed

+18
-11
lines changed

2 files changed

+18
-11
lines changed

src/relax/transform/fuse_tir.cc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -659,21 +659,25 @@ class FusedTIRConstructor : public ExprVisitor {
659659
*/
660660
void MapInputBuffer(const tir::PrimFunc& func, const relax::Expr& args) {
661661
Array<Expr> arg_list;
662-
Array<tir::Buffer> buffer_list;
663662
if (const auto* arg_tuple = args.as<TupleNode>()) {
664663
arg_list = arg_tuple->fields;
665664
} else {
666665
arg_list = {args};
667666
}
668667

668+
Array<Expr> relax_tensors;
669+
Array<tir::Buffer> tir_buffers;
670+
669671
ICHECK_GE(func->params.size(), arg_list.size());
670672
for (size_t i = 0; i < arg_list.size(); ++i) {
671673
const tir::Var& param = func->params[i];
672-
const tir::Buffer& buffer = func->buffer_map.at(param);
673-
buffer_list.push_back(buffer);
674+
if (auto buffer = func->buffer_map.Get(param)) {
675+
relax_tensors.push_back(arg_list[i]);
676+
tir_buffers.push_back(buffer.value());
677+
}
674678
}
675679

676-
MapArgsToBuffer(arg_list, buffer_list);
680+
MapArgsToBuffer(relax_tensors, tir_buffers);
677681
}
678682

679683
static Array<Integer> GetInplaceOutputIndices(const Array<Integer>& inplace_indices,

src/relax/transform/rewrite_dataflow_reshape.cc

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,15 @@
3434
namespace tvm {
3535
namespace relax {
3636

37-
std::vector<size_t> GetUsedArgsIndices(const tir::PrimFunc& fn, size_t num_args) {
37+
std::vector<size_t> GetUsedTensorArgIndices(const tir::PrimFunc& fn, size_t num_args) {
3838
std::vector<size_t> indices;
3939
for (size_t i = 0; i < num_args; ++i) {
40-
auto buffer_var = fn->buffer_map[fn->params[i]]->data;
41-
if (tir::UsesVar(fn->body, [=](const tir::VarNode* var) { return var == buffer_var.get(); })) {
42-
indices.push_back(i);
40+
if (auto buffer = fn->buffer_map.Get(fn->params[i])) {
41+
auto buffer_var = buffer.value()->data;
42+
if (tir::UsesVar(fn->body,
43+
[=](const tir::VarNode* var) { return var == buffer_var.get(); })) {
44+
indices.push_back(i);
45+
}
4346
}
4447
}
4548
return indices;
@@ -83,17 +86,17 @@ class DataflowReshapeRewriter : public ExprMutator {
8386

8487
auto prim_fn = Downcast<tir::PrimFunc>(mod_->Lookup(Downcast<GlobalVar>(call->args[0])));
8588
auto arg_tuple = Downcast<Tuple>(call->args[1])->fields;
86-
auto used_arg_indices = GetUsedArgsIndices(prim_fn, arg_tuple.size());
89+
auto used_tensor_arg_indices = GetUsedTensorArgIndices(prim_fn, arg_tuple.size());
8790

8891
// The number of inputs to call_tir(reshape, (...)) might not be one, since FuseOps
8992
// can generate a fused TupleGetItem + reshape function whose input is a tuple. FuseTIR
9093
// then flattens the tuple input so that the fused TIR reshape function ends up having
9194
// multiple input buffers. But only one of them should be accessed and reshaped.
92-
if (used_arg_indices.size() != 1) {
95+
if (used_tensor_arg_indices.size() != 1) {
9396
return GetRef<Call>(call);
9497
}
9598

96-
auto arg = arg_tuple[used_arg_indices[0]];
99+
auto arg = arg_tuple[used_tensor_arg_indices[0]];
97100

98101
if (!IsCallingTIRReshape(call, arg)) {
99102
return GetRef<Call>(call);

0 commit comments

Comments
 (0)