-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dot Product Computes Wrong Values #217
Comments
Hmm, this is interesting. Could you share a self-contained file that contains the content of |
@ptillet You can clone that repository and run the test via:
See README. Only requires Pytorch and Triton. |
@ptillet Any updates on this? I'm hoping to figure out what's missing. |
I made and pushed some fixes, and now I'm running into a strange situation where the first column of my output doesn't match, but the rest of the elements pass the test:
Replicate test on this commit:
There seems to be something wrong with Changing:
into
Now, I would expect the above matmul to just zero out the entire output, but the malformed first column is still there!
Seems like dot product took did not properly multiply the first element in the matrix, but multiplied the rest of them? Another strange behavior is that: |
@calclavia I'm really interested in this as well! Phil had told me that it wasn't possible through email, but maybe I didn't explain it correctly |
@lucidrains Great to see you're interested! I was hoping this could be a drop in replacement for your Performers repo to solve this issue (lucidrains/performer-pytorch#44) |
@calclavia yes, me too... i feel like the whole "linear" attention line of research is being held back by this |
@calclavia have you tried cuda-python yet? |
Hey~ Sorry for the delay. I'll be looking into this tonight or tomorrow. Since the |
@lucidrains I haven't tried CUDA Python. I'm not an expert at CUDA programming, hence Triton seems like a nice alternative that is easy to learn and can get the job done. A brief look at EPFL's CUDA seems to indicate they're using an |
@ptillet If it helps, I ran my tests on a NVIDIA RTX 3080 GPU, which is on the Ampere architecture. |
Hey! I looked into this and there are multiple issues at play:
FWIW I think causal attention will be much easier to implement when |
@ptillet Thanks for the analysis. So I'm assuming these are Triton compiler side fixes, and not a bug on my implementation of Triton. In terms of building a prefix sum primitive - how difficult do you think it would be to implement and what types of performance gains a primitive would likely yield? Ideally, wouldn't the compiler be able to "detect" these types of loop and optimize it to use a primitive behind the scenes? |
@calclavia It's a bit tricky. Our goal is actually to extend Triton-IR to support indexing operations, then prefix sum could just be implemented using an algorithm like https://en.wikipedia.org/wiki/Prefix_sum#Parallel_algorithms without having dedicated Triton-IR instructions. Then the compiler would look at all the indexing ops and add some shared memory barriers (and warp shuffles) accordingly to maximize performance without breaking correctness. Pattern matching across loop boundaries tends to be pretty hard and brittle. So in order to avoid that, Triton tries really hard to find abstractions that make sense :p It'll take a while before indexing ops are part of the ecosystem (I'm thinking maybe by Jan - Feb), but hopefully we'll get a prototype before then that will be enough for causal attention. Anyhow, the issues you're having are unrelated and I'll try to address them sooner than that :D |
Same wrong values for small blocks here! Is there any plan to fix such a bug? |
@ptillet I checked Triton v1.1 and I think this issue is still here (I pushed an update to my repo with the upgrade, so anyone can test it). Any updates on this? |
this should be fixed on the latest master branch, or the latest dev wheel |
@ptillet I've updated Triton but I'm getting segmentation fault on version Here's what I'm trying to run... Simple Pytorch function that does Causal Dot Product: Triton version (incomplete): Running my unit test
|
@ptillet I did more digging to see which part of the kernel causes the error. Added some comments. In particular, if I try to load another vector
|
I think this issue is related to #375 |
Hey! First of all, thanks for digging into this issue. Sorry for having been unresponsive, I've been quite busy with other matters, and I am well aware of the instabilities of triton 😅 I will look into this problem. FWIW, we are planning a rewrite of the backend that should greatly improve stability on these kinds of issues. |
Closing this, as the compiler now throws an error when the blocks are too small for |
GitHub has a limit of 10GB of cache per repository. Plus caches are scoped to a key and a branch, so a cache created on branch-A won’t be accessible by branch-B unless branch-A is the default branch. Since llvm-target is not the default branch, we cannot use GitHub cache directly.
First off - this is really cool library! I'm trying to implement the causal dot product kernel, but I'm running into some issues. It could be either a bug or my misunderstanding of the documentation.
Triton code: https://github.com/calclavia/Triton-Transformer/blob/master/ttx/attention/causal_product.py#L26
The above implements the algorithm from "Transformers are RNNs" paper (Algorithm 1) in Triton. In summary, I'm trying to batch-parallelize a for loop that computes a prefix sum. This is a simple O(n) implementation (not the more efficient O(log(n)) version).
The equivalent naive Pytorch implementation is here: https://github.com/calclavia/Triton-Transformer/blob/master/ttx/attention/causal_product.py#L6
When running a simple unit test, I'm getting very different values.
I tested a single dimensional vector and was unable to get matching values. The logic seems to be correct, but I suspect the issue is related to
tl.dot
. If anyone has insights, I would appreciate comments/feedback!The text was updated successfully, but these errors were encountered: