Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
08e9095
Summary:
namgyu-youn Sep 21, 2025
db23cf3
rename for clearly: Int8PlainInt8Tensor -> Int8Tensor
namgyu-youn Sep 22, 2025
b861dbc
add flags for static/dynamic quant
namgyu-youn Sep 23, 2025
9383550
update static/dynamic quantization workflows
namgyu-youn Sep 24, 2025
2c84ba4
add kernel preference unit test
namgyu-youn Sep 24, 2025
8ddddd3
add kernel preference unit test
namgyu-youn Sep 24, 2025
bd6f58a
Merge remote-tracking branch 'upstream/main' into int8-quant
namgyu-youn Sep 24, 2025
b5cb3c8
fix missing attribute
namgyu-youn Sep 24, 2025
9a51cae
remove kernel preference args
namgyu-youn Sep 28, 2025
c53dad0
link new API with old API using version 2
namgyu-youn Sep 28, 2025
d300b02
add granularity, block size support
namgyu-youn Sep 30, 2025
c43a3ec
Merge branch 'main' into int8-quant
namgyu-youn Oct 4, 2025
590e0b7
add transpose, index selector workflows
namgyu-youn Oct 4, 2025
b3d4f3e
remove external zero point
namgyu-youn Oct 4, 2025
df79aa8
update int8 quantization API
namgyu-youn Oct 7, 2025
910906b
Merge remote-tracking branch 'upstream/main' into int8-quant
namgyu-youn Oct 7, 2025
c61b36e
add static quantization support
namgyu-youn Oct 14, 2025
0a45f90
sync with main branch
namgyu-youn Oct 16, 2025
1251187
split dispatch decorator
namgyu-youn Oct 16, 2025
844d99d
update int8-quant api
namgyu-youn Oct 17, 2025
a844678
update type-hint to prevent depenedency issue
namgyu-youn Oct 17, 2025
2c0389a
fix ci error
namgyu-youn Oct 20, 2025
bafeb43
revert unrelated changes
namgyu-youn Oct 23, 2025
7006cae
fix rebase
namgyu-youn Oct 23, 2025
49a7a89
update int8 quant api
namgyu-youn Oct 23, 2025
062f3cc
update int8
namgyu-youn Oct 24, 2025
680cec9
build setup for unit test, enable per-row/per-tensor granuarity
namgyu-youn Oct 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/source/quantization_overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ First we want to lay out the torchao stack::

Quantization Algorithms/Flows: weight only/dynamic/static quantization, hqq, awq, gptq etc.
---------------------------------------------------------------------------------------------
Quantized Tensors (derived dtypes): Int4Tensor, Int4PreshuffledTensor, Float8Tensor
Quantized Tensors (derived dtypes): Int4Tensor, Int4PreshuffledTensor, Int8Tensor, Float8Tensor
---------------------------------------------------------------------------------------------
Quantization Primitive Ops/Efficient Kernels: matmul, quantize, dequantize
---------------------------------------------------------------------------------------------
Expand Down Expand Up @@ -88,6 +88,8 @@ So in general we structure Tensor subclasses by dervied dtpype and packing forma
- scaled int4
- preshuffled (special format to optimize for loading)
- float8 act + int4 weight dynamic quantization and int4 weight only quantization
* - Int8Tensor
- plain

.. note::
We don't have granularity specific tensor subclasses, i.e. no Float8RowwiseTensor or Float8BlockwiseTensor, all granularities are implemented in the same Tensor, we typically use a general `block_size` attribute to distinguish between different granularities, and each Tensor is allowed to support only a subset of all possible granularity options.
Expand Down
73 changes: 73 additions & 0 deletions test/quantization/quantize_/workflows/int8/test_int8_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch
from torch.testing._internal.common_utils import run_tests

from torchao.quantization.quantize_.workflows.int8.int8_tensor import (
Int8PlainInt8Tensor,
)
from torchao.quantization.utils import compute_error
from torchao.testing.utils import TorchAOIntegrationTestCase


