Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dot Product Computes Wrong Values #217

Closed
calclavia opened this issue Aug 17, 2021 · 23 comments
Closed

Dot Product Computes Wrong Values #217

calclavia opened this issue Aug 17, 2021 · 23 comments

Comments

@calclavia
Copy link

First off - this is really cool library! I'm trying to implement the causal dot product kernel, but I'm running into some issues. It could be either a bug or my misunderstanding of the documentation.

Triton code: https://github.com/calclavia/Triton-Transformer/blob/master/ttx/attention/causal_product.py#L26

The above implements the algorithm from "Transformers are RNNs" paper (Algorithm 1) in Triton. In summary, I'm trying to batch-parallelize a for loop that computes a prefix sum. This is a simple O(n) implementation (not the more efficient O(log(n)) version).

The equivalent naive Pytorch implementation is here: https://github.com/calclavia/Triton-Transformer/blob/master/ttx/attention/causal_product.py#L6

When running a simple unit test, I'm getting very different values.

ref_output tensor([[[[0.0157, 0.6251, 0.6993],
          [0.7910, 2.1930, 2.1000],
          [0.7413, 1.7217, 1.4139],
          [0.8863, 1.4795, 1.5222],
          [1.2453, 2.5844, 1.9665]]]])
triton_output tensor([[[0.4268, 0.6251, 0.6993],
         [2.4132, 0.6186, 0.3389],
         [0.8975, 0.4929, 0.2288],
         [0.8470, 0.0080, 0.3058],
         [1.1330, 0.7073, 0.0776]]])

I tested a single dimensional vector and was unable to get matching values. The logic seems to be correct, but I suspect the issue is related to tl.dot. If anyone has insights, I would appreciate comments/feedback!

@ptillet
Copy link
Collaborator

ptillet commented Aug 17, 2021

Hmm, this is interesting. Could you share a self-contained file that contains the content of causal_product_naive plus your failing unit test? Thanks!

@calclavia
Copy link
Author

calclavia commented Aug 18, 2021

@ptillet You can clone that repository and run the test via:

python -m unittest

See README. Only requires Pytorch and Triton.
The repo currently contains nothing but causal dot product.

@calclavia
Copy link
Author

@ptillet Any updates on this? I'm hoping to figure out what's missing.

@ptillet
Copy link
Collaborator

ptillet commented Aug 22, 2021

