-
Notifications
You must be signed in to change notification settings - Fork 185
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve primitives for FP6 quant #248
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/248
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 78e79ac with merge base a7bc592 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
torchao/dtypes/fp6.py
Outdated
return torch.stack([bits0, bits1, bits2], dim=-1).flatten(-2) | ||
|
||
|
||
def to_fp6(tensor: Tensor, no_bit_packing: bool = False) -> Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: thoughts about naming your dtype float6_e3m2
instead of fp6
? This is to be consistent with naming for other PyTorch low precision dtypes such as float8_e4m3|e5m2
from PyTorch core as well as the upcoming MX dtypes, which include float6_e3m2
and float6_e2m3
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking the same thing too! Will update the name.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
where can I read more about MX dtypes? This particular FP6 used by FP6-LLM paper does not represent +/-inf and NaN, so not sure if we should signal that in the name somehow too? (like float8_e4m3fn
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can check out https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf, page 12 describes the supported float6 flavors. I plan to add the the mx code in torchao soon.
For the fn
suffix...I'm planning to follow the OCP spec naming, which does not include naming qualifiers for special value handling, and replace fp
with float
to be consistent with other PyTorch dtype names. I think the fn
suffix made sense for float8 where different flavors had different special value handling, but none of these sub 8 bit dtypes support special values.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool! It seems like the FP6 I used here is exactly the same as MX FP6 E3M2 (without the scale - FP6 LLM author use 1 scale per row). Perhaps in the future MX dtype can replace this.
@@ -14,6 +8,13 @@ | |||
from . import _C |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to import _C
first since to/from_float6_e3m2()
(from dtypes
) calls C++ extension for CPU.
@@ -120,49 +119,14 @@ void weight_prepacking_fp16_to_fp6(uint16_t* weight_16bit, | |||
} | |||
} | |||
|
|||
void DeQuantMatrix_FP6_To_FP16(half* A_16bit_h, unsigned char* A_6bit_h, size_t M, size_t K, half* scale) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replaced with from_float6_e3m2()
Address #208
TODO:
On (8192, 8192) tensor. Ryzen 5600 and 4070Ti SUPER
NOTE:
torchao.ops.fp16_to_fp6_original()
(from original FP6-LLM repo + qtorch quantization logic). This does not support CUDA.(8192, 8192) FP6 input. Ryzen 5600 and 4070Ti SUPER.
NOTE:
fp6_weight_dequant()
(original implementation) is slow probably because the author use CUDA intrinsics__float2half()
and__half2float()
on CPU, which have to be implemented via bit manipulation.