Skip to content

Commit a07aa4f

Browse files
committed
apply eric's patch
1 parent 0abe71d commit a07aa4f

File tree

5 files changed

+7
-10
lines changed

5 files changed

+7
-10
lines changed

python/tvm/relax/transform/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979
TopologicalSort,
8080
UpdateParamStructInfo,
8181
UpdateVDevice,
82+
VMBuiltinLower,
8283
VMShapeLower,
8384
dataflowblock_pass,
8485
function_pass,

src/relax/op/memory/view.cc

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -291,11 +291,6 @@ StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) {
291291

292292
TVM_REGISTER_GLOBAL("tvm.relax.struct_info.infer_view_sinfo").set_body_typed(InferStructInfoView);
293293

294-
Expr LegalizeView(const BlockBuilder& bb, const Call& call) {
295-
// No-op. View is lowered during the LowerBuiltinView pass.
296-
return call;
297-
}
298-
299294
Expr LowerBuiltinView(const BlockBuilder& bb, const Call& call) {
300295
Expr data = call->args[0];
301296
Expr shape = call->args[1];
@@ -357,7 +352,6 @@ TVM_REGISTER_OP("relax.memory.view")
357352
"The view's byte offset, relative to the input tensor's byte offset.")
358353
.set_attr<Bool>("RequiresArgumentShapes", Bool(false))
359354
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoView)
360-
.set_attr<FLegalize>("FLegalize", LegalizeView)
361355
.set_attr<Bool>("FPurity", Bool(true))
362356
.set_attr<FLowerBuiltin>("FLowerBuiltin", LowerBuiltinView);
363357

tests/python/relax/test_op_view.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,8 @@ def inferred_sinfo(A: R.Tensor, relative_byte_offset: R.Prim("int64")):
453453

454454

455455
def test_legalize_is_no_op():
456+
"""R.memory.view is not legalized until LowerRuntimeBuiltin"""
457+
456458
@I.ir_module
457459
class Before:
458460
@R.function
@@ -462,7 +464,7 @@ def main(A: R.Tensor([4096], "float32")):
462464

463465
Expected = Before
464466

465-
After = tvm.relax.transform.LowerRuntimeBuiltin()(Before)
467+
After = tvm.relax.transform.LegalizeOps()(Before)
466468
tvm.ir.assert_structural_equal(Expected, After)
467469

468470

tests/python/relax/test_transform_static_plan_block_memory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def main(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((10,), dtype="float32
185185
tvm.ir.assert_structural_equal(mod, Expected)
186186
mod = relax.transform.LowerAllocTensor()(mod)
187187
mod = relax.transform.KillAfterLastUse()(mod)
188-
mod = relax.transform.VMBuiltinLower()(mod)
188+
mod = relax.transform.LowerRuntimeBuiltin()(mod)
189189
tvm.ir.assert_structural_equal(mod, ExpectedLowered)
190190

191191

tests/python/relax/test_vm_builtin_lower.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor:
5757
gv0 = alloc
5858
return gv0
5959

60-
After = relax.transform.VMBuiltinLower()(Before)
60+
After = relax.transform.LowerRuntimeBuiltin()(Before)
6161
tvm.ir.assert_structural_equal(Expected, After)
6262

6363

@@ -79,7 +79,7 @@ def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor:
7979
return gv0
8080

8181
with pytest.raises(tvm.TVMError):
82-
relax.transform.VMBuiltinLower()(Before)
82+
relax.transform.LowerRuntimeBuiltin()(Before)
8383

8484

8585
if __name__ == "__main__":

0 commit comments

Comments
 (0)