Hey! Sorry I haven't had time to look into it. I've been busy working on older issues (#170 and #176 ) that require quite a bit of work

@calclavia
Copy link
Author

calclavia commented Aug 23, 2021

I made and pushed some fixes, and now I'm running into a strange situation where the first column of my output doesn't match, but the rest of the elements pass the test:

ref_output tensor([[[[-1.8713, -3.2907,  2.6836,  1.0433]]]], device='cuda:0')
triton_output tensor([[[ 7.9835, -3.2907,  2.6836,  1.0433]]], device='cuda:0')

Replicate test on this commit:

python -m unittest tests.test_causal_dotprod.TestCausalDotProduct.test_causal_product_fwd_triton

There seems to be something wrong with tl.dot in terms of not working with the first column, and it's hard to debug why. I tried multiplying the dot product operands by constants to see where the cause was:

Changing:

output = tl.dot(q[None, :], state)

into

output = tl.dot(q[None, :] * 100, state * 0)

Now, I would expect the above matmul to just zero out the entire output, but the malformed first column is still there!

ref_output tensor([[[[-1.8713, -3.2907,  2.6836,  1.0433]]]], device='cuda:0')
triton_output tensor([[[79834.8906,     0.0000,     0.0000,     0.0000]]], device='cuda:0')

Seems like dot product took did not properly multiply the first element in the matrix, but multiplied the rest of them?

Another strange behavior is that:
My kernel only "works" to the current extent if num_warps=1. If I don't set num_warps, it would output all zeros.

@lucidrains
Copy link
Contributor

@calclavia I'm really interested in this as well! Phil had told me that it wasn't possible through email, but maybe I didn't explain it correctly

@calclavia
Copy link
Author

@lucidrains Great to see you're interested! I was hoping this could be a drop in replacement for your Performers repo to solve this issue (lucidrains/performer-pytorch#44)

@lucidrains
Copy link
Contributor

@calclavia yes, me too... i feel like the whole "linear" attention line of research is being held back by this

@lucidrains
Copy link
Contributor

@calclavia have you tried cuda-python yet?

@ptillet
Copy link
Collaborator

ptillet commented Sep 1, 2021

Hey~ Sorry for the delay. I'll be looking into this tonight or tomorrow. Since the dot unit tests pass, I wonder what is causing the results to be incorerct in this case. Maybe some interaction between broadcasting and tensor cores.

@calclavia
Copy link
Author

calclavia commented Sep 2, 2021

@lucidrains I haven't tried CUDA Python. I'm not an expert at CUDA programming, hence Triton seems like a nice alternative that is easy to learn and can get the job done.

A brief look at EPFL's CUDA seems to indicate they're using an O(n) implementation in their fallback code: https://github.com/idiap/fast-transformers/blob/master/fast_transformers/causal_product/causal_product_cuda.cu#L1204 but the optimized version is too complex for me to understand. It should be possible to implement the O(log(n)) version on Triton, it'll just take some time to figure out how to port the prefix sum algorithm described by NVIDIA.

@calclavia
Copy link
Author

@ptillet If it helps, I ran my tests on a NVIDIA RTX 3080 GPU, which is on the Ampere architecture.

@ptillet
Copy link
Collaborator

ptillet commented Sep 2, 2021

Hey! I looked into this and there are multiple issues at play:

  • I think what you ran into is an issue for dot products when the blocks are too small. It's quite tricky to get the thread partitioning right when the compiler needs to distribute 16 elements over hundreds of threads. This is one of the areas where the compiler is buggy at the moment.
  • It seems like with the most recent dev version, your causal attention makes the shared memory allocator hang. So that's another issue that I'll have to look into

FWIW I think causal attention will be much easier to implement when triton.language contains prefix sum primitives.

@calclavia
Copy link
Author

@ptillet Thanks for the analysis. So I'm assuming these are Triton compiler side fixes, and not a bug on my implementation of Triton.

In terms of building a prefix sum primitive - how difficult do you think it would be to implement and what types of performance gains a primitive would likely yield? Ideally, wouldn't the compiler be able to "detect" these types of loop and optimize it to use a primitive behind the scenes?

@ptillet
Copy link
Collaborator

ptillet commented Sep 3, 2021

@calclavia It's a bit tricky. Our goal is actually to extend Triton-IR to support indexing operations, then prefix sum could just be implemented using an algorithm like https://en.wikipedia.org/wiki/Prefix_sum#Parallel_algorithms without having dedicated Triton-IR instructions. Then the compiler would look at all the indexing ops and add some shared memory barriers (and warp shuffles) accordingly to maximize performance without breaking correctness.

Pattern matching across loop boundaries tends to be pretty hard and brittle. So in order to avoid that, Triton tries really hard to find abstractions that make sense :p It'll take a while before indexing ops are part of the ecosystem (I'm thinking maybe by Jan - Feb), but hopefully we'll get a prototype before then that will be enough for causal attention.

Anyhow, the issues you're having are unrelated and I'll try to address them sooner than that :D

@tdzdog
Copy link

tdzdog commented Sep 30, 2021

Same wrong values for small blocks here! Is there any plan to fix such a bug?

@calclavia
Copy link
Author

calclavia commented Feb 16, 2022

@ptillet I checked Triton v1.1 and I think this issue is still here (I pushed an update to my repo with the upgrade, so anyone can test it). Any updates on this?

@ptillet
Copy link
Collaborator

ptillet commented Feb 16, 2022

this should be fixed on the latest master branch, or the latest dev wheel

@calclavia
Copy link
Author

calclavia commented Feb 20, 2022

@ptillet I've updated Triton but I'm getting segmentation fault on version triton 2.0.0.dev20220211

Here's what I'm trying to run...

