Skip to content

Commit 32a6f01

Browse files
authored
[Relax][PyTorch] Improve ExportedProgram frontend by supporting unflatten.int, hardtanh_.default, dropout_.default, silu_.default, add_.Tensor and relu_.default (#17813)
* support `relu_.default` * support `add_.Tensor` * support `silu_.default` * support `dropout_.default` * support `hardswish_.default` * support `hardtanh_.default` * support `unflatten.int` * fix lint error
1 parent 0d2eab2 commit 32a6f01

File tree

2 files changed

+75
-2
lines changed

2 files changed

+75
-2
lines changed

python/tvm/relax/frontend/torch/exported_program_translator.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ class ExportedProgramImporter(BaseFXGraphImporter):
3939
def _hardtanh(self, node: fx.Node) -> relax.Expr:
4040
args = self.retrieve_args(node)
4141
x = args[0]
42-
min_val = node.args[1] if len(args) > 1 else node.kwargs("min_val", -1.0)
43-
max_val = node.args[2] if len(args) > 2 else node.kwargs("max_val", 1.0)
42+
min_val = node.args[1] if len(args) > 1 else node.kwargs.get("min_val", -1.0)
43+
max_val = node.args[2] if len(args) > 2 else node.kwargs.get("max_val", 1.0)
4444
return self.block_builder.emit(relax.op.clip(x, min_val, max_val))
4545

4646
def _log2(self, node: fx.Node) -> relax.Var:
@@ -216,6 +216,19 @@ def _slice(self, node: fx.Node) -> relax.Var:
216216
stride = [node.args[4] if len(node.args) > 4 else 1]
217217
return self.block_builder.emit(relax.op.strided_slice(x, axes, begin, end, stride))
218218

219+
def _unflatten(self, node: fx.Node) -> relax.Var:
220+
args = self.retrieve_args(node)
221+
x = args[0]
222+
dim = node.args[1]
223+
sizes = node.args[2]
224+
225+
x_shape = list(self.shape_of(x))
226+
if dim < 0:
227+
dim += len(x_shape)
228+
229+
new_shape = x_shape[:dim] + sizes + x_shape[dim + 1 :]
230+
return self.block_builder.emit(relax.op.reshape(x, new_shape))
231+
219232
########## Creation ##########
220233

221234
def _one_hot(self, node: fx.Node) -> relax.Var:
@@ -258,14 +271,17 @@ def create_convert_map(
258271
"cos.default": self._unary_op(relax.op.cos),
259272
"cosh.default": self._unary_op(relax.op.cosh),
260273
"dropout.default": lambda node: self.env[node.args[0]],
274+
"dropout_.default": lambda node: self.env[node.args[0]],
261275
"elu.default": self._elu,
262276
"erf.default": self._unary_op(relax.op.erf),
263277
"exp.default": self._unary_op(relax.op.exp),
264278
"floor.default": self._unary_op(relax.op.floor),
265279
"gelu.default": self._gelu,
266280
"hardsigmoid.default": self._hardsigmoid,
267281
"hardswish.default": self._hardswish,
282+
"hardswish_.default": self._hardswish,
268283
"hardtanh.default": self._hardtanh,
284+
"hardtanh_.default": self._hardtanh,
269285
"isfinite.default": self._unary_op(relax.op.isfinite),
270286
"isinf.default": self._unary_op(relax.op.isinf),
271287
"isnan.default": self._unary_op(relax.op.isnan),
@@ -278,12 +294,14 @@ def create_convert_map(
278294
"neg.default": self._unary_op(relax.op.negative),
279295
"reciprocal.default": self._reciprocal,
280296
"relu.default": self._unary_op(relax.op.nn.relu),
297+
"relu_.default": self._unary_op(relax.op.nn.relu),
281298
"round.default": self._round,
282299
"rsqrt.default": self._unary_op(relax.op.rsqrt),
283300
"selu.default": self._unary_op(relax.op.nn.selu),
284301
"sigmoid.default": self._unary_op(relax.op.sigmoid),
285302
"sign.default": self._unary_op(relax.op.sign),
286303
"silu.default": self._unary_op(relax.op.nn.silu),
304+
"silu_.default": self._unary_op(relax.op.nn.silu),
287305
"sin.default": self._unary_op(relax.op.sin),
288306
"sinh.default": self._unary_op(relax.op.sinh),
289307
"softmax.int": self._softmax,
@@ -296,6 +314,7 @@ def create_convert_map(
296314
"triu.default": self._tril_triu(relax.op.triu),
297315
# binary
298316
"add.Tensor": self._binary_op(relax.op.add, operator.add),
317+
"add_.Tensor": self._binary_op(relax.op.add, operator.add),
299318
"div.Tensor": self._binary_op(relax.op.divide, operator.truediv),
300319
"eq.Scalar": self._binary_op(relax.op.equal, operator.eq),
301320
"eq.Tensor": self._binary_op(relax.op.equal, operator.eq),
@@ -393,6 +412,7 @@ def create_convert_map(
393412
"tile.default": self._tile,
394413
"topk.default": self._topk,
395414
"transpose.int": self._transpose,
415+
"unflatten.int": self._unflatten,
396416
"unsqueeze.default": lambda node: self.block_builder.emit(
397417
relax.op.expand_dims(self.env[node.args[0]], node.args[1])
398418
),

tests/python/relax/test_frontend_from_exported_program.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,10 @@ class Dropout2(Module):
254254
def forward(self, input):
255255
return torch.dropout(input, 0.5, train=True)
256256

257+
class Dropout3(Module):
258+
def forward(self, input):
259+
return torch.ops.aten.dropout_(input, 0.5, train=True)
260+
257261
@tvm.script.ir_module
258262
class expected_dropout:
259263
@R.function
@@ -268,6 +272,7 @@ def main(
268272

269273
verify_model(Dropout1(), example_args, {}, expected_dropout)
270274
verify_model(Dropout2(), example_args, {}, expected_dropout)
275+
verify_model(Dropout3(), example_args, {}, expected_dropout)
271276

272277
# elu
273278
class Elu(Module):
@@ -383,6 +388,10 @@ class Hardswish2(torch.nn.Module):
383388
def forward(self, input):
384389
return torch.nn.functional.hardswish(input)
385390

391+
class Hardswish3(torch.nn.Module):
392+
def forward(self, input):
393+
return torch.ops.aten.hardswish_(input)
394+
386395
@tvm.script.ir_module
387396
class expected1:
388397
@R.function
@@ -402,6 +411,7 @@ def main(
402411

403412
verify_model(Hardswish(), example_args, {}, expected1)
404413
verify_model(Hardswish2(), example_args, {}, expected1)
414+
verify_model(Hardswish3(), example_args, {}, expected1)
405415

406416
# hardtanh
407417
test_hardtanh()
@@ -511,6 +521,10 @@ class ReLU1(Module):
511521
def forward(self, input):
512522
return torch.nn.functional.relu(input)
513523

524+
class ReLU2(Module):
525+
def forward(self, input):
526+
return torch.ops.aten.relu_(input)
527+
514528
@tvm.script.ir_module
515529
class expected_relu:
516530
@R.function
@@ -526,6 +540,7 @@ def main(
526540

527541
verify_model(ReLU0(), example_args, {}, expected_relu)
528542
verify_model(ReLU1(), example_args, {}, expected_relu)
543+
verify_model(ReLU2(), example_args, {}, expected_relu)
529544

530545
# selu
531546
class Selu1(Module):
@@ -597,6 +612,10 @@ class SiLU2(Module):
597612
def forward(self, input):
598613
return torch.nn.functional.silu(input)
599614

615+
class SiLU3(Module):
616+
def forward(self, input):
617+
return torch.ops.aten.silu_(input)
618+
600619
@tvm.script.ir_module
601620
class expected_silu:
602621
@R.function
@@ -612,6 +631,7 @@ def main(
612631

613632
verify_model(SiLU(), example_args, {}, expected_silu)
614633
verify_model(SiLU2(), example_args, {}, expected_silu)
634+
verify_model(SiLU3(), example_args, {}, expected_silu)
615635

616636
# softmax
617637
test_softmax()
@@ -636,6 +656,10 @@ class Hardtanh2(torch.nn.Module):
636656
def forward(self, input):
637657
return torch.nn.functional.hardtanh(input)
638658

659+
class Hardtanh3(torch.nn.Module):
660+
def forward(self, input):
661+
return torch.ops.aten.hardtanh_(input)
662+
639663
@tvm.script.ir_module
640664
class expected1:
641665
@R.function
@@ -653,6 +677,7 @@ def main(
653677
example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
654678
verify_model(Hardtanh(), example_args, {}, expected1)
655679
verify_model(Hardtanh2(), example_args, {}, expected1)
680+
verify_model(Hardtanh3(), example_args, {}, expected1)
656681

657682

658683
def test_leakyrelu():
@@ -845,6 +870,7 @@ def main(
845870

846871
operator_binary_1 = [
847872
(operator.add, R.add),
873+
(torch.ops.aten.add_, R.add),
848874
(operator.sub, R.subtract),
849875
(operator.mul, R.multiply),
850876
(operator.truediv, R.divide),
@@ -3603,6 +3629,33 @@ def main(
36033629
verify_model(Select(), example_args, {}, Expected)
36043630

36053631

3632+
def test_unflatten():
3633+
class Unflatten(Module):
3634+
def forward(self, input):
3635+
return torch.ops.aten.unflatten(input, 1, (3, 5))
3636+
3637+
class Unflatten1(Module):
3638+
def forward(self, input):
3639+
return torch.ops.aten.unflatten(input, -2, (3, 5))
3640+
3641+
@tvm.script.ir_module
3642+
class Expected:
3643+
@R.function
3644+
def main(
3645+
inp_0: R.Tensor((2, 15, 7), dtype="float32"),
3646+
) -> R.Tuple(R.Tensor((2, 3, 5, 7), dtype="float32")):
3647+
with R.dataflow():
3648+
lv: R.Tensor((2, 3, 5, 7), dtype="float32") = R.reshape(inp_0, [2, 3, 5, 7])
3649+
gv: R.Tuple(R.Tensor((2, 3, 5, 7), dtype="float32")) = (lv,)
3650+
R.output(gv)
3651+
return gv
3652+
3653+
example_args = (torch.randn(2, 15, 7, dtype=torch.float32),)
3654+
3655+
verify_model(Unflatten(), example_args, {}, Expected)
3656+
verify_model(Unflatten1(), example_args, {}, Expected)
3657+
3658+
36063659
def test_gather():
36073660
class Gather0(Module):
36083661
def forward(self, data, indices):

0 commit comments

Comments
 (0)