Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Missing broadcast_to before batch_matmul for CuBLAS #7730

Closed
comaniac opened this issue Mar 24, 2021 · 6 comments · Fixed by #8229
Closed

[Bug] Missing broadcast_to before batch_matmul for CuBLAS #7730

comaniac opened this issue Mar 24, 2021 · 6 comments · Fixed by #8229

Comments

@comaniac
Copy link
Contributor

comaniac commented Mar 24, 2021

The PR #7348 removes broadcast_to before batch_matmul because batch_matmul already supported implicitly broadcast. However, the CuBLAS implementation isn't changed accordingly, which results in the failure of the following case:

import numpy as np

import tvm
from tvm import relay
from tvm.contrib import graph_runtime

sa = (4, 128, 768)
sb = (1, 768, 768)

a = relay.var("a", shape=sa)
b = relay.var("b", shape=sb)
c = relay.nn.batch_matmul(a, b)
f = relay.Function([a, b], c)
mod = tvm.ir.IRModule.from_expr(f)
mod = relay.transform.InferType()(mod)

with tvm.transform.PassContext(opt_level=3):
    lib = relay.build(mod, target="cuda") # change target to "cuda -libs=cublas" will fail

ctx = tvm.gpu(0)
m = graph_runtime.GraphModule(lib["default"](ctx))
p = np.random.uniform(0, 1, sa)
q = np.random.uniform(0 ,1, sb)
m.set_input("a", p)
m.set_input("b", q)

ftimer = m.module.time_evaluator("run", ctx, number=1, repeat=10)
prof_res = np.array(ftimer().results) * 1000
print(np.mean(prof_res))

I guess we need to either add the broadcast_to back or support implicitly broadcasting in CuBLAS implementation.

cc @masahi @jwfromm

@masahi
Copy link
Member

masahi commented Mar 24, 2021

I see, this is the same issue raised by @csullivan in #6616 (review)

What was the solution to this problem? @jwfromm @csullivan

@csullivan
Copy link
Contributor

csullivan commented Mar 24, 2021

Thanks @comaniac @masahi. Yes the problem is that different targets, and target specific topi implementations, can support different optimizations. In the case of using the blas libraries supported for a target, implicit broadcast is not supported.

One option that comes to mind is to add a shape legalization pass that adds the broadcast if a target has specific attributes (e.g. libs=cublas/rocblas etc). However this isn't sufficient; depending on the op strategy priorities or the applied tuning configs, it's possible that the blas library implementation won't be used. A better option could be to make use of #7518, and do the shape legalization after the primitive functions have been lowered to TIR and can be inspected.

We could also disable implicit broadcast, but that can increase the memory use (from folding the constant broadcasts) which we've seen overflow device memory for larger batch sizes.

@comaniac
Copy link
Contributor Author

Another direction I can think of is adding the broadcast support in CuBLAS batch_matmul so that we could have a unified behavior of batch_matmul op in Relay, and we don't need to change anything else. Do you think that's reasonable and doable?

@csullivan
Copy link
Contributor

Reasonable and doable for the short term. The downside being that it only fixes the problem for one target at a time. We would also need to add broadcast support to RocBLAS and CBLAS/MKL to avoid the issue for those targets.

@tqchen
Copy link
Member

tqchen commented Jun 4, 2021

gentle ping @comaniac to see if you get a chance to followup on this issue

@comaniac
Copy link
Contributor Author

comaniac commented Jun 4, 2021

While @csullivan proposed a long term solution to resolve the implementation difference between targets, this issue on CUDA has been workaround in PyTorch frontend in the PR mentioned above. Specifically, now if either one of the two inputs of matmul is 2D, then PyTorch frontend reshapes the 3D tensor to 2D and uses dense instead of expanding the 2D tensor to 3D and using batch_matmul. Meanwhile, other frontends may still have this issue, so I'll see if I can get time to file a PR to fix the CuBLAS issue next week.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants