Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 42 additions & 1 deletion python/tvm/relax/frontend/torch/base_fx_graph_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,12 +449,53 @@ def _isin(self, node: fx.Node) -> relax.Var:

########## Neural Network ##########

def _adaptive_avg_pool1d(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
output_size = node.args[1] if len(node.args) > 1 else node.kwargs["output_size"]
# Expand to 3D by adding batch dim if input is 2D
x_ndim = x.struct_info.ndim
if x_ndim == 2:
x = relax.op.expand_dims(x, axis=0)

result = self.block_builder.emit(
relax.op.nn.adaptive_avg_pool1d(x, output_size, layout="NCW")
)
# Remove added batch dim from result
if x_ndim == 2:
result = relax.op.squeeze(result, axis=[0])
return result

def _adaptive_avg_pool2d(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
output_size = node.args[1]
return self.block_builder.emit(
# Expand to 4D by adding batch dim if input is 3D
x_ndim = x.struct_info.ndim
if x_ndim == 3:
x = relax.op.expand_dims(x, axis=0)

result = self.block_builder.emit(
relax.op.nn.adaptive_avg_pool2d(x, output_size, layout="NCHW")
)
# Remove added batch dim from result
if x_ndim == 3:
result = relax.op.squeeze(result, axis=[0])
return result

def _adaptive_avg_pool3d(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
output_size = node.args[1]
# Expand to 5D by adding batch dim if input is 4D
x_ndim = x.struct_info.ndim
if x_ndim == 4:
x = relax.op.expand_dims(x, axis=0)

result = self.block_builder.emit(
relax.op.nn.adaptive_avg_pool3d(x, output_size, layout="NCDHW")
)
# Remove added batch dim from result
if x_ndim == 4:
result = relax.op.squeeze(result, axis=[0])
return result

def _addmm(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,9 @@ def create_convert_map(
"_native_batch_norm_legit_functional.default": self._batch_norm_legit_functional,
"_native_batch_norm_legit_no_training.default": self._batch_norm_legit_no_training,
"batch_norm.default": self._batch_norm_legit_no_training,
"adaptive_avg_pool1d.default": self._adaptive_avg_pool1d,
"adaptive_avg_pool2d.default": self._adaptive_avg_pool2d,
"adaptive_avg_pool3d.default": self._adaptive_avg_pool3d,
"addmm.default": self._addmm,
"avg_pool2d.default": self._avg_pool2d,
"baddbmm.default": self._baddbmm,
Expand Down
46 changes: 45 additions & 1 deletion python/tvm/relax/frontend/torch/fx_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,13 +182,53 @@ def call_binary_op(op, lhs, rhs):

########## Neural Network ##########

def _adaptive_avg_pool1d_module(self, node: fx.Node) -> relax.Var:
module = self.named_modules[node.target]
x = self.env[node.args[0]]
output_size = module.output_size
# Expand to 3D by adding batch dim if input is 2D
x_ndim = x.struct_info.ndim
if x_ndim == 2:
x = relax.op.expand_dims(x, axis=0)
result = self.block_builder.emit(
relax.op.nn.adaptive_avg_pool1d(x, output_size, layout="NCW") # (N, C, L)
)
# Remove added batch dim from result
if x_ndim == 2:
result = relax.op.squeeze(result, axis=[0])
return result

def _adaptive_avg_pool2d_module(self, node: fx.Node) -> relax.Var:
module = self.named_modules[node.target]
x = self.env[node.args[0]]
output_size = module.output_size
return self.block_builder.emit(
# Expand to 4D by adding batch dim if input is 3D
x_ndim = x.struct_info.ndim
if x_ndim == 3:
x = relax.op.expand_dims(x, axis=0)
result = self.block_builder.emit(
relax.op.nn.adaptive_avg_pool2d(x, output_size, layout="NCHW")
)
# Remove added batch dim from result
if x_ndim == 3:
result = relax.op.squeeze(result, axis=[0])
return result

def _adaptive_avg_pool3d_module(self, node: fx.Node) -> relax.Var:
module = self.named_modules[node.target]
x = self.env[node.args[0]]
output_size = module.output_size
# Expand to 5D by adding batch dim if input is 4D
x_ndim = x.struct_info.ndim
if x_ndim == 4:
x = relax.op.expand_dims(x, axis=0)
result = self.block_builder.emit(
relax.op.nn.adaptive_avg_pool3d(x, output_size, layout="NCDHW") # (N, C, D, H, W)
)
# Remove added batch dim from result
if x_ndim == 4:
result = relax.op.squeeze(result, axis=[0])
return result

def _avg_pool2d_module(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
Expand Down Expand Up @@ -649,7 +689,9 @@ def create_convert_map(
nn.Softplus: self._softplus_module,
nn.Tanh: self._unary_op(relax.op.tanh),
# neural network
nn.AdaptiveAvgPool1d: self._adaptive_avg_pool1d_module,
nn.AdaptiveAvgPool2d: self._adaptive_avg_pool2d_module,
nn.AdaptiveAvgPool3d: self._adaptive_avg_pool3d_module,
nn.AvgPool2d: self._avg_pool2d_module,
nn.BatchNorm2d: self._batch_norm_2d_module,
nn.Conv1d: self._conv1d_module,
Expand Down Expand Up @@ -755,7 +797,9 @@ def create_convert_map(
"truediv": self._binary_op(relax.op.divide, operator.truediv),
"xor": self._binary_op(relax.op.bitwise_xor, operator.xor),
# neural network
"adaptive_avg_pool1d": self._adaptive_avg_pool1d,
"adaptive_avg_pool2d": self._adaptive_avg_pool2d,
"adaptive_avg_pool3d": self._adaptive_avg_pool3d,
"addmm": self._addmm,
"avg_pool2d": self._avg_pool2d,
"baddbmm": self._baddbmm,
Expand Down
64 changes: 64 additions & 0 deletions tests/python/relax/test_frontend_from_exported_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -1145,6 +1145,38 @@ def main(
verify_model(model, example_args, binding, expected1)


def test_adaptive_avgpool1d():
class AdaptiveAvgPool1d0(torch.nn.Module):
def __init__(self):
super().__init__()
self.pool = torch.nn.AdaptiveAvgPool1d(output_size=5)

def forward(self, input):
return self.pool(input)

class AdaptiveAvgPool1d1(torch.nn.Module):
def forward(self, input):
return torch.nn.functional.adaptive_avg_pool1d(input, output_size=5)

@tvm.script.ir_module
class expected1:
@R.function
def main(
input_1: R.Tensor((1, 3, 10), dtype="float32")
) -> R.Tuple(R.Tensor((1, 3, 5), dtype="float32")):
with R.dataflow():
lv: R.Tensor((1, 3, 5), dtype="float32") = R.nn.adaptive_avg_pool1d(
input_1, output_size=[5], layout="NCW"
)
gv: R.Tuple(R.Tensor((1, 3, 5), dtype="float32")) = (lv,)
R.output(gv)
return gv

example_args = (torch.randn(1, 3, 10, dtype=torch.float32),)
verify_model(AdaptiveAvgPool1d0(), example_args, {}, expected1)
verify_model(AdaptiveAvgPool1d1(), example_args, {}, expected1)


def test_adaptive_avgpool2d():
class AdaptiveAvgPool2d0(Module):
def __init__(self):
Expand Down Expand Up @@ -1178,6 +1210,38 @@ def main(
verify_model(AdaptiveAvgPool2d1(), example_args, {}, expected1)


def test_adaptive_avgpool3d():
class AdaptiveAvgPool3d0(torch.nn.Module):
def __init__(self):
super().__init__()
self.pool = torch.nn.AdaptiveAvgPool3d([4, 4, 4])

def forward(self, input):
return self.pool(input)

class AdaptiveAvgPool3d1(torch.nn.Module):
def forward(self, input):
return torch.nn.functional.adaptive_avg_pool3d(input, [4, 4, 4])

@tvm.script.ir_module
class expected1:
@R.function
def main(
input_1: R.Tensor((1, 3, 8, 8, 8), dtype="float32")
) -> R.Tuple(R.Tensor((1, 3, 4, 4, 4), dtype="float32")):
with R.dataflow():
lv: R.Tensor((1, 3, 4, 4, 4), dtype="float32") = R.nn.adaptive_avg_pool3d(
input_1, output_size=[4, 4, 4], layout="NCDHW", out_layout="NCDHW"
)
gv: R.Tuple(R.Tensor((1, 3, 4, 4, 4), dtype="float32")) = (lv,)
R.output(gv)
return gv

example_args = (torch.randn(1, 3, 8, 8, 8, dtype=torch.float32),)
verify_model(AdaptiveAvgPool3d0(), example_args, {}, expected1)
verify_model(AdaptiveAvgPool3d1(), example_args, {}, expected1)


def test_addmm():
class Addmm1(Module):
def __init__(self):
Expand Down
66 changes: 66 additions & 0 deletions tests/python/relax/test_frontend_from_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1181,6 +1181,39 @@ def main(input_1: R.Tensor((1, 3, 10, 10), dtype="float32")):
verify_model(AvgPool2d4(), input_info, {}, expected3)


def test_adaptive_avgpool1d():
input_info = [([1, 3, 16], "float32")]

class AdaptiveAvgPool1d0(torch.nn.Module):
def __init__(self):
super().__init__()
self.pool = torch.nn.AdaptiveAvgPool1d(8)

def forward(self, input):
return self.pool(input)

class AdaptiveAvgPool1d1(torch.nn.Module):
def forward(self, input):
return torch.nn.functional.adaptive_avg_pool1d(input, 8)

@tvm.script.ir_module
class expected1:
@R.function
def main(
input_1: R.Tensor((1, 3, 16), dtype="float32")
) -> R.Tensor((1, 3, 8), dtype="float32"):
with R.dataflow():
lv: R.Tensor((1, 3, 8), dtype="float32") = R.nn.adaptive_avg_pool1d(
input_1, output_size=[8], layout="NCW", out_layout="NCW"
)
gv: R.Tensor((1, 3, 8), dtype="float32") = lv
R.output(gv)
return gv

verify_model(AdaptiveAvgPool1d0(), input_info, {}, expected1)
verify_model(AdaptiveAvgPool1d1(), input_info, {}, expected1)


def test_adaptive_avgpool2d():
input_info = [([1, 3, 10, 10], "float32")]

Expand Down Expand Up @@ -1215,6 +1248,39 @@ def main(
verify_model(AdaptiveAvgPool2d1(), input_info, {}, expected1)


def test_adaptive_avgpool3d():
input_info = [([1, 3, 16, 16, 16], "float32")]

class AdaptiveAvgPool3d0(torch.nn.Module):
def __init__(self):
super().__init__()
self.pool = torch.nn.AdaptiveAvgPool3d((8, 8, 8))

def forward(self, input):
return self.pool(input)

class AdaptiveAvgPool3d1(torch.nn.Module):
def forward(self, input):
return torch.nn.functional.adaptive_avg_pool3d(input, (8, 8, 8))

@tvm.script.ir_module
class expected1:
@R.function
def main(
input_1: R.Tensor((1, 3, 16, 16, 16), dtype="float32")
) -> R.Tensor((1, 3, 8, 8, 8), dtype="float32"):
with R.dataflow():
lv: R.Tensor((1, 3, 8, 8, 8), dtype="float32") = R.nn.adaptive_avg_pool3d(
input_1, output_size=[8, 8, 8], layout="NCDHW", out_layout="NCDHW"
)
gv: R.Tensor((1, 3, 8, 8, 8), dtype="float32") = lv
R.output(gv)
return gv

verify_model(AdaptiveAvgPool3d0(), input_info, {}, expected1)
verify_model(AdaptiveAvgPool3d1(), input_info, {}, expected1)


def test_flatten():
input_info = [([1, 3, 10, 10], "float32")]

Expand Down
Loading