Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 95 additions & 12 deletions python/tvm/relax/frontend/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,38 +218,40 @@ def __init__( # pylint: disable=too-many-arguments
self,
in_channels: int,
out_channels: int,
kernel_size: int,
kernel_size: Union[List[int], int],
stride: int = 1,
padding: int = 0,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
dtype: Optional[str] = None,
data_layout: str = "NCHW",
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
self.data_layout = data_layout

# Allow dynamic input channels.
if isinstance(self.in_channels, int):
in_channels = int(self.in_channels / self.groups)
else:
in_channels = tir.floordiv(self.in_channels, self.groups)

self.weight = Parameter(
(
self.out_channels,
in_channels,
self.kernel_size,
self.kernel_size,
),
dtype,
)
# Expand kernel size if provided an integer.
if isinstance(kernel_size, int):
self.kernel_size = [kernel_size] * 2
else:
self.kernel_size = kernel_size

kernel_shape = [self.out_channels, in_channels] + list(self.kernel_size)

self.weight = Parameter(kernel_shape, dtype)

if bias:
self.bias = Parameter((self.out_channels,), dtype)
else:
Expand All @@ -270,7 +272,88 @@ def forward(self, x: Tensor) -> Tensor: # pylint: disable=invalid-name
The output tensor for the conv2d layer.
"""
return op.conv2d(
x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups
x,
self.weight,
self.bias,
self.stride,
self.padding,
self.dilation,
self.groups,
self.data_layout,
)


class Conv3D(Module):
"""
Module for conv3d layer.
"""

def __init__( # pylint: disable=too-many-arguments
self,
in_channels: int,
out_channels: int,
kernel_size: Union[List[int], int],
stride: Union[List[int], int] = 1,
padding: Union[List[int], int] = 0,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
dtype: Optional[str] = None,
data_layout: str = "NCDHW",
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
self.data_layout = data_layout

# Allow dynamic input channels.
if isinstance(self.in_channels, int):
in_channels = int(self.in_channels / self.groups)
else:
in_channels = tir.floordiv(self.in_channels, self.groups)

# Expand kernel size if given an integer.
if isinstance(kernel_size, int):
self.kernel_size = [kernel_size] * 3
else:
self.kernel_size = kernel_size

kernel_shape = [self.out_channels, self.in_channels] + list(self.kernel_size)

self.weight = Parameter(kernel_shape, dtype)

if bias:
self.bias = Parameter((self.out_channels,), dtype)
else:
self.bias = None

def forward(self, x: Tensor) -> Tensor: # pylint: disable=invalid-name
"""
Forward method for conv3d layer.

Parameters
----------
x : Tensor
The input tensor.

Returns
-------
ret : Tensor
The output tensor for the conv3d layer.
"""
return op.conv3d(
x,
self.weight,
self.bias,
self.stride,
self.padding,
self.dilation,
self.groups,
self.data_layout,
)


Expand Down
104 changes: 97 additions & 7 deletions python/tvm/relax/frontend/nn/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,7 @@ def conv2d(
padding: Optional[Union[int, Tuple, str]] = 0,
dilation: Optional[Union[int, Tuple]] = 1,
groups: Optional[int] = 1,
data_layout: Optional[str] = "NCHW",
name: str = "conv2d",
) -> Tensor:
"""Applies a 2D convolution over an input image composed of sevaral input planes
Expand Down Expand Up @@ -399,6 +400,9 @@ def conv2d(
groups : Optional[int]
Split input into a number of groups.

data_layout : Optional[str]
Layout of input and output data.

name : str
Name hint.

Expand All @@ -408,15 +412,89 @@ def conv2d(
The computed result with shape [B, O, oH, oW].
"""
conv_out = _op.nn.conv2d(
data=x._expr,
weight=weight._expr,
strides=stride,
padding=padding,
dilation=dilation,
data_layout=data_layout,
groups=groups,
)
if bias is not None:
if data_layout == "NCHW":
conv_out = _op.add(conv_out, _op.reshape(bias._expr, [1, -1, 1, 1]))
elif data_layout == "NHWC":
conv_out = _op.add(conv_out, _op.reshape(bias._expr, [1, 1, 1, -1]))
else:
raise NotImplementedError(f"Dont know how to handle layout {data_layout}.")

return wrap_nested(conv_out, name)


def conv3d(
x: Tensor,
weight: Tensor,
bias: Optional[Tensor] = None,
stride: Optional[Union[int, Tuple]] = 1,
padding: Optional[Union[int, Tuple, str]] = 0,
dilation: Optional[Union[int, Tuple]] = 1,
groups: Optional[int] = 1,
data_layout: Optional[str] = "NCDHW",
name: str = "conv3d",
) -> Tensor:
"""Applies a 3D convolution over an input image composed of sevaral input planes

Parameters
----------
x : Tensor
Input tensor of shape [B, N, D, H, W]

weight : Tensor
Filters of shape [O, N/groups, kD, kH, kW]

bias : Optional[Tensor]
Optional bias tensor of shape [O].

stride : Optional[Union[int, Tuple]]
The stride of the convolving kernel. Can be a single number
or tuple of (sD, sH, sW).

padding : Optional[[Union[int, Tuple]]]
Implicit paddings on both sides of the input.

dilation : Optional[Union[int, Tuple]]
The spacing between kernel elements. Can be a single number of tuple (dD, dH, dW).

groups : Optional[int]
Split input into a number of groups.

data_layout : Optional[str]
Optional layout of the input and output data.

name : str
Name hint.

Returns
-------
result : Tensor
The computed result with shape [B, O, oD, oH, oW].
"""
conv_out = _op.nn.conv3d(
data=x._expr,
weight=weight._expr,
strides=stride,
padding=padding,
dilation=dilation,
groups=groups,
data_layout=data_layout,
)
if bias is not None:
conv_out = _op.add(conv_out, _op.reshape(bias._expr, [1, -1, 1, 1]))
if data_layout == "NCDHW":
conv_out = _op.add(conv_out, _op.reshape(bias._expr, [1, -1, 1, 1, 1]))
elif data_layout == "NDHWC":
conv_out = _op.add(conv_out, _op.reshape(bias._expr, [1, 1, 1, 1, -1]))
else:
raise NotImplementedError(f"Dont know how to handle layout {data_layout}.")

return wrap_nested(conv_out, name)

Expand Down Expand Up @@ -1427,6 +1505,7 @@ def interpolate(
align_corners: Optional[bool] = None,
recompute_scale_factor: Optional[bool] = None,
antialias: Optional[bool] = None,
data_layout: Optional[str] = "NCHW",
name: str = "interpolate",
):
"""Resize a tensor using the specified mode.
Expand All @@ -1448,6 +1527,8 @@ def interpolate(
Recompute the scale_factor for use in interpolation.
antialias : Optional[bool]
Apply antialiasing to output.
data_layout : Optional[str]
Layout of the input and output data.
name : str
Name hint for this operation.

Expand All @@ -1460,11 +1541,14 @@ def interpolate(
assert antialias is None, "antialias is not supported."

if size is None:
shape = x.shape
if isinstance(scale_factor, (list, tuple)):
size = tuple(int(shape[i] * scale_factor[i]) for i in range(2, len(shape)))
else:
size = tuple(int(shape[i] * scale_factor) for i in range(2, len(shape)))
size = []
for i, dim in enumerate(data_layout):
# Only upscale spatial dimensions.
if dim not in ["N", "C"]:
if isinstance(scale_factor, (list, tuple)):
size.append(int(x.shape[i] * scale_factor[len(size)]))
else:
size.append(int(x.shape[i] * scale_factor))

if mode.startswith("nearest"):
mode = "nearest_neighbor"
Expand All @@ -1480,7 +1564,11 @@ def interpolate(

return wrap_nested(
_op.image.resize2d(
x._expr, size, layout="NCHW", method=mode, coordinate_transformation_mode=coord_trans
x._expr,
size,
layout=data_layout,
method=mode,
coordinate_transformation_mode=coord_trans,
),
name,
)
Expand Down Expand Up @@ -1991,6 +2079,8 @@ def where(condition: Tensor, x1: Tensor, x2: Tensor, name: str = "where") -> Ten
result : Tensor
The result tensor.
"""
# Cast condition to boolean.
condition = astype(condition, "bool")
return wrap_nested(_op.where(condition._expr, x1._expr, x2._expr), name)


Expand Down
5 changes: 5 additions & 0 deletions python/tvm/relax/op/op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ class Conv2DAttrs(Attrs):
"""Attributes for nn.conv2d"""


@tvm._ffi.register_object("relax.attrs.Conv3DAttrs")
class Conv3DAttrs(Attrs):
"""Attributes for nn.conv3d"""


@tvm._ffi.register_object("relax.attrs.Conv2DTransposeAttrs")
class Conv2DTransposeAttrs(Attrs):
"""Attributes for nn.conv2d_transpose"""
Expand Down
20 changes: 16 additions & 4 deletions src/relax/op/image/resize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,14 +105,26 @@ StructInfo InferStructInfoResize2D(const Call& call, const BlockBuilder& ctx) {
InferLayoutOutput InferLayoutResize2d(const Call& call,
const Map<String, Array<String>>& desired_layouts,
const VarLayoutMap& var_layout_map) {
ICHECK(NoDesiredLayout(call, desired_layouts));
const auto& it = desired_layouts.find("relax.image.resize2d");
const auto* attrs = call->attrs.as<Resize2DAttrs>();
ICHECK(attrs) << "Invalid Call";

LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]);
LayoutDecision data_layout;
ObjectPtr<Resize2DAttrs> new_attrs = make_object<Resize2DAttrs>(*attrs);
new_attrs->layout = TransposeLike(attrs->layout, InitialLayout(4), layout->layout).name();
return InferLayoutOutput({layout, InitialNLayout(call->args[1])}, {layout}, Attrs(new_attrs));

if (it != desired_layouts.end()) {
// We have a desired layout for resize2d.
Layout desired_data_layout = (*it).second[0];
ICHECK_EQ(desired_data_layout.ndim(), desired_data_layout.ndim_primal()) << "Axis swap only";
data_layout = TransposeLike(InitialLayout(4), attrs->layout, desired_data_layout);
new_attrs->layout = (*it).second[0];
} else {
// We dont have a desired layout for resize2d, propagate from the input instead.
data_layout = GetLayoutDecision(var_layout_map, call->args[0]);
new_attrs->layout = TransposeLike(attrs->layout, InitialLayout(4), data_layout->layout).name();
}
return InferLayoutOutput({data_layout, InitialNLayout(call->args[1])}, {data_layout},
Attrs(new_attrs));
}

TVM_REGISTER_OP("relax.image.resize2d")
Expand Down
33 changes: 33 additions & 0 deletions tests/python/relax/test_frontend_nn_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,39 @@ def forward(
assert_structural_equal(tvm_mod["forward"], forward, True)


def test_conv3d():
@R.function
def forward(
x: R.Tensor((1, 3, 32, 32, 32), dtype="float32"),
_io: R.Object,
weight: R.Tensor((32, 3, 3, 3, 3), dtype="float32"),
bias: R.Tensor((32,), dtype="float32"),
) -> R.Tuple(R.Tensor((1, 32, 30, 30, 30), dtype="float32"), R.Tuple(R.Object)):
R.func_attr({"num_input": 2})
with R.dataflow():
lv1: R.Tensor((1, 32, 30, 30, 30), dtype="float32") = R.nn.conv3d(x, weight)
lv2: R.Tensor((1, 32, 1, 1, 1), dtype="float32") = R.reshape(
bias, R.shape([1, 32, 1, 1, 1])
)
conv3d: R.Tensor((1, 32, 30, 30, 30), dtype="float32") = R.add(lv1, lv2)
gv1: R.Tuple(
R.Tensor((1, 32, 30, 30, 30), dtype="float32"), R.Tuple(R.Object)
) = conv3d, (_io,)
R.output(gv1)
return gv1

mod = modules.Conv3D(3, 32, 3, bias=True)
tvm_mod, _ = mod.export_tvm(
spec={
"forward": {
"x": spec.Tensor([1, 3, 32, 32, 32], "float32"),
}
},
debug=True,
)
assert_structural_equal(tvm_mod["forward"], forward, True)


def test_conv2d_dynamic():
@R.function
def forward(
Expand Down
Loading