-
Notifications
You must be signed in to change notification settings - Fork 191
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable float8 attention support (q/k/v) #1382
base: main
Are you sure you want to change the base?
Conversation
Summary: att, right now we need to manually add quantize call for q/k/v before sdpa op, but we can explore other APIs in the future Test Plan: TBD Reviewers: Subscribers: Tasks: Tags:
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1382
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit 2895626 with merge base 04d611a (): NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
q_float8_data = q_tensor_impl.float8_data | ||
# change from scalar to tensor of size [1] | ||
q_scale = q_tensor_impl.scale | ||
q_scale = torch.tensor([q_scale], device=q_scale.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are the scales on host?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you mean q_scale
before we call torch.tensor
? they should be using the same device as original weight I think, so should be on cuda before
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks good, left a few comments. I think we can add the int8 kernel when it gets added as well for CPU
from torchao.quantization.quant_api import _float8_symmetric_per_tensor_quant | ||
original_dtype = v.dtype | ||
if q.shape[-1] in [64, 128, 256]: | ||
q = _float8_symmetric_per_tensor_quant(q) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We also likely need/want to apply the hadamard transform. I don't remember off hand if this is include in fav3 api
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
didn't see this, maybe we can add it after spinquant is integrated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We also likely need/want to apply the hadamard transform. I don't remember off hand if this is include in fav3 api
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we update the current SAM2 readme with all the ao optimizations we have introduced?
* Update multimodal.md Complete markup for testing * Update run-docs Add ability to run on docs/multimodal.md * Update run-readme-pr.yml
Summary:
This PR integrates flashattention 3 kernel: https://github.com/Dao-AILab/flash-attention/blob/1feb711f46563960fc10a8e659c93c300619504b/flash_attn/flash_attn_interface.py#L1102 to float8 affine quantized tensor.
To use the kernel, right now we need to manually add quantize call for q/k/v before sdpa op, but we can explore other APIs in the future
@sijiac is working on adding new variations of attention implementation in the future (per row, per column, per block scaling etc.).
Test Plan:
python test/dtypes/test_affine_quantized_float.py -k test_float8_attention
SAM2
tested on sam2 and seems to be a bit slower than before, this is reasonable because sam2 is using 16 and 32 head dimension, but fa3 requires 64 being the minimum size, we need to do some padding to make this work (pad 32 to 64) which is expected to increase runtime significantly.
llama2
llama2 without fallback: doesn't work because
attn_mask
is not supported.llama2 numerics only
(just for testing, code is not checked in) tested on llama2 (with fallback to test numerics):
since
attn_mask
is not supported in flashattention 3 kernel, it's using the fallback path: https://github.com/pytorch/ao/pull/1382/files#diff-3019e8f38b0919dbaba5aa1329a697e89fc98749e35a7bdc274c71a0d3738ec2R285Reviewers:
Subscribers:
Tasks:
Tags: