Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix redefinition error
Browse files Browse the repository at this point in the history
  • Loading branch information
cassiniXu authored and cassiniXu committed Nov 19, 2019
1 parent 04673bf commit a7d0c54
Show file tree
Hide file tree
Showing 5 changed files with 4 additions and 179 deletions.
2 changes: 2 additions & 0 deletions python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero', 'shares_memory', 'may_share_memory', 'diff', 'resize',
'nan_to_num', 'where']


@set_module('mxnet.ndarray.numpy')
def zeros(shape, dtype=_np.float32, order='C', ctx=None):
"""Return a new array of given shape and type, filled with zeros.
Expand Down Expand Up @@ -5388,6 +5389,7 @@ def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None, **kwargs):
def where(condition, x=None, y=None):
"""where(condition, [x, y])
Return elements chosen from `x` or `y` depending on `condition`.
.. note::
When only `condition` is provided, this function is a shorthand for
``np.asarray(condition).nonzero()``. The rest of this documentation
Expand Down
1 change: 1 addition & 0 deletions python/mxnet/numpy_dispatch_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def _run_with_array_ufunc_proto(*args, **kwargs):
'shares_memory',
'may_share_memory',
'diff',
'resize',
'where',
]

Expand Down
154 changes: 0 additions & 154 deletions src/operator/numpy/np_matrix_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,6 @@ struct NumpyTransposeParam : public dmlc::Parameter<NumpyTransposeParam> {
}
};


struct NumpyDiagflatParam : public dmlc::Parameter<NumpyDiagflatParam> {
int k;
DMLC_DECLARE_PARAMETER(NumpyDiagflatParam) {
DMLC_DECLARE_FIELD(k).set_default(0).describe("Diagonal in question. The default is 0. "
"Use k>0 for diagonals above the main diagonal, "
"and k<0 for diagonals below the main diagonal. ");
}
};

struct NumpyVstackParam : public dmlc::Parameter<NumpyVstackParam> {
int num_args;
DMLC_DECLARE_PARAMETER(NumpyVstackParam) {
Expand Down Expand Up @@ -147,150 +137,6 @@ void NumpyTranspose(const nnvm::NodeAttrs& attrs,
}
}

inline mxnet::TShape NumpyDiagflatShapeImpl(const mxnet::TShape& ishape, const int k)
{
if (ishape.ndim() == 1) {
auto s = ishape[0] + std::abs(k);
return mxnet::TShape({s, s});
}

if (ishape.ndim() >=2 ){
auto s = 1;
for(int i = 0; i < ishape.ndim(); i++){
if(ishape[i] >= 2){
s = s * ishape[i];
}
}
s = s + std::abs(k);
return mxnet::TShape({s,s});
}
return mxnet::TShape({-1,-1});
}

inline bool NumpyDiagflatOpShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector* in_attrs,
mxnet::ShapeVector* out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
const mxnet::TShape& ishape = (*in_attrs)[0];
if (!mxnet::ndim_is_known(ishape)) {
return false;
}
const NumpyDiagflatParam& param = nnvm::get<NumpyDiagflatParam>(attrs.parsed);

mxnet::TShape oshape = NumpyDiagflatShapeImpl(ishape,
param.k);

if (shape_is_none(oshape)) {
LOG(FATAL) << "Diagonal does not exist.";
}
SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape);

return shape_is_known(out_attrs->at(0));
}

inline bool NumpyDiagflatOpType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);

TYPE_ASSIGN_CHECK(*out_attrs, 0, (*in_attrs)[0]);
TYPE_ASSIGN_CHECK(*in_attrs, 0, (*out_attrs)[0]);
return (*out_attrs)[0] != -1;
}

template<int req, bool back>
struct diagflat_gen {
template<typename DType>
MSHADOW_XINLINE static void Map(index_t i,
DType* out,
const DType* a,
mshadow::Shape<2> oshape,
int k){
using namespace mxnet_op;
auto j = unravel(i,oshape);
if (j[1] == j[0] + k){
auto l = j[0] < j[1] ? j[0] : j[1];
if (back){
KERNEL_ASSIGN(out[l],req,a[i]);
}else{
KERNEL_ASSIGN(out[i],req,a[l]);
}
}else if(!back){
KERNEL_ASSIGN(out[i],req,static_cast<DType>(0));
}
}
};

template<typename xpu, bool back>
void NumpyDiagflatOpImpl(const TBlob& in_data,
const TBlob& out_data,
const mxnet::TShape& ishape,
const mxnet::TShape& oshape,
index_t dsize,
const NumpyDiagflatParam& param,
mxnet_op::Stream<xpu> *s,
const std::vector<OpReqType>& req) {

using namespace mxnet_op;
using namespace mshadow;
MSHADOW_TYPE_SWITCH(out_data.type_flag_,DType,{
MXNET_ASSIGN_REQ_SWITCH(req[0],req_type,{
Kernel<diagflat_gen<req_type,back>, xpu>::Launch(s,
dsize,
out_data.dptr<DType>(),
in_data.dptr<DType>(),
Shape2(oshape[0],oshape[1]),
param.k);
});
});

}

template<typename xpu>
void NumpyDiagflatOpForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs)
{
using namespace mxnet_op;
using namespace mshadow;
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req.size(), 1U);
CHECK_EQ(req[0], kWriteTo);
Stream<xpu> *s = ctx.get_stream<xpu>();
const TBlob& in_data = inputs[0];
const TBlob& out_data = outputs[0];
const mxnet::TShape& ishape = inputs[0].shape_;
const mxnet::TShape& oshape = outputs[0].shape_;
const NumpyDiagflatParam& param = nnvm::get<NumpyDiagflatParam>(attrs.parsed);
NumpyDiagflatOpImpl<xpu, false>(in_data, out_data, ishape, oshape, out_data.Size(), param, s, req);
}

template<typename xpu>
void NumpyDiagflatOpBackward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mxnet_op;
using namespace mshadow;
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
Stream<xpu> *s = ctx.get_stream<xpu>();

const TBlob& in_data = inputs[0];
const TBlob& out_data = outputs[0];
const mxnet::TShape& ishape = inputs[0].shape_;
const mxnet::TShape& oshape = outputs[0].shape_;
const NumpyDiagflatParam& param = nnvm::get<NumpyDiagflatParam>(attrs.parsed);

NumpyDiagflatOpImpl<xpu, true>(in_data, out_data, oshape, ishape, in_data.Size(), param, s, req);
}

template<typename xpu>
void NumpyColumnStackForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
Expand Down
25 changes: 1 addition & 24 deletions src/operator/numpy/np_matrix_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1274,29 +1274,6 @@ inline bool HSplitOpShape(const nnvm::NodeAttrs& attrs,
return SplitOpShapeImpl(attrs, in_attrs, out_attrs, real_axis);
}

DMLC_REGISTER_PARAMETER(NumpyDiagflatParam);
NNVM_REGISTER_OP(_npi_diagflat)
.set_attr_parser(ParamParser<NumpyDiagflatParam>)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"data"};
})
.set_attr<mxnet::FInferShape>("FInferShape", NumpyDiagflatOpShape)
.set_attr<nnvm::FInferType>("FInferType", NumpyDiagflatOpType)
.set_attr<FCompute>("FCompute<cpu>",NumpyDiagflatOpForward<cpu>)
.set_attr<nnvm::FGradient>("FGradient",ElemwiseGradUseNone{"_backward_npi_diagflat"})
.add_argument("data","NDArray-or-Symbol","Input ndarray")
.add_arguments(NumpyDiagflatParam::__FIELDS__());

NNVM_REGISTER_OP(_backward_npi_diagflat)
.set_attr_parser(ParamParser<NumpyDiagflatParam>)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr<nnvm::TIsBackward>("TIsBackward",true)
.set_attr<FCompute>("FCompute<cpu>",NumpyDiagflatOpBackward<cpu>);

NNVM_REGISTER_OP(_npi_hsplit)
.set_attr_parser(ParamParser<SplitParam>)
.set_num_inputs(1)
Expand Down Expand Up @@ -1358,7 +1335,7 @@ NNVM_REGISTER_OP(_npi_diagflat)
return std::vector<std::string>{"data"};
})
.set_attr<mxnet::FInferShape>("FInferShape", NumpyDiagflatOpShape)
.set_attr<nnvm::FInferType>("FInferType", NumpyDiagOpType)
.set_attr<nnvm::FInferType>("FInferType", NumpyDiagflatOpType)
.set_attr<FCompute>("FCompute<cpu>",NumpyDiagflatOpForward<cpu>)
.set_attr<nnvm::FGradient>("FGradient",ElemwiseGradUseNone{"_backward_npi_diagflat"})
.add_argument("data","NDArray-or-Symbol","Input ndarray")
Expand Down
1 change: 0 additions & 1 deletion tests/python/unittest/test_numpy_interoperability.py
Original file line number Diff line number Diff line change
Expand Up @@ -1381,7 +1381,6 @@ def _prepare_workloads():
_add_workload_where()
_add_workload_diff()
_add_workload_resize()
_add_workload_diagflat()


_prepare_workloads()
Expand Down

0 comments on commit a7d0c54

Please sign in to comment.