Skip to content

Commit 19b66bf

Browse files
authored
[Relax][PyTorch] Add support for torchvision.ops.stochastic_depth (#17300)
* add a test for stochastic_depth * add support for torchvision.ops.stochastic_depth
1 parent e19541d commit 19b66bf

File tree

2 files changed

+33
-0
lines changed

2 files changed

+33
-0
lines changed

python/tvm/relax/frontend/torch/fx_translator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1672,6 +1672,7 @@ def create_convert_map(self):
16721672
"softmax": self._softmax,
16731673
"log_softmax": self._log_softmax,
16741674
"dropout": lambda node: self.env[node.args[0]],
1675+
"stochastic_depth": lambda node: self.env[node.args[0]],
16751676
"clamp": self._clamp,
16761677
"relu": lambda node: self.block_builder.emit(relax.op.nn.relu(self.env[node.args[0]])),
16771678
"leaky_relu": self._leakyrelu,

tests/python/relax/test_frontend_from_fx.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import torch.nn.functional as F
2020
from torch import fx
2121
from torch.nn import Module
22+
import torchvision
2223

2324
import tvm
2425
from tvm import relax
@@ -1212,6 +1213,37 @@ def main(
12121213
verify_model(Dropout2(), input_info, {}, expected1)
12131214

12141215

1216+
def test_stochastic_depth():
1217+
input_info = [([1, 3, 10, 10], "float32")]
1218+
1219+
class StochasticDepth1(Module):
1220+
def __init__(self):
1221+
super().__init__()
1222+
self.stochastic_depth = torchvision.ops.StochasticDepth(0.5, mode="row")
1223+
1224+
def forward(self, x):
1225+
return self.stochastic_depth(x)
1226+
1227+
class StochasticDepth2(Module):
1228+
def forward(self, x):
1229+
return torchvision.ops.stochastic_depth(x, 0.5, mode="row", training=False)
1230+
1231+
@tvm.script.ir_module
1232+
class expected1:
1233+
@R.function
1234+
def main(
1235+
input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
1236+
) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
1237+
# block 0
1238+
with R.dataflow():
1239+
gv: R.Tensor((1, 3, 10, 10), dtype="float32") = input_1
1240+
R.output(gv)
1241+
return gv
1242+
1243+
verify_model(StochasticDepth1(), input_info, {}, expected1)
1244+
verify_model(StochasticDepth2(), input_info, {}, expected1)
1245+
1246+
12151247
def test_layernorm():
12161248
input_info = [([1, 3, 10, 10], "float32")]
12171249

0 commit comments

Comments
 (0)