diff --git a/docs/advanced_features/quantization.md b/docs/advanced_features/quantization.md index 3610713e1a8d..90715a908ea7 100644 --- a/docs/advanced_features/quantization.md +++ b/docs/advanced_features/quantization.md @@ -353,6 +353,8 @@ python3 -m sglang.launch_server \ Our team is working on supporting more online quantization methods. SGLang will soon support methods including but not limited to `["awq", "gptq", "marlin", "gptq_marlin", "awq_marlin", "bitsandbytes", "gguf"]`. +### torchao online quantization method + SGLang also supports quantization methods based on [torchao](https://github.com/pytorch/ao). You can simply specify `--torchao-config` in the command line to support this feature. For example, if you want to enable `int4wo-128` for model `meta-llama/Meta-Llama-3.1-8B-Instruct`, you can launch the server with the following command: ```bash @@ -374,6 +376,12 @@ python3 -m sglang.launch_server \ --port 30000 --host 0.0.0.0 ``` +### `quark_int4fp8_moe` online quantization method + +SGLang running on AMD GPUs (CDNA3 or CDNA4 architecture) supports the quantization method `--quantization quark_int4fp8_moe`, that will replace [MoE layers](https://github.com/sgl-project/sglang/blob/v0.4.8/python/sglang/srt/layers/moe/fused_moe_triton/layer.py#L271) originally in high precision (bfloat16, float16 or float32) to use weights dynamically quantized to int4, that are upcasted to float8 during inference to run compute in float8 precision with activations dynamically quantized on the fly to float8. + +Other layers (e.g. projections in the attention layers) have their weights quantized online to float8 directly. + ## Reference - [GPTQModel](https://github.com/ModelCloud/GPTQModel) diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index aa10cb08d653..65d737e395aa 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -716,6 +716,7 @@ def _verify_quantization(self) -> None: "quark", "mxfp4", "auto-round", + "quark_int4fp8_moe", ] optimized_quantization_methods = [ "fp8", diff --git a/python/sglang/srt/layers/int4fp8_utils.py b/python/sglang/srt/layers/int4fp8_utils.py new file mode 100644 index 000000000000..92299b4c36fc --- /dev/null +++ b/python/sglang/srt/layers/int4fp8_utils.py @@ -0,0 +1,73 @@ +""" +Common utilities for quark. +""" + +import logging +from typing import Tuple + +import torch + +logger = logging.getLogger(__name__) + + +def quantize_fp8_scale_tensorwise(w: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + FP8_MAX = 448.0 + scale = w.abs().amax().float() / FP8_MAX + scaled = (w / scale).clamp(-FP8_MAX, FP8_MAX).to(torch.float8_e4m3fn) + return scaled, scale + + +def quantize_int4_scale_columnwise( + w: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + S4_MAX = 7 + w_flat = w.reshape(-1, w.shape[-1]).float() + scale = w_flat.abs().amax(axis=-1) / S4_MAX + scaled = torch.round(w_flat / scale[:, None]).to(torch.int8).clamp(-S4_MAX, S4_MAX) + return scaled.reshape(w.shape), scale.reshape(w.shape[:-1]) + + +def pack_int4_to_int32(to_pack: torch.Tensor, reorder: bool = True) -> torch.Tensor: + if to_pack.ndim > 2: + raise ValueError( + "Pack: Only supports tensors with dimensions not greater than 2." + ) + + if reorder: + order_map = [0, 2, 4, 6, 1, 3, 5, 7] + else: + order_map = [0, 1, 2, 3, 4, 5, 6, 7] + pack_num = 8 + if to_pack.ndim == 2: + packed = torch.zeros( + to_pack.shape[0], + to_pack.shape[1] // pack_num, + dtype=torch.int32, + device=to_pack.device, + ) + new_c = to_pack.shape[1] // pack_num + for c in range(new_c): + for i in range(pack_num): + # Use -3 as an example, high_position is 11111111,cause bit_or generate errors, so we can't use int4 directly + packed_col = to_pack[:, c * pack_num + order_map[i]].to(torch.int32) + packed_col = packed_col & 0x0F + packed[:, c] = torch.bitwise_or( + packed[:, c], torch.bitwise_left_shift(packed_col, i * 4) + ) + elif to_pack.ndim == 0: + packed = to_pack.to(torch.int32) + else: + packed = torch.zeros( + to_pack.shape[0] // pack_num, dtype=torch.int32, device=to_pack.device + ) + new_c = to_pack.shape[0] // pack_num + for c in range(new_c): + for i in range(pack_num): + # Use -3 as an example, high_position is 11111111,cause bit_or generate errors, so we can't use int4 directly + packed_col = to_pack[c * pack_num + order_map[i]] + packed_col = packed_col & 0x0F + packed[c] = torch.bitwise_or( + packed[c], torch.bitwise_left_shift(packed_col, i * 4) + ) + + return packed.view(torch.uint32) diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py index 428f3a261a94..1a33619a903b 100644 --- a/python/sglang/srt/layers/linear.py +++ b/python/sglang/srt/layers/linear.py @@ -66,6 +66,7 @@ "ModelOptFp4LinearMethod", "IPEXAWQLinearMethod", "PetitNvFp4LinearMethod", + "QuarkInt4Fp8LinearMethod", ] _is_cpu = is_cpu() diff --git a/python/sglang/srt/layers/quantization/__init__.py b/python/sglang/srt/layers/quantization/__init__.py index f4ec7d8c46a4..0e19019534f6 100644 --- a/python/sglang/srt/layers/quantization/__init__.py +++ b/python/sglang/srt/layers/quantization/__init__.py @@ -36,6 +36,7 @@ def override_quantization_method(self, *args, **kwargs): from sglang.srt.layers.quantization.petit import PetitNvFp4Config from sglang.srt.layers.quantization.qoq import QoQConfig from sglang.srt.layers.quantization.quark.quark import QuarkConfig +from sglang.srt.layers.quantization.quark_int4fp8_moe import QuarkInt4Fp8Config from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config @@ -68,6 +69,7 @@ def override_quantization_method(self, *args, **kwargs): "fbgemm_fp8": FBGEMMFp8Config, "quark": QuarkConfig, "auto-round": AutoRoundConfig, + "quark_int4fp8_moe": QuarkInt4Fp8Config, } diff --git a/python/sglang/srt/layers/quantization/quark_int4fp8_moe.py b/python/sglang/srt/layers/quantization/quark_int4fp8_moe.py new file mode 100644 index 000000000000..7fbbc61131d9 --- /dev/null +++ b/python/sglang/srt/layers/quantization/quark_int4fp8_moe.py @@ -0,0 +1,443 @@ +import logging +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +import torch +from tqdm import tqdm +from tqdm.std import EMA + +from sglang.srt.distributed import get_tensor_model_parallel_rank +from sglang.srt.layers.int4fp8_utils import ( + pack_int4_to_int32, + quantize_fp8_scale_tensorwise, + quantize_int4_scale_columnwise, +) +from sglang.srt.layers.moe import MoeRunnerConfig +from sglang.srt.layers.quantization.base_config import ( + FusedMoEMethodBase, + QuantizationConfig, + QuantizeMethodBase, +) +from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod +from sglang.srt.utils import BAR_FORMAT, is_hip, set_weight_attrs + +if TYPE_CHECKING: + from sglang.srt.layers.moe.token_dispatcher import DispatchOutput + +_is_hip = is_hip() + + +if _is_hip: + from aiter import ActivationType, QuantType + from aiter.fused_moe import fused_moe + from aiter.ops.shuffle import shuffle_weight + + ON_GFX950 = "gfx950" in torch.cuda.get_device_properties("cuda").gcnArchName + +logger = logging.getLogger(__name__) + + +def tqdm_reset_no_print(tqdm_bar: tqdm, total=None): + tqdm_bar.n = 0 + if total is not None: + tqdm_bar.total = total + if tqdm_bar.disable: + return + tqdm_bar.last_print_n = 0 + tqdm_bar.last_print_t = tqdm_bar.start_t = tqdm_bar._time() + tqdm_bar._ema_dn = EMA(tqdm_bar.smoothing) + tqdm_bar._ema_dt = EMA(tqdm_bar.smoothing) + tqdm_bar._ema_miniters = EMA(tqdm_bar.smoothing) + + +class QuarkInt4Fp8Config(QuantizationConfig): + """Config class for Quark Quantization. + + - Weight: static, per-channel, symmetric + - Activation: dynamic, per-token, symmetric + """ + + def __init__( + self, + is_checkpoint_fp8_serialized: bool = False, + activation_scheme: str = "dynamic", + ): + self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized + self.activation_scheme = activation_scheme + + if activation_scheme != "dynamic": + raise NotImplementedError( + "QuarkInt4Fp8Config only supports activation_scheme='dynamic'." + ) + + self.weight_block_size = None + + self.num_quant_layers = 0 + + tp_rank = get_tensor_model_parallel_rank() + + # The weight iterator already has a progress bar on rank=0, account for that. + position = 1 + tqdm._get_free_pos() + self.online_quant_progress_bar = tqdm( + total=0, + desc=f"Online quark_int4fp8_moe quantization on rank={tp_rank}", + position=position, + bar_format=BAR_FORMAT, + mininterval=2.0, + ) + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.float16, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + return 70 + + @classmethod + def get_name(self) -> str: + return "quark_int4fp8_moe" + + @classmethod + def get_config_filenames(cls) -> List[str]: + return [] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "QuarkInt4Fp8Config": + return cls() + + def get_quant_method( + self, + layer: torch.nn.Module, + prefix: str, + ) -> Optional["QuantizeMethodBase"]: + # TODO: fix circular imports issues in sglang forcing us to import here instead of at + # the top of file. + from sglang.srt.layers.linear import LinearBase + from sglang.srt.layers.moe.fused_moe_triton import FusedMoE + + if isinstance(layer, LinearBase): + return Fp8LinearMethod(self) + elif isinstance(layer, FusedMoE): + return QuarkInt4Fp8MoEMethod(self) + + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class QuarkInt4Fp8MoEMethod(FusedMoEMethodBase): + """MoE method for INT4FP8. + + Supports loading BF16/FP16 checkpoints, quantizing down to INT4, and dequantizing to FP8 during inference. + + Args: + quant_config: The quantization config. + """ + + def __init__(self, quant_config): + self.quant_config = quant_config + + self.online_quant_progress_bar = self.quant_config.online_quant_progress_bar + + self.tp_rank = get_tensor_model_parallel_rank() + + if not _is_hip: + raise NotImplementedError( + "The quark_int4fp8_moe online quantization scheme is only supported on AMD GPUs." + ) + + def get_weight_loader(self, layer, original_weight_loader): + def online_int4_fp8_weight_loader( + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + weight_name: str, + shard_id: str, + expert_id: int, + ): + if shard_id in ["w1", "w3"]: + shard_size = self.w13_shard_size + else: + shard_size = self.w2_shard_size + + original_use_presharded_weights = layer.use_presharded_weights + + if not layer.use_presharded_weights: + # In case the model is not pre-sharded (most checkpoints on HF Hub), + # we shard the model here in order to run online quantization on + # already sharded weights. + # Some models as `lmzheng/grok-1` are already be sharded. + layer.use_presharded_weights = True + + if shard_id in ["w1", "w3"]: + shard_dim = 0 + loaded_weight = loaded_weight.narrow( + shard_dim, shard_size * self.tp_rank, shard_size + ) + else: + shard_dim = 1 + loaded_weight = loaded_weight.narrow( + shard_dim, shard_size * self.tp_rank, shard_size + ) + + # We want to run online quantization on-device for speed purposes. + loaded_weight = loaded_weight.to(param.device) + + _, fp8_scale = quantize_fp8_scale_tensorwise(loaded_weight) + + int4_w, int4_scale = quantize_int4_scale_columnwise(loaded_weight) + + int4_w = pack_int4_to_int32(int4_w) + int4_scale /= fp8_scale + + if shard_id in ["w1", "w3"]: + if shard_id == "w1": + shard_slice = slice(0, shard_size) + idx = 0 + else: + shard_slice = slice(shard_size, 2 * shard_size) + idx = 1 + + assert param[expert_id][shard_slice].dtype == int4_w.dtype + + assert ( + layer.w13_int4_scale[expert_id][shard_slice].shape + == int4_scale.shape + ) + assert ( + layer.w13_int4_scale[expert_id][shard_slice].dtype + == int4_scale.dtype + ) + + layer.w13_int4_scale[expert_id][shard_slice].copy_(int4_scale) + + assert layer.w13_fp8_scale[expert_id][idx].shape == fp8_scale.shape + assert layer.w13_fp8_scale[expert_id][idx].dtype == fp8_scale.dtype + + layer.w13_fp8_scale[expert_id][idx].copy_(fp8_scale) + else: + assert param[expert_id].dtype == int4_w.dtype + assert param[expert_id].shape == int4_w.shape + + assert layer.w2_int4_scale[expert_id].shape == int4_scale.shape + assert layer.w2_int4_scale[expert_id].dtype == int4_scale.dtype + + layer.w2_int4_scale[expert_id].copy_(int4_scale) + + assert layer.w2_fp8_scale[expert_id].shape == fp8_scale.shape + assert layer.w2_fp8_scale[expert_id].dtype == fp8_scale.dtype + + layer.w2_fp8_scale[expert_id].copy_(fp8_scale) + + original_weight_loader( + param, + int4_w, + shard_id=shard_id, + weight_name=weight_name, + expert_id=expert_id, + ) + + # Reset `use_presharded_weights` as the same layer may load several different weights. + layer.use_presharded_weights = original_use_presharded_weights + + self.online_quant_progress_bar.update(1) + + return online_int4_fp8_weight_loader + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + # TODO: fix circular imports issues in sglang forcing us to import here instead of at + # the top of file. + from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported + + # print("intermediate_size_per_partition", intermediate_size_per_partition) + # fused moe logic already hands TP logic. + self.w13_shard_size = intermediate_size_per_partition + self.w2_shard_size = intermediate_size_per_partition + + assert "weight_loader" in extra_weight_attrs + original_weight_loader = extra_weight_attrs.get("weight_loader") + + online_int4fp8_weight_loader = self.get_weight_loader( + layer, original_weight_loader + ) + extra_weight_attrs["weight_loader"] = online_int4fp8_weight_loader + + params_dtype = torch.uint32 + # WEIGHTS + # INT4 MoE weight - INT32 packed + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // 8, + dtype=params_dtype, + ), + requires_grad=False, + ) + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition // 8, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # Allocate 2 scales for w1 and w3 respectively. + # They will be combined to a single scale after weight loading. + w13_fp8_scale = torch.nn.Parameter( + torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False + ) + w2_fp8_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("w13_fp8_scale", w13_fp8_scale) + layer.register_parameter("w2_fp8_scale", w2_fp8_scale) + + if _is_hip: + w13_int4_scale = torch.nn.Parameter( + torch.ones( + num_experts, + 2 * intermediate_size_per_partition, + dtype=torch.float32, + ), + requires_grad=False, + ) + w2_int4_scale = torch.nn.Parameter( + torch.ones(num_experts, hidden_size, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w13_int4_scale", w13_int4_scale) + layer.register_parameter("w2_int4_scale", w2_int4_scale) + + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} + ) + + set_weight_attrs(w13_fp8_scale, extra_weight_attrs) + set_weight_attrs(w2_fp8_scale, extra_weight_attrs) + + # Add the quantization method used (per tensor/grouped/channel) + # to ensure the weight scales are loaded in properly + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value} + ) + + set_weight_attrs(w13_int4_scale, extra_weight_attrs) + set_weight_attrs(w2_int4_scale, extra_weight_attrs) + + w13_input_scale = None + layer.register_parameter("w13_input_scale", w13_input_scale) + + w2_input_scale = None + layer.register_parameter("w2_input_scale", w2_input_scale) + + # Loading from the checkpoint w1, w2, w3 times the number of experts. + total = self.online_quant_progress_bar.total + num_experts * 3 + tqdm_reset_no_print(self.online_quant_progress_bar, total=total) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + if _is_hip and not ON_GFX950: + # CDNA3 does not support OCP FP8E4M3FN, but uses FP8E4M3FNUZ. + # CDNA4 supports OCP FP8E4M3FN. + layer.w13_int4_scale *= 0.5 + layer.w2_int4_scale *= 0.5 + + layer.w13_fp8_scale *= 2.0 + layer.w2_fp8_scale *= 2.0 + + # TODO: and use_aiter_moe: add after triton kernel added + # INT4-FP8 (INT4 MoE Weight, FP8 Compute) + # Weight Permutation + layer.w13_weight = torch.nn.Parameter( + shuffle_weight(layer.w13_weight.data, (16, 16)), + requires_grad=False, + ) + torch.cuda.empty_cache() + layer.w2_weight = torch.nn.Parameter( + shuffle_weight(layer.w2_weight.data, (16, 16)), + requires_grad=False, + ) + torch.cuda.empty_cache() + + # INT4-FP8 : offset INT4 w13_int4_scale to single w13_fp8_scale + # Fp8 moe kernel needs single fp8 w13_fp8_scale for w13 per expert. + # We won't do requant each expert's fp8 weight (not direct available), + # instead we adjust half of INT4 w13_int4_scale numbers + assert layer.w13_fp8_scale is not None + shard_size = layer.intermediate_size_per_partition + max_w13_scales = layer.w13_fp8_scale.max(dim=1).values + for expert_id in range(layer.num_experts): + start = 0 + max_w13_scale_fp8 = max_w13_scales[expert_id] + for shard_id in range(2): + if layer.w13_fp8_scale[expert_id][shard_id] != max_w13_scale_fp8: + int4_rescale = ( + layer.w13_fp8_scale[expert_id][shard_id] / max_w13_scale_fp8 + ) + layer.w13_int4_scale[expert_id][ + start : start + shard_size + ] *= int4_rescale + start += shard_size + + layer.w13_fp8_scale = torch.nn.Parameter(max_w13_scales, requires_grad=False) + + # special hack to asm_moe, which takes (weight_int4_scale * weight_scale) as post GEMM scaling + # optimal design - shall apply per-column weight_int4_scale before GEMM, and weight_scale post + for expert_id in range(layer.num_experts): + layer.w13_int4_scale[expert_id] *= max_w13_scales[expert_id] + layer.w2_int4_scale[expert_id] *= layer.w2_fp8_scale[expert_id] + + def create_moe_runner( + self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig + ): + self.moe_runner_config = moe_runner_config + + def apply( + self, + layer: torch.nn.Module, + dispatch_output: "DispatchOutput", + ) -> torch.Tensor: + # TODO: fix circular imports issues in sglang forcing us to import here instead of at + # the top of file. + from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput + + topk_output = dispatch_output.topk_output + moe_runner_config = self.moe_runner_config + + # TODO: add triton kernel and add check get_bool_env_var("CK_MOE") + assert ( + not moe_runner_config.no_combine + ), f"no_combine={moe_runner_config.no_combine} is not supported." + + output = fused_moe( + dispatch_output.hidden_states, + layer.w13_weight, + layer.w2_weight, + topk_output.topk_weights, + topk_output.topk_ids, + quant_type=QuantType.per_Token, + w1_scale=layer.w13_int4_scale, + w2_scale=layer.w2_int4_scale, + activation=( + ActivationType.Silu + if moe_runner_config.activation == "silu" + else ActivationType.Gelu + ), + ) + + return StandardCombineInput(hidden_states=output) diff --git a/python/sglang/srt/model_loader/utils.py b/python/sglang/srt/model_loader/utils.py index 29763bc103c0..f6cabe6dba18 100644 --- a/python/sglang/srt/model_loader/utils.py +++ b/python/sglang/srt/model_loader/utils.py @@ -87,7 +87,13 @@ def get_model_architecture(model_config: ModelConfig) -> Tuple[Type[nn.Module], architectures = getattr(model_config.hf_config, "architectures", []) # Special handling for quantized Mixtral. # FIXME(woosuk): This is a temporary hack. - mixtral_supported = ["fp8", "compressed-tensors", "gptq_marlin", "awq_marlin"] + mixtral_supported = [ + "fp8", + "compressed-tensors", + "gptq_marlin", + "awq_marlin", + "quark_int4fp8_moe", + ] if ( model_config.quantization is not None diff --git a/python/sglang/srt/model_loader/weight_utils.py b/python/sglang/srt/model_loader/weight_utils.py index f483774f2347..73ed8721ac2f 100644 --- a/python/sglang/srt/model_loader/weight_utils.py +++ b/python/sglang/srt/model_loader/weight_utils.py @@ -44,7 +44,12 @@ ci_download_with_validation_and_retry, ci_validate_and_cleanup_local_snapshot, ) -from sglang.srt.utils import find_local_repo_dir, log_info_on_rank0, print_warning_once +from sglang.srt.utils import ( + BAR_FORMAT, + find_local_repo_dir, + log_info_on_rank0, + print_warning_once, +) from sglang.utils import is_in_ci try: @@ -608,13 +613,6 @@ def filter_files_not_needed_for_inference(hf_weights_files: List[str]) -> List[s return hf_weights_files -# explicitly use pure text format, with a newline at the end -# this makes it impossible to see the animation in the progress bar -# but will avoid messing up with ray or multiprocessing, which wraps -# each line of output with some prefix. -_BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n" # noqa: E501 - - def np_cache_weights_iterator( model_name_or_path: str, cache_dir: Optional[str], @@ -642,7 +640,8 @@ def np_cache_weights_iterator( hf_weights_files, desc="Loading np_cache checkpoint shards", disable=not enable_tqdm, - bar_format=_BAR_FORMAT, + bar_format=BAR_FORMAT, + position=tqdm._get_free_pos(), ): state = torch.load(bin_file, map_location="cpu", weights_only=True) for name, param in state.items(): @@ -699,7 +698,8 @@ def safetensors_weights_iterator( hf_weights_files, desc="Loading safetensors checkpoint shards", disable=not enable_tqdm, - bar_format=_BAR_FORMAT, + bar_format=BAR_FORMAT, + position=tqdm._get_free_pos(), ): if disable_mmap: with open(st_file, "rb") as f: @@ -811,7 +811,7 @@ def _load_file(st_file: str): total=len(hf_weights_files), desc="Multi-thread loading shards", disable=not enable_tqdm, - bar_format=_BAR_FORMAT, + bar_format=BAR_FORMAT, ) else: futures_iter = concurrent.futures.as_completed(futures) @@ -853,7 +853,8 @@ def pt_weights_iterator( hf_weights_files, desc="Loading pt checkpoint shards", disable=not enable_tqdm, - bar_format=_BAR_FORMAT, + bar_format=BAR_FORMAT, + position=tqdm._get_free_pos(), ): state = _load_pt_file(bin_file) yield from state.items() @@ -880,7 +881,7 @@ def multi_thread_pt_weights_iterator( total=len(hf_weights_files), desc="Multi-thread loading pt checkpoint shards", disable=not enable_tqdm, - bar_format=_BAR_FORMAT, + bar_format=BAR_FORMAT, ) else: futures_iter = concurrent.futures.as_completed(futures) @@ -1033,7 +1034,8 @@ def runai_safetensors_weights_iterator( hf_weights_files, desc="Loading safetensors using Runai Model Streamer", disable=not enable_tqdm, - bar_format=_BAR_FORMAT, + bar_format=BAR_FORMAT, + position=tqdm._get_free_pos(), ): streamer.stream_file(st_file) yield from streamer.get_tensors() diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 261895b01c81..3ca37208ebbc 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -109,6 +109,7 @@ "auto-round", "compressed-tensors", # for Ktransformers "modelslim", # for NPU + "quark_int4fp8_moe", ] SPECULATIVE_DRAFT_MODEL_QUANTIZATION_CHOICES = [*QUANTIZATION_CHOICES, "unquant"] diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py index 8560246c686b..ac0ec98dfcfd 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -121,6 +121,12 @@ def is_hip() -> bool: builtins.FP8_E4M3_MAX = FP8_E4M3_MAX builtins.FP8_E4M3_MIN = FP8_E4M3_MIN +# explicitly use pure text format, with a newline at the end +# this makes it impossible to see the animation in the progress bar +# but will avoid messing up with ray or multiprocessing, which wraps +# each line of output with some prefix. +BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n" # noqa: E501 + @lru_cache(maxsize=1) def is_cuda(): diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index add18b8ff717..7578f7939999 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -99,6 +99,8 @@ # TestFile("test_no_overlap_scheduler.py", 234), # Disabled temporarily and track in #7703 # TestFile("test_vision_chunked_prefill.py", 175), # Disabled temporarily and track in #7701 # TestFile("test_wave_attention_backend.py", 150), # Disabled temporarily, see https://github.com/sgl-project/sglang/issues/11127 + # The time estimation for `test_int4fp8_moe.py` assumes `mistralai/Mixtral-8x7B-Instruct-v0.1` is already cached (running on 1xMI300X). + TestFile("test_int4fp8_moe.py", 313), ], "per-commit-4-gpu-amd": [ TestFile("test_pp_single_node.py", 150), diff --git a/test/srt/test_int4fp8_moe.py b/test/srt/test_int4fp8_moe.py new file mode 100644 index 000000000000..bcc7d9d46be6 --- /dev/null +++ b/test/srt/test_int4fp8_moe.py @@ -0,0 +1,55 @@ +from types import SimpleNamespace + +from sglang.srt.utils import kill_process_tree +from sglang.test.few_shot_gsm8k import run_eval +from sglang.test.test_utils import ( + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + + +class TestMixtralAccuracy(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = "mistralai/Mixtral-8x7B-Instruct-v0.1" + cls.base_url = DEFAULT_URL_FOR_TEST + + other_args = [ + "--tp", + "2", + "--mem-fraction-static", + "0.9", + "--context-length", + "38768", + "--quantization", + "quark_int4fp8_moe", + # The default aiter attention backend raises segmentation faults and other errors - as quark_int4fp8_moe is not related to attention, let's just use triton here. + "--attention-backend", + "triton", + ] + + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=45 * 60, + other_args=other_args, + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=8, + data_path=None, + num_questions=1400, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval(args) + print(f"{metrics=}") + self.assertGreater(metrics["accuracy"], 0.56)