@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
class TestInt8PlainInt8Tensor(TorchAOIntegrationTestCase):
def setUp(self):
super().setUp()
torch.manual_seed(42)
self.weight_fp = torch.randn(4, 3, dtype=torch.float32)
self.input_fp = torch.randn(2, 3, dtype=torch.float32)
self.bias = torch.randn(4)
self.block_size = [4, 3]

def test_creation_and_attributes(self):
"""Test tensor creation, dtypes, and ranges"""
tensor = Int8PlainInt8Tensor.from_hp(self.weight_fp, self.block_size)

self.assertEqual(tensor.shape, (4, 3))
self.assertEqual(tensor.qdata.dtype, torch.int8)
self.assertTrue(
torch.all(tensor.qdata >= -128) and torch.all(tensor.qdata <= 127)
)

def test_linear_operations(self):
"""Test fp+int8 and int8+int8 linear ops with quantization error check"""
weight_q8 = Int8PlainInt8Tensor.from_hp(self.weight_fp, self.block_size)
input_q8 = Int8PlainInt8Tensor.from_hp(self.input_fp, self.block_size)

reference = torch.nn.functional.linear(self.input_fp, self.weight_fp, self.bias)
result_fp = torch.nn.functional.linear(self.input_fp, weight_q8, self.bias)
result_q8 = torch.nn.functional.linear(input_q8, weight_q8, self.bias)

self.assertEqual(result_fp.shape, reference.shape)
self.assertEqual(result_q8.shape, reference.shape)
self.assertTrue(compute_error(result_fp, reference) > 10)
self.assertTrue(compute_error(result_q8, reference) > 10)

def test_error_handling_and_dequant(self):
"""Test input validation and dequantization accuracy"""
# Test 1D tensor validation
with self.assertRaises((AssertionError, ValueError, RuntimeError)):
Int8PlainInt8Tensor.from_hp(torch.randn(5), [1])

# Test wrong block_size validation
with self.assertRaises((AssertionError, ValueError, RuntimeError)):
Int8PlainInt8Tensor.from_hp(self.weight_fp, [1])

# Test dequantization with exact values
test_data = torch.tensor([[1.0, -1.0]], dtype=torch.float32)
tensor = Int8PlainInt8Tensor.from_hp(test_data, [1, 1])

dequantized = tensor.dequantize()
self.assertEqual(dequantized.shape, test_data.shape)
self.assertLess(torch.abs(dequantized - test_data).max().item(), 0.1)


if __name__ == "__main__":
run_tests()
2 changes: 2 additions & 0 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
Int4PreshuffledTensor,
Int4Tensor,
Int4TilePackedTo4dTensor,
Int8PlainInt8Tensor,
IntxOpaqueTensor,
IntxUnpackedToInt8Tensor,
)
Expand Down Expand Up @@ -168,6 +169,7 @@
"IntxOpaqueTensor",
"IntxUnpackedToInt8Tensor",
"Int4TilePackedTo4dTensor",
"Int8PlainInt8Tensor",
"Float8Tensor",
"Int4OpaqueTensor",
# smooth quant - subject to change
Expand Down
3 changes: 3 additions & 0 deletions torchao/quantization/quantize_/workflows/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from int8.int8_tensor import Int8PlainInt8Tensor

from .float8.float8_tensor import (
Float8Tensor,
QuantizeTensorToFloat8Kwargs,
Expand Down Expand Up @@ -36,6 +38,7 @@
"Int4MarlinSparseTensor",
"Int4PlainInt32Tensor",
"Int4TilePackedTo4dTensor",
"Int8PlainInt8Tensor",
"Float8Tensor",
"QuantizeTensorToFloat8Kwargs",
"Int4OpaqueTensor",
Expand Down
106 changes: 106 additions & 0 deletions torchao/quantization/quantize_/workflows/int8/int8_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.


import torch

from torchao.utils import TorchAOBaseTensor

__all__ = ["Int8PlainInt8Tensor"]

aten = torch.ops.aten


# TODO: Implement block-wise quantization using block_size
class Int8PlainInt8Tensor(TorchAOBaseTensor):
"""
int8 quantized tensor with plain layout

Tensor Attributes:
qdata: (N, K) int8 quantized weight data
scale: scale factors for dequantization
zero_point: zero points for dequantization

Non-Tensor Attributes:
block_size: block size for quantization granularity
shape: original tensor shape
"""

tensor_data_names = ["qdata", "scale", "zero_point"]
tensor_attribute_names = ["block_size"]

def __new__(cls, qdata, scale, zero_point, block_size, shape):
kwargs = {"device": qdata.device, "dtype": scale.dtype, "requires_grad": False}
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)

def __init__(self, qdata, scale, zero_point, block_size, shape):
self.qdata = qdata
self.scale = scale
self.zero_point = zero_point
self.block_size = block_size

@classmethod
def from_hp(cls, w: torch.Tensor, block_size: list[int]):
if w.dim() != 2 or len(block_size) != 2:
raise ValueError("Expected 2D tensor and block_size length 2")

# Rounding function from high precision dtype
scale = w.abs().max(dim=-1, keepdim=True)[0] / 127.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like block_size is not used? why is that?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can checkout

def _linear_fp_act_int8_weight_check(input_tensor, weight_tensor, bias):
for expected granularity

also this should be using these quant primitive ops:

scale, zero_point = choose_qparams_affine(
input=preprocessed_w,
mapping_type=MappingType.SYMMETRIC,
block_size=block_size,
target_dtype=target_dtype,
quant_min=quant_min,
quant_max=quant_max,
eps=1e-6,
)
wq = quantize_affine(
input=preprocessed_w,
block_size=block_size,
scale=scale,
zero_point=zero_point,
output_dtype=target_dtype,
quant_min=quant_min,
quant_max=quant_max,
)
, arguments can be found by tracing through the code path for int8 in
new_weight = to_affine_quantized_intx(
and
scale, zero_point = choose_qparams_affine(

this might require a bit too much context, let me know if you would like us to take over

Copy link
Contributor Author

@namgyu-youn namgyu-youn Sep 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, surely want to take over! Drafted this PR for those updates, but will look into it today (6 hours later)

btw, version 2 is updated at c53dad0 (version 1 is default)

scale = scale.clamp(min=1e-6)

int_data = torch.round(w / scale).clamp(-128, 127).to(torch.int8)

return cls(
int_data,
scale.squeeze(-1),
torch.zeros_like(scale.squeeze(-1), dtype=torch.int8),
block_size,
w.shape,
)


implements = Int8PlainInt8Tensor.implements


@implements([aten.dequantize.self])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this needed? if not we should remove for now

def _(func, types, args, kwargs):
"""dequantization: int8 -> float"""
tensor = args[0]
return (
tensor.qdata.to(tensor.scale.dtype)
- tensor.zero_point.to(tensor.scale.dtype).unsqueeze(1)
) * tensor.scale.unsqueeze(1)


@implements([torch.nn.functional.linear, aten.linear.default])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def _(func, types, args, kwargs):
"""quantization: float -> int8"""
input_tensor, weight_tensor, bias = (
args[0],
args[1],
args[2] if len(args) > 2 else None,
)

if isinstance(input_tensor, Int8PlainInt8Tensor):
# INT8 × INT8
x_int32 = input_tensor.qdata.to(torch.int32)
w_int32 = weight_tensor.qdata.to(torch.int32).t()

result = torch.mm(x_int32.view(-1, x_int32.size(-1)), w_int32)
scale = input_tensor.scale.view(-1, 1) * weight_tensor.scale.unsqueeze(0)
result = result.to(scale.dtype) * scale
result = result.view(*input_tensor.shape[:-1], -1)
else:
# FP × INT8
result = torch.nn.functional.linear(
input_tensor, weight_tensor.dequantize(), None
)

return result + bias if bias is not None else result


Int8PlainInt8Tensor.__module__ = "torchao.quantization"
torch.serialization.add_safe_globals([Int8PlainInt8Tensor])