-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Conversation
@haojin2 Can you look into the CI failures on this ? @mxnet-label-bot Add [Operator, pr-awaiting-review] |
1e400f0
to
1ad0561
Compare
06f25aa
to
1abd460
Compare
1abd460
to
1edf38f
Compare
1edf38f
to
95394a7
Compare
@szha @eric-haibin-lin Finally the CI passed... Please give a review when you have time. |
} | ||
} else { | ||
if (shape.ndim() == 2) { | ||
SoftmaxGrad<OP1, OP2, Req, negate, DType>( | ||
MXNET_INT_TYPE_SWITCH(inputs[2].type_flag_, IType, { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we really need to iterate different int types for length? Can we just cast the type to int64_t.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It may not necessarily be int64_t, and what do you mean by "cast the type to int64_t"? Allocating a new buffer on the fly within the operator to hold the casted length
input? I would consider that a big performance bottleneck.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It does not necessarily be int64_t
. I think you can use the same data type as M, which is index_t
I think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand length is a tensor here. I guess my question is do we really need to care what the dtype of the values from length tensor and use a MXNET_INT_TYPE_SWITCH
macro here to iterate. Can we just simply cast them to index_t
regardless of the dtype of length
tensor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- There's no iteration at all,
TYPE_SWITCH
is always a switch, not a loop. I would suggest that you read the code more carefully. - No matter you're doing a cast to whatever type within the kernel or not will not get rid of the
TYPE_SWITCH
at all. One way of doing a constant cast toindex_t
is within the computation kernel, like:
template<typename IType>
MSHADOW_XINLINE static void Map(int i, IType *buffer) {
index_t val = static_cast<index_t>(buffer[i]);
}
And this would still require a TYPE_SWITCH
when you're launching the kernel since this kernel is still templated and the input type is not limited.
Or if you really hate the TYPE_SWITCH
for the compute kernel, you then have to cast your input buffer first:
Tensor<xpu, 1, index_t> index_t_buffer = ctx.requested[0].get_space_typed<xpu, 1, index_t>(<some shape>);
// Cast your buffer to index_t_buffer
// Launch the compute kernel with index_t_buffer, now you only need one TYPE_SWITCH for your input data type.
This I would really hate cause it's not bringing any benefit at all.
Or you only support length
buffer of type int64
, then you still need to use MSHADOW_IDX_TYPE_SWITCH
to provide the error message when length
is of other data types. Or you could insert an additional check on the data type before the kernel launches, but that would drop support for other integer types.
3. The fact that length
is a tensor means that length=M
is not an option, maybe you want to say length=nullptr
? That would cause additional if branches within the kernel and will especially impact performance on GPU. The way this feature is implemented now puts that if branch on the host rather than within the kernel, which will have minimal performance impact.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think your communication was not clear enough in the first place, I've now added cast to index_t
in some existing code too, are you now comfortable with the changes now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for making the changes. Sorry, if my earlier comments caused you confusion.
Looking at the current code, I see the SoftmaxWithLength()
and Softmax()
are very similar now. (differ by one if block https://github.com/apache/incubator-mxnet/pull/15169/files#diff-be02d7c5660bf5cd623601a501fc7abeR134). Do you think we can combine this two functions into one now? :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, because you are only looking at the CPU version of the code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you referring to these two kernel functions?
https://github.com/apache/incubator-mxnet/blob/50c1c7ca7f2e6a7864e7a0aedc775ae7dd8be091/src/operator/nn/softmax-inl.h#L283
https://github.com/apache/incubator-mxnet/blob/50c1c7ca7f2e6a7864e7a0aedc775ae7dd8be091/src/operator/nn/softmax-inl.h#L339
They also only differ by a few lines. Would you think a better way to consolidate them?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a separate PR #15545 to optimize softmax GPU implementation, I can look into merging those 2 kernels in that PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we consolidate the SoftmaxWithLength function with Softmax using a default length? Having two copies of the function to serve one extra argument seems an overkill design.
@apeforest Trust me or not, I've tried what you're suggesting to do on my end way before this PR was raised but found it not easily done without affecting the existing softmax's performance by much. I think I'll stick to this version for now. |
If you have tried earlier, can you post some performance degradation results by trying the other approach? It may help to decide if we want to make some trade off between performance and code simplicity here. |
c5e5363
to
50c1c7c
Compare
50c1c7c
to
b940812
Compare
b940812
to
50d8ee7
Compare
This PR is more like patching an existing function to implement a feature which is not generally needed. From software engineering practice, this PR introduced unnecessary redundancy in the code and the performance impact of the alternative approach is not well measured. As we discussed offline, if this feature is urgently needed and the code has to be patched in such a way to maintain backward compatibility, it might be okay to ship as it is. However, we should add a TODO in MXNet 2.0 to rewrite softmax from scratch taking into consideration of all the required scenarios and how to make it extensible. @szha for review and final approve. |
50d8ee7
to
8d1fc65
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approved. Softmax with variable length is very common for a number of application and should be worth the performance tradeoff with complexity
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approve this PR to unblock the needed feature.
#15545 introduce a new softmax kernel and may need to be copied for this function as well.
Thanks all for the approval, merging this now. @ptrendx @szha @apeforest |
@apeforest did you add the TODO for 2.0 somewhere? #15169 (comment) |
Sorry for missing the review. Do we need to change MKL-DNN softmax to accommodate the change? |
* softmax with length forward * softmax with length backward * new macro to reduce compile-time heap usage and limit length to integers only * address comments
@@ -92,6 +101,13 @@ Example:: | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No documentation for example inputs with length?
Description
Softmax with extra length input to specify line length of input on an axis.
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes
Comments
Flakiness Check:
Benchmark results:
CPU: ~1.82x speedup
GPU: ~1.59x speedup
Benchmark script: