Skip to content

Commit

Permalink
Add Int4CPULayout and update int4 woq (#1278)
Browse files Browse the repository at this point in the history
* Add Int4CPULayout and update int4 woq

* Apply automatic Ruff fixes

* Fix CI

* Remote nightly

* Apply automatic Ruff fixes

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
  • Loading branch information
yanbing-j and github-actions[bot] authored Nov 27, 2024
1 parent 543209b commit 719440e
Show file tree
Hide file tree
Showing 14 changed files with 448 additions and 93 deletions.
1 change: 1 addition & 0 deletions .github/workflows/regression_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ jobs:
torch-spec: 'torch==2.5.1 --index-url https://download.pytorch.org/whl/cu121'
gpu-arch-type: "cuda"
gpu-arch-version: "12.1"

- name: CPU 2.3
runs-on: linux.4xlarge
torch-spec: 'torch==2.3.0 --index-url https://download.pytorch.org/whl/cpu'
Expand Down
55 changes: 29 additions & 26 deletions test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
run_tests,
)

from torchao.dtypes import SemiSparseLayout
from torchao.dtypes import Int4CPULayout, SemiSparseLayout
from torchao.quantization import (
float8_weight_only,
int4_weight_only,
Expand All @@ -17,20 +17,25 @@
int8_weight_only,
)
from torchao.quantization.quant_primitives import MappingType
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6

is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)


def get_quantization_functions(do_sparse: bool, do_int4: bool):
def get_quantization_functions(do_sparse: bool, do_int4: bool, device: str = "cuda"):
base_functions = [
int8_weight_only(),
int8_dynamic_activation_int4_weight(),
int8_dynamic_activation_int8_weight(),
int8_dynamic_activation_int8_weight(act_mapping_type=MappingType.ASYMMETRIC),
]
if do_int4:
base_functions.append(int4_weight_only(group_size=32))
if device == "cpu" and TORCH_VERSION_AT_LEAST_2_6:
base_functions.append(
int4_weight_only(group_size=32, layout=Int4CPULayout())
)
else:
base_functions.append(int4_weight_only(group_size=32))

if do_sparse:
base_functions.append(
Expand Down Expand Up @@ -152,30 +157,28 @@ class TestAffineQuantizedBasic(TestCase):
COMMON_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
COMMON_DTYPES = [torch.bfloat16]

@common_utils.parametrize("apply_quant", get_quantization_functions(False, True))
@common_utils.parametrize("device", COMMON_DEVICES)
@common_utils.parametrize("dtype", COMMON_DTYPES)
def test_flatten_unflatten(self, apply_quant, device, dtype):
if device == "cpu":
self.skipTest(f"Temporarily skipping for {device}")

linear = torch.nn.Linear(128, 256, dtype=dtype, device=device)
ql = apply_quant(linear)
lp_tensor = ql.weight
tensor_data_name_dict, tensor_attributes = lp_tensor.__tensor_flatten__()
tensor_data_dict = {
name: getattr(lp_tensor, name) for name in tensor_data_name_dict
}
outer_size = lp_tensor.size()
outer_stride = lp_tensor.stride()
reconstructed = type(lp_tensor).__tensor_unflatten__(
tensor_data_dict, tensor_attributes, outer_size, outer_stride
)
example_inputs = (torch.randn(32, 128, dtype=dtype, device=device),)
ref = ql(*example_inputs)
ql.weight = torch.nn.Parameter(reconstructed, requires_grad=False)
reconstruct_res = ql(*example_inputs)
self.assertEqual(reconstruct_res, ref)
def test_flatten_unflatten(self, device, dtype):
apply_quant_list = get_quantization_functions(False, True, device)
for apply_quant in apply_quant_list:
linear = torch.nn.Linear(128, 256, dtype=dtype, device=device)
ql = apply_quant(linear)
lp_tensor = ql.weight
tensor_data_name_dict, tensor_attributes = lp_tensor.__tensor_flatten__()
tensor_data_dict = {
name: getattr(lp_tensor, name) for name in tensor_data_name_dict
}
outer_size = lp_tensor.size()
outer_stride = lp_tensor.stride()
reconstructed = type(lp_tensor).__tensor_unflatten__(
tensor_data_dict, tensor_attributes, outer_size, outer_stride
)
example_inputs = (torch.randn(32, 128, dtype=dtype, device=device),)
ref = ql(*example_inputs)
ql.weight = torch.nn.Parameter(reconstructed, requires_grad=False)
reconstruct_res = ql(*example_inputs)
self.assertEqual(reconstruct_res, ref)


common_utils.instantiate_parametrized_tests(TestAffineQuantized)
Expand Down
18 changes: 14 additions & 4 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from torchao.quantization.dynamic_quant import (
DynamicallyPerAxisQuantizedLinear,
)
from torchao.dtypes import TensorCoreTiledLayout
from torchao.dtypes import TensorCoreTiledLayout, Int4CPULayout
from torchao.quantization.quant_api import (
int4_weight_only,
int8_weight_only,
Expand Down Expand Up @@ -93,6 +93,7 @@
is_fbcode,
benchmark_model
)
from torchao.dtypes.utils import is_device

logger = logging.getLogger("INFO")

Expand Down Expand Up @@ -133,7 +134,10 @@ def _int8da_int8w_api(mod):
change_linear_weights_to_int8_dqtensors(mod)

def _int4wo_api(mod):
if TORCH_VERSION_AT_LEAST_2_4:
if is_device(next(mod.parameters()).device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6:
quantize_(mod, int4_weight_only(layout=Int4CPULayout()), set_inductor_config=False)
unwrap_tensor_subclass(mod)
elif TORCH_VERSION_AT_LEAST_2_4:
quantize_(mod, int4_weight_only(), set_inductor_config=False)
if not TORCH_VERSION_AT_LEAST_2_5:
unwrap_tensor_subclass(mod)
Expand Down Expand Up @@ -935,10 +939,16 @@ def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype):
self.skipTest(f"Temporarily skipping for {device}")
if dtype != torch.bfloat16:
self.skipTest(f"Fails for {dtype}")
layout_list = []
if device == 'cpu' and TORCH_VERSION_AT_LEAST_2_6:
layout_list.append(Int4CPULayout())
else:
for inner_k_tiles in [4, 2]:
layout_list.append(TensorCoreTiledLayout(inner_k_tiles=inner_k_tiles))
for test_shape in ([(256, 256, 16)] + ([(256, 256, 8)] if device=='cuda' else [])):
for groupsize in [64, 32]:
for inner_k_tiles in [4, 2]:
kwargs = {"groupsize": groupsize, "layout": TensorCoreTiledLayout(inner_k_tiles=inner_k_tiles)}
for layout in layout_list:
kwargs = {"groupsize": groupsize, "layout": layout}

def api(mod):
kwargs_copy = kwargs.copy()
Expand Down
10 changes: 7 additions & 3 deletions test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
TORCH_VERSION_AT_LEAST_2_6,
is_fbcode,
)
from torchao.dtypes.utils import is_device

_SEED = 1234
torch.manual_seed(_SEED)
Expand Down Expand Up @@ -102,7 +103,8 @@ def _groupwise_affine_quantize_tensor_from_qparams(
.reshape_as(w)
)
if TORCH_VERSION_AT_LEAST_2_5:
w_int4x8 = (w_int4x8[::, ::2] << 4 | w_int4x8[::, 1::2]).to(torch.uint8)
if not (is_device(w.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6):
w_int4x8 = (w_int4x8[::, ::2] << 4 | w_int4x8[::, 1::2]).to(torch.uint8)

return w_int4x8

Expand Down Expand Up @@ -524,8 +526,10 @@ def test_groupwise_affine_dequantize_tensor_from_qparams(self):
groupsize = 128

if TORCH_VERSION_AT_LEAST_2_5:
input_uint8 = (input[::, ::2] << 4 | input[::, 1::2]).to(torch.uint8)
w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(input_uint8, scales, zeros, n_bit, groupsize)
input_tmp = input
if not (is_device(input.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6):
input_tmp = (input[::, ::2] << 4 | input[::, 1::2]).to(torch.uint8)
w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(input_tmp, scales, zeros, n_bit, groupsize)
else:
w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize)
w_bf16_ref = _groupwise_affine_dequantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize)
Expand Down
2 changes: 2 additions & 0 deletions torchao/dtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .nf4tensor import NF4Tensor, to_nf4
from .uintx import (
BlockSparseLayout,
Int4CPULayout,
MarlinQQQLayout,
MarlinSparseLayout,
SemiSparseLayout,
Expand Down Expand Up @@ -48,4 +49,5 @@
"UintxLayout",
"MarlinQQQTensor",
"MarlinQQQLayout",
"Int4CPULayout",
]
2 changes: 2 additions & 0 deletions torchao/dtypes/uintx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
SemiSparseLayout,
)
from .tensor_core_tiled_layout import (
Int4CPULayout,
TensorCoreTiledLayout,
)
from .uintx_layout import (
Expand All @@ -23,5 +24,6 @@
"MarlinSparseLayout",
"SemiSparseLayout",
"TensorCoreTiledLayout",
"Int4CPULayout",
"MarlinQQQLayout",
]
Loading

0 comments on commit 719440e

Please sign in to comment.