Skip to content
This repository has been archived by the owner on Aug 11, 2020. It is now read-only.

fix mkl batch gemm #343

Merged
merged 7 commits into from
Mar 6, 2019
Merged

fix mkl batch gemm #343

merged 7 commits into from
Mar 6, 2019

Conversation

TaoLv
Copy link
Member

@TaoLv TaoLv commented Jun 26, 2018

Since the same m/n/k is used for all single gemms, so we can put all these gemms into one group of mkl batch gemm.

@yajiedesign @piiswrong please review again.

@TaoLv
Copy link
Member Author

TaoLv commented Jun 26, 2018

@sxjscience

@sxjscience
Copy link
Member

sxjscience commented Jun 26, 2018 via email

@TaoLv
Copy link
Member Author

TaoLv commented Jun 26, 2018

@sxjscience I think performance should be same and I have verified that from mxnet level. The main purpose of this PR is for code refine and reducing some memory usage.

@sxjscience
Copy link
Member

sxjscience commented Jun 26, 2018 via email

@xinyu-intel
Copy link
Member

This commit may cause perf regression and we are working on it. Please do not merge now. Thanks!

@TaoLv
Copy link
Member Author

TaoLv commented Jul 18, 2018

Talked with @xinyu-intel offline. His performance regression was caused by other changes and not related to this PR. Code change here has passed MXNet operator unit test: https://github.com/apache/incubator-mxnet/blob/master/tests/python/unittest/test_operator.py#L2576.

I also used the following code to verify the performance change and got similar results before and after this PR:

import mxnet as mx
import numpy as np
import time

x = mx.sym.Variable('x')
y = mx.sym.Variable('y')
input_x = np.random.rand(200, 256, 256)
input_y = np.random.rand(200, 256, 512)

sym = mx.symbol.batch_dot(x, y).bind(mx.cpu(), {'x': mx.nd.array(input_x), 'y': mx.nd.array(input_y)})

start = time.time()
for i in range(1010):
    if i == 10:
        start = time.time()
    sym.forward(is_train=False)[0].asnumpy()

print((time.time() - start))

Is it okay to merge? @sxjscience @piiswrong

1, p_group_sizeb.data());
cblas_sgemm_batch(CblasColMajor, p_transa, p_transb,
p_m, p_n, p_k, p_alpha, pp_A.data(), p_lda, pp_B.data(),
p_ldb, p_beta, pp_C.data(), p_ldc, 1, p_group_sizeb);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not simply using &m, &n, &k for p_m, p_n, p_k?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To make the code easier to understand. MKL_INT p_m[1] = {m}; means this batched GEMM only has one group and all GEMMs in this group share the same m value. Maybe in the future, we can extend this to MKL_INT p_m[2] = {m1, m2};. Then we have two groups in one batched GEMM and the first group has m1 while the second group has m2. Using &m in this API will hide this definition and make it a little confusing.

@TaoLv
Copy link
Member Author

TaoLv commented Dec 12, 2018

@sxjscience @piiswrong @eric-haibin-lin is this good to merge?

@pengzhao-intel
Copy link

@TaoLv could you rebase the code and run CI again?

@eric-haibin-lin
Copy link
Member

Is there reference PR in MXNet to test the mshadow change end2end?

@szha szha merged commit c9d2f01 into dmlc:master Mar 6, 2019
@TaoLv TaoLv mentioned this pull request Mar 11, 2019
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants