Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 40 additions & 2 deletions python/sglang/bench_one_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
set_gpu_proc_affinity,
suppress_other_loggers,
)

from rpdTracerControl import rpdTracerControl

@dataclasses.dataclass
class BenchArgs:
Expand All @@ -89,6 +89,8 @@ class BenchArgs:
log_decode_step: int = 0
profile: bool = False
profile_filename_prefix: str = "profile"
enable_prefill_prof: bool = False
enable_decode_prof: bool = False

@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
Expand Down Expand Up @@ -123,6 +125,14 @@ def add_cli_args(parser: argparse.ArgumentParser):
help="Prefix of the profiling file names. The full profiling result file(s) be "
'"[profile_filename_prefix]_batch[batch_size]_input[input_len]_output[output_len].trace.json.gz"',
)
parser.add_argument(
"--enable-decode-prof",
action='store_true',
help="enable decode profiler.")
parser.add_argument(
"--enable-prefill-prof",
action='store_true',
help="enable prefill profiler.")

@classmethod
def from_cli_args(cls, args: argparse.Namespace):
Expand Down Expand Up @@ -327,6 +337,7 @@ def synchronize(device):


def latency_test_run_once(
is_warm_up, enable_prefill_prof, enable_decode_prof, tp_rank,
run_name,
model_runner,
rank_print,
Expand Down Expand Up @@ -373,8 +384,14 @@ def latency_test_run_once(
# Prefill
synchronize(device)
tic = time.time()
if enable_prefill_prof and not is_warm_up and tp_rank == 0:
print("Start profile Prefill")
prefill_profile = rpdTracerControl()
prefill_profile.start()
next_token_ids, _, batch = extend(reqs, model_runner)
synchronize(device)
if enable_prefill_prof and not is_warm_up and tp_rank == 0:
prefill_profile.stop()
prefill_latency = time.time() - tic
tot_latency += prefill_latency
throughput = input_len * batch_size / prefill_latency
Expand All @@ -386,9 +403,15 @@ def latency_test_run_once(

# Decode
decode_latencies = []
if enable_decode_prof and not is_warm_up and tp_rank == 0:
print("Start profile Decode")
# Create first instance (this loads the profiler and creates the file)
decode_profile = rpdTracerControl()
decode_profile.start()
for i in range(output_len - 1):
synchronize(device)
tic = time.time()

next_token_ids, _ = decode(next_token_ids, batch, model_runner)
synchronize(device)
latency = time.time() - tic
Expand All @@ -399,7 +422,8 @@ def latency_test_run_once(
rank_print(
f"Decode {i}. Batch size: {batch_size}, latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
)

if enable_decode_prof and not is_warm_up and tp_rank == 0:
decode_profile.stop()
if profile:
profiler.stop()
profile_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}.trace.json.gz"
Expand Down Expand Up @@ -452,6 +476,10 @@ def latency_test(
# Warm up
rank_print("Warmup ...")
latency_test_run_once(
True,
bench_args.enable_prefill_prof,
bench_args.enable_decode_prof,
tp_rank,
bench_args.run_name,
model_runner,
rank_print,
Expand All @@ -474,6 +502,10 @@ def latency_test(
):
reqs = prepare_synthetic_inputs_for_latency_test(bs, il)
ret = latency_test_run_once(
False,
bench_args.enable_prefill_prof,
bench_args.enable_decode_prof,
tp_rank,
bench_args.run_name,
model_runner,
rank_print,
Expand Down Expand Up @@ -501,6 +533,12 @@ def latency_test(

def main(server_args, bench_args):
server_args.cuda_graph_max_bs = max(bench_args.batch_size)
if bench_args.enable_prefill_prof or bench_args.enable_decode_prof:
# Optionally call this class method before creating first instance
rpdTracerControl.setFilename(name = "trace.rpd", append=False)

# Create first instance (this loads the profiler and creates the file)
profile = rpdTracerControl()

_set_envs_and_config(server_args)

Expand Down
39 changes: 17 additions & 22 deletions python/sglang/srt/layers/moe/fused_moe_triton/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.utils import get_bool_env_var, is_hip, set_weight_attrs
from sglang.srt.utils import get_bool_env_var, is_hip, permute_weight, set_weight_attrs

if torch.cuda.is_available():
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
Expand All @@ -30,9 +30,7 @@
_is_hip = is_hip()

if _is_hip:
from aiter import ActivationType
from aiter.fused_moe_bf16_asm import ck_moe_2stages
from aiter.ops.shuffle import shuffle_weight
from aiter import ck_moe

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -104,14 +102,14 @@ def create_weights(
set_weight_attrs(w2_weight, extra_weight_attrs)

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if _is_hip and get_bool_env_var("SGLANG_AITER_MOE"):
if _is_hip and get_bool_env_var("CK_MOE"):
layer.w13_weight = torch.nn.Parameter(
shuffle_weight(layer.w13_weight.data, (16, 16)),
permute_weight(layer.w13_weight.data),
requires_grad=False,
)
torch.cuda.empty_cache()
layer.w2_weight = torch.nn.Parameter(
shuffle_weight(layer.w2_weight.data, (16, 16)),
permute_weight(layer.w2_weight.data),
requires_grad=False,
)
torch.cuda.empty_cache()
Expand All @@ -133,7 +131,6 @@ def apply(
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor:
return self.forward(
x=x,
Expand All @@ -150,7 +147,6 @@ def apply(
apply_router_weight_on_input=apply_router_weight_on_input,
inplace=inplace,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
)

def forward_cuda(
Expand All @@ -169,7 +165,6 @@ def forward_cuda(
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor:
topk_weights, topk_ids = select_experts(
hidden_states=x,
Expand All @@ -181,20 +176,23 @@ def forward_cuda(
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
)

if _is_hip and get_bool_env_var("SGLANG_AITER_MOE"):
if _is_hip and get_bool_env_var("CK_MOE"):
assert not no_combine, "unsupported"
return ck_moe_2stages(
return ck_moe(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
activation=(
ActivationType.Silu if activation == "silu" else ActivationType.Gelu
),
None,
None,
None,
None,
32,
None,
activation,
)
else:
return fused_experts(
Expand Down Expand Up @@ -286,7 +284,6 @@ def __init__(
use_presharded_weights: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
):
super().__init__()

Expand All @@ -296,7 +293,6 @@ def __init__(
self.tp_size = (
tp_size if tp_size is not None else get_tensor_model_parallel_world_size()
)
self.routed_scaling_factor = routed_scaling_factor
self.top_k = top_k
self.num_experts = num_experts
assert intermediate_size % self.tp_size == 0
Expand Down Expand Up @@ -525,7 +521,7 @@ def weight_loader(
# Case input scale: input_scale loading is only supported for fp8
if "input_scale" in weight_name:
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust input_scale for e4m3fnuz (AMD)
if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
loaded_weight = loaded_weight * 2.0

# this is needed for compressed-tensors only
Expand Down Expand Up @@ -567,7 +563,7 @@ def weight_loader(
quant_method = getattr(param, "quant_method", None)
if quant_method == FusedMoeWeightScaleSupported.CHANNEL.value:
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust INT4 column-wise scaling number to e4m3fnuz (AMD)
if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
loaded_weight = loaded_weight * 0.5

self._load_per_channel_weight_scale(
Expand All @@ -590,7 +586,7 @@ def weight_loader(
)
elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value:
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust FP8 per-tensor scaling number for e4m3fnuz (AMD)
if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
loaded_weight = loaded_weight * 2.0

self._load_per_tensor_weight_scale(
Expand Down Expand Up @@ -641,7 +637,6 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
correction_bias=self.correction_bias,
activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input,
routed_scaling_factor=self.routed_scaling_factor,
)

if self.reduce_results and self.tp_size > 1:
Expand Down
Loading
Loading