-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added oneDNN reduce_op GRAD kernel #32280
Conversation
Thanks for your contribution! |
@jczaja @arogowie-intel Could you please review this code? |
paddle/fluid/operators/reduce_ops/mkldnn/reduce_mean_mkldnn_op.cc
Outdated
Show resolved
Hide resolved
case 4: | ||
return mkldnn::memory::format_tag::abcd; | ||
default: | ||
return mkldnn::memory::format_tag::abcde; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why 5 dim tensor is a default case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have made a restriction in GetExpectedKernelType that dims must be in range <1,5>. I had to ensure the compiler that there always will be a return value from this function. I can delete the default statement and just leave the instruction outside switch block. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd add case 5:
and in default statement throw an error that invalid argument passed.
case 4: | ||
return mkldnn::memory::format_tag::abcd; | ||
default: | ||
return mkldnn::memory::format_tag::abcde; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd add case 5:
and in default statement throw an error that invalid argument passed.
void RunKernel(const framework::ExecutionContext& ctx, | ||
dnnl::algorithm binary_type, float scale_x, | ||
float scale_y) const { | ||
auto& dev_ctx = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since you're not modifying dev_ctx:
auto& dev_ctx = | |
const auto& dev_ctx = |
|
||
auto dims = ctx.Attr<std::vector<int>>("dim"); | ||
auto* input_dy = ctx.Input<Tensor>(framework::GradVarName("Out")); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove this blank line.
auto* output_dx = ctx.Output<Tensor>(framework::GradVarName("X")); | ||
|
||
output_dx->mutable_data<T>(ctx.GetPlace()); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
||
auto src_dx_memory = handler.AcquireSrcMemory(output_dx); | ||
const auto src_dy_memory = handler.AcquireSecondSrcMemory(input_dy); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -83,8 +71,8 @@ def setUp(self): | |||
} | |||
|
|||
|
|||
@skip_check_grad_ci(reason="not implemented") | |||
class TestReduceSum5DReduceAllKeepDimsONEDNNOp(TestReduceSumDefaultONEDNNOp): | |||
class TestReduceSum5DReduceAllKeepDimsONEDNNOp( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class TestReduceSum5DReduceAllKeepDimsONEDNNOp( | |
class TestReduceSum5DReduceAllKeepDimsOneDNNOp( |
@skip_check_grad_ci(reason="not implemented") | ||
class TestReduceSum5DReduceAllKeepDimsONEDNNOp(TestReduceSumDefaultONEDNNOp): | ||
class TestReduceSum5DReduceAllKeepDimsONEDNNOp( | ||
TestReduceDefaultWithGradONEDNNOp): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TestReduceDefaultWithGradONEDNNOp): | |
TestReduceDefaultWithGradOneDNNOp): |
@@ -95,8 +83,7 @@ def setUp(self): | |||
} | |||
|
|||
|
|||
@skip_check_grad_ci(reason="not implemented") | |||
class TestReduceSum4DReduceAllONEDNNOp(TestReduceSumDefaultONEDNNOp): | |||
class TestReduceSum4DReduceAllONEDNNOp(TestReduceDefaultWithGradONEDNNOp): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class TestReduceSum4DReduceAllONEDNNOp(TestReduceDefaultWithGradONEDNNOp): | |
class TestReduceSum4DReduceAllOneDNNOp(TestReduceDefaultWithGradOneDNNOp): |
@@ -154,8 +141,7 @@ def setUp(self): | |||
} | |||
|
|||
|
|||
@skip_check_grad_ci(reason="not implemented") | |||
class TestReduceMean3DONEDNNOp(TestReduceSumDefaultONEDNNOp): | |||
class TestReduceMean3DONEDNNOp(TestReduceDefaultWithGradONEDNNOp): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class TestReduceMean3DONEDNNOp(TestReduceDefaultWithGradONEDNNOp): | |
class TestReduceMean3DOneDNNOp(TestReduceDefaultWithGradOneDNNOp): |
@@ -166,8 +152,7 @@ def setUp(self): | |||
} | |||
|
|||
|
|||
@skip_check_grad_ci(reason="not implemented") | |||
class TestReduceMean4DReduceAllONEDNNOp(TestReduceSumDefaultONEDNNOp): | |||
class TestReduceMean4DReduceAllONEDNNOp(TestReduceDefaultWithGradONEDNNOp): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class TestReduceMean4DReduceAllONEDNNOp(TestReduceDefaultWithGradONEDNNOp): | |
class TestReduceMean4DReduceAllOneDNNOp(TestReduceDefaultWithGradOneDNNOp): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@arogowie-intel I have implemented all your suggestions except one. Could you please re-review? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good job!
@luotao1 Could you please start your review? |
PR types
New features
PR changes
OPs
Describe
Added oneDNN reduce_op GRAD fp32 and bf16 kernels (reduce_sum, reduce_mean) for enabling Word2Vec model.
Forward operator PR link: Reduce FWD