-
Notifications
You must be signed in to change notification settings - Fork 6.8k
MXNET-1295 Adding integer index support to Sequence* family of operators. #13880
Conversation
Adding ability to use int32 arrays, or any castable-to-int type, as the sequence_length array to SequenceMask, SequenceLast, and SequenceReverse. Previously these operaters all requred sequence_length to be the same data type as the input array. See MxNet Jira ticket here: https://issues.apache.org/jira/browse/MXNET-1295 See also GitHub issues here: apache#12649 dmlc/gluon-nlp#346
@mxnet-label-bot add [pr-awaiting-review] |
Thanks for contributing this, @stephenrawls. Could you add some test cases for using different input types? (tests/python/unittest/test_operator.py) |
@szha Okay I added unit tests. As I alluded to in my original PR, it wasn't exactly clear the best way to do this, because the testing code seems to have a lot of assumptions that all argument types will be the same type, and that the only possible argument types are floats. For example, there are Also, the unit tests are relying on test_utils.py to convert all input arrays to a single input type, e.g. here is an example:
I ended up modifying the logic so that I could control the dtype of each array independently by just creating the original numpy array to be the data type that I want, and telling the code to keep the numpy dtype when converting to mxnet.ndarray:
If you want to use the new behavior, you have to opt-in and pass "asnumpy" as the dtype. Take a look, and if you prefer a different approach just let me know. Best, |
@stephenrawls thanks! the approach sounds reasonable. |
@stephenrawls thanks for the contribution! |
…ors. (apache#13880) * Adding integer index support to Sequence* family of operators. Adding ability to use int32 arrays, or any castable-to-int type, as the sequence_length array to SequenceMask, SequenceLast, and SequenceReverse. Previously these operaters all requred sequence_length to be the same data type as the input array. See MxNet Jira ticket here: https://issues.apache.org/jira/browse/MXNET-1295 See also GitHub issues here: apache#12649 dmlc/gluon-nlp#346 * Adding explicit braces to an if statement to fix g++ warning * fixing sequence_mask.cu by adding IType to template * Fixing whitespace errors reported by linter * Adding unit tests * Fixing length of lines to pass linter
…ors. (apache#13880) * Adding integer index support to Sequence* family of operators. Adding ability to use int32 arrays, or any castable-to-int type, as the sequence_length array to SequenceMask, SequenceLast, and SequenceReverse. Previously these operaters all requred sequence_length to be the same data type as the input array. See MxNet Jira ticket here: https://issues.apache.org/jira/browse/MXNET-1295 See also GitHub issues here: apache#12649 dmlc/gluon-nlp#346 * Adding explicit braces to an if statement to fix g++ warning * fixing sequence_mask.cu by adding IType to template * Fixing whitespace errors reported by linter * Adding unit tests * Fixing length of lines to pass linter
…ors. (apache#13880) * Adding integer index support to Sequence* family of operators. Adding ability to use int32 arrays, or any castable-to-int type, as the sequence_length array to SequenceMask, SequenceLast, and SequenceReverse. Previously these operaters all requred sequence_length to be the same data type as the input array. See MxNet Jira ticket here: https://issues.apache.org/jira/browse/MXNET-1295 See also GitHub issues here: apache#12649 dmlc/gluon-nlp#346 * Adding explicit braces to an if statement to fix g++ warning * fixing sequence_mask.cu by adding IType to template * Fixing whitespace errors reported by linter * Adding unit tests * Fixing length of lines to pass linter
Description
Adding ability to use int32 arrays, or any castable-to-int type, as
the sequence_length array to SequenceMask, SequenceLast, and
SequenceReverse. Previously these operaters all requred sequence_length
to be the same data type as the input array.
See MxNet Jira ticket here:
https://issues.apache.org/jira/browse/MXNET-1295
See also GitHub issues here:
#12649
dmlc/gluon-nlp#346
Test Coverage
I have tested this works by building locally and running all the Sequence* operators in python with a float32 input array and an int32 sequence_length array. I confirmed things work correct.
However when I looked at the unit tests it was not immediately clear how to add an appropriate test, because the current tests all use the
check_symbolic_forward()
function, which takes a single dtype for all arguments and doesn't allow clients to specify different dtypes for different input arguments.Checklist
Essentials
Please feel free to remove inapplicable items for your PR.