Skip to content

Commit d284cf4

Browse files
author
Josh Fromm
authored
[Relax][Frontend][NN] Add support for Conv3D (#16654)
1 parent ad3722f commit d284cf4

File tree

6 files changed

+306
-25
lines changed

6 files changed

+306
-25
lines changed

python/tvm/relax/frontend/nn/modules.py

Lines changed: 95 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -218,38 +218,40 @@ def __init__( # pylint: disable=too-many-arguments
218218
self,
219219
in_channels: int,
220220
out_channels: int,
221-
kernel_size: int,
221+
kernel_size: Union[List[int], int],
222222
stride: int = 1,
223223
padding: int = 0,
224224
dilation: int = 1,
225225
groups: int = 1,
226226
bias: bool = True,
227227
dtype: Optional[str] = None,
228+
data_layout: str = "NCHW",
228229
):
229230
super().__init__()
230231
self.in_channels = in_channels
231232
self.out_channels = out_channels
232-
self.kernel_size = kernel_size
233233
self.stride = stride
234234
self.padding = padding
235235
self.dilation = dilation
236236
self.groups = groups
237+
self.data_layout = data_layout
237238

238239
# Allow dynamic input channels.
239240
if isinstance(self.in_channels, int):
240241
in_channels = int(self.in_channels / self.groups)
241242
else:
242243
in_channels = tir.floordiv(self.in_channels, self.groups)
243244

244-
self.weight = Parameter(
245-
(
246-
self.out_channels,
247-
in_channels,
248-
self.kernel_size,
249-
self.kernel_size,
250-
),
251-
dtype,
252-
)
245+
# Expand kernel size if provided an integer.
246+
if isinstance(kernel_size, int):
247+
self.kernel_size = [kernel_size] * 2
248+
else:
249+
self.kernel_size = kernel_size
250+
251+
kernel_shape = [self.out_channels, in_channels] + list(self.kernel_size)
252+
253+
self.weight = Parameter(kernel_shape, dtype)
254+
253255
if bias:
254256
self.bias = Parameter((self.out_channels,), dtype)
255257
else:
@@ -270,7 +272,88 @@ def forward(self, x: Tensor) -> Tensor: # pylint: disable=invalid-name
270272
The output tensor for the conv2d layer.
271273
"""
272274
return op.conv2d(
273-
x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups
275+
x,
276+
self.weight,
277+
self.bias,
278+
self.stride,
279+
self.padding,
280+
self.dilation,
281+
self.groups,
282+
self.data_layout,
283+
)
284+
285+
286+
class Conv3D(Module):
287+
"""
288+
Module for conv3d layer.
289+
"""
290+
291+
def __init__( # pylint: disable=too-many-arguments
292+
self,
293+
in_channels: int,
294+
out_channels: int,
295+
kernel_size: Union[List[int], int],
296+
stride: Union[List[int], int] = 1,
297+
padding: Union[List[int], int] = 0,
298+
dilation: int = 1,
299+
groups: int = 1,
300+
bias: bool = True,
301+
dtype: Optional[str] = None,
302+
data_layout: str = "NCDHW",
303+
):
304+
super().__init__()
305+
self.in_channels = in_channels
306+
self.out_channels = out_channels
307+
self.stride = stride
308+
self.padding = padding
309+
self.dilation = dilation
310+
self.groups = groups
311+
self.data_layout = data_layout
312+
313+
# Allow dynamic input channels.
314+
if isinstance(self.in_channels, int):
315+
in_channels = int(self.in_channels / self.groups)
316+
else:
317+
in_channels = tir.floordiv(self.in_channels, self.groups)
318+
319+
# Expand kernel size if given an integer.
320+
if isinstance(kernel_size, int):
321+
self.kernel_size = [kernel_size] * 3
322+
else:
323+
self.kernel_size = kernel_size
324+
325+
kernel_shape = [self.out_channels, self.in_channels] + list(self.kernel_size)
326+
327+
self.weight = Parameter(kernel_shape, dtype)
328+
329+
if bias:
330+
self.bias = Parameter((self.out_channels,), dtype)
331+
else:
332+
self.bias = None
333+
334+
def forward(self, x: Tensor) -> Tensor: # pylint: disable=invalid-name
335+
"""
336+
Forward method for conv3d layer.
337+
338+
Parameters
339+
----------
340+
x : Tensor
341+
The input tensor.
342+
343+
Returns
344+
-------
345+
ret : Tensor
346+
The output tensor for the conv3d layer.
347+
"""
348+
return op.conv3d(
349+
x,
350+
self.weight,
351+
self.bias,
352+
self.stride,
353+
self.padding,
354+
self.dilation,
355+
self.groups,
356+
self.data_layout,
274357
)
275358

276359

python/tvm/relax/frontend/nn/op.py

Lines changed: 97 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,7 @@ def conv2d(
371371
padding: Optional[Union[int, Tuple, str]] = 0,
372372
dilation: Optional[Union[int, Tuple]] = 1,
373373
groups: Optional[int] = 1,
374+
data_layout: Optional[str] = "NCHW",
374375
name: str = "conv2d",
375376
) -> Tensor:
376377
"""Applies a 2D convolution over an input image composed of sevaral input planes
@@ -399,6 +400,9 @@ def conv2d(
399400
groups : Optional[int]
400401
Split input into a number of groups.
401402
403+
data_layout : Optional[str]
404+
Layout of input and output data.
405+
402406
name : str
403407
Name hint.
404408
@@ -408,15 +412,89 @@ def conv2d(
408412
The computed result with shape [B, O, oH, oW].
409413
"""
410414
conv_out = _op.nn.conv2d(
415+
data=x._expr,
416+
weight=weight._expr,
417+
strides=stride,
418+
padding=padding,
419+
dilation=dilation,
420+
data_layout=data_layout,
421+
groups=groups,
422+
)
423+
if bias is not None:
424+
if data_layout == "NCHW":
425+
conv_out = _op.add(conv_out, _op.reshape(bias._expr, [1, -1, 1, 1]))
426+
elif data_layout == "NHWC":
427+
conv_out = _op.add(conv_out, _op.reshape(bias._expr, [1, 1, 1, -1]))
428+
else:
429+
raise NotImplementedError(f"Dont know how to handle layout {data_layout}.")
430+
431+
return wrap_nested(conv_out, name)
432+
433+
434+
def conv3d(
435+
x: Tensor,
436+
weight: Tensor,
437+
bias: Optional[Tensor] = None,
438+
stride: Optional[Union[int, Tuple]] = 1,
439+
padding: Optional[Union[int, Tuple, str]] = 0,
440+
dilation: Optional[Union[int, Tuple]] = 1,
441+
groups: Optional[int] = 1,
442+
data_layout: Optional[str] = "NCDHW",
443+
name: str = "conv3d",
444+
) -> Tensor:
445+
"""Applies a 3D convolution over an input image composed of sevaral input planes
446+
447+
Parameters
448+
----------
449+
x : Tensor
450+
Input tensor of shape [B, N, D, H, W]
451+
452+
weight : Tensor
453+
Filters of shape [O, N/groups, kD, kH, kW]
454+
455+
bias : Optional[Tensor]
456+
Optional bias tensor of shape [O].
457+
458+
stride : Optional[Union[int, Tuple]]
459+
The stride of the convolving kernel. Can be a single number
460+
or tuple of (sD, sH, sW).
461+
462+
padding : Optional[[Union[int, Tuple]]]
463+
Implicit paddings on both sides of the input.
464+
465+
dilation : Optional[Union[int, Tuple]]
466+
The spacing between kernel elements. Can be a single number of tuple (dD, dH, dW).
467+
468+
groups : Optional[int]
469+
Split input into a number of groups.
470+
471+
data_layout : Optional[str]
472+
Optional layout of the input and output data.
473+
474+
name : str
475+
Name hint.
476+
477+
Returns
478+
-------
479+
result : Tensor
480+
The computed result with shape [B, O, oD, oH, oW].
481+
"""
482+
conv_out = _op.nn.conv3d(
411483
data=x._expr,
412484
weight=weight._expr,
413485
strides=stride,
414486
padding=padding,
415487
dilation=dilation,
416488
groups=groups,
489+
data_layout=data_layout,
417490
)
418491
if bias is not None:
419-
conv_out = _op.add(conv_out, _op.reshape(bias._expr, [1, -1, 1, 1]))
492+
if data_layout == "NCDHW":
493+
conv_out = _op.add(conv_out, _op.reshape(bias._expr, [1, -1, 1, 1, 1]))
494+
elif data_layout == "NDHWC":
495+
conv_out = _op.add(conv_out, _op.reshape(bias._expr, [1, 1, 1, 1, -1]))
496+
else:
497+
raise NotImplementedError(f"Dont know how to handle layout {data_layout}.")
420498

421499
return wrap_nested(conv_out, name)
422500

@@ -1427,6 +1505,7 @@ def interpolate(
14271505
align_corners: Optional[bool] = None,
14281506
recompute_scale_factor: Optional[bool] = None,
14291507
antialias: Optional[bool] = None,
1508+
data_layout: Optional[str] = "NCHW",
14301509
name: str = "interpolate",
14311510
):
14321511
"""Resize a tensor using the specified mode.
@@ -1448,6 +1527,8 @@ def interpolate(
14481527
Recompute the scale_factor for use in interpolation.
14491528
antialias : Optional[bool]
14501529
Apply antialiasing to output.
1530+
data_layout : Optional[str]
1531+
Layout of the input and output data.
14511532
name : str
14521533
Name hint for this operation.
14531534
@@ -1460,11 +1541,14 @@ def interpolate(
14601541
assert antialias is None, "antialias is not supported."
14611542

14621543
if size is None:
1463-
shape = x.shape
1464-
if isinstance(scale_factor, (list, tuple)):
1465-
size = tuple(int(shape[i] * scale_factor[i]) for i in range(2, len(shape)))
1466-
else:
1467-
size = tuple(int(shape[i] * scale_factor) for i in range(2, len(shape)))
1544+
size = []
1545+
for i, dim in enumerate(data_layout):
1546+
# Only upscale spatial dimensions.
1547+
if dim not in ["N", "C"]:
1548+
if isinstance(scale_factor, (list, tuple)):
1549+
size.append(int(x.shape[i] * scale_factor[len(size)]))
1550+
else:
1551+
size.append(int(x.shape[i] * scale_factor))
14681552

14691553
if mode.startswith("nearest"):
14701554
mode = "nearest_neighbor"
@@ -1480,7 +1564,11 @@ def interpolate(
14801564

14811565
return wrap_nested(
14821566
_op.image.resize2d(
1483-
x._expr, size, layout="NCHW", method=mode, coordinate_transformation_mode=coord_trans
1567+
x._expr,
1568+
size,
1569+
layout=data_layout,
1570+
method=mode,
1571+
coordinate_transformation_mode=coord_trans,
14841572
),
14851573
name,
14861574
)
@@ -1991,6 +2079,8 @@ def where(condition: Tensor, x1: Tensor, x2: Tensor, name: str = "where") -> Ten
19912079
result : Tensor
19922080
The result tensor.
19932081
"""
2082+
# Cast condition to boolean.
2083+
condition = astype(condition, "bool")
19942084
return wrap_nested(_op.where(condition._expr, x1._expr, x2._expr), name)
19952085

19962086

python/tvm/relax/op/op_attrs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,11 @@ class Conv2DAttrs(Attrs):
5959
"""Attributes for nn.conv2d"""
6060

6161

62+
@tvm._ffi.register_object("relax.attrs.Conv3DAttrs")
63+
class Conv3DAttrs(Attrs):
64+
"""Attributes for nn.conv3d"""
65+
66+
6267
@tvm._ffi.register_object("relax.attrs.Conv2DTransposeAttrs")
6368
class Conv2DTransposeAttrs(Attrs):
6469
"""Attributes for nn.conv2d_transpose"""

src/relax/op/image/resize.cc

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,14 +105,26 @@ StructInfo InferStructInfoResize2D(const Call& call, const BlockBuilder& ctx) {
105105
InferLayoutOutput InferLayoutResize2d(const Call& call,
106106
const Map<String, Array<String>>& desired_layouts,
107107
const VarLayoutMap& var_layout_map) {
108-
ICHECK(NoDesiredLayout(call, desired_layouts));
108+
const auto& it = desired_layouts.find("relax.image.resize2d");
109109
const auto* attrs = call->attrs.as<Resize2DAttrs>();
110110
ICHECK(attrs) << "Invalid Call";
111111

112-
LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]);
112+
LayoutDecision data_layout;
113113
ObjectPtr<Resize2DAttrs> new_attrs = make_object<Resize2DAttrs>(*attrs);
114-
new_attrs->layout = TransposeLike(attrs->layout, InitialLayout(4), layout->layout).name();
115-
return InferLayoutOutput({layout, InitialNLayout(call->args[1])}, {layout}, Attrs(new_attrs));
114+
115+
if (it != desired_layouts.end()) {
116+
// We have a desired layout for resize2d.
117+
Layout desired_data_layout = (*it).second[0];
118+
ICHECK_EQ(desired_data_layout.ndim(), desired_data_layout.ndim_primal()) << "Axis swap only";
119+
data_layout = TransposeLike(InitialLayout(4), attrs->layout, desired_data_layout);
120+
new_attrs->layout = (*it).second[0];
121+
} else {
122+
// We dont have a desired layout for resize2d, propagate from the input instead.
123+
data_layout = GetLayoutDecision(var_layout_map, call->args[0]);
124+
new_attrs->layout = TransposeLike(attrs->layout, InitialLayout(4), data_layout->layout).name();
125+
}
126+
return InferLayoutOutput({data_layout, InitialNLayout(call->args[1])}, {data_layout},
127+
Attrs(new_attrs));
116128
}
117129

118130
TVM_REGISTER_OP("relax.image.resize2d")

tests/python/relax/test_frontend_nn_modules.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,39 @@ def forward(
246246
assert_structural_equal(tvm_mod["forward"], forward, True)
247247

248248

249+
def test_conv3d():
250+
@R.function
251+
def forward(
252+
x: R.Tensor((1, 3, 32, 32, 32), dtype="float32"),
253+
_io: R.Object,
254+
weight: R.Tensor((32, 3, 3, 3, 3), dtype="float32"),
255+
bias: R.Tensor((32,), dtype="float32"),
256+
) -> R.Tuple(R.Tensor((1, 32, 30, 30, 30), dtype="float32"), R.Tuple(R.Object)):
257+
R.func_attr({"num_input": 2})
258+
with R.dataflow():
259+
lv1: R.Tensor((1, 32, 30, 30, 30), dtype="float32") = R.nn.conv3d(x, weight)
260+
lv2: R.Tensor((1, 32, 1, 1, 1), dtype="float32") = R.reshape(
261+
bias, R.shape([1, 32, 1, 1, 1])
262+
)
263+
conv3d: R.Tensor((1, 32, 30, 30, 30), dtype="float32") = R.add(lv1, lv2)
264+
gv1: R.Tuple(
265+
R.Tensor((1, 32, 30, 30, 30), dtype="float32"), R.Tuple(R.Object)
266+
) = conv3d, (_io,)
267+
R.output(gv1)
268+
return gv1
269+
270+
mod = modules.Conv3D(3, 32, 3, bias=True)
271+
tvm_mod, _ = mod.export_tvm(
272+
spec={
273+
"forward": {
274+
"x": spec.Tensor([1, 3, 32, 32, 32], "float32"),
275+
}
276+
},
277+
debug=True,
278+
)
279+
assert_structural_equal(tvm_mod["forward"], forward, True)
280+
281+
249282
def test_conv2d_dynamic():
250283
@R.function
251284
def forward(

0 commit comments

Comments
 (0)