Skip to content

Commit

Permalink
added global avg pool support per PR: ethereon#123 based on Issue: et…
Browse files Browse the repository at this point in the history
  • Loading branch information
sjain-stanford committed Jan 6, 2019
1 parent 4171311 commit af2967b
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 2 deletions.
15 changes: 13 additions & 2 deletions kaffe/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ class LayerAdapter(object):
def __init__(self, layer, kind):
self.layer = layer
self.kind = kind
self._input_shape = None

@property
def parameters(self):
Expand Down Expand Up @@ -130,12 +131,22 @@ def get_kernel_value(scalar, repeated, idx, default=None):
raise ValueError('Unable to determine kernel parameter!')
return default

def set_input_shape(self, input_shape):
self._input_shape = input_shape

@property
def kernel_parameters(self):
assert self.kind in (NodeKind.Convolution, NodeKind.Pooling)
params = self.parameters
k_h = self.get_kernel_value(params.kernel_h, params.kernel_size, 0)
k_w = self.get_kernel_value(params.kernel_w, params.kernel_size, 1)
global_pool = hasattr(params, 'global_pooling')
if params.kernel_size:
k_h = self.get_kernel_value(params.kernel_h, params.kernel_size, 0)
k_w = self.get_kernel_value(params.kernel_w, params.kernel_size, 1)
elif self._input_shape:
k_h, k_w = [self._input_shape.height, self._input_shape.width]
else: #errors out in get_kernel_value function
k_h = self.get_kernel_value(params.kernel_h, params.kernel_size, 0)
k_w = self.get_kernel_value(params.kernel_w, params.kernel_size, 1)
s_h = self.get_kernel_value(params.stride_h, params.stride, 0, default=1)
s_w = self.get_kernel_value(params.stride_w, params.stride, 1, default=1)
p_h = self.get_kernel_value(params.pad_h, params.pad, 0, default=0)
Expand Down
1 change: 1 addition & 0 deletions kaffe/shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def get_filter_output_shape(i_h, i_w, params, round_func):
def get_strided_kernel_output_shape(node, round_func):
assert node.layer is not None
input_shape = node.get_only_parent().output_shape
node.layer.set_input_shape(input_shape)
o_h, o_w = get_filter_output_shape(input_shape.height, input_shape.width,
node.layer.kernel_parameters, round_func)
params = node.layer.parameters
Expand Down

0 comments on commit af2967b

Please sign in to comment.