-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
bilinear upsampling and resizing #9303
Changes from all commits
0855745
4b03e6a
b7c10bd
104f209
73f0311
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1122,17 +1122,20 @@ def permute_dimensions(x, pattern): | |
return C.transpose(x, axis) | ||
|
||
|
||
def resize_images(x, height_factor, width_factor, data_format): | ||
if data_format == 'channels_first': | ||
output = repeat_elements(x, height_factor, axis=2) | ||
output = repeat_elements(output, width_factor, axis=3) | ||
return output | ||
elif data_format == 'channels_last': | ||
output = repeat_elements(x, height_factor, axis=1) | ||
output = repeat_elements(output, width_factor, axis=2) | ||
return output | ||
def resize_images(x, height_factor, width_factor, data_format, interpolation='nearest'): | ||
if interpolation == 'nearest': | ||
if data_format == 'channels_first': | ||
output = repeat_elements(x, height_factor, axis=2) | ||
output = repeat_elements(output, width_factor, axis=3) | ||
return output | ||
elif data_format == 'channels_last': | ||
output = repeat_elements(x, height_factor, axis=1) | ||
output = repeat_elements(output, width_factor, axis=2) | ||
return output | ||
else: | ||
raise ValueError('CNTK Backend: Invalid data_format:', data_format) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. in the tests k.backend() == |
||
else: | ||
raise ValueError('CNTK Backend: Invalid data_format:', data_format) | ||
raise NotImplementedError | ||
|
||
|
||
def resize_volumes(x, depth_factor, height_factor, width_factor, data_format): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1905,14 +1905,15 @@ def permute_dimensions(x, pattern): | |
return tf.transpose(x, perm=pattern) | ||
|
||
|
||
def resize_images(x, height_factor, width_factor, data_format): | ||
def resize_images(x, height_factor, width_factor, data_format, interpolation='nearest'): | ||
"""Resizes the images contained in a 4D tensor. | ||
|
||
# Arguments | ||
x: Tensor or variable to resize. | ||
height_factor: Positive integer. | ||
width_factor: Positive integer. | ||
data_format: string, `"channels_last"` or `"channels_first"`. | ||
interpolation: A string, one of `nearest` or `bilinear`. | ||
|
||
# Returns | ||
A tensor. | ||
|
@@ -1925,7 +1926,12 @@ def resize_images(x, height_factor, width_factor, data_format): | |
new_shape = tf.shape(x)[2:] | ||
new_shape *= tf.constant(np.array([height_factor, width_factor]).astype('int32')) | ||
x = permute_dimensions(x, [0, 2, 3, 1]) | ||
x = tf.image.resize_nearest_neighbor(x, new_shape) | ||
if interpolation == 'nearest': | ||
x = tf.image.resize_nearest_neighbor(x, new_shape) | ||
elif interpolation == 'bilinear': | ||
x = tf.image.resize_bilinear(x, new_shape) | ||
else: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this value error is good but do in other places |
||
raise ValueError('interpolation should be one of "nearest" or "bilinear".') | ||
x = permute_dimensions(x, [0, 3, 1, 2]) | ||
x.set_shape((None, None, original_shape[2] * height_factor if original_shape[2] is not None else None, | ||
original_shape[3] * width_factor if original_shape[3] is not None else None)) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -884,24 +884,42 @@ def repeat_elements(x, rep, axis): | |
return y | ||
|
||
|
||
def resize_images(x, height_factor, width_factor, data_format): | ||
def resize_images(x, height_factor, width_factor, data_format, interpolation='nearest'): | ||
"""Resize the images contained in a 4D tensor of shape | ||
- [batch, channels, height, width] (for 'channels_first' data_format) | ||
- [batch, height, width, channels] (for 'channels_last' data_format) | ||
by a factor of (height_factor, width_factor). Both factors should be | ||
positive integers. | ||
""" | ||
if data_format == 'channels_first': | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add comments to make easier to read |
||
output = repeat_elements(x, height_factor, axis=2) | ||
output = repeat_elements(output, width_factor, axis=3) | ||
return output | ||
axis_1 = 2 | ||
axis_2 = 3 | ||
elif data_format == 'channels_last': | ||
output = repeat_elements(x, height_factor, axis=1) | ||
output = repeat_elements(output, width_factor, axis=2) | ||
return output | ||
axis_1 = 1 | ||
axis_2 = 2 | ||
else: | ||
raise ValueError('Invalid data_format:', data_format) | ||
|
||
if interpolation == 'nearest': | ||
output = repeat_elements(x, height_factor, axis=axis_1) | ||
output = repeat_elements(output, width_factor, axis=axis_2) | ||
elif interpolation == 'bilinear': | ||
ratio = height_factor / width_factor | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For theano, ratio = height_factor // width_factor There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how should this be done? Am I correct that the values might be tensors and not something that can be cast to |
||
th_padding = _preprocess_padding('same') | ||
output = theano.tensor.nnet.abstract_conv.bilinear_upsampling( | ||
x, ratio=ratio) | ||
if hasattr(x, '_keras_shape'): | ||
output._keras_shape = list(x._keras_shape) | ||
repeat_dim_1 = x._keras_shape[axis_1] | ||
repeat_dim_2 = x._keras_shape[axis_2] | ||
output._keras_shape[axis_1] = repeat_dim_1 * rep | ||
output._keras_shape[axis_2] = repeat_dim_2 * rep | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unresolved reference There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hmm... it has been a while since I looked at this... what was rep supposed to be? |
||
output._keras_shape = tuple(output._keras_shape) | ||
else: | ||
raise ValueError('interpolation should be one of "nearest" or "bilinear".') | ||
|
||
return output | ||
|
||
|
||
def resize_volumes(x, depth_factor, height_factor, width_factor, data_format): | ||
"""Resize the volume contained in a 5D tensor of shape | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1565,6 +1565,7 @@ class UpSampling2D(Layer): | |
It defaults to the `image_data_format` value found in your | ||
Keras config file at `~/.keras/keras.json`. | ||
If you never set it, then it will be "channels_last". | ||
interpolation: A string, one of `nearest` or `bilinear`. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. other doc string stuff has default marked |
||
|
||
# Input shape | ||
4D tensor with shape: | ||
|
@@ -1582,11 +1583,14 @@ class UpSampling2D(Layer): | |
""" | ||
|
||
@interfaces.legacy_upsampling2d_support | ||
def __init__(self, size=(2, 2), data_format=None, **kwargs): | ||
def __init__(self, size=(2, 2), data_format=None, interpolation='nearest', **kwargs): | ||
super(UpSampling2D, self).__init__(**kwargs) | ||
self.data_format = conv_utils.normalize_data_format(data_format) | ||
self.size = conv_utils.normalize_tuple(size, 2, 'size') | ||
self.input_spec = InputSpec(ndim=4) | ||
if interpolation not in ['nearest', 'bilinear']: | ||
raise ValueError('interpolation should be one of "nearest" or "bilinear".') | ||
self.interpolation = interpolation | ||
|
||
def compute_output_shape(self, input_shape): | ||
if self.data_format == 'channels_first': | ||
|
@@ -1606,7 +1610,7 @@ def compute_output_shape(self, input_shape): | |
|
||
def call(self, inputs): | ||
return K.resize_images(inputs, self.size[0], self.size[1], | ||
self.data_format) | ||
self.data_format, self.interpolation) | ||
|
||
def get_config(self): | ||
config = {'size': self.size, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -734,6 +734,43 @@ def test_upsampling_2d(): | |
assert_allclose(np_output, expected_out) | ||
|
||
|
||
@pytest.mark.skipif((K.backend() == 'cntk'), | ||
reason="cntk does not support it yet") | ||
def test_upsampling_2d_bilinear(): | ||
num_samples = 2 | ||
stack_size = 2 | ||
input_num_row = 11 | ||
input_num_col = 12 | ||
|
||
for data_format in ['channels_first', 'channels_last']: | ||
if data_format == 'channels_first': | ||
inputs = np.random.rand(num_samples, stack_size, input_num_row, | ||
input_num_col) | ||
else: # tf | ||
inputs = np.random.rand(num_samples, input_num_row, input_num_col, | ||
stack_size) | ||
|
||
# basic test | ||
layer_test(convolutional.UpSampling2D, | ||
kwargs={'size': (2, 2), 'data_format': data_format, 'interpolation': 'bilinear'}, | ||
input_shape=inputs.shape) | ||
|
||
for length_row in [2]: | ||
for length_col in [2, 3]: | ||
layer = convolutional.UpSampling2D( | ||
size=(length_row, length_col), | ||
data_format=data_format) | ||
layer.build(inputs.shape) | ||
outputs = layer(K.variable(inputs)) | ||
np_output = K.eval(outputs) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this tests need to be more rigorous. only checking shape |
||
if data_format == 'channels_first': | ||
assert np_output.shape[2] == length_row * input_num_row | ||
assert np_output.shape[3] == length_col * input_num_col | ||
else: # tf | ||
assert np_output.shape[1] == length_row * input_num_row | ||
assert np_output.shape[2] == length_col * input_num_col | ||
|
||
|
||
@pytest.mark.skipif((K.backend() == 'cntk'), | ||
reason="cntk does not support it yet") | ||
def test_upsampling_3d(): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe tell the user the valid data_formats