Skip to content
8 changes: 5 additions & 3 deletions src/python/py/models/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def parse_hf_token(hf_token):
def set_io_dtype(precision, execution_provider, extra_options) -> ir.DataType:
int4_cpu = precision == "int4" and execution_provider == "cpu"
fp32_webgpu = execution_provider == "webgpu" and extra_options.get("use_webgpu_fp32", False)
bf16_cuda = precision == "int4" and execution_provider == "cuda" and extra_options.get("use_cuda_bf16", False)
bf16_cuda = precision == "int4" and execution_provider in {"cuda", "trt-rtx"} and extra_options.get("use_cuda_bf16", False)

if precision in {"int8", "fp32"} or int4_cpu or fp32_webgpu:
# FP32 precision
Expand Down Expand Up @@ -403,8 +403,10 @@ def get_args():
2 is fp16.
1 is fp32.
Default is 4 for the CPU EP and 0 for non-CPU EPs.
int4_block_size = 16/32/64/128/256: Specify the block size for int4 quantization.
int4_block_size = 16/32/64/128/256: Specify the block size for int4 quantization (MatMulNBits).
Default value is 32.
qmoe_block_size = 16/32/64/128/256: Specify the block size for QMoE expert weights quantization.
Default is 128 for trt-rtx, 32 for others. Supported EPs: cpu, webgpu, trt-rtx.
int4_is_symmetric = Quantize the weights symmetrically. Default is true.
If true, quantization is done to int4. If false, quantization is done to uint4.
int4_op_types_to_quantize = MatMul/Gather: Specify op types to target for int4 quantization.
Expand Down Expand Up @@ -469,7 +471,7 @@ def get_args():

args = parser.parse_args()
print(
"Valid precision + execution provider combinations are: FP32 CPU, FP32 CUDA, FP16 CUDA, FP16 DML, BF16 CUDA, FP16 TRT-RTX, INT4 CPU, INT4 CUDA, INT4 DML, INT4 WebGPU"
"Valid precision + execution provider combinations are: FP32 CPU, FP32 CUDA, FP16 CUDA, FP16 DML, BF16 CUDA, FP16 TRT-RTX, BF16 TRT-RTX, INT4 CPU, INT4 CUDA, INT4 DML, INT4 WebGPU"
)
return args

Expand Down
53 changes: 34 additions & 19 deletions src/python/py/models/builders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,19 +359,26 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
int4_algo_config = self.make_int4_algo_config(extra_options.get("int4_algo_config", "default"))
self.int4_block_size = extra_options.get("int4_block_size", 32)

# Validate that only CPU and WebGPU EPs support int4_block_size for QMoE
if self.ep not in ["cpu", "webgpu"] and "int4_block_size" in extra_options and moe_op_type == "QMoE":
# CPU, WebGPU, and TRT-RTX support block-wise quantization for QMoE.
# TRT-RTX defaults to 128; others default to 32 for consistency with MatMulNBits.
supported_blockwise_eps = ["cpu", "webgpu", "trt-rtx"]
default_qmoe_block_size = 128 if self.ep == "trt-rtx" else 32
self.qmoe_block_size = int(extra_options.get("qmoe_block_size", default_qmoe_block_size))

# Validate that unsupported EPs don't explicitly request block-wise quantization
if self.ep not in supported_blockwise_eps and "qmoe_block_size" in extra_options and moe_op_type == "QMoE":
raise ValueError(
f"The 'int4_block_size' option is not supported for {self.ep} execution provider with QMoE. "
"Block-wise quantization (block_size attribute) is only supported for CPU and WebGPU execution providers."
f"The 'qmoe_block_size' option is not supported for {self.ep} execution provider with QMoE. "
f"Block-wise quantization is only supported for: {', '.join(supported_blockwise_eps)}."
)

