diff --git a/kaffe/layers.py b/kaffe/layers.py index c3c5955..62b2a68 100644 --- a/kaffe/layers.py +++ b/kaffe/layers.py @@ -103,6 +103,7 @@ class LayerAdapter(object): def __init__(self, layer, kind): self.layer = layer self.kind = kind + self._input_shape = None @property def parameters(self): @@ -130,12 +131,22 @@ def get_kernel_value(scalar, repeated, idx, default=None): raise ValueError('Unable to determine kernel parameter!') return default + def set_input_shape(self, input_shape): + self._input_shape = input_shape + @property def kernel_parameters(self): assert self.kind in (NodeKind.Convolution, NodeKind.Pooling) params = self.parameters - k_h = self.get_kernel_value(params.kernel_h, params.kernel_size, 0) - k_w = self.get_kernel_value(params.kernel_w, params.kernel_size, 1) + global_pool = hasattr(params, 'global_pooling') + if params.kernel_size: + k_h = self.get_kernel_value(params.kernel_h, params.kernel_size, 0) + k_w = self.get_kernel_value(params.kernel_w, params.kernel_size, 1) + elif self._input_shape: + k_h, k_w = [self._input_shape.height, self._input_shape.width] + else: #errors out in get_kernel_value function + k_h = self.get_kernel_value(params.kernel_h, params.kernel_size, 0) + k_w = self.get_kernel_value(params.kernel_w, params.kernel_size, 1) s_h = self.get_kernel_value(params.stride_h, params.stride, 0, default=1) s_w = self.get_kernel_value(params.stride_w, params.stride, 1, default=1) p_h = self.get_kernel_value(params.pad_h, params.pad, 0, default=0) diff --git a/kaffe/shapes.py b/kaffe/shapes.py index a70ff14..c7eb640 100644 --- a/kaffe/shapes.py +++ b/kaffe/shapes.py @@ -15,6 +15,7 @@ def get_filter_output_shape(i_h, i_w, params, round_func): def get_strided_kernel_output_shape(node, round_func): assert node.layer is not None input_shape = node.get_only_parent().output_shape + node.layer.set_input_shape(input_shape) o_h, o_w = get_filter_output_shape(input_shape.height, input_shape.width, node.layer.kernel_parameters, round_func) params = node.layer.parameters