diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 4c9480b58748..a83d2692bddb 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1343,6 +1343,23 @@ def _index_select(self, node: fx.Node) -> relax.Var: index = self.env[node.args[2]] return self.block_builder.emit(relax.op.take(x, index, dim)) + def _inplace_masked_fill(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + mask = self.env[node.args[1]] + value = node.args[2] + rx_value = relax.const(value) + values = self.block_builder.emit(relax.op.full_like(x, rx_value)) + output = self.block_builder.emit(relax.op.where(mask, values, x)) + self.env[node.args[0]] = output + return output + + def _masked_fill(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + mask = self.env[node.args[1]] + rx_value = relax.const(node.args[2]) + values = self.block_builder.emit(relax.op.full_like(x, rx_value)) + return self.block_builder.emit(relax.op.where(mask, values, x)) + def _new_ones(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) self_var = args[0] diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index c82a5e2b1100..57570d9810e4 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -447,6 +447,7 @@ def create_convert_map( "full_like.default": self._full_like, "index_select.default": self._index_select, "lift_fresh_copy.default": self._to_copy, + "masked_fill.Scalar": self._masked_fill, "new_ones.default": self._new_ones, "one_hot.default": self._one_hot, "ones.default": self._ones, diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 297529e8bf29..50a3dc4a208d 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -468,23 +468,6 @@ def _inplace_fill(self, node: fx.Node) -> relax.Var: self.env[node.args[0]] = filled return filled - def _inplace_masked_fill(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - mask = self.env[node.args[1]] - value = node.args[2] - rx_value = relax.const(value) - values = self.block_builder.emit(relax.op.full_like(x, rx_value)) - output = self.block_builder.emit(relax.op.where(mask, values, x)) - self.env[node.args[0]] = output - return output - - def _masked_fill(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - mask = self.env[node.args[1]] - rx_value = relax.const(node.args[2]) - values = self.block_builder.emit(relax.op.full_like(x, rx_value)) - return self.block_builder.emit(relax.op.where(mask, values, x)) - def _masked_scatter(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] mask = self.env[node.args[1]] diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 26d3d3f7bde2..43a7f7af5372 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -3224,6 +3224,30 @@ def main( verify_model(Fill(), example_args, {}, Expected) +def test_masked_fill(): + class Masked_Fill(Module): + def forward(self, input: torch.Tensor, mask: torch.Tensor): + return torch.masked_fill(input, mask, 0) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + input: R.Tensor((128, 128), dtype="float32"), mask: R.Tensor((128, 128), dtype="bool") + ) -> R.Tuple(R.Tensor((128, 128), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((128, 128), dtype="float32") = R.full_like( + input, R.const(0, "int32"), dtype="void" + ) + lv1: R.Tensor((128, 128), dtype="float32") = R.where(mask, lv, input) + gv: R.Tuple(R.Tensor((128, 128), dtype="float32")) = (lv1,) + R.output(gv) + return gv + + example_args = (torch.randn(128, 128, dtype=torch.float32), torch.rand(128, 128) < 0.5) + verify_model(Masked_Fill(), example_args, {}, Expected) + + def test_new_ones(): class NewOnes(Module): def forward(self, x):