Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 65 additions & 48 deletions tests/python/relax/test_frontend_from_exported_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -5222,17 +5222,17 @@ def forward(self, data):
class Expected:
@R.function
def main(
inp_0: R.Tensor((5,), dtype="float32"),
data: R.Tensor((5,), dtype="float32"),
) -> R.Tuple(R.Tensor((5,), dtype="float32")):
with R.dataflow():
lv: R.Tensor((5,), dtype="float32") = R.zeros_like(inp_0, dtype="void")
lv: R.Tensor((5,), dtype="float32") = R.zeros(R.shape([5]), dtype="float32")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For consistency with other tests for zero-creation operators like test_zeros, it would be better to use R.full here. torch.empty_like is decomposed to aten.zeros, and in other tests torch.zeros is decomposed to aten.full which is then translated to R.full. Using R.full directly would make the expected IR more canonical and consistent across these tests.

Suggested change
lv: R.Tensor((5,), dtype="float32") = R.zeros(R.shape([5]), dtype="float32")
lv: R.Tensor((5,), dtype="float32") = R.full(R.shape([5]), R.const(0.0, "float32"), dtype="float32")

gv: R.Tuple(R.Tensor((5,), dtype="float32")) = (lv,)
R.output(gv)
return gv

example_args = (torch.randn(5, dtype=torch.float32),)

verify_model(EmptyLike(), example_args, {}, Expected)
verify_model(EmptyLike(), example_args, {}, Expected, run_ep_decomposition=True)


def test_one_hot():
Expand All @@ -5244,19 +5244,22 @@ def forward(self, indices):
class Expected:
@R.function
def main(
inp_0: R.Tensor((5,), dtype="int64"),
indices: R.Tensor((5,), dtype="int64"),
) -> R.Tuple(R.Tensor((5, 10), dtype="int64")):
with R.dataflow():
lv: R.Tensor((5, 10), dtype="int64") = R.one_hot(
inp_0, R.prim_value(1), R.prim_value(0), depth=10, axis=-1
lv: R.Tensor((10,), dtype="int64") = R.arange(
R.prim_value(0), R.prim_value(10), R.prim_value(1), dtype="int64"
)
gv: R.Tuple(R.Tensor((5, 10), dtype="int64")) = (lv,)
lv1: R.Tensor((5, 1), dtype="int64") = R.expand_dims(indices, axis=[-1])
lv2: R.Tensor((5, 10), dtype="bool") = R.equal(lv1, lv)
lv3: R.Tensor((5, 10), dtype="int64") = R.astype(lv2, dtype="int64")
gv: R.Tuple(R.Tensor((5, 10), dtype="int64")) = (lv3,)
R.output(gv)
return gv

example_args = (torch.randint(0, 10, (5,), dtype=torch.int64),)

verify_model(OneHot(), example_args, {}, Expected)
verify_model(OneHot(), example_args, {}, Expected, run_ep_decomposition=True)


def test_ones_like():
Expand All @@ -5271,14 +5274,16 @@ def main(
input: R.Tensor((128, 128), dtype="float32")
) -> R.Tuple(R.Tensor((128, 128), dtype="float32")):
with R.dataflow():
lv: R.Tensor((128, 128), dtype="float32") = R.ones_like(input, dtype="void")
lv: R.Tensor((128, 128), dtype="float32") = R.full_like(
input, R.const(1, "int32"), dtype="void"
)
gv: R.Tuple(R.Tensor((128, 128), dtype="float32")) = (lv,)
R.output(gv)
return gv

example_args = (torch.rand(128, 128, dtype=torch.float32),)

verify_model(OnesLike(), example_args, {}, Expected)
verify_model(OnesLike(), example_args, {}, Expected, run_ep_decomposition=True)


def test_zero_inplace():
Expand All @@ -5291,16 +5296,23 @@ class Expected:
@R.function
def main(
input: R.Tensor((128, 128), dtype="float32")
) -> R.Tuple(R.Tensor((128, 128), dtype="float32")):
) -> R.Tuple(R.Tensor((128, 128), dtype="float32"), R.Tensor((128, 128), dtype="float32")):
with R.dataflow():
lv: R.Tensor((128, 128), dtype="float32") = R.zeros_like(input, dtype="void")
gv: R.Tuple(R.Tensor((128, 128), dtype="float32")) = (lv,)
lv: R.Tensor((128, 128), dtype="float32") = R.full_like(
input, R.const(0, "int32"), dtype="void"
)
gv: R.Tuple(
R.Tensor((128, 128), dtype="float32"), R.Tensor((128, 128), dtype="float32")
) = (
lv,
lv,
)
R.output(gv)
return gv

example_args = (torch.rand(128, 128, dtype=torch.float32),)

verify_model(ZeroInplace(), example_args, {}, Expected)
verify_model(ZeroInplace(), example_args, {}, Expected, run_ep_decomposition=True)


def test_zeros():
Expand All @@ -5315,14 +5327,16 @@ def main(
input: R.Tensor((128, 128), dtype="float32")
) -> R.Tuple(R.Tensor((5, 2), dtype="float32")):
with R.dataflow():
lv: R.Tensor((5, 2), dtype="float32") = R.zeros(R.shape([5, 2]), dtype="float32")
lv: R.Tensor((5, 2), dtype="float32") = R.full(
R.shape([5, 2]), R.const(0.0, "float32"), dtype="float32"
)
gv: R.Tuple(R.Tensor((5, 2), dtype="float32")) = (lv,)
R.output(gv)
return gv

example_args = (torch.rand(128, 128, dtype=torch.float32),)

verify_model(Zeros(), example_args, {}, Expected)
verify_model(Zeros(), example_args, {}, Expected, run_ep_decomposition=True)


def test_zeros_like():
Expand All @@ -5337,13 +5351,15 @@ def main(
input: R.Tensor((128, 128), dtype="float32")
) -> R.Tuple(R.Tensor((128, 128), dtype="float32")):
with R.dataflow():
lv: R.Tensor((128, 128), dtype="float32") = R.zeros_like(input, dtype="void")
lv: R.Tensor((128, 128), dtype="float32") = R.full_like(
input, R.const(0, "int32"), dtype="void"
)
gv: R.Tuple(R.Tensor((128, 128), dtype="float32")) = (lv,)
R.output(gv)
return gv

example_args = (torch.rand(128, 128, dtype=torch.float32),)
verify_model(ZerosLike(), example_args, {}, Expected)
verify_model(ZerosLike(), example_args, {}, Expected, run_ep_decomposition=True)


def test_type_as():
Expand All @@ -5369,7 +5385,7 @@ def main(
torch.rand(128, 128, dtype=torch.float16),
)

verify_model(TypeAs(), example_args, {}, Expected)
verify_model(TypeAs(), example_args, {}, Expected, run_ep_decomposition=True)


def test_select():
Expand All @@ -5391,7 +5407,7 @@ def main(

example_args = (torch.randn(2, 3, dtype=torch.float32),)

verify_model(Select(), example_args, {}, Expected)
verify_model(Select(), example_args, {}, Expected, run_ep_decomposition=True)


def test_unflatten():
Expand All @@ -5417,8 +5433,8 @@ def main(

example_args = (torch.randn(2, 15, 7, dtype=torch.float32),)

verify_model(Unflatten(), example_args, {}, Expected)
verify_model(Unflatten1(), example_args, {}, Expected)
verify_model(Unflatten(), example_args, {}, Expected, run_ep_decomposition=True)
verify_model(Unflatten1(), example_args, {}, Expected, run_ep_decomposition=True)


def test_gather():
Expand Down Expand Up @@ -5495,10 +5511,10 @@ def main(
torch.randint(0, 3, (2, 3), dtype=torch.int64),
)

verify_model(Gather0(), example_args, {}, Expected0)
verify_model(Gather1(), example_args, {}, Expected1)
verify_model(Gather2(), example_args, {}, Expected2)
verify_model(Gather3(), example_args, {}, Expected3)
verify_model(Gather0(), example_args, {}, Expected0, run_ep_decomposition=True)
verify_model(Gather1(), example_args, {}, Expected1, run_ep_decomposition=True)
verify_model(Gather2(), example_args, {}, Expected2, run_ep_decomposition=True)
verify_model(Gather3(), example_args, {}, Expected3, run_ep_decomposition=True)


def test_index_put():
Expand Down Expand Up @@ -5669,11 +5685,11 @@ def main(
return gv

# Run verification for each case
verify_model(IndexPut1D(), example_args_1d, {}, Expected1D)
verify_model(IndexPut2D(), example_args_2d, {}, Expected2D)
verify_model(IndexPut3D(), example_args_3d, {}, Expected3D)
verify_model(IndexPut4D(), example_args_4d, {}, Expected4D)
verify_model(IndexPut5D(), example_args_5d, {}, Expected5D)
verify_model(IndexPut1D(), example_args_1d, {}, Expected1D, run_ep_decomposition=True)
verify_model(IndexPut2D(), example_args_2d, {}, Expected2D, run_ep_decomposition=True)
verify_model(IndexPut3D(), example_args_3d, {}, Expected3D, run_ep_decomposition=True)
verify_model(IndexPut4D(), example_args_4d, {}, Expected4D, run_ep_decomposition=True)
verify_model(IndexPut5D(), example_args_5d, {}, Expected5D, run_ep_decomposition=True)


def test_flip():
Expand Down Expand Up @@ -5711,8 +5727,8 @@ def main(

example_args = (torch.randn(2, 2, dtype=torch.float32),)

verify_model(Flip0(), example_args, {}, Expected0)
verify_model(Flip1(), example_args, {}, Expected1)
verify_model(Flip0(), example_args, {}, Expected0, run_ep_decomposition=True)
verify_model(Flip1(), example_args, {}, Expected1, run_ep_decomposition=True)


def test_take():
Expand All @@ -5724,12 +5740,12 @@ def forward(self, data, indices):
class Expected:
@R.function
def main(
inp_0: R.Tensor((5,), dtype="float32"),
inp_1: R.Tensor((3,), dtype="int64"),
data: R.Tensor((5,), dtype="float32"),
indices: R.Tensor((3,), dtype="int64"),
) -> R.Tuple(R.Tensor((3,), dtype="float32")):
with R.dataflow():
lv: R.Tensor((3,), dtype="int32") = R.astype(inp_1, dtype="int32")
lv1: R.Tensor((3,), dtype="float32") = R.take(inp_0, lv, axis=None)
lv: R.Tensor((5,), dtype="float32") = R.reshape(data, R.shape([5]))
lv1: R.Tensor((3,), dtype="float32") = R.index_tensor(lv, (indices,))
gv: R.Tuple(R.Tensor((3,), dtype="float32")) = (lv1,)
R.output(gv)
return gv
Expand All @@ -5739,7 +5755,7 @@ def main(
torch.randint(0, 5, (3,), dtype=torch.int64),
)

verify_model(Take(), example_args, {}, Expected)
verify_model(Take(), example_args, {}, Expected, run_ep_decomposition=True)


def test_std():
Expand All @@ -5751,16 +5767,17 @@ def forward(self, x):
class Expected:
@R.function
def main(
inp_0: R.Tensor((5, 3), dtype="float32"),
x: R.Tensor((5, 3), dtype="float32"),
) -> R.Tuple(R.Tensor((), dtype="float32")):
with R.dataflow():
lv: R.Tensor((), dtype="float32") = R.std(inp_0, axis=None, keepdims=False)
gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,)
lv: R.Tensor((), dtype="float32") = R.variance(x, axis=None, keepdims=False)
lv1: R.Tensor((), dtype="float32") = R.sqrt(lv)
gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv1,)
R.output(gv)
return gv

example_args = (torch.randn(5, 3, dtype=torch.float32),)
verify_model(Std(), example_args, {}, Expected)
verify_model(Std(), example_args, {}, Expected, run_ep_decomposition=True)


def test_var():
Expand All @@ -5772,16 +5789,16 @@ def forward(self, x):
class Expected:
@R.function
def main(
inp_0: R.Tensor((5, 3), dtype="float32"),
x: R.Tensor((5, 3), dtype="float32"),
) -> R.Tuple(R.Tensor((), dtype="float32")):
with R.dataflow():
lv: R.Tensor((), dtype="float32") = R.variance(inp_0, axis=None, keepdims=False)
lv: R.Tensor((), dtype="float32") = R.variance(x, axis=None, keepdims=False)
gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,)
R.output(gv)
return gv

example_args = (torch.randn(5, 3, dtype=torch.float32),)
verify_model(Var(), example_args, {}, Expected)
verify_model(Var(), example_args, {}, Expected, run_ep_decomposition=True)


def test_prod():
Expand All @@ -5793,16 +5810,16 @@ def forward(self, x):
class Expected:
@R.function
def main(
inp_0: R.Tensor((5, 3), dtype="float32"),
x: R.Tensor((5, 3), dtype="float32"),
) -> R.Tuple(R.Tensor((), dtype="float32")):
with R.dataflow():
lv: R.Tensor((), dtype="float32") = R.prod(inp_0, axis=None, keepdims=False)
lv: R.Tensor((), dtype="float32") = R.prod(x, axis=None, keepdims=False)
gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,)
R.output(gv)
return gv

example_args = (torch.randn(5, 3, dtype=torch.float32),)
verify_model(Prod(), example_args, {}, Expected)
verify_model(Prod(), example_args, {}, Expected, run_ep_decomposition=True)


def test_cumprod():
Expand Down
Loading