From 17632558fe843019e1b549386c14ce72e164097c Mon Sep 17 00:00:00 2001 From: stephenrawls <10453511+stephenrawls@users.noreply.github.com> Date: Sat, 26 Jan 2019 22:04:09 -0800 Subject: [PATCH] MXNET-1295 Adding integer index support to Sequence* family of operators. (#13880) * Adding integer index support to Sequence* family of operators. Adding ability to use int32 arrays, or any castable-to-int type, as the sequence_length array to SequenceMask, SequenceLast, and SequenceReverse. Previously these operaters all requred sequence_length to be the same data type as the input array. See MxNet Jira ticket here: https://issues.apache.org/jira/browse/MXNET-1295 See also GitHub issues here: https://github.com/apache/incubator-mxnet/issues/12649 https://github.com/dmlc/gluon-nlp/issues/346 * Adding explicit braces to an if statement to fix g++ warning * fixing sequence_mask.cu by adding IType to template * Fixing whitespace errors reported by linter * Adding unit tests * Fixing length of lines to pass linter --- python/mxnet/test_utils.py | 37 +++++++++------ src/operator/sequence_last-inl.h | 30 ++++++------- src/operator/sequence_last.cc | 16 +++++-- src/operator/sequence_last.cu | 9 ++-- src/operator/sequence_mask-inl.h | 24 +++++----- src/operator/sequence_mask.cc | 16 +++++-- src/operator/sequence_mask.cu | 9 ++-- src/operator/sequence_reverse-inl.h | 20 ++++----- src/operator/sequence_reverse.cc | 17 +++++-- src/operator/sequence_reverse.cu | 8 ++-- tests/python/unittest/test_operator.py | 62 +++++++++++++------------- 11 files changed, 144 insertions(+), 104 deletions(-) diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py index 0a4d17dc2668..4138e4d2d755 100644 --- a/python/mxnet/test_utils.py +++ b/python/mxnet/test_utils.py @@ -620,8 +620,11 @@ def _parse_location(sym, location, ctx, dtype=default_dtype()): *In either case, value of all the arguments must be provided.* ctx : Context Device context. - dtype: np.float16 or np.float32 or np.float64 - Datatype for mx.nd.array. + dtype: "asnumpy" or np.float16 or np.float32 or np.float64 + If dtype is "asnumpy" then the mx.nd.array created will have the same + type as th numpy array from which it is copied. + Otherwise, dtype is the explicit datatype for all mx.nd.array objects + created in this function. Returns ------- @@ -643,7 +646,7 @@ def _parse_location(sym, location, ctx, dtype=default_dtype()): ValueError: Symbol arguments and keys of the given location do not match. """ assert isinstance(location, (dict, list, tuple)) - assert dtype in (np.float16, np.float32, np.float64) + assert dtype == "asnumpy" or dtype in (np.float16, np.float32, np.float64) if isinstance(location, dict): if set(location.keys()) != set(sym.list_arguments()): raise ValueError("Symbol arguments and keys of the given location do not match." @@ -651,8 +654,8 @@ def _parse_location(sym, location, ctx, dtype=default_dtype()): % (str(set(sym.list_arguments())), str(set(location.keys())))) else: location = {k: v for k, v in zip(sym.list_arguments(), location)} - location = {k: mx.nd.array(v, ctx=ctx, dtype=dtype) if isinstance(v, np.ndarray) \ - else v for k, v in location.items()} + location = {k: mx.nd.array(v, ctx=ctx, dtype=v.dtype if dtype == "asnumpy" else dtype) \ + if isinstance(v, np.ndarray) else v for k, v in location.items()} return location @@ -677,8 +680,11 @@ def _parse_aux_states(sym, aux_states, ctx, dtype=default_dtype()): *In either case, all aux states of `sym` must be provided.* ctx : Context Device context. - dtype: np.float16 or np.float32 or np.float64 - Datatype for mx.nd.array. + dtype: "asnumpy" or np.float16 or np.float32 or np.float64 + If dtype is "asnumpy" then the mx.nd.array created will have the same + type as th numpy array from which it is copied. + Otherwise, dtype is the explicit datatype for all mx.nd.array objects + created in this function. Returns ------- @@ -702,7 +708,7 @@ def _parse_aux_states(sym, aux_states, ctx, dtype=default_dtype()): >>> _parse_aux_states(fc2, {'batchnorm0_moving_var': mean_states}, None) ValueError: Symbol aux_states names and given aux_states do not match. """ - assert dtype in (np.float16, np.float32, np.float64) + assert dtype == "asnumpy" or dtype in (np.float16, np.float32, np.float64) if aux_states is not None: if isinstance(aux_states, dict): if set(aux_states.keys()) != set(sym.list_auxiliary_states()): @@ -713,7 +719,8 @@ def _parse_aux_states(sym, aux_states, ctx, dtype=default_dtype()): elif isinstance(aux_states, (list, tuple)): aux_names = sym.list_auxiliary_states() aux_states = {k:v for k, v in zip(aux_names, aux_states)} - aux_states = {k: mx.nd.array(v, ctx=ctx, dtype=dtype) for k, v in aux_states.items()} + aux_states = {k: mx.nd.array(v, ctx=ctx, dtype=v.dtype if dtype == "asnumpy" else dtype) \ + for k, v in aux_states.items()} return aux_states @@ -962,8 +969,11 @@ def check_symbolic_forward(sym, location, expected, rtol=1E-4, atol=None, Contains the mapping between names of auxiliary states and their values. ctx : Context, optional running context - dtype: np.float16 or np.float32 or np.float64 - Datatype for mx.nd.array. + dtype: "asnumpy" or np.float16 or np.float32 or np.float64 + If dtype is "asnumpy" then the mx.nd.array created will have the same + type as th numpy array from which it is copied. + Otherwise, dtype is the explicit datatype for all mx.nd.array objects + created in this function. equal_nan: Boolean if True, `nan` is a valid value for checking equivalency (ie `nan` == `nan`) @@ -979,7 +989,7 @@ def check_symbolic_forward(sym, location, expected, rtol=1E-4, atol=None, >>> ret_expected = np.array([[19, 22], [43, 50]]) >>> check_symbolic_forward(sym_dot, [mat1, mat2], [ret_expected]) """ - assert dtype in (np.float16, np.float32, np.float64) + assert dtype == "asnumpy" or dtype in (np.float16, np.float32, np.float64) if ctx is None: ctx = default_context() @@ -988,7 +998,8 @@ def check_symbolic_forward(sym, location, expected, rtol=1E-4, atol=None, dtype=dtype) if isinstance(expected, dict): expected = [expected[k] for k in sym.list_outputs()] - args_grad_data = {k:mx.nd.empty(v.shape, ctx=ctx, dtype=dtype) for k, v in location.items()} + args_grad_data = {k:mx.nd.empty(v.shape, ctx=ctx, dtype=v.dtype if dtype == "asnumpy" else dtype) \ + for k, v in location.items()} executor = sym.bind(ctx=ctx, args=location, args_grad=args_grad_data, aux_states=aux_states) for g in executor.grad_arrays: diff --git a/src/operator/sequence_last-inl.h b/src/operator/sequence_last-inl.h index 1a59473cfc3a..61506c2af3de 100644 --- a/src/operator/sequence_last-inl.h +++ b/src/operator/sequence_last-inl.h @@ -65,9 +65,9 @@ struct SequenceLastParam : public dmlc::Parameter { template struct SequenceLastKernel { - template + template MSHADOW_XINLINE static void Map(int i, DType *out, const DType *in, - const DType *idx, int offset1, int offset2, + const IType *idx, int offset1, int offset2, mshadow::Shape<2> oshape) { const auto opos = mxnet_op::unravel(i, oshape); const int seqpos = static_cast(idx[opos[0]]) - 1; @@ -77,9 +77,9 @@ struct SequenceLastKernel { }; struct SequenceLastGradKernel { - template + template MSHADOW_XINLINE static void Map(int i, DType *in_grad, const DType *out_grad, - const DType *idx, int offset1, int offset2, + const IType *idx, int offset1, int offset2, mshadow::Shape<2> oshape) { const auto opos = mxnet_op::unravel(i, oshape); const int seqpos = static_cast(idx[opos[0]]) - 1; @@ -88,14 +88,14 @@ struct SequenceLastGradKernel { } }; -template +template class SequenceLastOp : public Operator { public: explicit SequenceLastOp(SequenceLastParam p) { this->param_ = p; } void sequence_last(const mshadow::Tensor &data, const mshadow::Tensor &out, - const mshadow::Tensor &indices, + const mshadow::Tensor &indices, const OpReqType req, mshadow::Stream *const s) { using namespace mshadow; using namespace mshadow::expr; @@ -115,7 +115,7 @@ class SequenceLastOp : public Operator { void sequence_last_grad(const mshadow::Tensor &in_grad, const mshadow::Tensor &out_grad, - const mshadow::Tensor &indices, + const mshadow::Tensor &indices, mshadow::Stream *const s) { using namespace mshadow; using namespace mshadow::expr; @@ -163,11 +163,11 @@ class SequenceLastOp : public Operator { Tensor out = out_data[seq_last::kOut].get_with_shape( Shape2(batch, rest_size), s); - Tensor indices = + Tensor indices = param_.use_sequence_length - ? in_data[seq_last::kSequenceLength].get(s) + ? in_data[seq_last::kSequenceLength].get(s) : ctx.requested[seq_last::kTempSpace] - .get_space_typed(Shape1(batch), s); + .get_space_typed(Shape1(batch), s); if (!param_.use_sequence_length) indices = max_seq_len; sequence_last(data, out, indices, req[seq_last::kOut], s); @@ -206,11 +206,11 @@ class SequenceLastOp : public Operator { Tensor output_grad = out_grad[seq_last::kOut].get_with_shape( Shape2(batch, rest_size), s); - Tensor indices = + Tensor indices = param_.use_sequence_length - ? in_data[seq_last::kSequenceLength].get(s) + ? in_data[seq_last::kSequenceLength].get(s) : ctx.requested[seq_last::kTempSpace] - .get_space_typed(Shape1(batch), s); + .get_space_typed(Shape1(batch), s); if (req[seq_last::kData] == kWriteTo) data_grad = 0.0f; sequence_last_grad(data_grad, output_grad, indices, s); @@ -221,7 +221,7 @@ class SequenceLastOp : public Operator { }; // class SequenceLastOp template -Operator *CreateOp(SequenceLastParam param, int dtype); +Operator *CreateOp(SequenceLastParam param, int dtype, int itype); #if DMLC_USE_CXX11 class SequenceLastProp : public OperatorProperty { @@ -281,8 +281,6 @@ class SequenceLastProp : public OperatorProperty { for (size_t i = 0; i < in_type->size(); ++i) { if ((*in_type)[i] == -1) { (*in_type)[i] = dtype; - } else { - UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments()[i]); } } out_type->clear(); diff --git a/src/operator/sequence_last.cc b/src/operator/sequence_last.cc index 345524b38134..f2388a8efbf3 100644 --- a/src/operator/sequence_last.cc +++ b/src/operator/sequence_last.cc @@ -28,10 +28,13 @@ namespace mxnet { namespace op { template <> -Operator *CreateOp(SequenceLastParam param, int dtype) { +Operator *CreateOp(SequenceLastParam param, int dtype, int itype) { Operator *op = nullptr; - MSHADOW_TYPE_SWITCH(dtype, DType, - { op = new SequenceLastOp(param); }) + MSHADOW_TYPE_SWITCH(dtype, DType, { + MSHADOW_TYPE_SWITCH(itype, IType, { + op = new SequenceLastOp(param); + }); + }); return op; } @@ -39,7 +42,12 @@ Operator *CreateOp(SequenceLastParam param, int dtype) { Operator *SequenceLastProp::CreateOperatorEx(Context ctx, std::vector *in_shape, std::vector *in_type) const { - DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0]); + if (in_type->size() >= 2 && (*in_type)[1] != -1) { + DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0], (*in_type)[1]); + } + + // sequence_length not passed in, so fall back to using input array dtype for second argument + DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0], (*in_type)[0]); } DMLC_REGISTER_PARAMETER(SequenceLastParam); diff --git a/src/operator/sequence_last.cu b/src/operator/sequence_last.cu index dfc4e59d1b85..fb5ae8471c48 100644 --- a/src/operator/sequence_last.cu +++ b/src/operator/sequence_last.cu @@ -28,10 +28,13 @@ namespace mxnet { namespace op { -template <> Operator *CreateOp(SequenceLastParam param, int dtype) { +template <> Operator *CreateOp(SequenceLastParam param, int dtype, int itype) { Operator *op = NULL; - MSHADOW_TYPE_SWITCH(dtype, DType, - { op = new SequenceLastOp(param); }) + MSHADOW_TYPE_SWITCH(dtype, DType, { + MSHADOW_TYPE_SWITCH(itype, IType, { + op = new SequenceLastOp(param); + }); + }); return op; } diff --git a/src/operator/sequence_mask-inl.h b/src/operator/sequence_mask-inl.h index c93ffb5f17b6..c2584abd4178 100644 --- a/src/operator/sequence_mask-inl.h +++ b/src/operator/sequence_mask-inl.h @@ -68,8 +68,8 @@ struct SequenceMaskParam : public dmlc::Parameter { // (seqlen, batch, rest) case template struct SequenceMask0Kernel { - template - MSHADOW_XINLINE static void Map(int b, DType *in, const DType *idx, + template + MSHADOW_XINLINE static void Map(int b, DType *in, const IType *idx, index_t max_s_len, index_t batch_size, index_t restsize, DType value) { const index_t seqpos = static_cast(idx[b]); @@ -86,8 +86,8 @@ struct SequenceMask0Kernel { // (batch, seqlen, rest) case template struct SequenceMask1Kernel { - template - MSHADOW_XINLINE static void Map(int b, DType *in, const DType *idx, + template + MSHADOW_XINLINE static void Map(int b, DType *in, const IType *idx, index_t max_s_len, index_t batch_size, index_t restsize, DType value) { const index_t seqpos = static_cast(idx[b]); @@ -101,13 +101,13 @@ struct SequenceMask1Kernel { } }; -template +template class SequenceMaskOp : public Operator { public: explicit SequenceMaskOp(SequenceMaskParam p) { this->param_ = p; } void sequence_mask(const mshadow::Tensor &data, - const mshadow::Tensor &indices, + const mshadow::Tensor &indices, const OpReqType req, mshadow::Stream *const s, DType val) { using namespace mshadow; @@ -153,8 +153,8 @@ class SequenceMaskOp : public Operator { // Actual implementation of masking Assign(out, req[seq_mask::kOut], F(data)); if (param_.use_sequence_length) { - Tensor indices = - in_data[seq_mask::kSequenceLength].get(s); + Tensor indices = + in_data[seq_mask::kSequenceLength].get(s); sequence_mask(out, indices, req[seq_mask::kOut], s, static_cast(param_.value)); } @@ -190,8 +190,8 @@ class SequenceMaskOp : public Operator { if (!param_.use_sequence_length) { Assign(data_g, req[seq_mask::kData], F(out_g)); } else { - Tensor indices = - in_data[seq_mask::kSequenceLength].get(s); + Tensor indices = + in_data[seq_mask::kSequenceLength].get(s); if (req[seq_mask::kData] == kAddTo) { Tensor out_g_temp = ctx.requested[seq_mask::kTempSpace].get_space_typed( @@ -212,7 +212,7 @@ class SequenceMaskOp : public Operator { }; // class SequenceMaskOp template -Operator *CreateOp(SequenceMaskParam param, int dtype); +Operator *CreateOp(SequenceMaskParam param, int dtype, int itype); #if DMLC_USE_CXX11 class SequenceMaskProp : public OperatorProperty { @@ -270,8 +270,6 @@ class SequenceMaskProp : public OperatorProperty { for (size_t i = 0; i < in_type->size(); ++i) { if ((*in_type)[i] == -1) { (*in_type)[i] = dtype; - } else { - UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments()[i]); } } out_type->clear(); diff --git a/src/operator/sequence_mask.cc b/src/operator/sequence_mask.cc index e02c57bfd917..76e58386b8ad 100644 --- a/src/operator/sequence_mask.cc +++ b/src/operator/sequence_mask.cc @@ -28,10 +28,13 @@ namespace mxnet { namespace op { template <> -Operator *CreateOp(SequenceMaskParam param, int dtype) { +Operator *CreateOp(SequenceMaskParam param, int dtype, int itype) { Operator *op = nullptr; - MSHADOW_TYPE_SWITCH(dtype, DType, - { op = new SequenceMaskOp(param); }) + MSHADOW_TYPE_SWITCH(dtype, DType, { + MSHADOW_TYPE_SWITCH(itype, IType, { + op = new SequenceMaskOp(param); + }); + }); return op; } @@ -39,7 +42,12 @@ Operator *CreateOp(SequenceMaskParam param, int dtype) { Operator *SequenceMaskProp::CreateOperatorEx(Context ctx, std::vector *in_shape, std::vector *in_type) const { - DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0]); + if (in_type->size() >= 2 && (*in_type)[1] != -1) { + DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0], (*in_type)[1]); + } + + // sequence_length not passed in, so fall back to using input array dtype for second argument + DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0], (*in_type)[0]); } DMLC_REGISTER_PARAMETER(SequenceMaskParam); diff --git a/src/operator/sequence_mask.cu b/src/operator/sequence_mask.cu index 2ca88322428e..cec627c4c697 100644 --- a/src/operator/sequence_mask.cu +++ b/src/operator/sequence_mask.cu @@ -29,10 +29,13 @@ namespace mxnet { namespace op { -template <> Operator *CreateOp(SequenceMaskParam param, int dtype) { +template <> Operator *CreateOp(SequenceMaskParam param, int dtype, int itype) { Operator *op = NULL; - MSHADOW_TYPE_SWITCH(dtype, DType, - { op = new SequenceMaskOp(param); }) + MSHADOW_TYPE_SWITCH(dtype, DType, { + MSHADOW_TYPE_SWITCH(itype, IType, { + op = new SequenceMaskOp(param); + }); + }); return op; } diff --git a/src/operator/sequence_reverse-inl.h b/src/operator/sequence_reverse-inl.h index 5c48729e18ff..eb9f71ccce9e 100644 --- a/src/operator/sequence_reverse-inl.h +++ b/src/operator/sequence_reverse-inl.h @@ -65,14 +65,14 @@ struct SequenceReverseParam : public dmlc::Parameter { }; struct ReverseKernel { - template + template MSHADOW_XINLINE static void Map(const int i, DType *const out_data, const DType *const in_data, const OpReqType req, const index_t max_seq_len, const index_t batch_size, const index_t other_dim, const index_t numel, - const DType *const indices) { + const IType *const indices) { for (index_t batch = 0; batch < batch_size; ++batch) { const index_t num_seq = indices ? static_cast(indices[batch]) : max_seq_len; @@ -102,13 +102,13 @@ struct ReverseKernel { } }; -template +template class SequenceReverseOp : public Operator { public: explicit SequenceReverseOp(SequenceReverseParam p) { this->param_ = p; } void sequence_reverse(const mshadow::Tensor &data, const mshadow::Tensor &out, - const OpReqType req, const DType *const indices, + const OpReqType req, const IType *const indices, mshadow::Stream *const s) { using namespace mshadow; using namespace mshadow::expr; @@ -145,9 +145,9 @@ class SequenceReverseOp : public Operator { Tensor out = out_data[seq_reverse::kOut].get_with_shape(s3, s); - const DType *const indices = + const IType *const indices = param_.use_sequence_length - ? in_data[seq_reverse::kSequenceLength].dptr() + ? in_data[seq_reverse::kSequenceLength].dptr() : nullptr; sequence_reverse(data, out, req[seq_reverse::kOut], indices, s); @@ -179,9 +179,9 @@ class SequenceReverseOp : public Operator { Tensor output_grad = out_grad[seq_reverse::kOut].get_with_shape(s3, s); - const DType *const indices = + const IType *const indices = param_.use_sequence_length - ? in_data[seq_reverse::kSequenceLength].dptr() + ? in_data[seq_reverse::kSequenceLength].dptr() : nullptr; sequence_reverse(output_grad, data_grad, req[seq_reverse::kData], indices, @@ -193,7 +193,7 @@ class SequenceReverseOp : public Operator { }; // class SequenceReverseOp template -Operator *CreateOp(SequenceReverseParam param, int dtype); +Operator *CreateOp(SequenceReverseParam param, int dtype, int itype); #if DMLC_USE_CXX11 class SequenceReverseProp : public OperatorProperty { @@ -249,8 +249,6 @@ class SequenceReverseProp : public OperatorProperty { for (size_t i = 0; i < in_type->size(); ++i) { if ((*in_type)[i] == -1) { (*in_type)[i] = dtype; - } else { - UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments()[i]); } } out_type->clear(); diff --git a/src/operator/sequence_reverse.cc b/src/operator/sequence_reverse.cc index 21cab7891101..9225b6b5dae2 100644 --- a/src/operator/sequence_reverse.cc +++ b/src/operator/sequence_reverse.cc @@ -28,10 +28,13 @@ namespace mxnet { namespace op { template <> -Operator *CreateOp(SequenceReverseParam param, int dtype) { +Operator *CreateOp(SequenceReverseParam param, int dtype, int itype) { Operator *op = nullptr; - MSHADOW_TYPE_SWITCH(dtype, DType, - { op = new SequenceReverseOp(param); }) + MSHADOW_TYPE_SWITCH(dtype, DType, { + MSHADOW_TYPE_SWITCH(itype, IType, { + op = new SequenceReverseOp(param); + }); + }); return op; } @@ -39,7 +42,13 @@ Operator *CreateOp(SequenceReverseParam param, int dtype) { Operator *SequenceReverseProp::CreateOperatorEx( Context ctx, std::vector *in_shape, std::vector *in_type) const { - DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0]); + + if (in_type->size() >= 2 && (*in_type)[1] != -1) { + DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0], (*in_type)[1]); + } + + // sequence_length not passed in, so fall back to using input array dtype for second argument + DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0], (*in_type)[0]); } DMLC_REGISTER_PARAMETER(SequenceReverseParam); diff --git a/src/operator/sequence_reverse.cu b/src/operator/sequence_reverse.cu index 1edc9c13d493..db5b416b32d7 100644 --- a/src/operator/sequence_reverse.cu +++ b/src/operator/sequence_reverse.cu @@ -28,11 +28,13 @@ namespace mxnet { namespace op { -template <> Operator *CreateOp(SequenceReverseParam param, int dtype) { +template <> Operator *CreateOp(SequenceReverseParam param, int dtype, int itype) { Operator *op = nullptr; MSHADOW_TYPE_SWITCH(dtype, DType, { - op = new SequenceReverseOp(param); - }) + MSHADOW_TYPE_SWITCH(itype, IType, { + op = new SequenceReverseOp(param); + }); + }); return op; } diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 3f34ade448dc..670cc7eb15e0 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -3292,36 +3292,38 @@ def check_sequence_func(ftype, mask_value=0, axis=0): L = mx.symbol.Variable('L') # lengths shapes = [(3, 4), (1, 1), (3, 4, 3, 1, 1)] for seqlenQ in [True, False]: - for s in shapes: - x = mx.random.uniform(-1, 1, s, ctx=mx.cpu()).copyto(xpu) - batch = s[1] if (axis == 0) else s[0] - seqlen = s[axis] - l_np = np.random.randint(1, seqlen + 1, batch) - l = mx.nd.array(l_np, ctx=mx.cpu()).copyto(xpu) - if not seqlenQ: - l_np = None - args = {'data':X, 'use_sequence_length':seqlenQ, "axis":axis} - if seqlenQ: - args['sequence_length'] = L - if ftype == "last": - Y = mx.symbol.SequenceLast(**args) - np_out = sequence_last_numpy(x.asnumpy(), l_np, axis) - elif ftype == "mask": - args['value'] = mask_value - Y = mx.symbol.SequenceMask(**args) - np_out = sequence_mask_numpy(x.asnumpy(), l_np, axis, mask_value) - elif ftype == "reverse": - Y = mx.symbol.SequenceReverse(**args) - np_out = sequence_reverse_numpy(x.asnumpy(), l_np, axis) - fargs = [x, l] if seqlenQ else [x] - gargs = [x.asnumpy(), l_np] if seqlenQ else [x.asnumpy()] - check_symbolic_forward(Y, fargs, [np_out]) - check_numeric_gradient(Y, gargs, grad_nodes={'X':'write'}, - numeric_eps=1e-2, rtol=1e-2) - check_numeric_gradient(Y, gargs, grad_nodes={'X':'add'}, - numeric_eps=1e-3, rtol=1e-2, atol=1E-4) - check_numeric_gradient(Y, gargs, grad_nodes={'X':'null'}, - numeric_eps=1e-3, rtol=1e-2, atol=1E-4) + for ary_dtype in [np.float32]: + for idx_dtype in [np.int32, np.float32]: + for s in shapes: + x = mx.random.uniform(-1, 1, s, ctx=mx.cpu()).astype(ary_dtype).copyto(xpu) + batch = s[1] if (axis == 0) else s[0] + seqlen = s[axis] + l_np = np.random.randint(1, seqlen + 1, batch) + l = mx.nd.array(l_np, ctx=mx.cpu(), dtype=idx_dtype).copyto(xpu) + if not seqlenQ: + l_np = None + args = {'data':X, 'use_sequence_length':seqlenQ, "axis":axis} + if seqlenQ: + args['sequence_length'] = L + if ftype == "last": + Y = mx.symbol.SequenceLast(**args) + np_out = sequence_last_numpy(x.asnumpy(), l_np, axis) + elif ftype == "mask": + args['value'] = mask_value + Y = mx.symbol.SequenceMask(**args) + np_out = sequence_mask_numpy(x.asnumpy(), l_np, axis, mask_value) + elif ftype == "reverse": + Y = mx.symbol.SequenceReverse(**args) + np_out = sequence_reverse_numpy(x.asnumpy(), l_np, axis) + fargs = [x, l] if seqlenQ else [x] + gargs = [x.asnumpy(), l_np] if seqlenQ else [x.asnumpy()] + check_symbolic_forward(Y, fargs, [np_out], dtype="asnumpy") + check_numeric_gradient(Y, gargs, grad_nodes={'X':'write'}, + numeric_eps=1e-2, rtol=1e-2) + check_numeric_gradient(Y, gargs, grad_nodes={'X':'add'}, + numeric_eps=1e-3, rtol=1e-2, atol=1E-4) + check_numeric_gradient(Y, gargs, grad_nodes={'X':'null'}, + numeric_eps=1e-3, rtol=1e-2, atol=1E-4) @with_seed()