-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added reshape+transpose+matmul_v2 fuse pass #36759
Added reshape+transpose+matmul_v2 fuse pass #36759
Conversation
Thanks for your contribution! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you very much for your contribution!
paddle/fluid/operators/matmul_op.cc
Outdated
framework::DDim GetDimForInput(const framework::InferShapeContext &ctx, | ||
std::string input_name) { | ||
static framework::DDim GetDimForInput(const framework::InferShapeContext &ctx, | ||
std::string input_name) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
std::string input_name) { | |
const std::string& input_name) { |
That might not be performance-critical, but it still may save us some time
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AFAIK GetDimForInput is always called with const char*.
I think that the string construction happens in place here based on char ptr so there is no copying of string but passing a char*. If so, then it's not necessarily slower. I will use a char argument instead
and create a string based on it inside the function explicitly. It will show better the intent behind it. I will also validate if it's exactly either 'X' or 'Y'.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That sounds nice, thank you!
@@ -19,6 +19,36 @@ | |||
namespace paddle { | |||
namespace operators { | |||
|
|||
static framework::DDim GetDimForInput(const framework::InferShapeContext& ctx, | |||
std::string input_name) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same reply as above
using paddle::framework::DDim; | ||
|
||
static DDim GetDimForInput(const ExecutionContext& ctx, | ||
std::string input_name) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same reply as above
} | ||
|
||
std::vector<int64_t> GetInputStrides(const ExecutionContext& ctx, | ||
std::string input_name) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was actually using input_name as read/write variable so I liked the copy/constructed object from const char*.
I will change it to char instead as mentioned in above comment.
if (!trans_x) { | ||
x_strides.insert(x_strides.end(), {M * K, K, 1}); | ||
if (!strides_x.empty()) { | ||
x_strides = strides_x; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please change the names of these variables? Since x_strides and strides_x sounds like it's exactly the same
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok
if (!trans_y) { | ||
y_strides.insert(y_strides.end(), {N * K, N, 1}); | ||
if (!strides_y.empty()) { | ||
y_strides = strides_y; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as with x_strides and strides_x
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
TestReshapeTransposeMatMulOp3DYFloat): | ||
def set_op_type_and_transpose_y_name(self): | ||
self.op_type = "matmul_v2" | ||
self.transpose_y_name = "trans_y" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why only transpose_y is tested?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe it is because the only use case for bert like models had transposition of y so it was prepared for that only.
return new_x; | ||
} | ||
|
||
std::vector<int64_t> GetInputStrides(const ExecutionContext& ctx, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
std::vector<int64_t> GetInputStrides(const ExecutionContext& ctx, | |
static std::vector<int64_t> GetInputStrides(const ExecutionContext& ctx, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since matmul and matmul_v2 are sharing these functions, maybe it would be nice to include just the signature in "matmul_mkldnn_op.h", so we won't need to have two copies of exactly the same function, what do you think about that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can move some function declarations there but some function signatures are different - those I won't move unless we figure out something else.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great job, only have one question
return new_x; | ||
} | ||
|
||
static framework::DDim GetDimForInput(const framework::InferShapeContext& ctx, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a similar question that Jakub drew attention to. Couldn't we dump this function to matmul_mkldnn_op.h to implement mkldnn? I can see that for matmul_mkldnn_op.cc, matmul_v2_mkldnn_op.cc we have different contexts: const ExecutionContext & ctx
and const framework::InferShapeContext & ctx
, do you know why they are different?
Actually, this function has the same logic for matmul_op.cc, matmul_v2_op.cc, matmul_mkldnn_op.cc, matmul_v2_mkldnn_op.cc, which gives us four copies of this code. So maybe at least for mkldnn we can optimize it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried moving functions to header by it doesn't build on CI saying that there's duplicate definition and other problems. I cannot reproduce it on my machine, maybe when I got cuda/cudnn installed and configured then I could try to reproduce that error. Otherwise testing by pushing to a PR is not a good way for debugging this. There are hard to understand include patterns for matul matmul v2 and mkldnn counterparts. Some functions are the same in matmul_xpu functions also. I couldn't figure it out given the time I had. I would suggest trying refactoring later in a separate PR depending on priorities. I think we could re-think the namespaces in those files.
They have different context I needed to use the same function in a place where I had access to found different context variable so I made it this way.
@sfraczek Hi, more CIs failed, do you know why ? |
Hi, yes I have reproduction thanks to Joanna and I'm working on it. The problem is with WITH_UNITY_BUILD |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thank you very much !
@sfraczek CheckPRTemplate passed. I will try to ask for approval now |
This PR will need two approvals.
|
on fp32 model by running cpu_infer.py from #36962 with commented out passes
Built with RelWithDebInfo From log of fuses: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, really good work :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Sorry to inform you that fbcf847's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
38d3971
@baoachun Please send more models with broadcasting. Thanks |
@sfraczek Baidu require to split this PR |
Sorry to inform you that 38d3971's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually. |
opened new version #37847 |
PR types
Performance optimization
PR changes
Others
Describe
This is reshape+transpose+matmul_v2 fuse pass which is based on previous identical fuse for matmul_v1: #23754, this fuse will speedup bert-like models.