Skip to content

Commit

Permalink
Add generic fake quantized embedding for QAT
Browse files Browse the repository at this point in the history
Summary: This is equivalent to #1020
but for nn.Embedding. This commit adds a generic fake quantized
embedding module to replace the uses of the existing more specific
QAT embeddings. For example, `Int4WeightOnlyQATEmbedding` can be
expressed as follows:

```
from torchao.quantization.prototype.qat.api import FakeQuantizeConfig
from torchao.quantization.prototype.qat.embedding import FakeQuantizedEmbedding

weight_config = FakeQuantizeConfig(
    dtype=torch.int4,
    group_size=group_size,
    is_symmetric=True,
)
fq_embedding = FakeQuantizedEmbedding(16, 32, weight_config=weight_config)
```

Test Plan:
python test/quantization/test_qat.py -k test_qat_4w_embedding
python test/quantization/test_qat.py -k test_fake_quantized_embedding_4w
  • Loading branch information
andrewor14 committed Oct 16, 2024
1 parent 48bc81c commit 997e2ce
Show file tree
Hide file tree
Showing 2 changed files with 153 additions and 53 deletions.
37 changes: 37 additions & 0 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
from torchao.quantization.prototype.qat.fake_quantizer import (
FakeQuantizer,
)
from torchao.quantization.prototype.qat.embedding import (
FakeQuantizedEmbedding,
)
from torchao.quantization.prototype.qat.linear import (
FakeQuantizedLinear,
)
Expand Down Expand Up @@ -852,6 +855,40 @@ def linear_forward_4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
baseline_out = linear_forward_4w(x2, fq_linear.weight)
torch.testing.assert_close(baseline_out, fq_out, atol=0, rtol=0)

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
def test_fake_quantized_embedding_4w(self):
"""
Test that we can express int4 per group symmetric weight only fake quantization
with `FakeQuantizedEmbedding`.
"""
num_embeddings = 64
embedding_dim = 128
group_size = 32
torch.manual_seed(self.SEED)
fq_embedding = FakeQuantizedEmbedding(
num_embeddings,
embedding_dim,
weight_config=FakeQuantizeConfig(TorchAODType.INT4, group_size=group_size),
)

def embedding_forward_4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
"""
Baseline for int4 per group symmetric weight only fake quantization.
"""
(s, zp) = get_group_qparams_symmetric(weight, 4, group_size, torch.float32)
zp = zp.to(torch.int32)
(qmin, qmax) = _get_qmin_qmax(4)
w_fq = _fake_quantize_per_channel_group(weight, s, zp, qmin, qmax, group_size)
return F.embedding(x, w_fq)

# Compare embedding values
torch.manual_seed(self.SEED)
x = torch.randint(num_embeddings, (5, 10))
x2 = copy.deepcopy(x)
fq_out = fq_embedding(x)
baseline_out = embedding_forward_4w(x2, fq_embedding.weight)
torch.testing.assert_close(baseline_out, fq_out, atol=0, rtol=0)


if __name__ == "__main__":
unittest.main()
169 changes: 116 additions & 53 deletions torchao/quantization/prototype/qat/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,73 @@
from torchao.quantization.quant_api import (
_replace_with_custom_fn_if_matches_filter,
)
from torchao.quantization.quant_primitives import TorchAODType
from .api import FakeQuantizeConfig
from .fake_quantizer import FakeQuantizer
from .utils import (
_fake_quantize_per_channel_group,
_get_qmin_qmax,
)


class FakeQuantizedEmbedding(torch.nn.Embedding):
"""
General embedding layer with fake quantized weights.
Specific target dtypes, granularity, schemes etc. are specified
through separate configs for weights and activations.
Example usage::
weight_config = FakeQuantizeConfig(
dtype=torch.int4,
group_size=8,
symmetric=True,
)
fq_embedding = FakeQuantizedEmbedding(5, 10, weight_config)
fq_embedding(torch.LongTensor([3]))
"""

