Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Add a zero operator with no default dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
apeforest committed Sep 12, 2018
1 parent e239f15 commit dcc5f78
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 4 deletions.
4 changes: 2 additions & 2 deletions src/executor/graph_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,8 @@ nnvm::NodeEntry AggregateGradient(std::vector<nnvm::NodeEntry>&& v) {

if (v.empty()) {
nnvm::NodePtr ng = nnvm::Node::Create();
ng->attrs.op = zeros_op;
ng->attrs.name = "zeros";
ng->attrs.op = Op::Get("_zeros_no_default");
ng->attrs.name = "zeros_no_default";
ng->attrs.op->attr_parser(&(ng->attrs));
return nnvm::NodeEntry{ng, 0, 0};
}
Expand Down
3 changes: 1 addition & 2 deletions src/executor/infer_graph_attr_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,7 @@ nnvm::Graph InferAttr(nnvm::Graph &&ret,
dispatch_mode = &dispatch_modes[nid];
if (dispatch_modes[nid] == DispatchMode::kUndefined) forward_known = false;
}
auto finfer = (inode.source->op() == Op::Get("_zeros")) ? fdefault :
finfer_shape.get(inode.source->op(), fdefault);
auto finfer = finfer_shape.get(inode.source->op(), fdefault);
if (!forward_known) {
if (finfer != nullptr) {
// Call inference function of the operator.
Expand Down
12 changes: 12 additions & 0 deletions src/operator/tensor/init_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,21 @@ namespace op {

DMLC_REGISTER_PARAMETER(InitOpParam);
DMLC_REGISTER_PARAMETER(InitOpWithScalarParam);
DMLC_REGISTER_PARAMETER(InitOpNoDefaultParam);
DMLC_REGISTER_PARAMETER(RangeParam);
DMLC_REGISTER_PARAMETER(EyeParam);

NNVM_REGISTER_OP(_zeros_no_default)
.describe("fill target with zeros with no default type")
.set_num_inputs(0)
.set_num_outputs(1)
.set_attr_parser(ParamParser<InitOpNoDefaultParam>)
.set_attr<nnvm::FInferShape>("FInferShape", InitShape<InitOpNoDefaultParam>)
.set_attr<nnvm::FInferType>("FInferType", InitType<InitOpNoDefaultParam>)
.set_attr<FInferStorageType>("FInferStorageType", InitStorageType<InitOpNoDefaultParam, true, true>)
.set_attr<FCompute>("FCompute<cpu>", FillCompute<cpu, 0>)
.set_attr<FComputeEx>("FComputeEx<cpu>", FillComputeZerosEx<cpu>)
.add_arguments(InitOpNoDefaultParam::__FIELDS__());

NNVM_REGISTER_OP(_zeros)
.describe("fill target with zeros")
Expand Down
3 changes: 3 additions & 0 deletions src/operator/tensor/init_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ void FillZerosCsrImpl(mshadow::Stream<mshadow::gpu> *s, const NDArray& dst) {
});
}

NNVM_REGISTER_OP(_zeros_no_default)
.set_attr<FCompute>("FCompute<gpu>", FillCompute<gpu, 0>)
.set_attr<FComputeEx>("FComputeEx<gpu>", FillComputeZerosEx<gpu>);

NNVM_REGISTER_OP(_zeros)
.set_attr<FCompute>("FCompute<gpu>", FillCompute<gpu, 0>)
Expand Down
18 changes: 18 additions & 0 deletions src/operator/tensor/init_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,24 @@ struct InitOpParam : public dmlc::Parameter<InitOpParam> {
}
};

struct InitOpNoDefaultParam : public dmlc::Parameter<InitOpNoDefaultParam> {
TShape shape;
std::string ctx;
int dtype;
DMLC_DECLARE_PARAMETER(InitOpNoDefaultParam) {
DMLC_DECLARE_FIELD(shape)
.set_default(TShape())
.describe("The shape of the output");
DMLC_DECLARE_FIELD(ctx)
.set_default("")
.describe("Context of output, in format [cpu|gpu|cpu_pinned](n)."
"Only used for imperative calls.");
DMLC_DECLARE_FIELD(dtype)
.set_default(-1)
.describe("Target data type.");
}
};

struct EyeParam : public dmlc::Parameter<EyeParam> {
nnvm::dim_t N;
nnvm::dim_t M;
Expand Down

0 comments on commit dcc5f78

Please sign in to comment.