forked from apache/mxnet
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[MXNET-1215] Allow dynamic shape exists in imperative mode (apache#13283
) * Add NDArray with lazy shape, Add boolean mask operator * Fix lints * Address comments and refactor forward pass * Add backward pass for boolean index * Fix tests.... * Add unittests * Make lint happy * Address comments * Address comments, and complete docstring * Update docstring
- Loading branch information
Showing
6 changed files
with
320 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
/*! | ||
* Copyright (c) 2018 by Contributors | ||
* \file boolean_mask-inl.h | ||
*/ | ||
|
||
#ifndef MXNET_OPERATOR_CONTRIB_BOOLEAN_MASK_INL_H_ | ||
#define MXNET_OPERATOR_CONTRIB_BOOLEAN_MASK_INL_H_ | ||
|
||
#include <dmlc/logging.h> | ||
#include <dmlc/parameter.h> | ||
#include <mxnet/operator.h> | ||
#include <mxnet/ndarray.h> | ||
#include <map> | ||
#include <vector> | ||
#include <string> | ||
#include <utility> | ||
#include <algorithm> | ||
#include "../operator_common.h" | ||
#include "../mxnet_op.h" | ||
#include "../tensor/init_op.h" | ||
#include "../mshadow_op.h" | ||
#include "../elemwise_op_common.h" | ||
|
||
namespace mxnet { | ||
namespace op { | ||
|
||
struct BooleanMaskParam : public dmlc::Parameter<BooleanMaskParam> { | ||
int axis; | ||
DMLC_DECLARE_PARAMETER(BooleanMaskParam) { | ||
DMLC_DECLARE_FIELD(axis).set_default(0) | ||
.describe("An integer that represents the axis in NDArray to mask from."); | ||
} | ||
}; | ||
|
||
template<typename xpu> | ||
inline void BooleanMaskForward(const nnvm::NodeAttrs& attrs, | ||
const OpContext &ctx, | ||
const std::vector<NDArray> &inputs, | ||
const std::vector<OpReqType> &req, | ||
const std::vector<NDArray> &outputs) { | ||
// TODO(@junrushao1994): This implementation is a proof-of-concept, | ||
// hence very slow actually. Performance should be improved in the future. | ||
CHECK_EQ(inputs.size(), 2U); | ||
CHECK_EQ(outputs.size(), 1U); | ||
const BooleanMaskParam& param = nnvm::get<BooleanMaskParam>(attrs.parsed); | ||
const int axis = param.axis; | ||
const NDArray &data = inputs[0]; | ||
const NDArray &idx = inputs[1]; | ||
const NDArray &out = outputs[0]; | ||
CHECK_EQ(axis, 0) << "Not supported yet"; | ||
CHECK_EQ(data.shape()[axis], idx.shape()[0]); | ||
CHECK_EQ(idx.shape().ndim(), 1U); | ||
// count the number of 1s in `idx`, so that we could know the output dimension | ||
size_t valid_num = 0; | ||
MSHADOW_TYPE_SWITCH(idx.dtype(), DType, { | ||
DType* idx_dptr = idx.data().dptr<DType>(); | ||
int length = idx.shape()[0]; | ||
for (int i = 0; i < length; i++) { | ||
if (idx_dptr[i]) { | ||
++valid_num; | ||
} | ||
} | ||
}); | ||
// set the output shape forcefully | ||
TShape s = data.shape(); | ||
s[axis] = valid_num; | ||
const_cast<NDArray &>(out).Init(s); | ||
// do the copy | ||
MSHADOW_TYPE_SWITCH(idx.dtype(), DType, { | ||
DType* idx_dptr = idx.data().dptr<DType>(); | ||
int length = idx.shape()[0]; | ||
mshadow::Stream<xpu> *stream = ctx.get_stream<xpu>(); | ||
for (int i = 0, j = 0; i < length; ++i) { | ||
if (idx_dptr[i]) { | ||
NDArray src = data.At(i); | ||
NDArray dst = out.At(j++); | ||
CHECK(src.shape() == dst.shape()); | ||
mxnet_op::copy(stream, dst.data(), src.data()); | ||
} | ||
} | ||
}); | ||
} | ||
|
||
template<typename xpu> | ||
inline void BooleanMaskBackward(const nnvm::NodeAttrs& attrs, | ||
const OpContext &ctx, | ||
const std::vector<NDArray> &inputs, | ||
const std::vector<OpReqType> &req, | ||
const std::vector<NDArray> &outputs) { | ||
CHECK_EQ(inputs.size(), 3U); | ||
CHECK_EQ(outputs.size(), 2U); | ||
// inputs: {ograd, data, idx} | ||
// outputs: {igrad_data, igrad_idx} | ||
const NDArray& ograd = inputs[0]; | ||
const NDArray& idx = inputs[2]; | ||
const NDArray& igrad_data = outputs[0]; | ||
MSHADOW_TYPE_SWITCH(idx.dtype(), DType, { | ||
DType* idx_dptr = idx.data().dptr<DType>(); | ||
int length = idx.shape()[0]; | ||
mshadow::Stream<xpu> *stream = ctx.get_stream<xpu>(); | ||
Fill<false>(stream, igrad_data.data(), req[0], 0); | ||
for (int i = 0, j = 0; i < length; ++i) { | ||
if (idx_dptr[i]) { | ||
NDArray src = ograd.At(j++); | ||
NDArray dst = igrad_data.At(i); | ||
CHECK(src.shape() == dst.shape()); | ||
mxnet_op::copy(stream, dst.data(), src.data()); | ||
} | ||
} | ||
}); | ||
} | ||
|
||
} // namespace op | ||
} // namespace mxnet | ||
|
||
#endif // MXNET_OPERATOR_CONTRIB_BOOLEAN_MASK_INL_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
/*! | ||
* Copyright (c) 2018 by Contributors | ||
* \file boolean_mask.cc | ||
*/ | ||
|
||
#include "./boolean_mask-inl.h" | ||
|
||
namespace mxnet { | ||
namespace op { | ||
|
||
DMLC_REGISTER_PARAMETER(BooleanMaskParam); | ||
|
||
|
||
bool BooleanMaskType(const nnvm::NodeAttrs& attrs, | ||
std::vector<int> *in_attrs, | ||
std::vector<int> *out_attrs) { | ||
CHECK_EQ(in_attrs->size(), 2); | ||
CHECK_EQ(out_attrs->size(), 1); | ||
TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); | ||
TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0)); | ||
return out_attrs->at(0) != -1; | ||
} | ||
|
||
bool BooleanMaskStorageType(const nnvm::NodeAttrs& attrs, | ||
const int dev_mask, | ||
DispatchMode* dispatch_mode, | ||
std::vector<int> *in_attrs, | ||
std::vector<int> *out_attrs) { | ||
CHECK_EQ(in_attrs->size(), 2); | ||
CHECK_EQ(out_attrs->size(), 1); | ||
for (int &attr : *in_attrs) { | ||
CHECK_EQ(attr, kDefaultStorage) << "Only default storage is supported"; | ||
} | ||
for (int &attr : *out_attrs) { | ||
attr = kDefaultStorage; | ||
} | ||
*dispatch_mode = DispatchMode::kFComputeEx; | ||
return true; | ||
} | ||
|
||
bool BooleanMaskBackStorageType(const nnvm::NodeAttrs& attrs, | ||
const int dev_mask, | ||
DispatchMode* dispatch_mode, | ||
std::vector<int> *in_attrs, | ||
std::vector<int> *out_attrs) { | ||
CHECK_EQ(in_attrs->size(), 3); | ||
CHECK_EQ(out_attrs->size(), 2); | ||
for (int &attr : *in_attrs) { | ||
CHECK_EQ(attr, kDefaultStorage) << "Only default storage is supported"; | ||
} | ||
for (int &attr : *out_attrs) { | ||
attr = kDefaultStorage; | ||
} | ||
for (size_t i = 0; i < out_attrs->size(); i++) | ||
out_attrs->at(i) = kDefaultStorage; | ||
*dispatch_mode = DispatchMode::kFComputeEx; | ||
return true; | ||
} | ||
|
||
NNVM_REGISTER_OP(_contrib_boolean_mask) | ||
.describe(R"code( | ||
Experimental CPU-only support for boolean masking. | ||
Given an n-d NDArray data, and a 1-d NDArray index, | ||
the operator produces an un-predeterminable shaped n-d NDArray out, | ||
which stands for the rows in x where the corresonding element in index is non-zero. | ||
>>> data = mx.nd.array([[1, 2, 3],[4, 5, 6],[7, 8, 9]]) | ||
>>> index = mx.nd.array([0, 1, 0]) | ||
>>> out = mx.nd.contrib.boolean_mask(data, index) | ||
>>> out | ||
[[4. 5. 6.]] | ||
<NDArray 1x3 @cpu(0)> | ||
)code" ADD_FILELINE) | ||
.set_attr_parser(ParamParser<BooleanMaskParam>) | ||
.set_num_inputs(2) | ||
.set_num_outputs(1) | ||
.set_attr<nnvm::FInferType>("FInferType", BooleanMaskType) | ||
.set_attr<FComputeEx>("FComputeEx<cpu>", BooleanMaskForward<cpu>) | ||
.set_attr<FInferStorageType>("FInferStorageType", BooleanMaskStorageType) | ||
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_contrib_boolean_mask"}) | ||
.add_argument("data", "NDArray-or-Symbol", "Data") | ||
.add_argument("index", "NDArray-or-Symbol", "Mask") | ||
.add_arguments(BooleanMaskParam::__FIELDS__()); | ||
|
||
NNVM_REGISTER_OP(_backward_contrib_boolean_mask) | ||
.set_num_inputs(3) | ||
.set_num_outputs(2) | ||
.set_attr<nnvm::TIsBackward>("TIsBackward", true) | ||
.set_attr<FInferStorageType>("FInferStorageType", BooleanMaskBackStorageType) | ||
.set_attr<FComputeEx>("FComputeEx<cpu>", BooleanMaskBackward<cpu>) | ||
.add_arguments(BooleanMaskParam::__FIELDS__()); | ||
|
||
} // namespace op | ||
} // namespace mxnet |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters