Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tensorrt_llm/_torch/auto_deploy/config/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ transforms:
# see https://github.com/NVIDIA/TensorRT-LLM/pull/3668#discussion_r2052714528
optimize_rope:
stage: pattern_matcher
quantize_int4_linear_from_config:
stage: pattern_matcher
quantize_fp8_linear_from_config:
stage: pattern_matcher
quantize_nvfp4_linear_from_config:
Expand Down
43 changes: 43 additions & 0 deletions tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from tensorrt_llm._torch.auto_deploy.utils.quantization_utils import (
cutlass_fp4_scale_to_modelopt_fp4_scale,
unpack_uint8_to_int4_weight_2d,
)

# FP4 tables (E2M1)
Expand Down Expand Up @@ -276,3 +277,45 @@ def torch_fake_quant_nvfp4_linear(
weight_zp: List[torch.Tensor],
) -> torch.Tensor:
return torch.ops.aten.linear(input, weight_quantized.repeat(1, 2).to(input.dtype), bias)


@torch.library.custom_op("auto_deploy::torch_fake_quant_int4_linear", mutates_args=())
def torch_fake_quant_int4_linear(
input: torch.Tensor, # [..., K]
weight_quantized: torch.Tensor, # [N//2, K] unit8 (packed)
bias: Optional[torch.Tensor], # [N] or None
input_scale: List[torch.Tensor], # [ pre_quant_scale ]
weight_scale: List[torch.Tensor], # [ weight_scale ]
input_zp: List[torch.Tensor],
weight_zp: List[torch.Tensor],
) -> torch.Tensor:
BLOCK_SIZE = 128
# activation pre-scale
pre_quant_scale = input_scale[0].to(dtype=input.dtype)
x_scaled = torch.mul(input, pre_quant_scale)

q_int4 = unpack_uint8_to_int4_weight_2d(weight_quantized, weight_scale[0]) # (N,K), int8
amax_2d = (weight_scale[0] * 7.0).to(input.dtype) # (N, K//128)

scale_blocks = (7.0 / amax_2d).to(torch.float32) # (N, K//128)
scale_full = scale_blocks.repeat_interleave(BLOCK_SIZE, dim=1) # (N,K)

# Dequantize
w_deq = (q_int4.to(torch.float32) / scale_full).to(input.dtype)

return torch.ops.auto_deploy.torch_linear_simple.default(x_scaled, w_deq, bias)


@torch_fake_quant_int4_linear.register_fake
def _fake(
input: torch.Tensor,
weight_quantized: torch.Tensor,
bias: Optional[torch.Tensor],
input_scale: List[torch.Tensor],
weight_scale: List[torch.Tensor],
input_zp: List[torch.Tensor],
weight_zp: List[torch.Tensor],
) -> torch.Tensor:
N_half = weight_quantized.shape[-2]
N = N_half * 2
return torch.empty((*input.shape[:-1], N), dtype=input.dtype, device=input.device)
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ def has(cls, reader_cls: str) -> bool:

@QuantConfigReaderRegistry.register("modelopt")
class ModelOPTQuantConfigReader(QuantConfigReader):
_ALWAYS_EXCLUDE = ("lm_head", "model.embed_tokens")

def read_config(self, config: Dict) -> Dict:
producer = config.get("producer", {}).get("name")
# sanity check
Expand All @@ -91,7 +93,10 @@ def read_config(self, config: Dict) -> Dict:

