diff --git a/python/sglang/bench_one_batch.py b/python/sglang/bench_one_batch.py index f8c67c8f4e44..dc5bb691ab7d 100644 --- a/python/sglang/bench_one_batch.py +++ b/python/sglang/bench_one_batch.py @@ -74,7 +74,7 @@ set_gpu_proc_affinity, suppress_other_loggers, ) - +from rpdTracerControl import rpdTracerControl @dataclasses.dataclass class BenchArgs: @@ -89,6 +89,8 @@ class BenchArgs: log_decode_step: int = 0 profile: bool = False profile_filename_prefix: str = "profile" + enable_prefill_prof: bool = False + enable_decode_prof: bool = False @staticmethod def add_cli_args(parser: argparse.ArgumentParser): @@ -123,6 +125,14 @@ def add_cli_args(parser: argparse.ArgumentParser): help="Prefix of the profiling file names. The full profiling result file(s) be " '"[profile_filename_prefix]_batch[batch_size]_input[input_len]_output[output_len].trace.json.gz"', ) + parser.add_argument( + "--enable-decode-prof", + action='store_true', + help="enable decode profiler.") + parser.add_argument( + "--enable-prefill-prof", + action='store_true', + help="enable prefill profiler.") @classmethod def from_cli_args(cls, args: argparse.Namespace): @@ -327,6 +337,7 @@ def synchronize(device): def latency_test_run_once( + is_warm_up, enable_prefill_prof, enable_decode_prof, tp_rank, run_name, model_runner, rank_print, @@ -373,8 +384,14 @@ def latency_test_run_once( # Prefill synchronize(device) tic = time.time() + if enable_prefill_prof and not is_warm_up and tp_rank == 0: + print("Start profile Prefill") + prefill_profile = rpdTracerControl() + prefill_profile.start() next_token_ids, _, batch = extend(reqs, model_runner) synchronize(device) + if enable_prefill_prof and not is_warm_up and tp_rank == 0: + prefill_profile.stop() prefill_latency = time.time() - tic tot_latency += prefill_latency throughput = input_len * batch_size / prefill_latency @@ -386,9 +403,15 @@ def latency_test_run_once( # Decode decode_latencies = [] + if enable_decode_prof and not is_warm_up and tp_rank == 0: + print("Start profile Decode") + # Create first instance (this loads the profiler and creates the file) + decode_profile = rpdTracerControl() + decode_profile.start() for i in range(output_len - 1): synchronize(device) tic = time.time() + next_token_ids, _ = decode(next_token_ids, batch, model_runner) synchronize(device) latency = time.time() - tic @@ -399,7 +422,8 @@ def latency_test_run_once( rank_print( f"Decode {i}. Batch size: {batch_size}, latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s" ) - + if enable_decode_prof and not is_warm_up and tp_rank == 0: + decode_profile.stop() if profile: profiler.stop() profile_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}.trace.json.gz" @@ -452,6 +476,10 @@ def latency_test( # Warm up rank_print("Warmup ...") latency_test_run_once( + True, + bench_args.enable_prefill_prof, + bench_args.enable_decode_prof, + tp_rank, bench_args.run_name, model_runner, rank_print, @@ -474,6 +502,10 @@ def latency_test( ): reqs = prepare_synthetic_inputs_for_latency_test(bs, il) ret = latency_test_run_once( + False, + bench_args.enable_prefill_prof, + bench_args.enable_decode_prof, + tp_rank, bench_args.run_name, model_runner, rank_print, @@ -501,6 +533,12 @@ def latency_test( def main(server_args, bench_args): server_args.cuda_graph_max_bs = max(bench_args.batch_size) + if bench_args.enable_prefill_prof or bench_args.enable_decode_prof: + # Optionally call this class method before creating first instance + rpdTracerControl.setFilename(name = "trace.rpd", append=False) + + # Create first instance (this loads the profiler and creates the file) + profile = rpdTracerControl() _set_envs_and_config(server_args) diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 12194ba181e3..a33cf691fa52 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -18,7 +18,7 @@ QuantizationConfig, QuantizeMethodBase, ) -from sglang.srt.utils import get_bool_env_var, is_hip, set_weight_attrs +from sglang.srt.utils import get_bool_env_var, is_hip, permute_weight, set_weight_attrs if torch.cuda.is_available(): from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts @@ -30,9 +30,7 @@ _is_hip = is_hip() if _is_hip: - from aiter import ActivationType - from aiter.fused_moe_bf16_asm import ck_moe_2stages - from aiter.ops.shuffle import shuffle_weight + from aiter import ck_moe logger = logging.getLogger(__name__) @@ -104,14 +102,14 @@ def create_weights( set_weight_attrs(w2_weight, extra_weight_attrs) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - if _is_hip and get_bool_env_var("SGLANG_AITER_MOE"): + if _is_hip and get_bool_env_var("CK_MOE"): layer.w13_weight = torch.nn.Parameter( - shuffle_weight(layer.w13_weight.data, (16, 16)), + permute_weight(layer.w13_weight.data), requires_grad=False, ) torch.cuda.empty_cache() layer.w2_weight = torch.nn.Parameter( - shuffle_weight(layer.w2_weight.data, (16, 16)), + permute_weight(layer.w2_weight.data), requires_grad=False, ) torch.cuda.empty_cache() @@ -133,7 +131,6 @@ def apply( apply_router_weight_on_input: bool = False, inplace: bool = True, no_combine: bool = False, - routed_scaling_factor: Optional[float] = None, ) -> torch.Tensor: return self.forward( x=x, @@ -150,7 +147,6 @@ def apply( apply_router_weight_on_input=apply_router_weight_on_input, inplace=inplace, no_combine=no_combine, - routed_scaling_factor=routed_scaling_factor, ) def forward_cuda( @@ -169,7 +165,6 @@ def forward_cuda( apply_router_weight_on_input: bool = False, inplace: bool = True, no_combine: bool = False, - routed_scaling_factor: Optional[float] = None, ) -> torch.Tensor: topk_weights, topk_ids = select_experts( hidden_states=x, @@ -181,20 +176,23 @@ def forward_cuda( num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, correction_bias=correction_bias, - routed_scaling_factor=routed_scaling_factor, ) - if _is_hip and get_bool_env_var("SGLANG_AITER_MOE"): + if _is_hip and get_bool_env_var("CK_MOE"): assert not no_combine, "unsupported" - return ck_moe_2stages( + return ck_moe( x, layer.w13_weight, layer.w2_weight, topk_weights, topk_ids, - activation=( - ActivationType.Silu if activation == "silu" else ActivationType.Gelu - ), + None, + None, + None, + None, + 32, + None, + activation, ) else: return fused_experts( @@ -286,7 +284,6 @@ def __init__( use_presharded_weights: bool = False, inplace: bool = True, no_combine: bool = False, - routed_scaling_factor: Optional[float] = None, ): super().__init__() @@ -296,7 +293,6 @@ def __init__( self.tp_size = ( tp_size if tp_size is not None else get_tensor_model_parallel_world_size() ) - self.routed_scaling_factor = routed_scaling_factor self.top_k = top_k self.num_experts = num_experts assert intermediate_size % self.tp_size == 0 @@ -525,7 +521,7 @@ def weight_loader( # Case input scale: input_scale loading is only supported for fp8 if "input_scale" in weight_name: # INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust input_scale for e4m3fnuz (AMD) - if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"): + if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"): loaded_weight = loaded_weight * 2.0 # this is needed for compressed-tensors only @@ -567,7 +563,7 @@ def weight_loader( quant_method = getattr(param, "quant_method", None) if quant_method == FusedMoeWeightScaleSupported.CHANNEL.value: # INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust INT4 column-wise scaling number to e4m3fnuz (AMD) - if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"): + if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"): loaded_weight = loaded_weight * 0.5 self._load_per_channel_weight_scale( @@ -590,7 +586,7 @@ def weight_loader( ) elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value: # INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust FP8 per-tensor scaling number for e4m3fnuz (AMD) - if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"): + if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"): loaded_weight = loaded_weight * 2.0 self._load_per_tensor_weight_scale( @@ -641,7 +637,6 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): correction_bias=self.correction_bias, activation=self.activation, apply_router_weight_on_input=self.apply_router_weight_on_input, - routed_scaling_factor=self.routed_scaling_factor, ) if self.reduce_results and self.tp_size > 1: diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 521ba7deb916..abb105af432e 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -8,6 +8,15 @@ from torch.nn import Module from torch.nn.parameter import Parameter +from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod +from sglang.srt.layers.quantization.utils import ( + all_close_1d, + convert_to_channelwise, + is_layer_skipped, + per_tensor_dequantize, + requantize_with_max_scale, +) + try: from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( apply_fp8_marlin_linear, @@ -18,12 +27,11 @@ except ImportError: MARLIN_FP8_AVAILABLE = False - def dummy_func(*args, **kwargs): - raise ImportError( - "marlin FP8 requires some operators from vllm. Please install vllm." - ) + def apply_fp8_marlin_linear(*args, **kwargs): + raise ImportError("vllm is not installed") - apply_fp8_marlin_linear = prepare_fp8_layer_for_marlin = dummy_func + def prepare_fp8_layer_for_marlin(*args, **kwargs): + raise ImportError("vllm is not installed") from sglang.srt.distributed import get_tensor_model_parallel_world_size @@ -41,12 +49,7 @@ def dummy_func(*args, **kwargs): QuantizationConfig, QuantizeMethodBase, ) -from sglang.srt.layers.quantization.fp8_kernel import ( - fp8_dtype, - is_fp8_fnuz, - per_token_group_quant_fp8, - scaled_fp8_quant, -) +from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8 from sglang.srt.layers.quantization.fp8_utils import ( apply_fp8_linear, apply_w8a8_block_fp8_linear, @@ -54,41 +57,44 @@ def dummy_func(*args, **kwargs): input_to_float8, normalize_e4m3fn_to_e4m3fnuz, ) -from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod -from sglang.srt.layers.quantization.utils import ( - all_close_1d, - convert_to_channelwise, - is_layer_skipped, - per_tensor_dequantize, - requantize_with_max_scale, -) + +#from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( +# #all_close_1d, +# apply_fp8_linear, +# #convert_to_channelwise, +# #cutlass_fp8_supported, +# #per_tensor_dequantize, +# #requantize_with_max_scale, +#) + from sglang.srt.utils import ( get_bool_env_var, is_cuda, is_hip, - log_info_on_rank0, + permute_weight, print_warning_once, set_weight_attrs, ) -_is_hip = is_hip() -_is_cuda = is_cuda() - -_is_fp8_fnuz = is_fp8_fnuz() +ACTIVATION_SCHEMES = ["static", "dynamic"] -use_hip_int4 = get_bool_env_var("SGLANG_INT4_WEIGHT") -use_aiter_moe = get_bool_env_var("SGLANG_AITER_MOE") +_is_hip = is_hip() if _is_hip: - from aiter import ActivationType, QuantType - from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages + from aiter.fused_moe_bf16_asm import asm_moe from aiter.ops.shuffle import shuffle_weight + from aiter import ActivationType + from aiter.fused_moe_bf16_asm import ( + ck_moe_2stages_win4, + ck_moe_2stages, + ) -if not _is_cuda: - from vllm._custom_ops import scaled_fp8_quant +_is_cuda = is_cuda() - -ACTIVATION_SCHEMES = ["static", "dynamic"] +if _is_cuda: + from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant +else: + from vllm import _custom_ops as vllm_ops logger = logging.getLogger(__name__) @@ -105,7 +111,10 @@ def __init__( ) -> None: self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized if is_checkpoint_fp8_serialized: - log_info_on_rank0(logger, "Detected fp8 checkpoint.") + logger.warning( + "Detected fp8 checkpoint. Please note that the " + "format is experimental and subject to change." + ) if activation_scheme not in ACTIVATION_SCHEMES: raise ValueError(f"Unsupported activation scheme {activation_scheme}") self.activation_scheme = activation_scheme @@ -235,7 +244,7 @@ def create_weights( f"{input_size_per_partition} is not divisible by " f"weight quantization block_k = {block_k}." ) - # Required by column parallel or enabling merged weights + # Required by collum parallel or enabling merged weights if ( tp_size > 1 and output_size // output_size_per_partition == tp_size ) or len(output_partition_sizes) > 1: @@ -248,6 +257,7 @@ def create_weights( ) layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition layer.output_size_per_partition = output_size_per_partition layer.orig_dtype = params_dtype @@ -308,28 +318,31 @@ def create_weights( layer.register_parameter("input_scale", None) def process_weights_after_loading(self, layer: Module) -> None: + # Block quant doesn't need to process weights after loading if self.block_quant: # If ROCm, normalize the weights and scales to e4m3fnuz - if _is_fp8_fnuz: + if _is_hip: # activation_scheme: dynamic weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( weight=layer.weight, weight_scale=layer.weight_scale_inv, input_scale=None, ) - + layer.weight = torch.nn.Parameter(weight, requires_grad=False) + layer.weight_scale_inv = torch.nn.Parameter( + weight_scale, requires_grad=False + ) layer.input_scale = None else: - weight, weight_scale = layer.weight.data, layer.weight_scale_inv.data - layer.weight = torch.nn.Parameter(weight, requires_grad=False) - layer.weight_scale_inv = torch.nn.Parameter( - weight_scale, requires_grad=False - ) + layer.weight = torch.nn.Parameter( + layer.weight.data, requires_grad=False + ) + layer.weight_scale_inv = torch.nn.Parameter( + layer.weight_scale_inv.data, requires_grad=False + ) return - layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False) - # If checkpoint not serialized fp8, quantize the weights. if not self.quant_config.is_checkpoint_fp8_serialized: if self.cutlass_fp8_supported or self.use_marlin: @@ -369,7 +382,7 @@ def process_weights_after_loading(self, layer: Module) -> None: weight = layer.weight weight_scale = layer.weight_scale # If ROCm, normalize the weights and scales to e4m3fnuz - if _is_fp8_fnuz: + if _is_hip: weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( weight=weight, weight_scale=weight_scale, @@ -393,9 +406,12 @@ def process_weights_after_loading(self, layer: Module) -> None: ) if self.use_marlin: - prepare_fp8_layer_for_marlin(layer) - # Activations not quantized for marlin. - del layer.input_scale + try: + prepare_fp8_layer_for_marlin(layer) + # Activations not quantized for marlin. + del layer.input_scale + except ImportError: + self.use_marlin = False def apply( self, @@ -405,15 +421,18 @@ def apply( ) -> torch.Tensor: if self.use_marlin: - return apply_fp8_marlin_linear( - input=x, - weight=layer.weight, - weight_scale=layer.weight_scale, - workspace=layer.workspace, - size_n=layer.output_size_per_partition, - size_k=layer.input_size_per_partition, - bias=bias, - ) + try: + return apply_fp8_marlin_linear( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + workspace=layer.workspace, + size_n=layer.output_size_per_partition, + size_k=layer.input_size_per_partition, + bias=bias, + ) + except ImportError: + self.use_marlin = False if self.block_quant: return apply_w8a8_block_fp8_linear( @@ -483,7 +502,11 @@ def create_weights( from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported if self.quant_config.is_checkpoint_fp8_serialized: - params_dtype = torch.uint32 if use_hip_int4 else torch.float8_e4m3fn + params_dtype = ( + torch.uint32 + if get_bool_env_var("USE_INT4_WEIGHT") + else torch.float8_e4m3fn + ) tp_size = get_tensor_model_parallel_world_size() if self.block_quant: block_n, block_k = ( @@ -491,7 +514,7 @@ def create_weights( self.quant_config.weight_block_size[1], ) # NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n. - # Required by column parallel or enabling merged weights + # Required by collum parallel or enabling merged weights if intermediate_size % block_n != 0: raise ValueError( f"The output_size of gate's and up's weight = " @@ -508,7 +531,7 @@ def create_weights( ) # WEIGHTS - if _is_hip and use_hip_int4: + if get_bool_env_var("USE_INT4_WEIGHT"): # INT4 MoE weight - INT32 packed w13_weight = torch.nn.Parameter( torch.empty( @@ -580,7 +603,9 @@ def create_weights( layer.register_parameter("w13_weight_scale", w13_weight_scale) layer.register_parameter("w2_weight_scale", w2_weight_scale) - if _is_hip: # and use_aiter_moe: TODO: add check back after triton kernel + if ( + _is_hip + ): # and get_bool_env_var("CK_MOE"): TODO: add check back after triton kernel # ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling w13_weight_scale1 = torch.nn.Parameter( torch.ones(num_experts, 2 * intermediate_size, dtype=torch.float32), @@ -607,7 +632,7 @@ def create_weights( set_weight_attrs(w13_weight_scale, extra_weight_attrs) set_weight_attrs(w2_weight_scale, extra_weight_attrs) - if _is_hip and use_hip_int4: + if get_bool_env_var("USE_INT4_WEIGHT"): extra_weight_attrs.update( {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value} ) @@ -639,14 +664,14 @@ def create_weights( layer.w2_input_scale = None def process_weights_after_loading(self, layer: Module) -> None: - if _is_hip and use_hip_int4: + if get_bool_env_var("USE_INT4_WEIGHT"): self.process_weights_hip_int4(layer) return # Block quant doesn't need to process weights after loading if self.block_quant: # If ROCm, normalize the weights and scales to e4m3fnuz - if _is_fp8_fnuz: + if _is_hip: # activation_scheme: dynamic w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( weight=layer.w13_weight, @@ -670,19 +695,20 @@ def process_weights_after_loading(self, layer: Module) -> None: ) layer.w2_input_scale = None - if _is_hip and use_aiter_moe: - # Pre-shuffle weights - layer.w13_weight.data = shuffle_weight( - layer.w13_weight.contiguous(), (16, 16) - ) - layer.w2_weight.data = shuffle_weight( - layer.w2_weight.contiguous(), (16, 16) - ) + if get_bool_env_var("CK_MOE"): + # Pre-shuffle weights + layer.w13_weight.data = shuffle_weight( + layer.w13_weight.contiguous(), (16, 16) + ) + layer.w2_weight.data = shuffle_weight( + layer.w2_weight.contiguous(), (16, 16) + ) return # If checkpoint is fp16 or bfloat16, quantize in place. if not self.quant_config.is_checkpoint_fp8_serialized: - # If ROCm, fp8_dtype will be float8_e4m3fnuz (MI300x HW) + # If ROCm, use float8_e4m3fnuz instead (MI300x HW) + fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype) w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype) @@ -695,12 +721,20 @@ def process_weights_after_loading(self, layer: Module) -> None: requires_grad=False, ) for expert in range(layer.num_experts): - w13_weight[expert, :, :], layer.w13_weight_scale[expert] = ( - scaled_fp8_quant(layer.w13_weight.data[expert, :, :]) - ) - w2_weight[expert, :, :], layer.w2_weight_scale[expert] = ( - scaled_fp8_quant(layer.w2_weight.data[expert, :, :]) - ) + if _is_cuda: + w13_weight[expert, :, :], layer.w13_weight_scale[expert] = ( + sgl_scaled_fp8_quant(layer.w13_weight.data[expert, :, :]) + ) + w2_weight[expert, :, :], layer.w2_weight_scale[expert] = ( + sgl_scaled_fp8_quant(layer.w2_weight.data[expert, :, :]) + ) + else: + w13_weight[expert, :, :], layer.w13_weight_scale[expert] = ( + vllm_ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :]) + ) + w2_weight[expert, :, :], layer.w2_weight_scale[expert] = ( + vllm_ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :]) + ) layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) @@ -736,7 +770,7 @@ def process_weights_after_loading(self, layer: Module) -> None: ) # If ROCm, normalize the weights and scales to e4m3fnuz - if _is_fp8_fnuz: + if _is_hip: # Normalize the weights and scales w13_weight, w13_weight_scale, w13_input_scale = ( normalize_e4m3fn_to_e4m3fnuz( @@ -777,10 +811,18 @@ def process_weights_after_loading(self, layer: Module) -> None: layer.w13_weight[expert_id][start : start + shard_size, :], layer.w13_weight_scale[expert_id][shard_id], ) - ( - layer.w13_weight[expert_id][start : start + shard_size, :], - _, - ) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id]) + if _is_cuda: + ( + layer.w13_weight[expert_id][start : start + shard_size, :], + _, + ) = sgl_scaled_fp8_quant(dq_weight, max_w13_scales[expert_id]) + else: + ( + layer.w13_weight[expert_id][start : start + shard_size, :], + _, + ) = vllm_ops.scaled_fp8_quant( + dq_weight, max_w13_scales[expert_id] + ) start += shard_size layer.w13_weight_scale = torch.nn.Parameter( @@ -792,16 +834,18 @@ def process_weights_after_loading(self, layer: Module) -> None: return def process_weights_hip_int4(self, layer: Module): - # TODO: and use_aiter_moe: add after triton kernel added + # TODO: and get_bool_env_var("CK_MOE"): add after triton kernel added # INT4-FP8 (INT4 MoE Weight, FP8 Compute) # Weight Permutation layer.w13_weight = torch.nn.Parameter( - shuffle_weight(layer.w13_weight.data, (16, 16)), + permute_weight(layer.w13_weight.data), + #shuffle_weight(layer.w13_weight.data, (16, 16)), requires_grad=False, ) torch.cuda.empty_cache() layer.w2_weight = torch.nn.Parameter( - shuffle_weight(layer.w2_weight.data, (16, 16)), + permute_weight(layer.w2_weight.data), + #shuffle_weight(layer.w2_weight.data, (16, 16)), requires_grad=False, ) torch.cuda.empty_cache() @@ -839,21 +883,23 @@ def process_weights_hip_scale_padding(self, layer: Module): padding_size, # Avoid circular import ) - if use_aiter_moe: + if get_bool_env_var("CK_MOE"): layer.w13_weight = torch.nn.Parameter( - shuffle_weight(layer.w13_weight.data, (16, 16)), + permute_weight(layer.w13_weight.data), + #shuffle_weight(layer.w13_weight.data, (16, 16)), requires_grad=False, ) torch.cuda.empty_cache() layer.w2_weight = torch.nn.Parameter( - shuffle_weight(layer.w2_weight.data, (16, 16)), + permute_weight(layer.w2_weight.data), + #shuffle_weight(layer.w2_weight.data, (16, 16)), requires_grad=False, ) torch.cuda.empty_cache() - # ROCm (use_aiter_moe): using column-wise scaling + # ROCm (CK_MOE): using column-wise scaling layer.w13_weight_scale1 *= layer.w13_weight_scale.unsqueeze(-1) layer.w2_weight_scale1 *= layer.w2_weight_scale.unsqueeze(-1) - elif get_bool_env_var("SGLANG_MOE_PADDING"): + elif get_bool_env_var("MOE_PADDING"): # If ROCm, apply weight padding (min. Mem channel contention) only if set layer.w13_weight = torch.nn.Parameter( F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0), @@ -882,7 +928,6 @@ def apply( apply_router_weight_on_input: bool = False, inplace: bool = True, no_combine: bool = False, - routed_scaling_factor: Optional[float] = None, ) -> torch.Tensor: from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts from sglang.srt.layers.moe.topk import select_experts @@ -898,79 +943,28 @@ def apply( num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, correction_bias=correction_bias, - routed_scaling_factor=routed_scaling_factor, - ) - - if _is_hip: - ret = self.maybe_apply_hip_fused_experts( - layer, - x, - topk_weights, - topk_ids, - activation, - no_combine, - ) - if ret is not None: - return ret - - # Expert fusion with FP8 quantization - return fused_experts( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=inplace and not no_combine, - activation=activation, - apply_router_weight_on_input=apply_router_weight_on_input, - use_fp8_w8a8=True, - w1_scale=( - layer.w13_weight_scale_inv - if self.block_quant - else layer.w13_weight_scale - ), - w2_scale=( - layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale - ), - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale, - block_shape=self.quant_config.weight_block_size, - no_combine=no_combine, ) - def maybe_apply_hip_fused_experts( - self, - layer: torch.nn.Module, - x: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - activation: str = "silu", - no_combine: bool = False, - ) -> Optional[torch.Tensor]: - if use_hip_int4: - # TODO: add triton kernel and add check use_aiter_moe + if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"): + # TODO: add triton kernel and add check get_bool_env_var("CK_MOE") assert not no_combine, f"{no_combine=} is not supported." - return ck_moe_2stages( + return ck_moe_2stages_win4( x, layer.w13_weight, layer.w2_weight, topk_weights, topk_ids, - QuantType.per_Token, layer.w13_weight_scale1, layer.w2_weight_scale1, - activation=( - ActivationType.Silu if activation == "silu" else ActivationType.Gelu - ), + activation=ActivationType.Silu if activation=="silu" else ActivationType.Gelu, ) - - if use_aiter_moe: + if _is_hip and get_bool_env_var("CK_MOE"): + # TODO(CK_MOE): FP8 or FP8 block_quant only supports 'silu' for the time-being. + #assert ( + # activation == "silu" + #), f"CK_MOE: FP8 and/or FP8 bloack_quant {activation=} will be supported later, unset CK_MOE" assert not no_combine, f"{no_combine=} is not supported." if self.block_quant: - # TODO(use_aiter_moe): FP8 block_quant only supports 'silu' for the time-being. - assert ( - activation == "silu" - ), f"use_aiter_moe: FP8 bloack_quant {activation=} will be supported later, unset use_aiter_moe" return asm_moe( x, layer.w13_weight, @@ -983,22 +977,44 @@ def maybe_apply_hip_fused_experts( expert_mask=None, ) else: + return ck_moe_2stages( x, layer.w13_weight, layer.w2_weight, topk_weights, topk_ids, - QuantType.per_Token, layer.w13_weight_scale1, layer.w2_weight_scale1, - activation=( - ActivationType.Silu - if activation == "silu" - else ActivationType.Gelu - ), + activation=ActivationType.Silu if activation=="silu" else ActivationType.Gelu, ) - return None + else: + # Expert fusion with FP8 quantization + return fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=inplace and not no_combine, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + use_fp8_w8a8=True, + w1_scale=( + layer.w13_weight_scale_inv + if self.block_quant + else layer.w13_weight_scale + ), + w2_scale=( + layer.w2_weight_scale_inv + if self.block_quant + else layer.w2_weight_scale + ), + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + block_shape=self.quant_config.weight_block_size, + no_combine=no_combine, + ) class Fp8KVCacheMethod(BaseKVCacheMethod): diff --git a/python/sglang/srt/models/grok.py b/python/sglang/srt/models/grok.py index a8cde8e09c02..eff7b72a49af 100644 --- a/python/sglang/srt/models/grok.py +++ b/python/sglang/srt/models/grok.py @@ -563,11 +563,11 @@ def __init__( debug_tensor_dump_inject = global_server_args_dict["debug_tensor_dump_inject"] warnings.filterwarnings("ignore", category=FutureWarning) - if get_tensor_model_parallel_rank() == 0: - logger.info( - f"#parameters (analytical): {self.get_num_params_analytical() / 1e9:.2f} B, " - f"#parameters (actual): {self.get_num_params_torch() / 1e9:.2f} B" - ) + # if get_tensor_model_parallel_rank() == 0: + # logger.info( + # f"#parameters (analytical): {self.get_num_params_analytical() / 1e9:.2f} B, " + # f"#parameters (actual): {self.get_num_params_torch() / 1e9:.2f} B" + # ) def forward( self,