self.quant_attrs = {
"int4": {
"accuracy_level": int(
extra_options.get("int4_accuracy_level", 4 if self.ep in ["cpu", "webgpu"] else 0)
),
"block_size": int(self.int4_block_size),
"qmoe_block_size": int(self.qmoe_block_size),
"qdq_block_size": int(self.int4_block_size),
"is_symmetric": extra_options.get("int4_is_symmetric", True),
"op_types_to_quantize": extra_options.get("int4_op_types_to_quantize", ("MatMul",)),
"nodes_to_exclude": extra_options.get("int4_nodes_to_exclude", []),
Expand All @@ -380,11 +387,11 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
"use_qdq": extra_options.get("use_qdq", False),
}

# Propagate block_size to MoE/QMoE op when supported and requested.
# QMoE on CPU/WebGPU supports block-wise quantization via the 'block_size' attribute.
# Propagate block_size to MoE/QMoE op when supported.
# QMoE on supported EPs uses block-wise quantization via the 'block_size' attribute.
# Ensure the attribute is set on the MoE op so runtime kernels can honor it.
if self.moe_attrs.get("op_type") == "QMoE" and self.ep in ["cpu", "webgpu"]:
self.moe_attrs["block_size"] = int(self.int4_block_size)
if self.moe_attrs.get("op_type") == "QMoE" and self.ep in supported_blockwise_eps:
self.moe_attrs["block_size"] = int(self.qmoe_block_size)
if self.quant_type is not None:
# Create quantized attributes from quantization config
self.quant_attrs["config"] = config.quantization_config
Expand Down Expand Up @@ -501,6 +508,7 @@ def is_gqa_supported(self) -> bool:
("webgpu", ir.DataType.FLOAT16),
("webgpu", ir.DataType.FLOAT),
("trt-rtx", ir.DataType.FLOAT16),
("trt-rtx", ir.DataType.BFLOAT16),
}
return (self.ep, self.io_dtype) in valid_gqa_configurations

Expand Down Expand Up @@ -703,7 +711,7 @@ def make_int4_algo_config(self, quant_method: str):
def to_int4(self) -> ir.Model:
quant = MatMulNBitsQuantizer(
model=ir.to_proto(self.model),
block_size=self.quant_attrs["int4"]["block_size"],
block_size=self.quant_attrs["int4"]["qdq_block_size"],
is_symmetric=self.quant_attrs["int4"]["is_symmetric"],
accuracy_level=self.quant_attrs["int4"]["accuracy_level"],
nodes_to_exclude=self.quant_attrs["int4"]["nodes_to_exclude"],
Expand Down Expand Up @@ -3227,11 +3235,17 @@ def make_qmoe_op(self, name, **kwargs):
kwargs.get("weight3", ""),
kwargs.get("scales3", ""),
kwargs.get("bias3", ""),
kwargs.get("zero_points1", ""),
kwargs.get("zero_points2", ""),
Comment thread
anujj marked this conversation as resolved.
kwargs.get("zero_points3", ""),
]

# TRT-RTX doesn't support zero_points inputs at all
Comment thread
anujj marked this conversation as resolved.
# For other EPs, always include as optional inputs (even empty strings)
if self.ep != "trt-rtx":
inputs.extend([
kwargs.get("zero_points1", ""),
kwargs.get("zero_points2", ""),
kwargs.get("zero_points3", ""),
])

output = f"{name}/output_0"

