Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix bug for torch.uint1-7 not support in torch<2.6 #10368

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 23 additions & 17 deletions src/diffusers/quantizers/torchao/torchao_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,23 +33,29 @@

if is_torch_available():
import torch
import torch.nn as nn

SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION = (
# At the moment, only int8 is supported for integer quantization dtypes.
# In Torch 2.6, int1-int7 will be introduced, so this can be visited in the future
# to support more quantization methods, such as intx_weight_only.
torch.int8,
torch.float8_e4m3fn,
torch.float8_e5m2,
torch.uint1,
torch.uint2,
torch.uint3,
torch.uint4,
torch.uint5,
torch.uint6,
torch.uint7,
)

if version.parse(torch.__version__) >= version.parse('2.6'):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh thanks so much!
can we use is_torch_version?

def is_torch_version(operation: str, version: str):

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! Apologies that I missed this 😅

I think this should be something like:

if version <= 2.5:
    # do not use any of the uintx
else:
    # use the uintx types as well

I believe the uintx data types were added in 2.5 and not 2.6

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also maybe good to have another branch that adds float8_e4m3fn and float8_e5m2 only if torch version is >= 2.2. int8 came in 2.0 I believe but since minimum diffusers requirement is torch>=1.4, another branch for it might be needed. LMK if I can help with any changes

SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION = (
# At the moment, only int8 is supported for integer quantization dtypes.
# In Torch 2.6, int1-int7 will be introduced, so this can be visited in the future
# to support more quantization methods, such as intx_weight_only.
torch.int8,
torch.float8_e4m3fn,
torch.float8_e5m2,
torch.uint1,
torch.uint2,
torch.uint3,
torch.uint4,
torch.uint5,
torch.uint6,
torch.uint7,
)
else:
SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION = (
torch.int8,
torch.float8_e4m3fn,
torch.float8_e5m2,
)

if is_torchao_available():
from torchao.quantization import quantize_
Expand Down
Loading