Skip to content

Commit bc896a5

Browse files
Deivanayaki-Sdeivanayakisankaralingam
andauthored
[Relax][PyTorch] Add Adaptive AvgPool 1D and 3D Op Support for Exported Program and FX graph (#17922)
* add mappings for adap avgpool 1d and 3d op and its test script * fix lint issues * add egde case handling in fx translator and fix test scripts * remove line space to fix the lint issue * fix test script issues --------- Co-authored-by: deivanayakisankaralingam <deiva@Deivanayaki>
1 parent 1f8103e commit bc896a5

File tree

6 files changed

+561
-2
lines changed

6 files changed

+561
-2
lines changed

python/tvm/relax/frontend/torch/base_fx_graph_translator.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -510,12 +510,53 @@ def _linalg_vector_norm(self, node: fx.Node) -> relax.Var:
510510

511511
########## Neural Network ##########
512512

513+
def _adaptive_avg_pool1d(self, node: fx.Node) -> relax.Var:
514+
x = self.env[node.args[0]]
515+
output_size = node.args[1] if len(node.args) > 1 else node.kwargs["output_size"]
516+
# Expand to 3D by adding batch dim if input is 2D
517+
x_ndim = x.struct_info.ndim
518+
if x_ndim == 2:
519+
x = relax.op.expand_dims(x, axis=0)
520+
521+
result = self.block_builder.emit(
522+
relax.op.nn.adaptive_avg_pool1d(x, output_size, layout="NCW")
523+
)
524+
# Remove added batch dim from result
525+
if x_ndim == 2:
526+
result = relax.op.squeeze(result, axis=[0])
527+
return result
528+
513529
def _adaptive_avg_pool2d(self, node: fx.Node) -> relax.Var:
514530
x = self.env[node.args[0]]
515531
output_size = node.args[1]
516-
return self.block_builder.emit(
532+
# Expand to 4D by adding batch dim if input is 3D
533+
x_ndim = x.struct_info.ndim
534+
if x_ndim == 3:
535+
x = relax.op.expand_dims(x, axis=0)
536+
537+
result = self.block_builder.emit(
517538
relax.op.nn.adaptive_avg_pool2d(x, output_size, layout="NCHW")
518539
)
540+
# Remove added batch dim from result
541+
if x_ndim == 3:
542+
result = relax.op.squeeze(result, axis=[0])
543+
return result
544+
545+
def _adaptive_avg_pool3d(self, node: fx.Node) -> relax.Var:
546+
x = self.env[node.args[0]]
547+
output_size = node.args[1]
548+
# Expand to 5D by adding batch dim if input is 4D
549+
x_ndim = x.struct_info.ndim
550+
if x_ndim == 4:
551+
x = relax.op.expand_dims(x, axis=0)
552+
553+
result = self.block_builder.emit(
554+
relax.op.nn.adaptive_avg_pool3d(x, output_size, layout="NCDHW")
555+
)
556+
# Remove added batch dim from result
557+
if x_ndim == 4:
558+
result = relax.op.squeeze(result, axis=[0])
559+
return result
519560

520561
def _addmm(self, node: fx.Node) -> relax.Var:
521562
x = self.env[node.args[0]]

python/tvm/relax/frontend/torch/exported_program_translator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,9 @@ def create_convert_map(
397397
"_native_batch_norm_legit_functional.default": self._batch_norm_legit_functional,
398398
"_native_batch_norm_legit_no_training.default": self._batch_norm_legit_no_training,
399399
"batch_norm.default": self._batch_norm_legit_no_training,
400+
"adaptive_avg_pool1d.default": self._adaptive_avg_pool1d,
400401
"adaptive_avg_pool2d.default": self._adaptive_avg_pool2d,
402+
"adaptive_avg_pool3d.default": self._adaptive_avg_pool3d,
401403
"addmm.default": self._addmm,
402404
"avg_pool2d.default": self._avg_pool2d,
403405
"baddbmm.default": self._baddbmm,

python/tvm/relax/frontend/torch/fx_translator.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,13 +182,53 @@ def call_binary_op(op, lhs, rhs):
182182

183183
########## Neural Network ##########
184184

185+
def _adaptive_avg_pool1d_module(self, node: fx.Node) -> relax.Var:
186+
module = self.named_modules[node.target]
187+
x = self.env[node.args[0]]
188+
output_size = module.output_size
189+
# Expand to 3D by adding batch dim if input is 2D
190+
x_ndim = x.struct_info.ndim
191+
if x_ndim == 2:
192+
x = relax.op.expand_dims(x, axis=0)
193+
result = self.block_builder.emit(
194+
relax.op.nn.adaptive_avg_pool1d(x, output_size, layout="NCW") # (N, C, L)
195+
)
196+
# Remove added batch dim from result
197+
if x_ndim == 2:
198+
result = relax.op.squeeze(result, axis=[0])
199+
return result
200+
185201
def _adaptive_avg_pool2d_module(self, node: fx.Node) -> relax.Var:
186202
module = self.named_modules[node.target]
187203
x = self.env[node.args[0]]
188204
output_size = module.output_size
189-
return self.block_builder.emit(
205+
# Expand to 4D by adding batch dim if input is 3D
206+
x_ndim = x.struct_info.ndim
207+
if x_ndim == 3:
208+
x = relax.op.expand_dims(x, axis=0)
209+
result = self.block_builder.emit(
190210
relax.op.nn.adaptive_avg_pool2d(x, output_size, layout="NCHW")
191211
)
212+
# Remove added batch dim from result
213+
if x_ndim == 3:
214+
result = relax.op.squeeze(result, axis=[0])
215+
return result
216+
217+
def _adaptive_avg_pool3d_module(self, node: fx.Node) -> relax.Var:
218+
module = self.named_modules[node.target]
219+
x = self.env[node.args[0]]
220+
output_size = module.output_size
221+
# Expand to 5D by adding batch dim if input is 4D
222+
x_ndim = x.struct_info.ndim
223+
if x_ndim == 4:
224+
x = relax.op.expand_dims(x, axis=0)
225+
result = self.block_builder.emit(
226+
relax.op.nn.adaptive_avg_pool3d(x, output_size, layout="NCDHW") # (N, C, D, H, W)
227+
)
228+
# Remove added batch dim from result
229+
if x_ndim == 4:
230+
result = relax.op.squeeze(result, axis=[0])
231+
return result
192232

193233
def _avg_pool2d_module(self, node: fx.Node) -> relax.Var:
194234
x = self.env[node.args[0]]
@@ -668,7 +708,9 @@ def create_convert_map(
668708
nn.Softplus: self._softplus_module,
669709
nn.Tanh: self._unary_op(relax.op.tanh),
670710
# neural network
711+
nn.AdaptiveAvgPool1d: self._adaptive_avg_pool1d_module,
671712
nn.AdaptiveAvgPool2d: self._adaptive_avg_pool2d_module,
713+
nn.AdaptiveAvgPool3d: self._adaptive_avg_pool3d_module,
672714
nn.AvgPool2d: self._avg_pool2d_module,
673715
nn.BatchNorm2d: self._batch_norm_2d_module,
674716
nn.Conv1d: self._conv1d_module,
@@ -778,7 +820,9 @@ def create_convert_map(
778820
"truediv": self._binary_op(relax.op.divide, operator.truediv),
779821
"xor": self._binary_op(relax.op.bitwise_xor, operator.xor),
780822
# neural network
823+
"adaptive_avg_pool1d": self._adaptive_avg_pool1d,
781824
"adaptive_avg_pool2d": self._adaptive_avg_pool2d,
825+
"adaptive_avg_pool3d": self._adaptive_avg_pool3d,
782826
"addmm": self._addmm,
783827
"avg_pool2d": self._avg_pool2d,
784828
"baddbmm": self._baddbmm,

tests/python/relax/test_frontend_from_exported_program.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1211,6 +1211,38 @@ def main(
12111211
verify_model(model, example_args, binding, expected1)
12121212

12131213

1214+
def test_adaptive_avgpool1d():
1215+
class AdaptiveAvgPool1d0(torch.nn.Module):
1216+
def __init__(self):
1217+
super().__init__()
1218+
self.pool = torch.nn.AdaptiveAvgPool1d(output_size=5)
1219+
1220+
def forward(self, input):
1221+
return self.pool(input)
1222+
1223+
class AdaptiveAvgPool1d1(torch.nn.Module):
1224+
def forward(self, input):
1225+
return torch.nn.functional.adaptive_avg_pool1d(input, output_size=5)
1226+
1227+
@tvm.script.ir_module
1228+
class expected1:
1229+
@R.function
1230+
def main(
1231+
input_1: R.Tensor((1, 3, 10), dtype="float32")
1232+
) -> R.Tuple(R.Tensor((1, 3, 5), dtype="float32")):
1233+
with R.dataflow():
1234+
lv: R.Tensor((1, 3, 5), dtype="float32") = R.nn.adaptive_avg_pool1d(
1235+
input_1, output_size=[5], layout="NCW"
1236+
)
1237+
gv: R.Tuple(R.Tensor((1, 3, 5), dtype="float32")) = (lv,)
1238+
R.output(gv)
1239+
return gv
1240+
1241+
example_args = (torch.randn(1, 3, 10, dtype=torch.float32),)
1242+
verify_model(AdaptiveAvgPool1d0(), example_args, {}, expected1)
1243+
verify_model(AdaptiveAvgPool1d1(), example_args, {}, expected1)
1244+
1245+
12141246
def test_adaptive_avgpool2d():
12151247
class AdaptiveAvgPool2d0(Module):
12161248
def __init__(self):
@@ -1244,6 +1276,38 @@ def main(
12441276
verify_model(AdaptiveAvgPool2d1(), example_args, {}, expected1)
12451277

12461278

1279+
def test_adaptive_avgpool3d():
1280+
class AdaptiveAvgPool3d0(torch.nn.Module):
1281+
def __init__(self):
1282+
super().__init__()
1283+
self.pool = torch.nn.AdaptiveAvgPool3d([4, 4, 4])
1284+
1285+
def forward(self, input):
1286+
return self.pool(input)
1287+
1288+
class AdaptiveAvgPool3d1(torch.nn.Module):
1289+
def forward(self, input):
1290+
return torch.nn.functional.adaptive_avg_pool3d(input, [4, 4, 4])
1291+
1292+
@tvm.script.ir_module
1293+
class expected1:
1294+
@R.function
1295+
def main(
1296+
input_1: R.Tensor((1, 3, 8, 8, 8), dtype="float32")
1297+
) -> R.Tuple(R.Tensor((1, 3, 4, 4, 4), dtype="float32")):
1298+
with R.dataflow():
1299+
lv: R.Tensor((1, 3, 4, 4, 4), dtype="float32") = R.nn.adaptive_avg_pool3d(
1300+
input_1, output_size=[4, 4, 4], layout="NCDHW", out_layout="NCDHW"
1301+
)
1302+
gv: R.Tuple(R.Tensor((1, 3, 4, 4, 4), dtype="float32")) = (lv,)
1303+
R.output(gv)
1304+
return gv
1305+
1306+
example_args = (torch.randn(1, 3, 8, 8, 8, dtype=torch.float32),)
1307+
verify_model(AdaptiveAvgPool3d0(), example_args, {}, expected1)
1308+
verify_model(AdaptiveAvgPool3d1(), example_args, {}, expected1)
1309+
1310+
12471311
def test_addmm():
12481312
class Addmm1(Module):
12491313
def __init__(self):

tests/python/relax/test_frontend_from_fx.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1381,6 +1381,39 @@ def main(input_1: R.Tensor((1, 3, 10, 10), dtype="float32")):
13811381
verify_model(AvgPool2d4(), input_info, {}, expected3)
13821382

13831383

1384+
def test_adaptive_avgpool1d():
1385+
input_info = [([1, 3, 16], "float32")]
1386+
1387+
class AdaptiveAvgPool1d0(torch.nn.Module):
1388+
def __init__(self):
1389+
super().__init__()
1390+
self.pool = torch.nn.AdaptiveAvgPool1d(8)
1391+
1392+
def forward(self, input):
1393+
return self.pool(input)
1394+
1395+
class AdaptiveAvgPool1d1(torch.nn.Module):
1396+
def forward(self, input):
1397+
return torch.nn.functional.adaptive_avg_pool1d(input, 8)
1398+
1399+
@tvm.script.ir_module
1400+
class expected1:
1401+
@R.function
1402+
def main(
1403+
input_1: R.Tensor((1, 3, 16), dtype="float32")
1404+
) -> R.Tensor((1, 3, 8), dtype="float32"):
1405+
with R.dataflow():
1406+
lv: R.Tensor((1, 3, 8), dtype="float32") = R.nn.adaptive_avg_pool1d(
1407+
input_1, output_size=[8], layout="NCW", out_layout="NCW"
1408+
)
1409+
gv: R.Tensor((1, 3, 8), dtype="float32") = lv
1410+
R.output(gv)
1411+
return gv
1412+
1413+
verify_model(AdaptiveAvgPool1d0(), input_info, {}, expected1)
1414+
verify_model(AdaptiveAvgPool1d1(), input_info, {}, expected1)
1415+
1416+
13841417
def test_adaptive_avgpool2d():
13851418
input_info = [([1, 3, 10, 10], "float32")]
13861419

@@ -1415,6 +1448,39 @@ def main(
14151448
verify_model(AdaptiveAvgPool2d1(), input_info, {}, expected1)
14161449

14171450

1451+
def test_adaptive_avgpool3d():
1452+
input_info = [([1, 3, 16, 16, 16], "float32")]
1453+
1454+
class AdaptiveAvgPool3d0(torch.nn.Module):
1455+
def __init__(self):
1456+
super().__init__()
1457+
self.pool = torch.nn.AdaptiveAvgPool3d((8, 8, 8))
1458+
1459+
def forward(self, input):
1460+
return self.pool(input)
1461+
1462+
class AdaptiveAvgPool3d1(torch.nn.Module):
1463+
def forward(self, input):
1464+
return torch.nn.functional.adaptive_avg_pool3d(input, (8, 8, 8))
1465+
1466+
@tvm.script.ir_module
1467+
class expected1:
1468+
@R.function
1469+
def main(
1470+
input_1: R.Tensor((1, 3, 16, 16, 16), dtype="float32")
1471+
) -> R.Tensor((1, 3, 8, 8, 8), dtype="float32"):
1472+
with R.dataflow():
1473+
lv: R.Tensor((1, 3, 8, 8, 8), dtype="float32") = R.nn.adaptive_avg_pool3d(
1474+
input_1, output_size=[8, 8, 8], layout="NCDHW", out_layout="NCDHW"
1475+
)
1476+
gv: R.Tensor((1, 3, 8, 8, 8), dtype="float32") = lv
1477+
R.output(gv)
1478+
return gv
1479+
1480+
verify_model(AdaptiveAvgPool3d0(), input_info, {}, expected1)
1481+
verify_model(AdaptiveAvgPool3d1(), input_info, {}, expected1)
1482+
1483+
14181484
def test_flatten():
14191485
input_info = [([1, 3, 10, 10], "float32")]
14201486

0 commit comments

Comments
 (0)