Simple Pytorch function that does Causal Dot Product:
https://github.com/calclavia/Triton-Transformer/blob/lrn/ttx/lrn/torch.py

Triton version (incomplete):
https://github.com/calclavia/Triton-Transformer/blob/lrn/ttx/lrn/triton_2d.py

Running my unit test python -m unittest tests.test_lrn yields the following error (which I'm not sure what it means):

Traceback (most recent call last):
  File "/workspace/tests/test_lrn.py", line 25, in test_triton
    triton_output = lrn_fwd_triton(q, k, v, z1, z2, state)
  File "/workspace/ttx/lrn/triton.py", line 120, in lrn_fwd_triton
    V_BLOCK_SIZE=v_block_size
  File "/opt/conda/lib/python3.7/site-packages/triton/code_gen.py", line 783, in __call__
    return self.kernel(*wargs, **kwargs, grid=self.grid)
  File "/opt/conda/lib/python3.7/site-packages/triton/code_gen.py", line 774, in __call__
    self.fn.bin_cache, num_warps, num_stages, self.add_to_cache, grid)
  File "/opt/conda/lib/python3.7/site-packages/triton/code_gen.py", line 724, in add_to_cache
    constants=constants,
  File "/opt/conda/lib/python3.7/site-packages/triton/code_gen.py", line 675, in _compile
    name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, generator.module, device, num_warps, num_stages)
IndexError: map::at

@calclavia
Copy link
Author

calclavia commented Feb 23, 2022

@ptillet I did more digging to see which part of the kernel causes the error. Added some comments. In particular, if I try to load another vector q = tl.load(q_ptr + k_offsets[:, None], mask=k_mask[:, None], other=0) in the loop, I get a segfault. The tl.store at the end of the loop causes the IndexError: map::at error. The kernel seems to run if I comment out tl.store.

    for _ in range(0, length, 1):

        # Load a single row of K and V as matrices.
        # [K_BLOCK_SIZE, 1]
        k = tl.load(k_ptr + k_offsets,
                    mask=k_mask, other=0)[:, None]
        # [1, V_BLOCK_SIZE]
        v = tl.load(v_ptr + v_offsets, mask=v_mask, other=0)[None, :]

        # Compute context [V, 1] x [1, K] => [V, K]
        context = tl.dot(v, k)
        state += context

        # Load a single row of Q of shape [K, 1]
        # TODO: Loading this causes segfault
        # q = tl.load(q_ptr + k_offsets[:, None], mask=k_mask[:, None], other=0)

        # Compute output = S * Q. [V, K] x [K, 1] x  => [D, 1]
        # TODO: Correct equation
        # output = tl.dot(state, q)
        output = tl.dot(state, k)

        # TODO: Storing output causes IndexError: map::at
        tl.store(
            output_ptr + v_offsets[:, None],
            output,
            mask=v_mask[:, None]
        )

        # Move to next row
        k_offsets += kdim
        v_offsets += vdim

@calclavia
Copy link
Author

I think this issue is related to #375
If I use tl.dot within a loop, it tends to be buggy with Segmentation fault

@ptillet
Copy link
Collaborator

ptillet commented Feb 24, 2022

Hey!

First of all, thanks for digging into this issue. Sorry for having been unresponsive, I've been quite busy with other matters, and I am well aware of the instabilities of triton 😅 I will look into this problem. FWIW, we are planning a rewrite of the backend that should greatly improve stability on these kinds of issues.

@ptillet
Copy link
Collaborator

ptillet commented Feb 22, 2023

Closing this, as the compiler now throws an error when the blocks are too small for tl.dot. Please submit a new issue with a repro if you have an alternative kernel that compiles

@ptillet ptillet closed this as completed Feb 22, 2023
ZzEeKkAa pushed a commit to ZzEeKkAa/triton that referenced this issue Aug 5, 2024
GitHub has a limit of 10GB of cache per repository. Plus caches are
scoped to a key and a branch, so a cache created on branch-A won’t be
accessible by branch-B unless branch-A is the default branch. Since
llvm-target is not the default branch, we cannot use GitHub cache
directly.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants