Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable the CPU int4 with HQQ quant #1824

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,13 +143,15 @@ def _int8da_int8w_api(
change_linear_weights_to_int8_dqtensors(mod)


def _int4wo_api(mod):
def _int4wo_api(mod, use_hqq=False):
if (
is_device(next(mod.parameters()).device.type, "cpu")
and TORCH_VERSION_AT_LEAST_2_6
):
quantize_(
mod, int4_weight_only(layout=Int4CPULayout()), set_inductor_config=False
mod,
int4_weight_only(layout=Int4CPULayout(), use_hqq=use_hqq),
set_inductor_config=False,
)
unwrap_tensor_subclass(mod)
elif TORCH_VERSION_AT_LEAST_2_4:
Expand Down Expand Up @@ -1042,8 +1044,6 @@ def test_int8_weight_only_quant_with_freeze(self, device, dtype):
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.")
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
def test_int4_weight_only_quant_subclass_api(self, device, dtype):
if device == "cpu":
self.skipTest(f"Temporarily skipping for {device}")
if dtype != torch.bfloat16:
self.skipTest(f"Fails for {dtype}")
for test_shape in [(16, 1024, 16)] + (
Expand All @@ -1053,6 +1053,20 @@ def test_int4_weight_only_quant_subclass_api(self, device, dtype):
_int4wo_api, device, 15, test_shape=test_shape, test_dtype=dtype
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "int4 hqq requires torch nightly.")
def test_int4_weight_only_hqq_quant_subclass_api(self, device, dtype):
if dtype != torch.bfloat16:
self.skipTest(f"Fails for {dtype}")
for test_shape in [(16, 1024, 16), (1, 1024, 256)]:
api = partial(
_int4wo_api,
use_hqq=True,
)
self._test_lin_weight_subclass_api_impl(
api, device, 15, test_shape=test_shape, test_dtype=dtype
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(
not TORCH_VERSION_AT_LEAST_2_5, "gemlite tests needs torch 2.5 or greater"
Expand Down Expand Up @@ -1103,8 +1117,6 @@ def test_gemlite_layout(self, device, dtype):
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.")
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype):
if device == "cpu":
self.skipTest(f"Temporarily skipping for {device}")
if dtype != torch.bfloat16:
self.skipTest(f"Fails for {dtype}")
layout_list = []
Expand Down
10 changes: 8 additions & 2 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,7 +781,8 @@ def reset_memory():
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.6+")
@common_utils.parametrize("dtype", [torch.float, torch.bfloat16, torch.half])
@common_utils.parametrize("x_dim", [2, 3])
def test_int4wo_cpu(self, dtype, x_dim):
@common_utils.parametrize("use_hqq", [True, False])
def test_int4wo_cpu(self, dtype, x_dim, use_hqq):
from torchao.dtypes import Int4CPULayout

device = "cpu"
Expand All @@ -791,7 +792,12 @@ def test_int4wo_cpu(self, dtype, x_dim):
example_inputs = (example_inputs[0].unsqueeze(0),)

with torch.no_grad():
quantize_(m, int4_weight_only(group_size=32, layout=Int4CPULayout()))
quantize_(
m,
int4_weight_only(
group_size=32, layout=Int4CPULayout(), use_hqq=use_hqq
),
)
# ensure the expected op is in the code
_, code = torch._inductor.utils.run_and_get_code(
torch.compile(m, fullgraph=True, dynamic=True),
Expand Down
3 changes: 2 additions & 1 deletion torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ def from_hp_to_intx(
else input_float.dtype
)
device = input_float.device
from torchao.dtypes import Int4CPULayout
from torchao.dtypes.uintx import TensorCoreTiledLayout

data, scale, zero_point, _ = choose_qparams_and_quantize_affine_hqq(
Expand All @@ -235,7 +236,7 @@ def from_hp_to_intx(
device=device,
verbose=False,
raw_output=not isinstance(
_layout, (TensorCoreTiledLayout, PlainLayout)
_layout, (TensorCoreTiledLayout, PlainLayout, Int4CPULayout)
),
# raw_output=False is basically the 'convert to TensorCoreTiledLayout zero_point version' option (add scale*midpoint)
# note in choose_qparams_affine, preserve_zero = False does this same thing while also controlling whether
Expand Down
Loading