def __init__(
self,
num_embeddings: int,
embedding_dim: int,
padding_idx: Optional[int] = None,
max_norm: Optional[float] = None,
norm_type: float = 2.0,
scale_grad_by_freq: bool = False,
sparse: bool = False,
weight_config: Optional[FakeQuantizeConfig] = None,
*args,
**kwargs,
) -> None:
super().__init__(
num_embeddings,
embedding_dim,
padding_idx,
max_norm,
norm_type,
scale_grad_by_freq,
sparse,
*args,
**kwargs,
)
if weight_config is not None:
self.weight_fake_quantizer = FakeQuantizer(weight_config)
else:
self.weight_fake_quantizer = None

def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.weight_fake_quantizer is not None:
w = self.weight_fake_quantizer(self.weight)
else:
w = self.weight
return F.embedding(
x, w, self.padding_idx, self.max_norm,
self.norm_type, self.scale_grad_by_freq, self.sparse,
)


# ======================================
# | Embedding int4 weight-only QAT |
# ======================================
Expand All @@ -40,7 +101,7 @@ def __init__(
self.bit_width = 4
self.group_size: int = group_size
self.scale_precision: torch.dtype = scale_precision
self.zero_point_precision: torch.dtype = zero_point_precision,
self.zero_point_precision: torch.dtype = zero_point_precision

def prepare(
self,
Expand All @@ -56,16 +117,18 @@ def filter_fn(child: torch.nn.Module, cur_fqn:str) -> bool:

def replacement_fn(child: torch.nn.Module) -> torch.nn.Module:
new_embedding = Int4WeightOnlyQATEmbedding(
group_size=self.group_size,

# other nn.Embedding args
# nn.Embedding args
num_embeddings=child.num_embeddings,
embedding_dim=child.embedding_dim,
padding_idx=child.padding_idx,
max_norm=child.max_norm,
norm_type=child.norm_type,
scale_grad_by_freq=child.scale_grad_by_freq,
sparse=child.sparse,
# quantization args
group_size=self.group_size,
scale_precision=self.scale_precision,
zero_point_precision=self.zero_point_precision,
device=child.weight.device,
)
# In distributed training, the model may be instantiated
Expand Down Expand Up @@ -98,28 +161,31 @@ def _convert_helper(self, module: torch.nn.Module):
from torchao._executorch_ops import _quantized_decomposed_quantize_per_channel_group_wrapper
for name, child in module.named_children():
if isinstance(child, Int4WeightOnlyQATEmbedding):
group_size = child.weight_fake_quantizer.config.group_size
scale_precision = child.weight_fake_quantizer.config.scale_precision
zero_point_precision = child.weight_fake_quantizer.config.zero_point_precision
quantized_embedding = Int4WeightOnlyEmbedding(
group_size=child.group_size,
scale_precision=child.scale_precision,
zero_point_precision=child.zero_point_precision,

# other nn.Embedding args
# nn.Embedding args
num_embeddings=child.num_embeddings,
embedding_dim=child.embedding_dim,
padding_idx=child.padding_idx,
max_norm=child.max_norm,
norm_type=child.norm_type,
scale_grad_by_freq=child.scale_grad_by_freq,
sparse=child.sparse,
# quantization args
group_size=group_size,
scale_precision=scale_precision,
zero_point_precision=zero_point_precision,
device=child.weight.device,
)
setattr(module, name, quantized_embedding)

# Load weights and qparams into quantized embedding
(qmin, qmax) = _get_qmin_qmax(self.bit_width)
(s, zp) = get_group_qparams_symmetric(child.weight, self.bit_width, child.group_size)
(s, zp) = get_group_qparams_symmetric(child.weight, self.bit_width, group_size)
q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper(
child.weight, s, zp, qmin, qmax, torch.int8, child.group_size,
child.weight, s, zp, qmin, qmax, torch.int8, group_size,
)
quantized_embedding.weight = q_weight
quantized_embedding.scales = s
Expand All @@ -128,7 +194,7 @@ def _convert_helper(self, module: torch.nn.Module):
self._convert_helper(child)


class Int4WeightOnlyQATEmbedding(torch.nn.Embedding):
class Int4WeightOnlyQATEmbedding(FakeQuantizedEmbedding):
"""
This module implements a embedding layer with int4 fake quantized
grouped per channel weights.
Expand All @@ -141,47 +207,42 @@ class Int4WeightOnlyQATEmbedding(torch.nn.Embedding):

def __init__(
self,
num_embeddings: int,
embedding_dim: int,
padding_idx: Optional[int] = None,
max_norm: Optional[float] = None,
norm_type: float = 2.0,
scale_grad_by_freq: bool = False,
sparse: bool = False,
group_size: int = 32,
scale_precision: torch.dtype = torch.float32,
zero_point_precision: torch.dtype = torch.int32,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.bit_width = 4
self.group_size = group_size
self.scale_precision = scale_precision
self.zero_point_precision = zero_point_precision
self._fake_quant_enabled = True

def forward(self, x):
weight = self.weight

if self._fake_quant_enabled:
(weight_scales, weight_zp) = get_group_qparams_symmetric(
self.weight, self.bit_width, self.group_size, self.scale_precision,
)
# TODO: pass zp dtype to `get_group_qparams_symmetric` instead
weight_zp = weight_zp.to(self.zero_point_precision)
(weight_qmin, weight_qmax) = _get_qmin_qmax(self.bit_width)
w_fq = _fake_quantize_per_channel_group(
self.weight,
weight_scales,
weight_zp,
weight_qmin,
weight_qmax,
self.group_size,
)
else:
w_fq = self.weight

return F.embedding(
x, w_fq, self.padding_idx, self.max_norm,
self.norm_type, self.scale_grad_by_freq, self.sparse,
weight_config = FakeQuantizeConfig(
dtype=TorchAODType.INT4,
group_size=group_size,
is_symmetric=True,
is_dynamic=True,
scale_precision=scale_precision,
zero_point_precision=zero_point_precision,
)
super().__init__(
num_embeddings,
embedding_dim,
padding_idx,
max_norm,
norm_type,
scale_grad_by_freq,
sparse,
weight_config,
*args,
**kwargs,
)

def enable_fake_quant(self, enabled: bool = True):
self._fake_quant_enabled = enabled
self.weight_fake_quantizer.enabled = enabled

def disable_fake_quant(self):
self.enable_fake_quant(False)
Expand All @@ -194,25 +255,21 @@ class Int4WeightOnlyEmbedding(torch.nn.Module):
"""
def __init__(
self,
group_size: int,
scale_precision: torch.dtype,
zero_point_precision: torch.dtype,

# nn.Embedding args
num_embeddings: int,
embedding_dim: int,
padding_idx: Optional[int] = None,
max_norm: Optional[float] = None,
norm_type: float = 2.0,
scale_grad_by_freq: bool = False,
sparse: bool = False,
group_size: int = 32,
scale_precision: torch.dtype = torch.float32,
zero_point_precision: torch.dtype = torch.int32,
device: torch.device = None,
):
super().__init__()
self.bit_width = 4
self.group_size = group_size
self.scale_precision = scale_precision
self.zero_point_precision = zero_point_precision

# nn.Embedding args
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
Expand All @@ -221,6 +278,12 @@ def __init__(
self.scale_grad_by_freq = scale_grad_by_freq
self.sparse = sparse

# quantization args
self.bit_width
self.group_size = group_size
self.scale_precision = scale_precision
self.zero_point_precision = zero_point_precision

# currently storing unpacked int8 weights
self.register_buffer(
"weight",
Expand Down

0 comments on commit 997e2ce

Please sign in to comment.