Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Squeeze op (#9700)
Browse files Browse the repository at this point in the history
* Add squeeze op

* Add unit test

* Fix lint

* User IdentityCompute directly
  • Loading branch information
reminisce authored and piiswrong committed Feb 7, 2018
1 parent a80ea3f commit c19f506
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 0 deletions.
67 changes: 67 additions & 0 deletions src/operator/tensor/matrix_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1834,6 +1834,73 @@ void StackOpBackward(const nnvm::NodeAttrs& attrs,
})
}

struct SqueezeParam : public dmlc::Parameter<SqueezeParam> {
dmlc::optional<TShape> axis;
DMLC_DECLARE_PARAMETER(SqueezeParam) {
DMLC_DECLARE_FIELD(axis)
.set_default(dmlc::optional<TShape>())
.describe("Selects a subset of the single-dimensional entries in the shape."
" If an axis is selected with shape entry greater than one, an error is raised.");
}
};

// Given a shape that may have dim size equal to 0,
// move all the zeros to the last of the shape array
// and keep the relative order of the non-zero values.
// Returns the new shape size after moving all zeros to the end.
inline uint32_t SqueezeShapeHelper(TShape* shape) {
CHECK(shape != nullptr);
uint32_t count = 0;
for (uint32_t i = 0; i < shape->ndim(); ++i) {
if ((*shape)[i] == 0) {
++count;
} else {
std::swap((*shape)[i], (*shape)[i-count]);
}
}
return shape->ndim() - count;
}

inline bool SqueezeShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
const SqueezeParam& param = nnvm::get<SqueezeParam>(attrs.parsed);
CHECK_EQ(in_attrs->size(), 1U) << "Input: [data]";
CHECK_EQ(out_attrs->size(), 1U);
const TShape& dshape = in_attrs->at(0);
const int dndim = dshape.ndim();
if (shape_is_none(dshape)) return false;
TShape oshape = dshape;
if (param.axis.has_value()) {
// preprocess axis
TShape axes = param.axis.value();
for (uint32_t i = 0; i < axes.ndim(); ++i) {
if (axes[i] < 0) {
axes[i] += dndim;
CHECK_GE(axes[i], 0)
<< "axis " << axes[i] - dndim << " is out of bounds for array of dimension " << dndim;
}
CHECK_LT(axes[i], dndim)
<< "axis " << axes[i] << " is out of bounds for array of dimension " << dndim;
CHECK_EQ(dshape[axes[i]], 1)
<< "cannot select an axis to squeeze out which has size="
<< dshape[axes[i]] << " not equal to one";
CHECK_NE(oshape[axes[i]], 0) << "duplicate value in axis";
oshape[axes[i]] = 0;
}
} else {
for (uint32_t i = 0; i < oshape.ndim(); ++i) {
if (oshape[i] == 1) oshape[i] = 0;
}
}
uint32_t oshape_size = SqueezeShapeHelper(&oshape);
if (oshape_size == 0) { // corner case when dshape is (1, 1, 1, 1)
oshape[0] = 1;
oshape_size = 1;
}
SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape(oshape.data(), oshape.data()+oshape_size));
return true;
}

} // namespace op
} // namespace mxnet
Expand Down
39 changes: 39 additions & 0 deletions src/operator/tensor/matrix_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ DMLC_REGISTER_PARAMETER(RepeatParam);
DMLC_REGISTER_PARAMETER(TileParam);
DMLC_REGISTER_PARAMETER(ReverseParam);
DMLC_REGISTER_PARAMETER(StackParam);
DMLC_REGISTER_PARAMETER(SqueezeParam);

NNVM_REGISTER_OP(Reshape)
.add_alias("reshape")
Expand Down Expand Up @@ -739,5 +740,43 @@ NNVM_REGISTER_OP(_backward_stack)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FCompute>("FCompute<cpu>", StackOpBackward<cpu>);

NNVM_REGISTER_OP(squeeze)
.describe(R"code(Remove single-dimensional entries from the shape of an array.
Same behavior of defining the output tensor shape as numpy.squeeze for the most of cases.
See the following note for exception.
Examples::
data = [[[0], [1], [2]]]
squeeze(data) = [0, 1, 2]
squeeze(data, axis=0) = [[0], [1], [2]]
squeeze(data, axis=2) = [[0, 1, 2]]
squeeze(data, axis=(0, 2)) = [0, 1, 2]
.. Note::
The output of this operator will keep at least one dimension not removed. For example,
squeeze([[[4]]]) = [4], while in numpy.squeeze, the output will become a scalar.
)code")
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr_parser(ParamParser<SqueezeParam>)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"data"};
})
.set_attr<nnvm::FInferShape>("FInferShape", SqueezeShape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FCompute>("FCompute<cpu>", UnaryOp::IdentityCompute<cpu>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_squeeze"})
.add_argument("data", "NDArray-or-Symbol[]", "data to squeeze")
.add_arguments(StackParam::__FIELDS__());

NNVM_REGISTER_OP(_backward_squeeze)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr_parser(ParamParser<SqueezeParam>)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FCompute>("FCompute<cpu>", UnaryOp::IdentityCompute<cpu>);

} // namespace op
} // namespace mxnet
7 changes: 7 additions & 0 deletions src/operator/tensor/matrix_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -198,5 +198,12 @@ NNVM_REGISTER_OP(stack)

NNVM_REGISTER_OP(_backward_stack)
.set_attr<FCompute>("FCompute<gpu>", StackOpBackward<gpu>);

NNVM_REGISTER_OP(squeeze)
.set_attr<FCompute>("FCompute<gpu>", UnaryOp::IdentityCompute<gpu>);

NNVM_REGISTER_OP(_backward_squeeze)
.set_attr<FCompute>("FCompute<gpu>", UnaryOp::IdentityCompute<gpu>);

} // namespace op
} // namespace mxnet
34 changes: 34 additions & 0 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4881,6 +4881,40 @@ def test_float16_min_max():
assert np.finfo('float16').max == mx.nd.max(a).asscalar()


def test_squeeze_op():
def check_squeeze_op(shape, axis=None):
data = mx.nd.random.uniform(low=-10.0, high=10.0, shape=shape)
if axis is None:
out = mx.nd.squeeze(data).asnumpy()
out_expected = np.squeeze(data.asnumpy())
else:
out = mx.nd.squeeze(data, axis=axis).asnumpy()
out_expected = np.squeeze(data.asnumpy(), axis=axis)
if out.shape == (1,): # as an exception (1, 1, 1) will be squeezed to (1,)
out_expected = np.squeeze(data.asnumpy(), axis=tuple([i for i in range(1, len(shape))]))
assert same(out, out_expected)

# check forward
check_squeeze_op((1, 5, 1, 3, 1), 0)
check_squeeze_op((1, 5, 1, 3, 1), 2)
check_squeeze_op((1, 5, 1, 3, 1), 4)
check_squeeze_op((1, 5, 1, 3, 1), (0, 4))
check_squeeze_op((1, 5, 1, 3, 1), (0, 2, 4))
check_squeeze_op((1, 5, 1, 3, 1))
check_squeeze_op((1, 1, 1, 1))

# check gradient
data = mx.symbol.Variable('data')
shape = (1, 2, 1, 3, 1)
data_tmp = np.ones(shape)
test = mx.sym.squeeze(data)
check_numeric_gradient(test, [data_tmp])
test = mx.sym.squeeze(data, axis=2)
check_numeric_gradient(test, [data_tmp])
test = mx.sym.squeeze(data, axis=(2, 4))
check_numeric_gradient(test, [data_tmp])


if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit c19f506

Please sign in to comment.