quant_config = config.get("quantization", {})
# Inject default exclusion, add "model.embed_tokens" for "tie_word_embedding:true" case
quant_config.setdefault("exclude_modules", ["lm_head", "model.embed_tokens"])
excludes = quant_config.get("exclude_modules", [])
quant_config["exclude_modules"] = excludes + [
n for n in self._ALWAYS_EXCLUDE if n not in excludes
]
# Update dtype
if quant_config.get("quant_algo") == "NVFP4":
quant_config["torch_dtype"] = "float16"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,89 @@ def convert_amax_hook(self, state_dict, prefix, *args, scale_name: str, amax_nam
state_dict[scale_name] = scale


@TransformRegistry.register("quantize_int4_linear_from_config")
class INT4LinearQuantizationFromConfig(Quantization):
"""Config-based INT4 (AWQ) for the unified ModelOpt checkpoints."""

algo_name = "W4A16_AWQ"

@staticmethod
def target_op():
return torch.ops.auto_deploy.torch_fake_quant_int4_linear.default

@staticmethod
def quantize_weight(original_weight: torch.Tensor) -> torch.Tensor:
N, K = original_weight.shape
return torch.empty((N // 2, K), dtype=torch.uint8, device=original_weight.device)

@staticmethod
def scale_names() -> List[str]:
return ["pre_quant_scale", "weight_scale"]

@staticmethod
def default_scales(original_weight_shape: Tuple) -> Dict[str, torch.Tensor]:
N, K = original_weight_shape
BLOCK = 128
assert K % BLOCK == 0, "K must be divisible by 128 for INT4 block quant."
return {
"pre_quant_scale": torch.ones(K, dtype=torch.float32),
"weight_scale": torch.empty((N, K // BLOCK), dtype=torch.float32),
}

@staticmethod
def build_custom_args_for_linear(scales: Dict[str, Node]) -> Tuple[object, ...]:
return ([scales["pre_quant_scale"]], [scales["weight_scale"]], [], [])

@staticmethod
def load_hook(state_dict, prefix, *args, weight_name: str):
"""
Unified ckpt passthrough:
- weight: keep packed uint8 (N//2, K)
- pre_quant_scale buffer: (K,) or ones(K) if missing
- weight_scale buffer: (N, K//128) float32 (no reshape, no *7 here)
"""
if weight_name not in state_dict:
return
BLOCK = 128

mod_prefix, _, _ = weight_name.rpartition(".")
pre_qs_ckpt = f"{mod_prefix}.pre_quant_scale" # may be absent
wscale_ckpt = f"{mod_prefix}.weight_scale" # required

pre_qs_buf = f"{mod_prefix}.pre_quant_scale"
wscale_buf = f"{mod_prefix}.weight_scale"

w_packed = state_dict[weight_name]
if w_packed.dtype != torch.uint8:
return

assert wscale_ckpt in state_dict, f"Missing {wscale_ckpt}"
wscale_mat = state_dict[wscale_ckpt] # (N, K//128) float32

N_half, K = w_packed.shape
N = N_half * 2
assert K % BLOCK == 0
assert wscale_mat.shape == (N, K // BLOCK), (
f"weight_scale shape {wscale_mat.shape} != {(N, K // BLOCK)}"
)

# pre_quant_scale: use if present else ones(K)
if pre_qs_ckpt in state_dict:
pre_qs_val = state_dict[pre_qs_ckpt].to(torch.float32)
if pre_qs_val.dim() == 0:
pre_qs_val = pre_qs_val.expand(K).clone()
else:
assert pre_qs_val.numel() == K, (
f"{pre_qs_ckpt} has {pre_qs_val.numel()} elems, expected {K}"
)
else:
pre_qs_val = torch.ones(K, dtype=torch.float32)

state_dict[weight_name] = w_packed # (N//2, K) uint8
state_dict[pre_qs_buf] = pre_qs_val # (K,) float32
state_dict[wscale_buf] = wscale_mat.to(torch.float32) # (N, K//128)


@TransformRegistry.register("quantize_fp8_bmm_from_config")
class FP8BMMQuantizationFromConfig(Quantization):
algo_name = "FP8"
Expand Down
64 changes: 64 additions & 0 deletions tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,67 @@ def extract_scales_from_node(node: Node, scale_names: list[str]) -> Dict[str, Op
scales[name] = args[3 + i]

return scales


def unpack_uint8_to_int4_weight_2d(
packed_weight: torch.Tensor, weights_scaling_factor: torch.Tensor
) -> torch.Tensor:
"""
Reverse of `modelopt.torch.export.quant_utils.pack_int4_in_uint8` for the 2D case.
Args:
packed_weight: (out_dim//2, in_dim), uint8
weights_scaling_factor: (out_dim, in_dim//block_size) [used for shape/block inference]
Returns:
int8 weights in [-8,7], shape (out_dim, in_dim)
"""
assert packed_weight.dim() == 2
assert packed_weight.dtype == torch.uint8

out_half, in_dim = packed_weight.shape
out_dim = out_half * 2
block_size = in_dim // weights_scaling_factor.shape[-1]
assert in_dim % block_size == 0

# inverse of: reshaped = int8_tensor.T.reshape(in_dim, out_dim//2, 2)
pw = packed_weight.T.contiguous() # (in_dim, out_dim//2)

low = (pw & 0x0F).to(torch.int16)
high = ((pw >> 4) & 0x0F).to(torch.int16)

low = torch.where(low >= 8, low - 16, low).to(torch.int8)
high = torch.where(high >= 8, high - 16, high).to(torch.int8)

rebuilt = torch.stack([low, high], dim=-1) # (in_dim, out_dim//2, 2)
int8_T = rebuilt.reshape(in_dim, out_dim) # (in_dim, out_dim)
int8_W = int8_T.T.contiguous() # (out_dim, in_dim)
return int8_W


# copied from modelopt.torch.export.quant_utils.pack_int4_in_uint8
def pack_int4_in_uint8(weight, weights_scaling_factor):
"""Packs the INT4 weights into uint8 tensor."""
out_dim = weight.shape[-2]
assert out_dim % 2 == 0, f"Cannot pack weight. Out dimension {out_dim} is not an even number."
in_dim = weight.shape[-1]
block_size = weight.shape[-1] // weights_scaling_factor.shape[-1]
int8_tensor = (
(weight / weights_scaling_factor[..., :, torch.arange(in_dim) // block_size])
.round()
.clamp(-8, 7)
.to(torch.int8)
)
# -- Handle the MoE (3D) case vs. the 2D case --
if int8_tensor.dim() == 3:
transpose = int8_tensor.permute(0, 2, 1)
transpose = transpose.reshape(-1, in_dim, out_dim // 2, 2)
val0 = transpose[..., 0] & 0x0F
val1 = transpose[..., 1] & 0x0F
packed_byte = val0 | (val1 << 4)
return packed_byte.permute(0, 2, 1).contiguous().view(torch.uint8)
else:
# 2D weights: shape typically (out_dim, in_dim)
reshaped = int8_tensor.T.reshape(in_dim, out_dim // 2, 2)
val0 = reshaped[..., 0] & 0x0F
val1 = reshaped[..., 1] & 0x0F
packed_byte = val0 | (val1 << 4)
return packed_byte.T.contiguous().view(torch.uint8)
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,16 @@
from _torch_test_utils import fp4_compatible, fp8_compatible, trtllm_ops_available

import tensorrt_llm._torch.auto_deploy.custom_ops # noqa: F401
from tensorrt_llm._torch.auto_deploy.utils.quantization_utils import fp4_global_scale
from tensorrt_llm._torch.auto_deploy.utils.quantization_utils import (
fp4_global_scale,
pack_int4_in_uint8,
unpack_uint8_to_int4_weight_2d,
)

torch.manual_seed(0)

SCALING_VECTOR_SIZE = 16 # NVFP4 block size along K
INT4_BLOCK_SIZE = 128


@pytest.mark.parametrize("bias", [torch.rand(32).to("cuda") * 10, None])
Expand Down Expand Up @@ -194,3 +199,102 @@ def test_quant_linear_nvfp4_matches_fused_op(bias):

assert out_unified.shape == out_fused.shape
torch.testing.assert_close(out_unified, out_fused, rtol=1e-3, atol=5e-3)


def test_int4awq_unpack_roundtrip():
"""Pack (via provided pack_int4_in_uint8) -> Unpack should exactly recover the signed INT4 grid."""
device = "cuda"
dtype = torch.float32

N, K = 64, 256
assert K % INT4_BLOCK_SIZE == 0 and N % 2 == 0

W = (torch.randn(N, K, device=device, dtype=dtype) * 2.0).contiguous()

# Per-block amax along K -> shape (N, K//INT4_BLOCK_SIZE)
Wv = W.view(N, K // INT4_BLOCK_SIZE, INT4_BLOCK_SIZE)
amax_blocks = Wv.abs().amax(dim=-1).to(torch.float32) # (N, K//128)

# The packer expects weights_scaling_factor with shape (N, K//INT4_BLOCK_SIZE)
# and quantizes as: round(W / factor).clamp(-8,7)
weights_scaling_factor = (amax_blocks / 7.0).to(torch.float32)

# Build the exact expected INT4 integers ([-8, 7]) the packer produces
col_idx = torch.arange(K, device=device)
block_idx = col_idx // INT4_BLOCK_SIZE # (K,)
scale_full = weights_scaling_factor[:, block_idx] # (N, K)
q_ref = torch.round(W / (scale_full + 1e-12)).clamp(-8, 7).to(torch.int8)

# Pack with the provided function, then unpack with the UUT
packed = pack_int4_in_uint8(W, weights_scaling_factor) # (N//2, K), uint8
assert packed.dtype == torch.uint8 and packed.shape == (N // 2, K)

q_unpacked = unpack_uint8_to_int4_weight_2d(packed, weights_scaling_factor)
assert q_unpacked.dtype == torch.int8 and q_unpacked.shape == (N, K)

# Integer path should be exact
torch.testing.assert_close(q_unpacked, q_ref, rtol=0, atol=0)


@pytest.mark.parametrize("bias_opt", [None, "with_bias"])
@pytest.mark.parametrize("input_dtype", [torch.float16, torch.bfloat16])
def test_fake_quant_int4_linear_matches_fp_reference(bias_opt, input_dtype):
"""Use provided pack_int4_in_uint8 with weights_scaling_factor=amax/7.
Compare op output to a separately-computed dequant reference, and sanity-check vs FP32.
"""
device = "cuda"
torch.cuda.manual_seed(0)

B, K, N = 3, 512, 128
assert K % INT4_BLOCK_SIZE == 0 and N % 2 == 0

x = torch.randn(B, K, device=device, dtype=input_dtype)
W = torch.randn(N, K, device=device, dtype=torch.float32) * 1.5

Wv = W.view(N, K // INT4_BLOCK_SIZE, INT4_BLOCK_SIZE)
amax_blocks = Wv.abs().amax(dim=-1).to(torch.float32) # (N, K//128)
weights_scaling_factor = (amax_blocks / 7.0).to(torch.float32)

packed = pack_int4_in_uint8(W, weights_scaling_factor) # (N//2, K), uint8
assert packed.dtype == torch.uint8
assert packed.shape == (N // 2, K)

bias = None
if bias_opt == "with_bias":
bias = (torch.randn(N, device=device, dtype=input_dtype) * 0.1).contiguous()

s_in = torch.tensor(1.0, device=device, dtype=input_dtype)
out_int4 = torch.ops.auto_deploy.torch_fake_quant_int4_linear(
x, # [..., K]
packed, # [N//2, K], uint8
bias, # [N] or None
[s_in], # input_scale: [pre_quant_scale]
[weights_scaling_factor], # weight_scale: [amax/7]
[], # input_zp
[], # weight_zp
)

# a separate FP reference path that mirrors the op
q_unpacked = unpack_uint8_to_int4_weight_2d(packed, weights_scaling_factor).to(
torch.float32
) # (N, K)

# Mirror op’s casting order: amax_2d cast to input dtype before compute scale_blocks
amax_2d = (weights_scaling_factor * 7.0).to(x.dtype) # (N, K//128)
scale_blocks = (7.0 / (amax_2d + 1e-12)).to(torch.float32) # (N, K//128)
scale_full = scale_blocks.repeat_interleave(INT4_BLOCK_SIZE, dim=1) # (N, K)

# Dequant + same linear op as the op
w_deq = (q_unpacked / scale_full).to(x.dtype)
x_scaled = (x * s_in).to(x.dtype)
out_ref_deq = torch.ops.auto_deploy.torch_linear_simple.default(x_scaled, w_deq, bias)

torch.testing.assert_close(out_int4, out_ref_deq, rtol=1e-5, atol=1e-5)
# Sanity: closeness vs original FP32 weights (quantization error budget)
out_fp32 = torch.nn.functional.linear(
x.to(torch.float32),
W,
bias.to(torch.float32) if bias is not None else None,
).to(out_int4.dtype)
cos = F.cosine_similarity(out_fp32.reshape(-1), out_int4.reshape(-1), dim=0)
assert cos > 0.98
Loading