Skip to content

Commit 6d68a18

Browse files
committed
[NNVM][TENSORFLOW] Some cleanup by combining depthwise with convolution.
1 parent 7ee9cca commit 6d68a18

File tree

1 file changed

+21
-83
lines changed

1 file changed

+21
-83
lines changed

nnvm/python/nnvm/frontend/tensorflow.py

Lines changed: 21 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -168,81 +168,7 @@ def _impl(inputs, attr, params):
168168
custom_check=_dimension_constraint())(inputs, attr)
169169
return _impl
170170

171-
def _conv():
172-
def _impl(inputs, attr, params):
173-
attr['data_format'] = attr['data_format'].decode("utf-8")
174-
175-
# Extract kernel shape from params
176-
conv_param_weights = params[inputs[1].list_output_names()[0]]
177-
178-
if attr['data_format'] == 'NHWC':
179-
attr['kernel_shape'] = (conv_param_weights.shape[0], conv_param_weights.shape[1])
180-
attr['channels'] = conv_param_weights.shape[3]
181-
if 'dilations' in attr:
182-
attr['dilations'] = (attr['dilations'][0], attr['dilations'][1])
183-
elif attr['data_format'] == 'NCHW':
184-
attr['kernel_shape'] = (conv_param_weights.shape[2], conv_param_weights.shape[3])
185-
attr['channels'] = conv_param_weights.shape[1]
186-
if 'dilations' in attr:
187-
attr['dilations'] = (attr['dilations'][2], attr['dilations'][3])
188-
else:
189-
raise TypeError("Unsupported data format type : {}".format(attr['data_format']))
190-
191-
# Fix strides
192-
attr['strides'] = (attr['strides'][1], attr['strides'][2])
193-
194-
# Fix padding
195-
input_shapes = attr['_input_shapes'][inputs[0]]
196-
attr['padding'] = attr['padding'].decode("utf-8")
197-
198-
if attr['padding'] == 'VALID':
199-
attr['padding'] = [0, 0]
200-
elif attr['padding'] == 'SAME':
201-
stride_h, stride_w = attr['strides']
202-
kernel_h, kernel_w = attr['kernel_shape']
203-
if attr['data_format'] == 'NHWC':
204-
in_h = input_shapes[0][1]
205-
in_w = input_shapes[0][2]
206-
else:
207-
in_h = input_shapes[0][2]
208-
in_w = input_shapes[0][3]
209-
210-
pad_v = _get_pad_pair(in_h, kernel_h, stride_h)
211-
pad_h = _get_pad_pair(in_w, kernel_w, stride_w)
212-
213-
if attr['data_format'] == 'NHWC':
214-
inputs[0] = _sym.pad(data=inputs[0],
215-
pad_width=((0, 0),
216-
(pad_v[0], pad_v[1]),
217-
(pad_h[0], pad_h[1]),
218-
(0, 0)))
219-
else:
220-
inputs[0] = _sym.pad(data=inputs[0],
221-
pad_width=((0, 0),
222-
(0, 0),
223-
(pad_v[0], pad_v[1]),
224-
(pad_h[0], pad_h[1])))
225-
226-
attr['padding'] = [0, 0]
227-
228-
else:
229-
raise TypeError("Unsupported padding type : {}".format(attr['padding']))
230-
231-
if 'kernel_layout' not in attr:
232-
attr['kernel_layout'] = 'HWIO' if attr['data_format'] == 'NHWC' else 'OIHW'
233-
234-
return AttrCvt(
235-
op_name=_dimension_picker('conv'),
236-
transforms={
237-
'kernel_shape': 'kernel_size',
238-
'data_format': 'layout',
239-
'dilations': ('dilation', (0, 0)),
240-
'group': ('groups', 1)},
241-
extras={'use_bias': len(inputs) == 3},
242-
custom_check=_dimension_constraint())(inputs, attr)
243-
return _impl
244-
245-
def _depthwise_conv():
171+
def _conv(opname):
246172
def _impl(inputs, attr, params):
247173
attr['data_format'] = attr['data_format'].decode("utf-8")
248174
input_shapes = attr['_input_shapes'][inputs[0]]
@@ -253,24 +179,33 @@ def _impl(inputs, attr, params):
253179
if attr['data_format'] == 'NHWC':
254180
kernel_h, kernel_w, _, depth_mult = conv_param_weights.shape
255181
attr['kernel_shape'] = (conv_param_weights.shape[0], conv_param_weights.shape[1])
256-
attr['channels'] = input_shapes[0][3] * depth_mult
182+
if opname == 'conv':
183+
attr['channels'] = conv_param_weights.shape[3]
184+
else:
185+
attr['channels'] = input_shapes[0][3] * depth_mult
186+
257187
if 'dilations' in attr:
258188
attr['dilations'] = (attr['dilations'][0], attr['dilations'][1])
259189
elif attr['data_format'] == 'NCHW':
260190
depth_mult, _, kernel_h, kernel_w = conv_param_weights.shape
261191
attr['kernel_shape'] = (conv_param_weights.shape[2], conv_param_weights.shape[3])
262-
attr['channels'] = input_shapes[0][1] * depth_mult
192+
if opname == 'conv':
193+
attr['channels'] = conv_param_weights.shape[1]
194+
else:
195+
attr['channels'] = input_shapes[0][1] * depth_mult
196+
263197
if 'dilations' in attr:
264198
attr['dilations'] = (attr['dilations'][2], attr['dilations'][3])
265199
else:
266200
raise TypeError("Unsupported data format type : {}".format(attr['data_format']))
267201

202+
203+
if opname == 'depthwise':
204+
attr['groups'] = attr['channels']
205+
268206
# Fix strides
269207
attr['strides'] = (attr['strides'][1], attr['strides'][2])
270208

271-
# Fix groups
272-
attr['groups'] = attr['channels']
273-
274209
# Fix padding
275210
attr['padding'] = attr['padding'].decode("utf-8")
276211

@@ -308,7 +243,10 @@ def _impl(inputs, attr, params):
308243
raise TypeError("Unsupported padding type : {}".format(attr['padding']))
309244

310245
if 'kernel_layout' not in attr:
311-
attr['kernel_layout'] = 'HWOI' if attr['data_format'] == 'NHWC' else 'OIHW'
246+
if opname == 'conv':
247+
attr['kernel_layout'] = 'HWIO' if attr['data_format'] == 'NHWC' else 'OIHW'
248+
else:
249+
attr['kernel_layout'] = 'HWOI' if attr['data_format'] == 'NHWC' else 'OIHW'
312250

313251
return AttrCvt(
314252
op_name=_dimension_picker('conv'),
@@ -687,7 +625,7 @@ def _impl(inputs, in_state_c, in_state_h, attr, params):
687625
'CheckNumerics' : _check_numerics(),
688626
'Concat' : _concat(),
689627
'ConcatV2' : _concatV2(),
690-
'Conv2D' : _conv(),
628+
'Conv2D' : _conv('conv'),
691629
'DecodeJpeg' : _decode_image(),
692630
'ExpandDims' : _expand_dims(),
693631
'Identity' : _identity(),
@@ -704,7 +642,7 @@ def _impl(inputs, in_state_c, in_state_h, attr, params):
704642
'Squeeze' : _squeeze(),
705643
'FusedBatchNorm' : _fused_batch_norm(),
706644
'Relu6' : _relu6(),
707-
'DepthwiseConv2dNative' : _depthwise_conv(),
645+
'DepthwiseConv2dNative' : _conv('depthwise'),
708646
'Shape' : _shape(),
709647
'Sigmoid' : AttrCvt('sigmoid'),
710648
'Fill' : _fill(),

0 commit comments

Comments
 (0)