@@ -371,6 +371,7 @@ def conv2d(
371371 padding : Optional [Union [int , Tuple , str ]] = 0 ,
372372 dilation : Optional [Union [int , Tuple ]] = 1 ,
373373 groups : Optional [int ] = 1 ,
374+ data_layout : Optional [str ] = "NCHW" ,
374375 name : str = "conv2d" ,
375376) -> Tensor :
376377 """Applies a 2D convolution over an input image composed of sevaral input planes
@@ -399,6 +400,9 @@ def conv2d(
399400 groups : Optional[int]
400401 Split input into a number of groups.
401402
403+ data_layout : Optional[str]
404+ Layout of input and output data.
405+
402406 name : str
403407 Name hint.
404408
@@ -408,15 +412,89 @@ def conv2d(
408412 The computed result with shape [B, O, oH, oW].
409413 """
410414 conv_out = _op .nn .conv2d (
415+ data = x ._expr ,
416+ weight = weight ._expr ,
417+ strides = stride ,
418+ padding = padding ,
419+ dilation = dilation ,
420+ data_layout = data_layout ,
421+ groups = groups ,
422+ )
423+ if bias is not None :
424+ if data_layout == "NCHW" :
425+ conv_out = _op .add (conv_out , _op .reshape (bias ._expr , [1 , - 1 , 1 , 1 ]))
426+ elif data_layout == "NHWC" :
427+ conv_out = _op .add (conv_out , _op .reshape (bias ._expr , [1 , 1 , 1 , - 1 ]))
428+ else :
429+ raise NotImplementedError (f"Dont know how to handle layout { data_layout } ." )
430+
431+ return wrap_nested (conv_out , name )
432+
433+
434+ def conv3d (
435+ x : Tensor ,
436+ weight : Tensor ,
437+ bias : Optional [Tensor ] = None ,
438+ stride : Optional [Union [int , Tuple ]] = 1 ,
439+ padding : Optional [Union [int , Tuple , str ]] = 0 ,
440+ dilation : Optional [Union [int , Tuple ]] = 1 ,
441+ groups : Optional [int ] = 1 ,
442+ data_layout : Optional [str ] = "NCDHW" ,
443+ name : str = "conv3d" ,
444+ ) -> Tensor :
445+ """Applies a 3D convolution over an input image composed of sevaral input planes
446+
447+ Parameters
448+ ----------
449+ x : Tensor
450+ Input tensor of shape [B, N, D, H, W]
451+
452+ weight : Tensor
453+ Filters of shape [O, N/groups, kD, kH, kW]
454+
455+ bias : Optional[Tensor]
456+ Optional bias tensor of shape [O].
457+
458+ stride : Optional[Union[int, Tuple]]
459+ The stride of the convolving kernel. Can be a single number
460+ or tuple of (sD, sH, sW).
461+
462+ padding : Optional[[Union[int, Tuple]]]
463+ Implicit paddings on both sides of the input.
464+
465+ dilation : Optional[Union[int, Tuple]]
466+ The spacing between kernel elements. Can be a single number of tuple (dD, dH, dW).
467+
468+ groups : Optional[int]
469+ Split input into a number of groups.
470+
471+ data_layout : Optional[str]
472+ Optional layout of the input and output data.
473+
474+ name : str
475+ Name hint.
476+
477+ Returns
478+ -------
479+ result : Tensor
480+ The computed result with shape [B, O, oD, oH, oW].
481+ """
482+ conv_out = _op .nn .conv3d (
411483 data = x ._expr ,
412484 weight = weight ._expr ,
413485 strides = stride ,
414486 padding = padding ,
415487 dilation = dilation ,
416488 groups = groups ,
489+ data_layout = data_layout ,
417490 )
418491 if bias is not None :
419- conv_out = _op .add (conv_out , _op .reshape (bias ._expr , [1 , - 1 , 1 , 1 ]))
492+ if data_layout == "NCDHW" :
493+ conv_out = _op .add (conv_out , _op .reshape (bias ._expr , [1 , - 1 , 1 , 1 , 1 ]))
494+ elif data_layout == "NDHWC" :
495+ conv_out = _op .add (conv_out , _op .reshape (bias ._expr , [1 , 1 , 1 , 1 , - 1 ]))
496+ else :
497+ raise NotImplementedError (f"Dont know how to handle layout { data_layout } ." )
420498
421499 return wrap_nested (conv_out , name )
422500
@@ -1427,6 +1505,7 @@ def interpolate(
14271505 align_corners : Optional [bool ] = None ,
14281506 recompute_scale_factor : Optional [bool ] = None ,
14291507 antialias : Optional [bool ] = None ,
1508+ data_layout : Optional [str ] = "NCHW" ,
14301509 name : str = "interpolate" ,
14311510):
14321511 """Resize a tensor using the specified mode.
@@ -1448,6 +1527,8 @@ def interpolate(
14481527 Recompute the scale_factor for use in interpolation.
14491528 antialias : Optional[bool]
14501529 Apply antialiasing to output.
1530+ data_layout : Optional[str]
1531+ Layout of the input and output data.
14511532 name : str
14521533 Name hint for this operation.
14531534
@@ -1460,11 +1541,14 @@ def interpolate(
14601541 assert antialias is None , "antialias is not supported."
14611542
14621543 if size is None :
1463- shape = x .shape
1464- if isinstance (scale_factor , (list , tuple )):
1465- size = tuple (int (shape [i ] * scale_factor [i ]) for i in range (2 , len (shape )))
1466- else :
1467- size = tuple (int (shape [i ] * scale_factor ) for i in range (2 , len (shape )))
1544+ size = []
1545+ for i , dim in enumerate (data_layout ):
1546+ # Only upscale spatial dimensions.
1547+ if dim not in ["N" , "C" ]:
1548+ if isinstance (scale_factor , (list , tuple )):
1549+ size .append (int (x .shape [i ] * scale_factor [len (size )]))
1550+ else :
1551+ size .append (int (x .shape [i ] * scale_factor ))
14681552
14691553 if mode .startswith ("nearest" ):
14701554 mode = "nearest_neighbor"
@@ -1480,7 +1564,11 @@ def interpolate(
14801564
14811565 return wrap_nested (
14821566 _op .image .resize2d (
1483- x ._expr , size , layout = "NCHW" , method = mode , coordinate_transformation_mode = coord_trans
1567+ x ._expr ,
1568+ size ,
1569+ layout = data_layout ,
1570+ method = mode ,
1571+ coordinate_transformation_mode = coord_trans ,
14841572 ),
14851573 name ,
14861574 )
@@ -1991,6 +2079,8 @@ def where(condition: Tensor, x1: Tensor, x2: Tensor, name: str = "where") -> Ten
19912079 result : Tensor
19922080 The result tensor.
19932081 """
2082+ # Cast condition to boolean.
2083+ condition = astype (condition , "bool" )
19942084 return wrap_nested (_op .where (condition ._expr , x1 ._expr , x2 ._expr ), name )
19952085
19962086
0 commit comments