Skip to content

Commit a156181

Browse files
authored
[Relax] Fix EliminiateCommonSubexpr removing alloc tensor (#16852)
1 parent 3e802d1 commit a156181

File tree

3 files changed

+57
-5
lines changed

3 files changed

+57
-5
lines changed

src/relax/op/op.cc

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -851,7 +851,8 @@ RELAY_REGISTER_OP("relax.builtin.alloc_tensor")
851851
"The storage scope of the storage to allocate. Default is global.")
852852
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoAllocateTensor)
853853
// memory allocation isn't considered a "visible effect" as far as purity is concerned
854-
.set_attr<Bool>("FPurity", Bool(true));
854+
.set_attr<Bool>("FPurity", Bool(true))
855+
.set_attr<Bool>("TAllocator", Bool(true));
855856

856857
Expr MakeAllocTensor(Expr shape, DataTypeImm dtype, PrimValue runtime_device_index,
857858
StringImm storage_scope) {
@@ -875,7 +876,8 @@ RELAY_REGISTER_OP("relax.memory.alloc_storage")
875876
.add_argument("dtype", "DataTypeImm", "The dtype of the tensor to allocate.")
876877
.set_attr<FInferStructInfo>("FInferStructInfo", ReturnObjectStructInfo)
877878
// memory allocation isn't considered a "visible effect" as far as purity is concerned
878-
.set_attr<Bool>("FPurity", Bool(true));
879+
.set_attr<Bool>("FPurity", Bool(true))
880+
.set_attr<Bool>("TAllocator", Bool(true));
879881

880882
Expr MakeAllocStorage(Expr size, PrimValue virtual_device_index, StringImm storage_scope,
881883
DataTypeImm dtype) {
@@ -906,7 +908,8 @@ RELAY_REGISTER_OP("relax.memory.alloc_tensor")
906908
.add_argument("dtype", "DataTypeImm", "The dtype of the tensor to allocate.")
907909
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoMemAllocTensor)
908910
// memory allocation isn't considered a "visible effect" as far as purity is concerned
909-
.set_attr<Bool>("FPurity", Bool(true));
911+
.set_attr<Bool>("FPurity", Bool(true))
912+
.set_attr<Bool>("TAllocator", Bool(true));
910913

911914
Expr MakeMemAllocTensor(Expr storage, PrimValue offset, Expr shape, DataTypeImm dtype) {
912915
static const Op& op = Op::Get("relax.memory.alloc_tensor");
@@ -960,7 +963,8 @@ RELAY_REGISTER_OP("relax.vm.alloc_storage")
960963
"The storage scope of the storage to allocate. Default is global.")
961964
.set_attr<FInferStructInfo>("FInferStructInfo", ReturnObjectStructInfo)
962965
// memory allocation isn't considered a "visible effect" as far as purity is concerned
963-
.set_attr<Bool>("FPurity", Bool(true));
966+
.set_attr<Bool>("FPurity", Bool(true))
967+
.set_attr<Bool>("TAllocator", Bool(true));
964968

965969
Expr MakeVMAllocStorage(Expr size, PrimValue runtime_device_index, DataTypeImm dtype,
966970
StringImm storage_scope) {
@@ -998,7 +1002,8 @@ RELAY_REGISTER_OP("relax.vm.alloc_tensor")
9981002
.add_argument("dtype", "DataTypeImm", "The dtype of the tensor to allocate.")
9991003
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoVMAllocTensor)
10001004
// memory allocation isn't considered a "visible effect" as far as purity is concerned
1001-
.set_attr<Bool>("FPurity", Bool(true));
1005+
.set_attr<Bool>("FPurity", Bool(true))
1006+
.set_attr<Bool>("TAllocator", Bool(true));
10021007

10031008
Expr MakeVMAllocTensor(Expr storage, PrimValue offset, Expr shape, DataTypeImm dtype) {
10041009
static const Op& op = Op::Get("relax.vm.alloc_tensor");

src/relax/transform/eliminate_common_subexpr.cc

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,8 @@ class CommonSubexprEliminator : public ExprMutator {
126126
} else if (ContainsImpureCall(bound_value)) {
127127
VLOG(1) << "Since the expression is impure, cannot de-duplicate " << bound_value;
128128

129+
} else if (IsAllocatorCall(bound_value)) {
130+
VLOG(1) << "Skip allocator calls";
129131
} else if (auto it = expr_replacements_.find(lookup_key);
130132
it != expr_replacements_.end() && it->second.size()) {
131133
VLOG(1) << "Value " << bound_value << " has previously been bound as " << it->second[0]
@@ -186,6 +188,19 @@ class CommonSubexprEliminator : public ExprMutator {
186188
return clean_mutator.VisitExpr(expr);
187189
}
188190

191+
bool IsAllocatorCall(const Expr& expr) {
192+
static const auto& allocator_attr_map = Op::GetAttrMap<Bool>("TAllocator");
193+
if (const auto* call = expr.as<CallNode>()) {
194+
if (const auto* op = call->op.as<OpNode>()) {
195+
bool is_allocator = allocator_attr_map.get(GetRef<Op>(op), Bool(false))->value;
196+
if (is_allocator) {
197+
return true;
198+
}
199+
}
200+
}
201+
return false;
202+
}
203+
189204
bool call_only_{false};
190205
std::unordered_map<ReplacementKey, std::vector<Var>> expr_replacements_;
191206
};

tests/python/relax/test_transform_cse.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -627,5 +627,37 @@ def foo(
627627
verify(Before, Expected)
628628

629629

630+
def test_keep_alloc_tensor():
631+
@I.ir_module
632+
class Before:
633+
@R.function
634+
def foo(x: R.Tensor((2, 3), dtype="float32")):
635+
tmp_buf1 = R.builtin.alloc_tensor(R.shape([64]), R.dtype("int32"), R.prim_value(0))
636+
tmp_buf2 = R.builtin.alloc_tensor(R.shape([64]), R.dtype("int32"), R.prim_value(0))
637+
out = R.add(tmp_buf1, tmp_buf2)
638+
return out
639+
640+
Expected = Before
641+
642+
verify(Before, Expected)
643+
644+
645+
def test_keep_alloc_storage():
646+
@I.ir_module
647+
class Before:
648+
@R.function
649+
def foo(x: R.Tensor((2, 3), dtype="float32")):
650+
tmp_storage1 = R.vm.alloc_storage(R.shape([64]), runtime_device_index=0, dtype="uint8")
651+
tmp_buf1 = R.vm.alloc_tensor(tmp_storage1, offset=0, shape=R.shape([64]), dtype="int32")
652+
tmp_storage2 = R.vm.alloc_storage(R.shape([64]), runtime_device_index=0, dtype="uint8")
653+
tmp_buf2 = R.vm.alloc_tensor(tmp_storage2, offset=0, shape=R.shape([64]), dtype="int32")
654+
out = R.add(tmp_buf1, tmp_buf2)
655+
return out
656+
657+
Expected = Before
658+
659+
verify(Before, Expected)
660+
661+
630662
if __name__ == "__main__":
631663
tvm.testing.main()

0 commit comments

Comments
 (0)