extra_kwargs = (
Expand Down Expand Up @@ -3264,21 +3278,21 @@ def make_qmoe_weights(self, weights):
dtype = torch.quint4x2 if self.moe_attrs["expert_weight_bits"] == 4 else torch.int8
qweight, scales = None, None

# For QMoE, only use block-wise quantization when explicitly requested
# via int4_block_size and when using CPU or WebGPU execution providers, since
# block_size is only supported for these EPs in the QMoE operator.
use_blockwise_quant = "int4_block_size" in self.extra_options and self.ep in ["cpu", "webgpu"]
# Use block-wise quantization for supported EPs when qmoe_block_size > 0.
# TRT-RTX defaults to 128; others default to 32.
supported_blockwise_eps = ["cpu", "webgpu", "trt-rtx"]
use_blockwise_quant = self.ep in supported_blockwise_eps and self.qmoe_block_size > 0

if use_blockwise_quant:
block_size = self.quant_attrs["int4"]["block_size"]
block_size = self.quant_attrs["int4"]["qmoe_block_size"]
try:
qweight, scales = self._symmetric_blockwise_quantize(weights, block_size)
self.moe_attrs["block_size"] = block_size
return qweight, scales.to(torch.float16)
except Exception as e:
raise RuntimeError(f"Block-wise quantization failed with block_size={block_size}: {e}")

# Use tensor-level quantization (default for QMoE)
# Use tensor-level quantization (default for QMoE on CPU/WebGPU when not explicitly requested)
self.moe_attrs["block_size"] = 0

# Existing tensor-level quantization implementation (fallback)
Expand Down Expand Up @@ -3354,6 +3368,7 @@ def _symmetric_blockwise_quantize(self, weights, block_size):

quantized_flat = quantized_int8.view(*original_shape[:-1], num_blocks * block_size)

# remove padding
if pad_size > 0:
quantized_flat = quantized_flat[..., :-pad_size]

Expand Down
28 changes: 20 additions & 8 deletions src/python/py/models/builders/gptoss.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def make_attention(self, layer_id, attention, root_input, **kwargs):
self.window_size = original_window_size

def make_moe(self, layer_id, mlp, root_input):
if self.ep in {"cpu", "cuda"}:
if self.ep in {"cpu", "cuda", "trt-rtx"}:
self.make_moe_fused(layer_id, mlp, root_input)
else:
self.make_moe_decomposed(layer_id, mlp, root_input)
Expand Down Expand Up @@ -639,24 +639,32 @@ def make_moe_fused(self, layer_id, mlp, root_input):
down_proj_qweight_tensor = torch.stack(down_proj_qweight_list, dim=0).to(torch.uint8)
down_proj_scales_tensor = torch.stack(down_proj_scales_list, dim=0)

# qweight tensors always use the same shape regardless of quantization method
# Determine shape based on Quark vs non-Quark
pack_size = 8 // self.moe_attrs["expert_weight_bits"]
if has_quark_experts:
hidden_size_padded = self.hidden_size
intermediate_size_padded = self.intermediate_size
else:
hidden_size_padded = gate_up_proj_qweight_list[0].shape[-1] * pack_size
intermediate_size_padded = down_proj_qweight_list[0].shape[-1] * pack_size

# Save qweight tensors
self.make_initializer(
gate_up_proj_qweight_tensor.view(self.moe_attrs["num_experts"], -1, self.hidden_size // pack_size),
gate_up_proj_qweight_tensor.view(self.moe_attrs["num_experts"], -1, hidden_size_padded // pack_size),
gate_up_proj_weight,
)
self.make_initializer(
down_proj_qweight_tensor.view(
self.moe_attrs["num_experts"], self.hidden_size, self.intermediate_size // pack_size
self.moe_attrs["num_experts"], self.hidden_size, intermediate_size_padded // pack_size
),
down_proj_weight,
)

# scales tensors have different shapes depending on quantization method
# Save scales tensors
self.make_initializer(gate_up_proj_scales_tensor, gate_up_proj_scales, to=self.io_dtype)
self.make_initializer(down_proj_scales_tensor, down_proj_scales, to=self.io_dtype)

# Save MoE biases as initializers
# Save biases (shared for all paths)
if has_quark_experts:
gate_up_bias = self.combine_quark_gate_up_biases_from_experts(mlp.experts)
down_bias = self.combine_quark_down_biases_from_experts(mlp.experts)
Expand All @@ -667,7 +675,11 @@ def make_moe_fused(self, layer_id, mlp, root_input):
self.make_initializer(gate_up_bias, gate_up_proj_bias, to=self.io_dtype)
self.make_initializer(down_bias, down_proj_bias, to=self.io_dtype)

# Single make_moe_op call with EP-based zero_points
# TRT-RTX doesn't support zero_points inputs
moe_name = f"{basename}/{op_type}"
use_zero_points = has_quark_experts and self.ep != "trt-rtx"

self.make_moe_op(
moe_name,
root_input=root_input,
Expand All @@ -678,8 +690,8 @@ def make_moe_fused(self, layer_id, mlp, root_input):
weight2=down_proj_weight,
scales2=down_proj_scales,
bias2=down_proj_bias,
zero_points1=gate_up_proj_zero_points if has_quark_experts else "",
zero_points2=down_proj_zero_points if has_quark_experts else "",
zero_points1=gate_up_proj_zero_points if use_zero_points else "",
zero_points2=down_proj_zero_points if use_zero_points else "",
)

# Assign output 0 of previous MoE as root input to next SkipLayerNorm
Expand Down
Loading