Skip to content

[FRONTEND] fix matmul int8 overflow issue#2297

Merged
ptillet merged 1 commit intotriton-lang:mainfrom
lightb0x:main
Sep 17, 2023
Merged

[FRONTEND] fix matmul int8 overflow issue#2297
ptillet merged 1 commit intotriton-lang:mainfrom
lightb0x:main

Conversation

@lightb0x
Copy link
Copy Markdown
Contributor

Previously on matmul, if inputs are int8, output was also int8.
This commit fixes the overflow problem with int32 output.
#2296

Previously on matmul, if inputs are int8, output was also int8.
This commit fixes the overflow problem with int32 output.
@lightb0x lightb0x requested a review from ptillet as a code owner September 14, 2023 10:58
@gflegar
Copy link
Copy Markdown
Collaborator

gflegar commented Sep 14, 2023

Not sure this is how it should be handled. Generally, you would expect that a binary operation on two operands of the same type would also produce the result of that same type (or if it's a combination of two types, the lower precision type should be promoted to the higher one, and produce the result of the higher type). This would make the output type confusing.

Maybe a better option would be to handle the type of c similarly to dot_out_type, where we have this default promotion logic as is, but the user can optionally specify another type and in that case the output is stored in that type?

@lightb0x
Copy link
Copy Markdown
Contributor Author

Thanks for quick and detailed review.
I agree on your point and I will go for your suggestion with skip_cast parameter.
I think skip_cast would be better than out_dtype because it will

  • prevent many user-level pitfalls
  • use less computation & memory

@gflegar
Copy link
Copy Markdown
Collaborator

gflegar commented Sep 14, 2023

So the idea with skip_cast would be to use dot_out_type for c if the flag is set?

This does limit the number of options we could support, because it could be quite possible that you want to accumulate in, e.g., float32 for higher precision, but then save in float16 to save on data while still avoiding overflow, while your inputs are in float8 or something like that.

What are the user-level pitfalls you had in mind?

Re computation & memory pressure, if you're thinking about the additional cast from a type to itself would introduce some overhead, I'm quite sure that Triton can optimize that away.

(Also, just wanted to clarify that I'm not an official reviewer, nor affiliated with the maintainers of this repository. I'm just following it closely and noticed that this might cause unexpected semantics.)

@lightb0x
Copy link
Copy Markdown
Contributor Author

lightb0x commented Sep 14, 2023

For user-level pitfalls, I thought wrong in that manual override requires specific datatype anyways.
FP8 <-> INT8 can be an example, because converting between these two datatypes will lose a lot of information.

For computation & memory overhead, I'm seeing quite a lot of overhead for int <-> fp casting with torch.Tensor.to.
(maybe this can be an example of user-level pitfall with skip_cast)
I will try triton way of casting.

Tradeoff between proposed code vs. (skip_cast or out_dtype) would be simplicity(and less maintainability) vs. interface change.
I will try both. Let's see how it goes.

Thank you for your opinion!

@ptillet
Copy link
Copy Markdown
Collaborator

ptillet commented Sep 17, 2023

I'm gonna merge that. I don't feel super strongly either way, since triton.ops.matmul isn't really part of the language and just used for testing

@ptillet ptillet enabled auto-merge (squash) September 17, 2023 16:21
@ptillet ptillet merged commit 2b06600 into triton-lang:main Sep 17, 2023
@github-actions
Copy link
Copy Markdown

⚠️ This PR does not produce bitwise identical kernels as the branch it's merged against. Please check artifacts for details. Download the output file here.

alexander-zinoviev pushed a commit to alexander-zinoviev/triton that referenced this pull request Sep 21, 2023
Previously on matmul, if inputs are int8, output was also int8.
This commit fixes the overflow problem with int32 output.
triton-lang#2296
pingzhuu pushed a commit to siliconflow/triton that referenced this pull request Apr 2, 2024
Previously on matmul, if inputs are int8, output was also int8.
This commit fixes the overflow problem with int32 output.
triton-lang#2296
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants