Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-1330] Bring nnvm::Tuple to mxnet::Tuple #14270

Merged
merged 9 commits into from
Mar 1, 2019
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
22 changes: 11 additions & 11 deletions docs/architecture/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -301,9 +301,9 @@ The `OperatorProperty` interface consists of:
* **InferShape:**

```c++
virtual bool InferShape(std::vector<TShape> *in_shape,
std::vector<TShape> *out_shape,
std::vector<TShape> *aux_shape) const = 0;
virtual bool InferShape(mxnet::ShapeVector *in_shape,
mxnet::ShapeVector *out_shape,
mxnet::ShapeVector *aux_shape) const = 0;
```

This interface has two purposes:
Expand All @@ -322,9 +322,9 @@ MXNet defines two interfaces to achieve this:

```c++
virtual std::vector<ResourceRequest> ForwardResource(
const std::vector<TShape> &in_shape) const;
const mxnet::ShapeVector &in_shape) const;
virtual std::vector<ResourceRequest> BackwardResource(
const std::vector<TShape> &in_shape) const;
const mxnet::ShapeVector &in_shape) const;
```
The `ResourceRequest` structure (in `resource.h`) currently contains only a type flag:

Expand Down Expand Up @@ -473,7 +473,7 @@ To do so, you could define a `ConvolutionParam` structure, as follows:
```c++
#include <dmlc/parameter.h>
struct ConvolutionParam : public dmlc::Parameter<ConvolutionParam> {
TShape kernel, stride, pad;
mxnet::TShape kernel, stride, pad;
uint32_t num_filter, num_group, workspace;
bool no_bias;
};
Expand Down Expand Up @@ -582,10 +582,10 @@ must be provided before any calculation occurs.
let's check input data shape consistency and provide output shape.

```cpp
typedef TShape (*UnaryShapeFunction)(const TShape& src,
typedef mxnet::TShape (*UnaryShapeFunction)(const mxnet::TShape& src,
const EnvArguments& env);
typedef TShape (*BinaryShapeFunction)(const TShape& lhs,
const TShape& rhs,
typedef mxnet::TShape (*BinaryShapeFunction)(const mxnet::TShape& lhs,
const mxnet::TShape& rhs,
const EnvArguments& env);
```
You can use `mshadow::TShape` to check input data shape and designate output data shape.
Expand All @@ -611,9 +611,9 @@ In our smooth l1 loss example, it's okay to use the default behavior whereby the
Written explicitly, it is:

```cpp
inline TShape SmoothL1Shape_(const TShape& src,
inline mxnet::TShape SmoothL1Shape_(const mxnet::TShape& src,
const EnvArguments& env) {
return TShape(src);
return mxnet::TShape(src);
}
```

Expand Down
8 changes: 4 additions & 4 deletions docs/faq/add_op_in_backend.md
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,8 @@ element-wise multiplication and addition.
For our `quadratic` operator, shape inference possesses quite similar logic.
```cpp
inline bool QuadraticOpShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape>* in_attrs,
std::vector<TShape>* out_attrs) {
mxnet::ShapeVector* in_attrs,
mxnet::ShapeVector* out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);

Expand Down Expand Up @@ -216,8 +216,8 @@ The function `QuadraticOpShape` posted here is for the purpose of illustration o
```cpp
template<int n_in, int n_out>
inline bool ElemwiseShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs);
mxnet::ShapeVector *in_attrs,
mxnet::ShapeVector *out_attrs);
```

The same logic goes for data type inference. We will leave the analysis of
Expand Down
2 changes: 1 addition & 1 deletion docs/faq/new_op.md
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ can add argument descriptions in bulk with `.add_arguments(ActivationParam::__FI

#### FInferShape or TIsBackward (for Backward Only Ops)

Normally operators need to have `FInferShape` with prototype `bool(const nnvm::NodeAttrs& attrs, std::vector<TShape> *in_attrs, std::vector<TShape> *out_attrs)`. `FInferShape` fills unknown shapes (`shape.ndim() == 0`) in in_attrs/out_attrs based on known shapes in in_attrs/out_attrs. Use `ElemwiseShape<n_in, n_out>` for simple operators with uniform shapes.
Normally operators need to have `FInferShape` with prototype `bool(const nnvm::NodeAttrs& attrs, mxnet::ShapeVector *in_attrs, mxnet::ShapeVector *out_attrs)`. `FInferShape` fills unknown shapes (`shape.ndim() == 0`) in in_attrs/out_attrs based on known shapes in in_attrs/out_attrs. Use `ElemwiseShape<n_in, n_out>` for simple operators with uniform shapes.

Operators that are only used for a backward pass can instead register `.set_attr<nnvm::TIsBackward>("TIsBackward", true)`
and their shapes with be copied from the corresponding forward operators.
Expand Down
3 changes: 1 addition & 2 deletions include/mxnet/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "nnvm/tuple.h"
junrushao marked this conversation as resolved.
Show resolved Hide resolved
#include "nnvm/symbolic.h"
#include "libinfo.h"
#include "tuple.h"


/*!
Expand Down Expand Up @@ -95,8 +96,6 @@ typedef mshadow::gpu gpu;
typedef mshadow::index_t index_t;
/*! \brief data type that will be used to store ndarray */
typedef mshadow::default_real_t real_t;
/*! \brief Shape data structure used to record shape information */
using TShape = nnvm::TShape;
/*! \brief operator structure from NNVM */
using Op = nnvm::Op;

Expand Down
4 changes: 2 additions & 2 deletions include/mxnet/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ class Executor {
const bool allow_up_sizing,
const Context& default_ctx,
const std::map<std::string, Context>& ctx_map,
const std::unordered_map<std::string, TShape>&
const std::unordered_map<std::string, mxnet::TShape>&
provided_arg_shapes,
std::vector<NDArray>* in_args,
std::vector<NDArray>* arg_grads,
Expand Down Expand Up @@ -155,7 +155,7 @@ class Executor {
const std::vector<Context>& in_arg_ctxes,
const std::vector<Context>& arg_grad_ctxes,
const std::vector<Context>& aux_state_ctxes,
const std::unordered_map<std::string, TShape>& arg_shape_map,
const std::unordered_map<std::string, mxnet::TShape>& arg_shape_map,
const std::unordered_map<std::string, int>& arg_dtype_map,
const std::unordered_map<std::string, int>& arg_stype_map,
const std::vector<OpReqType>& grad_req_types,
Expand Down
Loading