Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Testing triton when C matrix has some initial values. Possible? #170

Closed
navdeepkk opened this issue Jul 31, 2021 · 6 comments
Closed

Testing triton when C matrix has some initial values. Possible? #170

navdeepkk opened this issue Jul 31, 2021 · 6 comments

Comments

@navdeepkk
Copy link

Hi, Is it possible to benchmark triton when C matrix is not assumed to be all zeros, i.e., it is loaded from global memory rather than initializing an accumulator tile with zeros?

Thanks!

@navdeepkk
Copy link
Author

@ptillet any help would be appreciated. Thanks!

@ptillet
Copy link
Collaborator

ptillet commented Aug 3, 2021

@navdeepkk Sorry for the delay. I was double-checking something. So right now it seems like initialize matmul accumulator with loaded values is buggy. But what you can do instead is to add an accumulation after the matmul loop.

acc = 0
for k in range(K, 0, -BLOCK_K):
  ...
D = ... # construct pointer to the other tensor
acc += tl.load(D)

I will work on fixing bugs so that one can do

D = ... # construct pointer to the other tensor
acc = tl.load(D)
for k in range(K, 0, -BLOCK_K):
  ...

instead

@navdeepkk
Copy link
Author

Thanks! I'll try this.

@ptillet
Copy link
Collaborator

ptillet commented Aug 6, 2021

Actually I just checked and this doesn't seem to work anymore now ... But I'll be working on it over the next few days.

@navdeepkk
Copy link
Author

Oh Okay, I have some matmul kernels generated, which load C from global memory. I wanted to compare against Triton. I'll get on that when this is fixed. Thanks!

@ptillet
Copy link
Collaborator

ptillet commented Aug 30, 2021

Hey. This should work now on top of master :) Just add for example c = c + tl.load(c_ptrs, mask=c_mask) there: https://github.com/openai/triton/blob/master/python/tutorials/03-matrix-multiplication.py#L253. You can also initialize the accumulator here https://github.com/openai/triton/blob/master/python/tutorials/03-matrix-multiplication.py#L228, but it should be less efficient, and for now it'll only work if you explicitly cast it to float32 after loading. Let me know if you have any more issue

@ptillet ptillet closed this as completed Aug 30, 2021
B1tway pushed a commit to B1tway/triton that referenced this issue Apr 3, 2023
…_options_and_report

[FRONTEND] Fix triton-translate default options
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants