@@ -168,81 +168,7 @@ def _impl(inputs, attr, params):
168168 custom_check = _dimension_constraint ())(inputs , attr )
169169 return _impl
170170
171- def _conv ():
172- def _impl (inputs , attr , params ):
173- attr ['data_format' ] = attr ['data_format' ].decode ("utf-8" )
174-
175- # Extract kernel shape from params
176- conv_param_weights = params [inputs [1 ].list_output_names ()[0 ]]
177-
178- if attr ['data_format' ] == 'NHWC' :
179- attr ['kernel_shape' ] = (conv_param_weights .shape [0 ], conv_param_weights .shape [1 ])
180- attr ['channels' ] = conv_param_weights .shape [3 ]
181- if 'dilations' in attr :
182- attr ['dilations' ] = (attr ['dilations' ][0 ], attr ['dilations' ][1 ])
183- elif attr ['data_format' ] == 'NCHW' :
184- attr ['kernel_shape' ] = (conv_param_weights .shape [2 ], conv_param_weights .shape [3 ])
185- attr ['channels' ] = conv_param_weights .shape [1 ]
186- if 'dilations' in attr :
187- attr ['dilations' ] = (attr ['dilations' ][2 ], attr ['dilations' ][3 ])
188- else :
189- raise TypeError ("Unsupported data format type : {}" .format (attr ['data_format' ]))
190-
191- # Fix strides
192- attr ['strides' ] = (attr ['strides' ][1 ], attr ['strides' ][2 ])
193-
194- # Fix padding
195- input_shapes = attr ['_input_shapes' ][inputs [0 ]]
196- attr ['padding' ] = attr ['padding' ].decode ("utf-8" )
197-
198- if attr ['padding' ] == 'VALID' :
199- attr ['padding' ] = [0 , 0 ]
200- elif attr ['padding' ] == 'SAME' :
201- stride_h , stride_w = attr ['strides' ]
202- kernel_h , kernel_w = attr ['kernel_shape' ]
203- if attr ['data_format' ] == 'NHWC' :
204- in_h = input_shapes [0 ][1 ]
205- in_w = input_shapes [0 ][2 ]
206- else :
207- in_h = input_shapes [0 ][2 ]
208- in_w = input_shapes [0 ][3 ]
209-
210- pad_v = _get_pad_pair (in_h , kernel_h , stride_h )
211- pad_h = _get_pad_pair (in_w , kernel_w , stride_w )
212-
213- if attr ['data_format' ] == 'NHWC' :
214- inputs [0 ] = _sym .pad (data = inputs [0 ],
215- pad_width = ((0 , 0 ),
216- (pad_v [0 ], pad_v [1 ]),
217- (pad_h [0 ], pad_h [1 ]),
218- (0 , 0 )))
219- else :
220- inputs [0 ] = _sym .pad (data = inputs [0 ],
221- pad_width = ((0 , 0 ),
222- (0 , 0 ),
223- (pad_v [0 ], pad_v [1 ]),
224- (pad_h [0 ], pad_h [1 ])))
225-
226- attr ['padding' ] = [0 , 0 ]
227-
228- else :
229- raise TypeError ("Unsupported padding type : {}" .format (attr ['padding' ]))
230-
231- if 'kernel_layout' not in attr :
232- attr ['kernel_layout' ] = 'HWIO' if attr ['data_format' ] == 'NHWC' else 'OIHW'
233-
234- return AttrCvt (
235- op_name = _dimension_picker ('conv' ),
236- transforms = {
237- 'kernel_shape' : 'kernel_size' ,
238- 'data_format' : 'layout' ,
239- 'dilations' : ('dilation' , (0 , 0 )),
240- 'group' : ('groups' , 1 )},
241- extras = {'use_bias' : len (inputs ) == 3 },
242- custom_check = _dimension_constraint ())(inputs , attr )
243- return _impl
244-
245- def _depthwise_conv ():
171+ def _conv (opname ):
246172 def _impl (inputs , attr , params ):
247173 attr ['data_format' ] = attr ['data_format' ].decode ("utf-8" )
248174 input_shapes = attr ['_input_shapes' ][inputs [0 ]]
@@ -253,24 +179,33 @@ def _impl(inputs, attr, params):
253179 if attr ['data_format' ] == 'NHWC' :
254180 kernel_h , kernel_w , _ , depth_mult = conv_param_weights .shape
255181 attr ['kernel_shape' ] = (conv_param_weights .shape [0 ], conv_param_weights .shape [1 ])
256- attr ['channels' ] = input_shapes [0 ][3 ] * depth_mult
182+ if opname == 'conv' :
183+ attr ['channels' ] = conv_param_weights .shape [3 ]
184+ else :
185+ attr ['channels' ] = input_shapes [0 ][3 ] * depth_mult
186+
257187 if 'dilations' in attr :
258188 attr ['dilations' ] = (attr ['dilations' ][0 ], attr ['dilations' ][1 ])
259189 elif attr ['data_format' ] == 'NCHW' :
260190 depth_mult , _ , kernel_h , kernel_w = conv_param_weights .shape
261191 attr ['kernel_shape' ] = (conv_param_weights .shape [2 ], conv_param_weights .shape [3 ])
262- attr ['channels' ] = input_shapes [0 ][1 ] * depth_mult
192+ if opname == 'conv' :
193+ attr ['channels' ] = conv_param_weights .shape [1 ]
194+ else :
195+ attr ['channels' ] = input_shapes [0 ][1 ] * depth_mult
196+
263197 if 'dilations' in attr :
264198 attr ['dilations' ] = (attr ['dilations' ][2 ], attr ['dilations' ][3 ])
265199 else :
266200 raise TypeError ("Unsupported data format type : {}" .format (attr ['data_format' ]))
267201
202+
203+ if opname == 'depthwise' :
204+ attr ['groups' ] = attr ['channels' ]
205+
268206 # Fix strides
269207 attr ['strides' ] = (attr ['strides' ][1 ], attr ['strides' ][2 ])
270208
271- # Fix groups
272- attr ['groups' ] = attr ['channels' ]
273-
274209 # Fix padding
275210 attr ['padding' ] = attr ['padding' ].decode ("utf-8" )
276211
@@ -308,7 +243,10 @@ def _impl(inputs, attr, params):
308243 raise TypeError ("Unsupported padding type : {}" .format (attr ['padding' ]))
309244
310245 if 'kernel_layout' not in attr :
311- attr ['kernel_layout' ] = 'HWOI' if attr ['data_format' ] == 'NHWC' else 'OIHW'
246+ if opname == 'conv' :
247+ attr ['kernel_layout' ] = 'HWIO' if attr ['data_format' ] == 'NHWC' else 'OIHW'
248+ else :
249+ attr ['kernel_layout' ] = 'HWOI' if attr ['data_format' ] == 'NHWC' else 'OIHW'
312250
313251 return AttrCvt (
314252 op_name = _dimension_picker ('conv' ),
@@ -687,7 +625,7 @@ def _impl(inputs, in_state_c, in_state_h, attr, params):
687625 'CheckNumerics' : _check_numerics (),
688626 'Concat' : _concat (),
689627 'ConcatV2' : _concatV2 (),
690- 'Conv2D' : _conv (),
628+ 'Conv2D' : _conv ('conv' ),
691629 'DecodeJpeg' : _decode_image (),
692630 'ExpandDims' : _expand_dims (),
693631 'Identity' : _identity (),
@@ -704,7 +642,7 @@ def _impl(inputs, in_state_c, in_state_h, attr, params):
704642 'Squeeze' : _squeeze (),
705643 'FusedBatchNorm' : _fused_batch_norm (),
706644 'Relu6' : _relu6 (),
707- 'DepthwiseConv2dNative' : _depthwise_conv ( ),
645+ 'DepthwiseConv2dNative' : _conv ( 'depthwise' ),
708646 'Shape' : _shape (),
709647 'Sigmoid' : AttrCvt ('sigmoid' ),
710648 'Fill' : _fill (),
0 commit comments