Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

MXNET-1295 Adding integer index support to Sequence* family of operators. #13880

Merged
merged 6 commits into from
Jan 27, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 24 additions & 13 deletions python/mxnet/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,8 +620,11 @@ def _parse_location(sym, location, ctx, dtype=default_dtype()):
*In either case, value of all the arguments must be provided.*
ctx : Context
Device context.
dtype: np.float16 or np.float32 or np.float64
Datatype for mx.nd.array.
dtype: "asnumpy" or np.float16 or np.float32 or np.float64
If dtype is "asnumpy" then the mx.nd.array created will have the same
type as th numpy array from which it is copied.
Otherwise, dtype is the explicit datatype for all mx.nd.array objects
created in this function.

Returns
-------
Expand All @@ -643,16 +646,16 @@ def _parse_location(sym, location, ctx, dtype=default_dtype()):
ValueError: Symbol arguments and keys of the given location do not match.
"""
assert isinstance(location, (dict, list, tuple))
assert dtype in (np.float16, np.float32, np.float64)
assert dtype == "asnumpy" or dtype in (np.float16, np.float32, np.float64)
if isinstance(location, dict):
if set(location.keys()) != set(sym.list_arguments()):
raise ValueError("Symbol arguments and keys of the given location do not match."
"symbol args:%s, location.keys():%s"
% (str(set(sym.list_arguments())), str(set(location.keys()))))
else:
location = {k: v for k, v in zip(sym.list_arguments(), location)}
location = {k: mx.nd.array(v, ctx=ctx, dtype=dtype) if isinstance(v, np.ndarray) \
else v for k, v in location.items()}
location = {k: mx.nd.array(v, ctx=ctx, dtype=v.dtype if dtype == "asnumpy" else dtype) \
if isinstance(v, np.ndarray) else v for k, v in location.items()}
return location


Expand All @@ -677,8 +680,11 @@ def _parse_aux_states(sym, aux_states, ctx, dtype=default_dtype()):
*In either case, all aux states of `sym` must be provided.*
ctx : Context
Device context.
dtype: np.float16 or np.float32 or np.float64
Datatype for mx.nd.array.
dtype: "asnumpy" or np.float16 or np.float32 or np.float64
If dtype is "asnumpy" then the mx.nd.array created will have the same
type as th numpy array from which it is copied.
Otherwise, dtype is the explicit datatype for all mx.nd.array objects
created in this function.

Returns
-------
Expand All @@ -702,7 +708,7 @@ def _parse_aux_states(sym, aux_states, ctx, dtype=default_dtype()):
>>> _parse_aux_states(fc2, {'batchnorm0_moving_var': mean_states}, None)
ValueError: Symbol aux_states names and given aux_states do not match.
"""
assert dtype in (np.float16, np.float32, np.float64)
assert dtype == "asnumpy" or dtype in (np.float16, np.float32, np.float64)
if aux_states is not None:
if isinstance(aux_states, dict):
if set(aux_states.keys()) != set(sym.list_auxiliary_states()):
Expand All @@ -713,7 +719,8 @@ def _parse_aux_states(sym, aux_states, ctx, dtype=default_dtype()):
elif isinstance(aux_states, (list, tuple)):
aux_names = sym.list_auxiliary_states()
aux_states = {k:v for k, v in zip(aux_names, aux_states)}
aux_states = {k: mx.nd.array(v, ctx=ctx, dtype=dtype) for k, v in aux_states.items()}
aux_states = {k: mx.nd.array(v, ctx=ctx, dtype=v.dtype if dtype == "asnumpy" else dtype) \
for k, v in aux_states.items()}
return aux_states


Expand Down Expand Up @@ -962,8 +969,11 @@ def check_symbolic_forward(sym, location, expected, rtol=1E-4, atol=None,
Contains the mapping between names of auxiliary states and their values.
ctx : Context, optional
running context
dtype: np.float16 or np.float32 or np.float64
Datatype for mx.nd.array.
dtype: "asnumpy" or np.float16 or np.float32 or np.float64
If dtype is "asnumpy" then the mx.nd.array created will have the same
type as th numpy array from which it is copied.
Otherwise, dtype is the explicit datatype for all mx.nd.array objects
created in this function.

equal_nan: Boolean
if True, `nan` is a valid value for checking equivalency (ie `nan` == `nan`)
Expand All @@ -979,7 +989,7 @@ def check_symbolic_forward(sym, location, expected, rtol=1E-4, atol=None,
>>> ret_expected = np.array([[19, 22], [43, 50]])
>>> check_symbolic_forward(sym_dot, [mat1, mat2], [ret_expected])
"""
assert dtype in (np.float16, np.float32, np.float64)
assert dtype == "asnumpy" or dtype in (np.float16, np.float32, np.float64)
if ctx is None:
ctx = default_context()

Expand All @@ -988,7 +998,8 @@ def check_symbolic_forward(sym, location, expected, rtol=1E-4, atol=None,
dtype=dtype)
if isinstance(expected, dict):
expected = [expected[k] for k in sym.list_outputs()]
args_grad_data = {k:mx.nd.empty(v.shape, ctx=ctx, dtype=dtype) for k, v in location.items()}
args_grad_data = {k:mx.nd.empty(v.shape, ctx=ctx, dtype=v.dtype if dtype == "asnumpy" else dtype) \
for k, v in location.items()}

executor = sym.bind(ctx=ctx, args=location, args_grad=args_grad_data, aux_states=aux_states)
for g in executor.grad_arrays:
Expand Down
30 changes: 14 additions & 16 deletions src/operator/sequence_last-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ struct SequenceLastParam : public dmlc::Parameter<SequenceLastParam> {

template <int req>
struct SequenceLastKernel {
template <typename DType>
template <typename DType, typename IType>
MSHADOW_XINLINE static void Map(int i, DType *out, const DType *in,
const DType *idx, int offset1, int offset2,
const IType *idx, int offset1, int offset2,
mshadow::Shape<2> oshape) {
const auto opos = mxnet_op::unravel(i, oshape);
const int seqpos = static_cast<int>(idx[opos[0]]) - 1;
Expand All @@ -77,9 +77,9 @@ struct SequenceLastKernel {
};

struct SequenceLastGradKernel {
template <typename DType>
template <typename DType, typename IType>
MSHADOW_XINLINE static void Map(int i, DType *in_grad, const DType *out_grad,
const DType *idx, int offset1, int offset2,
const IType *idx, int offset1, int offset2,
mshadow::Shape<2> oshape) {
const auto opos = mxnet_op::unravel(i, oshape);
const int seqpos = static_cast<int>(idx[opos[0]]) - 1;
Expand All @@ -88,14 +88,14 @@ struct SequenceLastGradKernel {
}
};

template <typename xpu, typename DType>
template <typename xpu, typename DType, typename IType>
class SequenceLastOp : public Operator {
public:
explicit SequenceLastOp(SequenceLastParam p) { this->param_ = p; }

void sequence_last(const mshadow::Tensor<xpu, 3, DType> &data,
const mshadow::Tensor<xpu, 2, DType> &out,
const mshadow::Tensor<xpu, 1, DType> &indices,
const mshadow::Tensor<xpu, 1, IType> &indices,
const OpReqType req, mshadow::Stream<xpu> *const s) {
using namespace mshadow;
using namespace mshadow::expr;
Expand All @@ -115,7 +115,7 @@ class SequenceLastOp : public Operator {

void sequence_last_grad(const mshadow::Tensor<xpu, 3, DType> &in_grad,
const mshadow::Tensor<xpu, 2, DType> &out_grad,
const mshadow::Tensor<xpu, 1, DType> &indices,
const mshadow::Tensor<xpu, 1, IType> &indices,
mshadow::Stream<xpu> *const s) {
using namespace mshadow;
using namespace mshadow::expr;
Expand Down Expand Up @@ -163,11 +163,11 @@ class SequenceLastOp : public Operator {
Tensor<xpu, 2, DType> out =
out_data[seq_last::kOut].get_with_shape<xpu, 2, DType>(
Shape2(batch, rest_size), s);
Tensor<xpu, 1, DType> indices =
Tensor<xpu, 1, IType> indices =
param_.use_sequence_length
? in_data[seq_last::kSequenceLength].get<xpu, 1, DType>(s)
? in_data[seq_last::kSequenceLength].get<xpu, 1, IType>(s)
: ctx.requested[seq_last::kTempSpace]
.get_space_typed<xpu, 1, DType>(Shape1(batch), s);
.get_space_typed<xpu, 1, IType>(Shape1(batch), s);
if (!param_.use_sequence_length) indices = max_seq_len;

sequence_last(data, out, indices, req[seq_last::kOut], s);
Expand Down Expand Up @@ -206,11 +206,11 @@ class SequenceLastOp : public Operator {
Tensor<xpu, 2, DType> output_grad =
out_grad[seq_last::kOut].get_with_shape<xpu, 2, DType>(
Shape2(batch, rest_size), s);
Tensor<xpu, 1, DType> indices =
Tensor<xpu, 1, IType> indices =
param_.use_sequence_length
? in_data[seq_last::kSequenceLength].get<xpu, 1, DType>(s)
? in_data[seq_last::kSequenceLength].get<xpu, 1, IType>(s)
: ctx.requested[seq_last::kTempSpace]
.get_space_typed<xpu, 1, DType>(Shape1(batch), s);
.get_space_typed<xpu, 1, IType>(Shape1(batch), s);

if (req[seq_last::kData] == kWriteTo) data_grad = 0.0f;
sequence_last_grad(data_grad, output_grad, indices, s);
Expand All @@ -221,7 +221,7 @@ class SequenceLastOp : public Operator {
}; // class SequenceLastOp

template <typename xpu>
Operator *CreateOp(SequenceLastParam param, int dtype);
Operator *CreateOp(SequenceLastParam param, int dtype, int itype);

#if DMLC_USE_CXX11
class SequenceLastProp : public OperatorProperty {
Expand Down Expand Up @@ -281,8 +281,6 @@ class SequenceLastProp : public OperatorProperty {
for (size_t i = 0; i < in_type->size(); ++i) {
if ((*in_type)[i] == -1) {
(*in_type)[i] = dtype;
} else {
UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments()[i]);
}
}
out_type->clear();
Expand Down
16 changes: 12 additions & 4 deletions src/operator/sequence_last.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,26 @@
namespace mxnet {
namespace op {
template <>
Operator *CreateOp<cpu>(SequenceLastParam param, int dtype) {
Operator *CreateOp<cpu>(SequenceLastParam param, int dtype, int itype) {
Operator *op = nullptr;
MSHADOW_TYPE_SWITCH(dtype, DType,
{ op = new SequenceLastOp<cpu, DType>(param); })
MSHADOW_TYPE_SWITCH(dtype, DType, {
MSHADOW_TYPE_SWITCH(itype, IType, {
op = new SequenceLastOp<cpu, DType, IType>(param);
});
});
return op;
}

// DO_BIND_DISPATCH comes from operator_common.h
Operator *SequenceLastProp::CreateOperatorEx(Context ctx,
std::vector<TShape> *in_shape,
std::vector<int> *in_type) const {
DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0]);
if (in_type->size() >= 2 && (*in_type)[1] != -1) {
DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0], (*in_type)[1]);
}

// sequence_length not passed in, so fall back to using input array dtype for second argument
DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0], (*in_type)[0]);
}

DMLC_REGISTER_PARAMETER(SequenceLastParam);
Expand Down
9 changes: 6 additions & 3 deletions src/operator/sequence_last.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,13 @@

namespace mxnet {
namespace op {
template <> Operator *CreateOp<gpu>(SequenceLastParam param, int dtype) {
template <> Operator *CreateOp<gpu>(SequenceLastParam param, int dtype, int itype) {
Operator *op = NULL;
MSHADOW_TYPE_SWITCH(dtype, DType,
{ op = new SequenceLastOp<gpu, DType>(param); })
MSHADOW_TYPE_SWITCH(dtype, DType, {
MSHADOW_TYPE_SWITCH(itype, IType, {
op = new SequenceLastOp<gpu, DType, IType>(param);
});
});
return op;
}

Expand Down
24 changes: 11 additions & 13 deletions src/operator/sequence_mask-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ struct SequenceMaskParam : public dmlc::Parameter<SequenceMaskParam> {
// (seqlen, batch, rest) case
template <int req>
struct SequenceMask0Kernel {
template <typename DType>
MSHADOW_XINLINE static void Map(int b, DType *in, const DType *idx,
template <typename DType, typename IType>
MSHADOW_XINLINE static void Map(int b, DType *in, const IType *idx,
index_t max_s_len, index_t batch_size,
index_t restsize, DType value) {
const index_t seqpos = static_cast<int>(idx[b]);
Expand All @@ -86,8 +86,8 @@ struct SequenceMask0Kernel {
// (batch, seqlen, rest) case
template <int req>
struct SequenceMask1Kernel {
template <typename DType>
MSHADOW_XINLINE static void Map(int b, DType *in, const DType *idx,
template <typename DType, typename IType>
MSHADOW_XINLINE static void Map(int b, DType *in, const IType *idx,
index_t max_s_len, index_t batch_size,
index_t restsize, DType value) {
const index_t seqpos = static_cast<int>(idx[b]);
Expand All @@ -101,13 +101,13 @@ struct SequenceMask1Kernel {
}
};

template <typename xpu, typename DType>
template <typename xpu, typename DType, typename IType>
class SequenceMaskOp : public Operator {
public:
explicit SequenceMaskOp(SequenceMaskParam p) { this->param_ = p; }

void sequence_mask(const mshadow::Tensor<xpu, 3, DType> &data,
const mshadow::Tensor<xpu, 1, DType> &indices,
const mshadow::Tensor<xpu, 1, IType> &indices,
const OpReqType req, mshadow::Stream<xpu> *const s,
DType val) {
using namespace mshadow;
Expand Down Expand Up @@ -153,8 +153,8 @@ class SequenceMaskOp : public Operator {
// Actual implementation of masking
Assign(out, req[seq_mask::kOut], F<mshadow_op::identity>(data));
if (param_.use_sequence_length) {
Tensor<xpu, 1, DType> indices =
in_data[seq_mask::kSequenceLength].get<xpu, 1, DType>(s);
Tensor<xpu, 1, IType> indices =
in_data[seq_mask::kSequenceLength].get<xpu, 1, IType>(s);
sequence_mask(out, indices, req[seq_mask::kOut], s,
static_cast<DType>(param_.value));
}
Expand Down Expand Up @@ -190,8 +190,8 @@ class SequenceMaskOp : public Operator {
if (!param_.use_sequence_length) {
Assign(data_g, req[seq_mask::kData], F<mshadow_op::identity>(out_g));
} else {
Tensor<xpu, 1, DType> indices =
in_data[seq_mask::kSequenceLength].get<xpu, 1, DType>(s);
Tensor<xpu, 1, IType> indices =
in_data[seq_mask::kSequenceLength].get<xpu, 1, IType>(s);
if (req[seq_mask::kData] == kAddTo) {
Tensor<xpu, 3, DType> out_g_temp =
ctx.requested[seq_mask::kTempSpace].get_space_typed<xpu, 3, DType>(
Expand All @@ -212,7 +212,7 @@ class SequenceMaskOp : public Operator {
}; // class SequenceMaskOp

template <typename xpu>
Operator *CreateOp(SequenceMaskParam param, int dtype);
Operator *CreateOp(SequenceMaskParam param, int dtype, int itype);

#if DMLC_USE_CXX11
class SequenceMaskProp : public OperatorProperty {
Expand Down Expand Up @@ -270,8 +270,6 @@ class SequenceMaskProp : public OperatorProperty {
for (size_t i = 0; i < in_type->size(); ++i) {
if ((*in_type)[i] == -1) {
(*in_type)[i] = dtype;
} else {
UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments()[i]);
}
}
out_type->clear();
Expand Down
16 changes: 12 additions & 4 deletions src/operator/sequence_mask.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,26 @@
namespace mxnet {
namespace op {
template <>
Operator *CreateOp<cpu>(SequenceMaskParam param, int dtype) {
Operator *CreateOp<cpu>(SequenceMaskParam param, int dtype, int itype) {
Operator *op = nullptr;
MSHADOW_TYPE_SWITCH(dtype, DType,
{ op = new SequenceMaskOp<cpu, DType>(param); })
MSHADOW_TYPE_SWITCH(dtype, DType, {
MSHADOW_TYPE_SWITCH(itype, IType, {
op = new SequenceMaskOp<cpu, DType, IType>(param);
});
});
return op;
}

// DO_BIND_DISPATCH comes from operator_common.h
Operator *SequenceMaskProp::CreateOperatorEx(Context ctx,
std::vector<TShape> *in_shape,
std::vector<int> *in_type) const {
DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0]);
if (in_type->size() >= 2 && (*in_type)[1] != -1) {
DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0], (*in_type)[1]);
}

// sequence_length not passed in, so fall back to using input array dtype for second argument
DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0], (*in_type)[0]);
}

DMLC_REGISTER_PARAMETER(SequenceMaskParam);
Expand Down
9 changes: 6 additions & 3 deletions src/operator/sequence_mask.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,13 @@
namespace mxnet {
namespace op {

template <> Operator *CreateOp<gpu>(SequenceMaskParam param, int dtype) {
template <> Operator *CreateOp<gpu>(SequenceMaskParam param, int dtype, int itype) {
Operator *op = NULL;
MSHADOW_TYPE_SWITCH(dtype, DType,
{ op = new SequenceMaskOp<gpu, DType>(param); })
MSHADOW_TYPE_SWITCH(dtype, DType, {
MSHADOW_TYPE_SWITCH(itype, IType, {
op = new SequenceMaskOp<gpu, DType, IType>(param);
});
});
return op;
}

Expand Down
Loading