2020
2121
2222# Reference: cutlass/tools/util/include/cutlass/util/reference/host/convolution.h
23- def conv2d_backward_weight_nchw_python (dy_np , x_np , kernel_size , stride , padding ):
23+ def conv2d_backward_weight_nchw_python (
24+ dy_np , x_np , kernel_size , stride , padding , groups = 1 , channels = None
25+ ):
2426 """Gradient of the conv2d op with respect to weight, in NCHW layout.
2527
2628 Parameters
@@ -51,17 +53,34 @@ def conv2d_backward_weight_nchw_python(dy_np, x_np, kernel_size, stride, padding
5153 R , S = kernel_size
5254 pad_h , pad_w = padding
5355 stride_h , stride_w = stride
54- dw = np .zeros ((K , C , R , S )).astype (dy_np .dtype )
56+ is_depth_wise = C == K and C == groups
57+
58+ if is_depth_wise :
59+ assert channels == groups , "Only channel_mult == 1 supported for now."
60+ dw = np .zeros ((K , 1 , R , S )).astype (dy_np .dtype )
61+ else :
62+ assert groups == 1 , "General grouped conv2d not supported for now."
63+ dw = np .zeros ((K , C , R , S )).astype (dy_np .dtype )
5564
5665 for k in range (K ):
5766 for r in range (R ):
5867 for s in range (S ):
59- for c in range (C ):
68+ for c in range (dw . shape [ 1 ] ):
6069 acc = 0
6170 for n in range (N ):
6271 for p in range (P ):
6372 for q in range (Q ):
64- coord = (n , c , p * stride_h - pad_h + r , q * stride_w - pad_w + s )
73+ if not is_depth_wise :
74+ in_c = c
75+ else :
76+ in_c = k
77+
78+ coord = (
79+ n ,
80+ in_c ,
81+ p * stride_h - pad_h + r ,
82+ q * stride_w - pad_w + s ,
83+ )
6584
6685 if (
6786 coord [2 ] < H
@@ -76,7 +95,9 @@ def conv2d_backward_weight_nchw_python(dy_np, x_np, kernel_size, stride, padding
7695 return dw
7796
7897
79- def conv2d_backward_weight_python (dy_np , x_np , kernel_size , stride , padding , layout = "NCHW" ):
98+ def conv2d_backward_weight_python (
99+ dy_np , x_np , kernel_size , stride , padding , layout = "NCHW" , groups = 1 , channels = None
100+ ):
80101 """Gradient of the conv2d op with respect to weight, in NCHW or NHWC layout.
81102
82103 Parameters
@@ -99,20 +120,30 @@ def conv2d_backward_weight_python(dy_np, x_np, kernel_size, stride, padding, lay
99120 layout: string
100121 Layout of dy_np and x_np
101122
123+ groups: int
124+ Number of groups for grouped convolution.
125+
126+ channels : int
127+ Number of output channels of this convolution.
128+
102129 Returns
103130 -------
104131 dw_np : np.ndarray
105132 Tensor of shape [num_filter, in_channel, filter_height, filter_width] for NCHW layout,
106133 [num_filter, filter_height, filter_width, in_channel] for NHWC layout.
107134 """
108135 if layout == "NCHW" :
109- return conv2d_backward_weight_nchw_python (dy_np , x_np , kernel_size , stride , padding )
136+ return conv2d_backward_weight_nchw_python (
137+ dy_np , x_np , kernel_size , stride , padding , groups , channels
138+ )
110139
111140 dw_np_oihw = conv2d_backward_weight_nchw_python (
112141 np .transpose (dy_np , [0 , 3 , 1 , 2 ]),
113142 np .transpose (x_np , [0 , 3 , 1 , 2 ]),
114143 kernel_size ,
115144 stride ,
116145 padding ,
146+ groups ,
147+ channels ,
117148 )
118149 return np .transpose (dw_np_oihw , [0 , 2 , 3 , 1 ])
0 commit comments