Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
ef874c0
✨ feat(npu): add online MXFP8 quantization support for Ascend NPU (Pa…
TallMessiWu Mar 18, 2026
d2d19c6
✨ feat(diffusion): add online MXFP8 quantization support for Wan2.2 o…
TallMessiWu Mar 18, 2026
c838ade
:bug: fix(diffusion): fix npu method call error
TallMessiWu Mar 19, 2026
be3b684
:bug: fix(diffusion): fix MXFP8 quantization scale dimension mismatch…
TallMessiWu Mar 19, 2026
fd79b23
:recycle: refactor(mxfp8): split linear method into config and NPU la…
TallMessiWu Mar 20, 2026
df61b29
:twisted_rightwards_arrows: merge: sync from upstream
TallMessiWu Mar 20, 2026
490ad0b
:sparkles: feat(diffusion): add offline MXFP8 pre-quantized weight su…
TallMessiWu Mar 20, 2026
cc80690
:bug: fix(diffusion): correct MXFP8 weight dtype and scale shape
TallMessiWu Mar 23, 2026
b9aa785
✨ feat(wan22): Redesigned the wan_repack tool. Now support one-click …
TallMessiWu Mar 24, 2026
22bee9e
:recycle: refactor(mxfp8): hoist imports and replace print with logger
TallMessiWu Mar 24, 2026
a29bb3d
:pencil2: fix(diffusion/mxfp8): address review comments on ModelSlimM…
TallMessiWu Mar 25, 2026
3bbf703
:twisted_rightwards_arrows: chore(merge): sync upstream/main, keep MX…
TallMessiWu Mar 25, 2026
250fe65
:adhesive_bandage: fix(diffusion): register --quantization CLI arg to…
TallMessiWu Mar 25, 2026
e146b03
:bug: fix(mxfp8_npu): move weight to current NPU device before quanti…
TallMessiWu Mar 25, 2026
711bb8b
:rewind: revert(llm): remove LLM MXFP8 online quantization (Path B) f…
TallMessiWu Mar 25, 2026
1604d4e
:twisted_rightwards_arrows: chore(merge): sync upstream/main into junlin
TallMessiWu Mar 31, 2026
1101cf5
:adhesive_bandage: fix(loader): preserve --quantization flag priority…
TallMessiWu Mar 31, 2026
f1c652b
:sparkles: feat(npu/mxfp8): add W8A8 MXFP8 LLM support on Ascend NPU
TallMessiWu Apr 1, 2026
97c45b6
:recycle: refactor(npu/mxfp8): refactor code to align with vllm-ascend
TallMessiWu Apr 1, 2026
6026a18
🐛 fix(quantization/modelslim): resolve circular import in schemes/__i…
TallMessiWu Apr 2, 2026
3025e2d
:bug: fix(quantization/modelslim): fix no scheme found error.
TallMessiWu Apr 2, 2026
da92418
:bug: fix(llm/mxfp8): fix meaningless output issue
TallMessiWu Apr 3, 2026
29c04bc
:bug: fix(llm/mxfp8): fix meaningless output issue
TallMessiWu Apr 7, 2026
6693352
:sparkles: feat(npu-quant): add MXFP4 W4A8 online quantization for Qw…
TallMessiWu Apr 8, 2026
e9dec3c
:sparkles: feat(npu/w4a8): add W4A8 offline Dense scheme and guards
TallMessiWu Apr 16, 2026
066becd
:recycle: refactor(quant): rename mxfp4_npu to mxfp4_w4a8_npu and add…
TallMessiWu Apr 24, 2026
93d542a
:sparkles: feat(modelslim): add W4A8_MXFP offline scheme for LLM
TallMessiWu Apr 24, 2026
456bd14
:bug: fix(modelslim): fix W4A8_MXFP weight dtype to float8_e4m3fn
TallMessiWu Apr 24, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
ModelOptFp4Config,
)
from sglang.multimodal_gen.runtime.layers.quantization.modelslim import ModelSlimConfig
from sglang.multimodal_gen.runtime.layers.quantization.mxfp8_npu import MXFP8Config

QuantizationMethods = Literal["fp8", "modelopt_fp4", "modelslim"]
QuantizationMethods = Literal["fp8", "modelopt_fp4", "modelslim", "mxfp8"]

QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods))

Expand All @@ -20,6 +21,7 @@
"modelopt_fp4": ModelOptFp4Config,
"modelslim": ModelSlimConfig,
"fp8": Fp8Config,
"mxfp8": MXFP8Config,
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,12 @@ def _get_scheme_from_parts(
return ModelSlimW4A4Int4(
quant_config=self.quant_description, prefix=layer_name
)
elif quant_type == "W8A8_MXFP8":
from sglang.multimodal_gen.runtime.layers.quantization.modelslim_mxfp8_scheme import (
ModelSlimMXFP8Scheme,
)

return ModelSlimMXFP8Scheme()
raise NotImplementedError("No modelslim compatible scheme was found.")

def get_scheme(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
"""ModelSlim MXFP8 scheme for pre-quantized weight inference on Ascend NPU.

Loads weights pre-quantized by msmodelslim (float8_e4m3fn weights,
uint8 scales) and runs MXFP8 matmul at inference.
"""

from typing import List, Optional

import torch
import torch_npu

from sglang.multimodal_gen.runtime.models.parameter import (
GroupQuantScaleParameter,
ModelWeightParameter,
)
from sglang.srt.layers.quantization.modelslim.schemes import ModelSlimLinearScheme

MXFP8_BLOCK_SIZE = 32
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For robustness across different versions of torch_npu and torch, it is better to use a fallback mechanism for the float8_e8m0fnu dtype, similar to the implementation in the SRT backend.

MXFP8_BLOCK_SIZE = 32
_FLOAT8_E8M0FNU_DTYPE = getattr(
    torch_npu, "float8_e8m0fnu", getattr(torch, "float8_e8m0fnu", None)
)



class ModelSlimMXFP8Scheme(ModelSlimLinearScheme):

def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
weight_loader = extra_weight_attrs.get("weight_loader")
output_size_per_partition = sum(output_partition_sizes)

# msmodelslim exports weight as float8_e4m3fn, shape [out, in]
weight = ModelWeightParameter(
data=torch.empty(
(output_size_per_partition, input_size_per_partition),
dtype=torch.float8_e4m3fn,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)

# msmodelslim exports weight_scale as uint8, shape [out, in/32].
# NOTE: This parameter is intentionally named "weight_scale" (not
# "weight_scale_inv" as used in mxfp8_npu.py) because the weight loader
# matches parameter names to checkpoint keys, and msmodelslim checkpoints
# store this tensor under the key "<layer>.weight_scale".
scale_dim = input_size_per_partition // MXFP8_BLOCK_SIZE
weight_scale = GroupQuantScaleParameter(
data=torch.empty(
(output_size_per_partition, scale_dim),
dtype=torch.uint8,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight_scale", weight_scale)

def process_weights_after_loading(self, layer: torch.nn.Module):
# weight is already float8_e4m3fn, no cast needed
weight = layer.weight.data
layer.weight = torch.nn.Parameter(weight, requires_grad=False)

# Reshape weight_scale: [out, in/32] -> [out, in/32//2, 2]
weight_scale = layer.weight_scale.data
weight_scale = weight_scale.reshape(weight_scale.shape[0], -1, 2)
layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
Comment on lines +65 to +73
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Pre-transposing the weight and scale tensors during model loading avoids the overhead of performing transposes on every forward pass. This optimization is already present in the SRT implementation of this scheme.

    def process_weights_after_loading(self, layer: torch.nn.Module):
        # weight is already float8_e4m3fn, no cast needed
        weight = layer.weight.data
        # Pre-transpose weight and scale to [in, out] for npu_quant_matmul.
        # Use .data assignment without .contiguous() to preserve the transpose
        # view strides — npu_quant_matmul reads strides correctly.
        layer.weight = torch.nn.Parameter(weight.transpose(0, 1), requires_grad=False)

        # Reshape weight_scale: [out, in/32] -> [out, in/32//2, 2]
        weight_scale = layer.weight_scale.data
        weight_scale = weight_scale.reshape(weight_scale.shape[0], -1, 2)
        layer.weight_scale = torch.nn.Parameter(weight_scale.transpose(0, 1), requires_grad=False)


def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:

original_dtype = x.dtype
if original_dtype not in (torch.float16, torch.bfloat16):
# npu_dynamic_mx_quant only accepts fp16/bf16 activations
x = x.to(torch.bfloat16)
original_dtype = torch.bfloat16

# npu_dynamic_mx_quant requires a 2D input [tokens, hidden_size].
# Diffusion transformer inputs are typically 3D [batch, seq, hidden] or
# higher. Flattening to 2D merges all leading dimensions into a single
# token axis so the NPU kernel can compute per-token MXFP8 scales, then
# we restore the original shape from the output.
input_shape = x.shape
x_2d = x.reshape(-1, x.shape[-1])

# Dynamic MXFP8 activation quantisation
qx, input_scale = torch_npu.npu_dynamic_mx_quant(
x_2d, dst_type=torch_npu.float8_e4m3fn
)

# MXFP8 matmul
output = torch_npu.npu_quant_matmul(
qx,
layer.weight.transpose(0, 1),
layer.weight_scale.transpose(0, 1),
scale_dtype=torch_npu.float8_e8m0fnu,
pertoken_scale=input_scale,
pertoken_scale_dtype=torch_npu.float8_e8m0fnu,
bias=bias.to(torch.float32) if bias is not None else None,
output_dtype=original_dtype,
group_sizes=[1, 1, MXFP8_BLOCK_SIZE],
)
Comment on lines +102 to +112
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Use the pre-transposed weights and the robust dtype fallback in the matmul call.

Suggested change
output = torch_npu.npu_quant_matmul(
qx,
layer.weight.transpose(0, 1),
layer.weight_scale.transpose(0, 1),
scale_dtype=torch_npu.float8_e8m0fnu,
pertoken_scale=input_scale,
pertoken_scale_dtype=torch_npu.float8_e8m0fnu,
bias=bias.to(torch.float32) if bias is not None else None,
output_dtype=original_dtype,
group_sizes=[1, 1, MXFP8_BLOCK_SIZE],
)
# MXFP8 matmul (weight & scale already transposed at load time)
output = torch_npu.npu_quant_matmul(
qx,
layer.weight,
layer.weight_scale,
scale_dtype=_FLOAT8_E8M0FNU_DTYPE,
pertoken_scale=input_scale,
pertoken_scale_dtype=_FLOAT8_E8M0FNU_DTYPE,
bias=bias.to(torch.float32) if bias is not None else None,
output_dtype=original_dtype,
group_sizes=[1, 1, MXFP8_BLOCK_SIZE],
)


# Restore original shape
output_shape = list(input_shape[:-1]) + [output.shape[-1]]
output = output.reshape(output_shape)

return output
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
"""Online MXFP8 quantization for Diffusion models on Ascend NPU.

Provides ``MXFP8Config`` (registered as ``"mxfp8"``) and
``NPUMXFP8DiffusionLinearMethod`` which quantise FP16/BF16 weights to MXFP8
at load time and use ``npu_dynamic_mx_quant`` + ``npu_quant_matmul`` for
inference, mirroring the LLM-side ``NPUMXFP8LinearMethod``.
"""

from __future__ import annotations

from typing import Any, Dict, List, Optional

import torch
import torch_npu
from torch.nn.parameter import Parameter

from sglang.multimodal_gen.runtime.layers.linear import LinearBase, LinearMethodBase
from sglang.multimodal_gen.runtime.layers.quantization.configs.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.multimodal_gen.runtime.models.parameter import ModelWeightParameter
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger

logger = init_logger(__name__)

MXFP8_BLOCK_SIZE = 32
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Add a fallback for the float8_e8m0fnu dtype to ensure compatibility across different environment versions.

MXFP8_BLOCK_SIZE = 32
_FLOAT8_E8M0FNU_DTYPE = getattr(
    torch_npu, "float8_e8m0fnu", getattr(torch, "float8_e8m0fnu", None)
)



class MXFP8Config(QuantizationConfig):
"""Config for online MXFP8 quantization on Ascend NPU (Diffusion)."""

def __init__(self) -> None:
super().__init__()

@classmethod
def get_name(cls) -> str:
return "mxfp8"

@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.bfloat16, torch.float16]

@classmethod
def get_min_capability(cls) -> int:
return 0 # NPU, not CUDA

@classmethod
def get_config_filenames(cls) -> List[str]:
return []

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "MXFP8Config":
return cls()

def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional[QuantizeMethodBase]:
if isinstance(layer, LinearBase):
return NPUMXFP8DiffusionLinearMethod(self)
return None

def get_scaled_act_names(self) -> List[str]:
return []


class NPUMXFP8DiffusionLinearMethod(LinearMethodBase):
"""Ascend NPU MXFP8 linear method for Diffusion models.

Online mode: loads FP16/BF16 weights → quantises to MXFP8 at load time.
Inference: dynamic MXFP8 activation quant + MXFP8 matmul (block_size=32).
"""

def __init__(self, quant_config: MXFP8Config):
self.quant_config = quant_config

def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")

layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.orig_dtype = params_dtype

# Load weights in original dtype; quantise later in process_weights_after_loading
weight = ModelWeightParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition,
dtype=params_dtype,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:

weight_fp = layer.weight.data
if weight_fp.dtype not in (torch.float16, torch.bfloat16):
weight_fp = weight_fp.to(torch.bfloat16)

# Move weight to NPU if needed. We intentionally use a conditional
# move rather than an assert because `dit_cpu_offload` defaults to
# True in ServerArgs, which causes fsdp_load to move every parameter
# back to CPU after loading (even when the target device is NPU).
# npu_dynamic_mx_quant requires an NPU tensor, so we must transfer
# here. The quantized fp8 weights produced below will remain on NPU
# for inference; if the model still needs to be offloaded after
# quantization (e.g. very large model on a small NPU), a higher-level
# offload pass can move them back afterwards.
if not weight_fp.is_npu:
weight_fp = weight_fp.to(f"npu:{torch.npu.current_device()}")

# Online MXFP8 quantisation of weights (block_size=32)
qw, w_scale = torch_npu.npu_dynamic_mx_quant(
weight_fp, dst_type=torch_npu.float8_e4m3fn
)
layer.weight = Parameter(qw, requires_grad=False)
layer.weight_scale_inv = Parameter(w_scale, requires_grad=False)
Comment on lines +108 to +131
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Pre-transpose the weights and scales during model loading to improve inference performance. Since this is online quantization, using .contiguous() after transpose is safe and recommended for NPU matmul efficiency.

Suggested change
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
weight_fp = layer.weight.data
if weight_fp.dtype not in (torch.float16, torch.bfloat16):
weight_fp = weight_fp.to(torch.bfloat16)
# Move weight to NPU if needed. We intentionally use a conditional
# move rather than an assert because `dit_cpu_offload` defaults to
# True in ServerArgs, which causes fsdp_load to move every parameter
# back to CPU after loading (even when the target device is NPU).
# npu_dynamic_mx_quant requires an NPU tensor, so we must transfer
# here. The quantized fp8 weights produced below will remain on NPU
# for inference; if the model still needs to be offloaded after
# quantization (e.g. very large model on a small NPU), a higher-level
# offload pass can move them back afterwards.
if not weight_fp.is_npu:
weight_fp = weight_fp.to(f"npu:{torch.npu.current_device()}")
# Online MXFP8 quantisation of weights (block_size=32)
qw, w_scale = torch_npu.npu_dynamic_mx_quant(
weight_fp, dst_type=torch_npu.float8_e4m3fn
)
layer.weight = Parameter(qw, requires_grad=False)
layer.weight_scale_inv = Parameter(w_scale, requires_grad=False)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
weight_fp = layer.weight.data
if weight_fp.dtype not in (torch.float16, torch.bfloat16):
weight_fp = weight_fp.to(torch.bfloat16)
# Move weight to NPU if needed.
if not weight_fp.is_npu:
weight_fp = weight_fp.to(f"npu:{torch.npu.current_device()}")
# Online MXFP8 quantisation of weights (block_size=32)
qw, w_scale = torch_npu.npu_dynamic_mx_quant(
weight_fp, dst_type=torch_npu.float8_e4m3fn
)
# Pre-transpose to [in, out] for npu_quant_matmul (avoid per-call transpose)
layer.weight = Parameter(qw.transpose(0, 1).contiguous(), requires_grad=False)
layer.weight_scale_inv = Parameter(w_scale.transpose(0, 1).contiguous(), requires_grad=False)


def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
original_dtype = x.dtype
if original_dtype not in (torch.float16, torch.bfloat16):
x = x.to(torch.bfloat16)
original_dtype = torch.bfloat16

# Flatten to 2D [tokens, hidden] so npu_dynamic_mx_quant returns 3D scale
input_shape = x.shape
x_2d = x.reshape(-1, x.shape[-1])

# Dynamic MXFP8 activation quantisation
qx, input_scale = torch_npu.npu_dynamic_mx_quant(
x_2d, dst_type=torch_npu.float8_e4m3fn
)

# MXFP8 matmul
output = torch_npu.npu_quant_matmul(
qx,
layer.weight.transpose(0, 1),
layer.weight_scale_inv.transpose(0, 1),
scale_dtype=torch_npu.float8_e8m0fnu,
pertoken_scale=input_scale,
pertoken_scale_dtype=torch_npu.float8_e8m0fnu,
bias=bias.to(torch.float32) if bias is not None else None,
output_dtype=original_dtype,
group_sizes=[1, 1, MXFP8_BLOCK_SIZE],
)
Comment on lines +154 to +164
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Update the matmul call to use the pre-transposed parameters and the robust dtype fallback.

Suggested change
output = torch_npu.npu_quant_matmul(
qx,
layer.weight.transpose(0, 1),
layer.weight_scale_inv.transpose(0, 1),
scale_dtype=torch_npu.float8_e8m0fnu,
pertoken_scale=input_scale,
pertoken_scale_dtype=torch_npu.float8_e8m0fnu,
bias=bias.to(torch.float32) if bias is not None else None,
output_dtype=original_dtype,
group_sizes=[1, 1, MXFP8_BLOCK_SIZE],
)
# MXFP8 matmul (weight & scale already transposed at load time)
output = torch_npu.npu_quant_matmul(
qx,
layer.weight,
layer.weight_scale_inv,
scale_dtype=_FLOAT8_E8M0FNU_DTYPE,
pertoken_scale=input_scale,
pertoken_scale_dtype=_FLOAT8_E8M0FNU_DTYPE,
bias=bias.to(torch.float32) if bias is not None else None,
output_dtype=original_dtype,
group_sizes=[1, 1, MXFP8_BLOCK_SIZE],
)


# Restore original shape (replace last dim with output features)
output_shape = list(input_shape[:-1]) + [output.shape[-1]]
output = output.reshape(output_shape)

return output
1 change: 1 addition & 0 deletions python/sglang/multimodal_gen/runtime/loader/fsdp_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,7 @@ def load_model_from_full_model_state_dict(
"bias",
"norm_q",
"norm_k",
"weight_scale",
]
for new_param_name in unused_keys:
meta_sharded_param = meta_sd.get(new_param_name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -274,8 +274,17 @@ def _resolve_quant_config(
) -> Optional[QuantizationConfig]:
"""
resolve quant config from checkpoints' metadata
priority: model config.json -> safetensors metadata -> format-specific fallback
priority: explicit --quantization flag -> model config.json -> safetensors metadata -> format-specific fallback
"""
# priority: explicit --quantization flag (e.g. mxfp8, mxfp4, modelslim)
if server_args.quantization is not None:
from sglang.multimodal_gen.runtime.layers.quantization import (
get_quantization_config,
)

quant_cls = get_quantization_config(server_args.quantization)
return quant_cls.from_config({})

quant_config = get_quant_config(hf_config, component_model_path)
if quant_config is None and server_args.transformer_weights_path:
for safetensors_file in safetensors_list:
Expand Down
12 changes: 12 additions & 0 deletions python/sglang/multimodal_gen/runtime/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,10 @@ class ServerArgs:

disable_autocast: bool | None = None

# Explicit quantization method override (e.g. "mxfp8", "fp8", "modelslim").
# When set, the transformer loader will use this instead of auto-detection.
quantization: str | None = None

# Quantization / Nunchaku SVDQuant configuration
nunchaku_config: NunchakuSVDQuantArgs | NunchakuConfig | None = field(
default_factory=NunchakuSVDQuantArgs, repr=False
Expand Down Expand Up @@ -771,6 +775,14 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
help="Disable autocast for denoising loop and vae decoding in pipeline sampling",
)

parser.add_argument(
"--quantization",
type=str,
default=None,
help='Quantization method override (e.g. "mxfp8", "fp8", "modelslim"). '
"When set, the transformer loader will use this instead of auto-detection.",
)

# Nunchaku SVDQuant quantization parameters
NunchakuSVDQuantArgs.add_cli_args(parser)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,16 @@


def find_quant_modelslim_config(model_config, component_model_path):
# Try exact name first, then glob for variant filenames (e.g. after repack)
quant_config_file = Path(component_model_path, "quant_model_description.json")
if not quant_config_file.is_file():
candidates = sorted(
Path(component_model_path).glob("quant_model_description*.json")
)
quant_config_file = candidates[0] if candidates else None

quant_cfg = None
if quant_config_file.is_file():
if quant_config_file is not None and Path(quant_config_file).is_file():
with open(quant_config_file) as f:
quant_cfg = json.load(f)
# This field is required for flagless model loading but is not present in
Expand Down
Loading
Loading