Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fp8dq requires both dimensions to be divisible by 16 #1268

Closed
piotr-bazan-nv opened this issue Nov 12, 2024 · 7 comments
Closed

fp8dq requires both dimensions to be divisible by 16 #1268

piotr-bazan-nv opened this issue Nov 12, 2024 · 7 comments
Assignees

Comments

@piotr-bazan-nv
Copy link

piotr-bazan-nv commented Nov 12, 2024

When trying to quantize a model the exepction is raised:

TorchRuntimeError: Failed running call_function <built-in function linear>(*(FakeTensor(..., device='cuda:0', size=(2, 32)), LinearActivationQuantizedTensor(AffineQuantizedTensor(layout_tensor=Float8AQTLayout(
float8_data=FakeTensor(..., device='cuda:0', size=(15, 32), dtype=torch.float8_e4m3fn),
scale=FakeTensor(..., device='cuda:0', size=()),
transposed=False, layout_type=Float8LayoutType(mm_config=Float8MMConfig(emulate=False, use_fast_accum=True, pad_inner_dim=False))), block_size=torch.Size([15, 32]), shape=torch.Size([15, 32]), device=cuda:0, dtype=torch.float32, requires_grad=False), functools.partial(<function _input_activation_quant_func_fp8 at 0x7a94b4f4d120>, activation_granularity=PerTensor(), activation_dtype=torch.float8_e4m3fn)), Parameter(FakeTensor(..., device='cuda:0', size=(15,), requires_grad=True))), **{}):
Expected both dimensions of mat2 to be divisble by 16 but got torch.Size([32, 15])

Minimal code to reproduce the issue:

import torch
from torchao.quantization import (
    float8_dynamic_activation_float8_weight,
    quantize_,
)
dim1 = 32
dim2 = 15

class ToyModel(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.model = torch.nn.Linear(dim1, dim2)

    def forward(self, x):
        return self.model(x)

model = ToyModel().to("cuda").eval()

quantize_(model, float8_dynamic_activation_float8_weight())
model = torch.compile(model=model, fullgraph=True, mode="max-autotune")
model(torch.randn(2, 32).to('cuda'))

Is this by design or is it a bug? Currently this prevents many models to be quantized.

@HDCharles
Copy link
Contributor

HDCharles commented Nov 12, 2024

Hey, yes, that's a requirement of scaled_mm in general though

"Expected trailing dimension of mat1 to be divisible by 16 but got mat1 shape: (16x41)."

you can use something like

ao/test/float8/test_base.py

Lines 776 to 782 in 4120526

def module_filter_fn(mod, fqn):
return (
mod.in_features >= size_limit
and mod.out_features >= size_limit
and mod.in_features % 16 == 0
and mod.out_features % 16 == 0
)

as a filter fn argument in quantize_

we're working on other kernels that are more flexible

@jerryzh168
Copy link
Contributor

should probably add this to

def apply_float8_dynamic_activation_quant(weight: torch.Tensor):
if it applies to float8 quant method itself

@drisspg
Copy link
Contributor

drisspg commented Nov 19, 2024

import torch
from torchao.quantization import (
    float8_dynamic_activation_float8_weight,
    quantize_,
)
import logging

logging.getLogger("torchao").setLevel(logging.INFO)

logging.basicConfig(level=logging.INFO)
dim1 = 32
dim2 = 15

class ToyModel(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.model = torch.nn.Linear(dim1, dim2)

    def forward(self, x):
        return self.model(x)

model = ToyModel().to("cuda").eval()

quantize_(model, float8_dynamic_activation_float8_weight())
model = torch.compile(model=model, fullgraph=True, mode="max-autotune")
model(torch.randn(2, 32).to('cuda'))

we do properly raise

INFO:torchao.quantization.quant_api:Skipping float8 quantization: weight shape torch.Size([15, 32]) is not compatible with _scaled_mm. Both input dimension (32) and output dimension (15) must be multiples of 16. 

@jerryzh168
Copy link
Contributor

maybe it's a issue with torchao versions, @piotr-bazan-nv what torchao version are you using?

@piotr-bazan-nv
Copy link
Author

@jerryzh168 It's 0.6.1

@jerryzh168
Copy link
Contributor

#1194 is added after the release I think, you should be able to get the change in nightly or 0.7

@piotr-bazan-nv
Copy link
Author

Thanks @jerryzh168. Closing the issue then.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants