Skip to content
20 changes: 1 addition & 19 deletions python/sglang/srt/managers/detokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from typing import Dict, List, Optional, Tuple, Union

import psutil
import pybase64
import setproctitle
import zmq

Expand Down Expand Up @@ -319,30 +318,13 @@ def _decode_batch_token_id_output(self, recv_obj: BatchTokenIDOutput):

return output_strs

def _extract_routed_experts(
self, recv_obj: BatchTokenIDOutput
) -> list[str | None] | None:
routed_experts = None
if recv_obj.routed_experts is not None:
routed_experts = [
(
pybase64.b64encode(routed_experts.numpy().tobytes()).decode("utf-8")
if routed_experts is not None
else None
)
for routed_experts in recv_obj.routed_experts
]
return routed_experts

def handle_batch_token_id_out(self, recv_obj: BatchTokenIDOutput):
# If handling idle batch, set output_strs to [].
output_strs = (
self._decode_batch_token_id_output(recv_obj)
if len(recv_obj.rids) > 0
else []
)
routed_experts = self._extract_routed_experts(recv_obj)

return BatchStrOutput(
rids=recv_obj.rids,
http_worker_ipcs=recv_obj.http_worker_ipcs,
Expand Down Expand Up @@ -370,7 +352,7 @@ def handle_batch_token_id_out(self, recv_obj: BatchTokenIDOutput):
output_token_ids_logprobs_idx=recv_obj.output_token_ids_logprobs_idx,
output_token_entropy_val=recv_obj.output_token_entropy_val,
output_hidden_states=recv_obj.output_hidden_states,
routed_experts=routed_experts,
routed_experts=recv_obj.routed_experts,
customized_info=recv_obj.customized_info,
placeholder_tokens_idx=None,
placeholder_tokens_val=None,
Expand Down
7 changes: 6 additions & 1 deletion python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union

import fastapi
import pybase64
import uvloop
import zmq
import zmq.asyncio
Expand Down Expand Up @@ -1597,7 +1598,11 @@ def _handle_batch_output(
if getattr(recv_obj, "output_hidden_states", None):
meta_info["hidden_states"] = recv_obj.output_hidden_states[i]
if getattr(recv_obj, "routed_experts", None):
meta_info["routed_experts"] = recv_obj.routed_experts[i]
routed_experts_tensor = recv_obj.routed_experts[i]
if routed_experts_tensor is not None:
meta_info["routed_experts"] = pybase64.b64encode(
routed_experts_tensor.numpy().tobytes()
).decode("utf-8")
if getattr(recv_obj, "customized_info", None):
for k, v in recv_obj.customized_info.items():
meta_info[k] = v[i]
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/utils/numa_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def _is_numa_available() -> bool:
return False

if not shutil.which("numactl") and envs.SGLANG_NUMA_BIND_V2.get():
logger.warning(
logger.debug(
"numactl command not found, skipping NUMA node configuration for GPU. Install numactl (e.g., apt-get install numactl) to enable automatic NUMA binding."
)
return False
Expand Down
11 changes: 10 additions & 1 deletion python/sglang/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,16 @@ def encode_image_base64(image_path: Union[str, bytes]):
elif isinstance(image_path, bytes):
return pybase64.b64encode(image_path).decode("utf-8")
else:
# image_path is PIL.WebPImagePlugin.WebPImageFile
import torch

if isinstance(image_path, torch.Tensor):
# Convert GPU-decoded image tensor (C, H, W) uint8 to PIL Image
from PIL import Image

tensor = image_path.cpu() if image_path.device.type != "cpu" else image_path
image_path = Image.fromarray(tensor.permute(1, 2, 0).numpy())

# image_path is a PIL Image
image = image_path
buffered = BytesIO()
image.save(buffered, format="PNG")
Expand Down
1 change: 1 addition & 0 deletions test/registered/language/test_srt_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def setUpClass(cls):
model_path=DEFAULT_MODEL_NAME_FOR_TEST,
cuda_graph_max_bs=4,
mem_fraction_static=0.7,
log_level="info",
)
sgl.set_default_backend(cls.backend)

Expand Down
39 changes: 16 additions & 23 deletions test/registered/rl/test_return_routed_experts.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import asyncio
import json
import logging
import unittest
from typing import List

import aiohttp
import requests
import torch
from torch.nn.utils.rnn import pad_sequence

from sglang.benchmark.utils import download_and_cache_hf_file
from sglang.srt.layers.moe.routed_experts_capturer import (
extract_routed_experts_from_meta_info,
)
Expand All @@ -21,23 +22,18 @@
popen_launch_server,
)

register_cuda_ci(est_time=360, suite="stage-c-test-4-gpu-h100")
register_cuda_ci(est_time=200, suite="stage-b-test-2-gpu-large")
register_amd_ci(
est_time=360,
suite="stage-c-test-4-gpu-amd",
disabled="TP=4 DP=4 routed expert mismatch >15% on AMD; needs TP/DP tuning + concurrency reduction",
est_time=200,
suite="stage-b-test-2-gpu-large-amd",
disabled="TP=2 DP=2 routed expert mismatch >15% on AMD; needs TP/DP tuning + concurrency reduction",
)

SHAREGPT_URL = (
"https://huggingface.co/datasets/anon8231489123/"
"ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json"
)
SHAREGPT_REPO_ID = "anon8231489123/ShareGPT_Vicuna_unfiltered"
SHAREGPT_FILENAME = "ShareGPT_V3_unfiltered_cleaned_split.json"
logger = logging.getLogger(__name__)


@unittest.skip(
"Flaky in CI, need to be fixed and re-enabled. See https://github.com/sgl-project/sglang/issues/21266"
)
class TestReturnRoutedExperts(CustomTestCase):
# modified from test_hicache.py
@classmethod
Expand All @@ -50,31 +46,28 @@ def setUpClass(cls):
"--disable-cuda-graph",
"--disable-radix-cache",
"--tp",
4,
2,
"--dp",
4,
2,
"--enable-dp-attention",
]
cls.reference_args = [
"--enable-return-routed-experts",
"--enable-deterministic-inference",
"--tp",
4,
2,
"--dp",
4,
2,
"--enable-dp-attention",
]
cls.sampling_args = {
"temperature": 0,
}
# prepare ShareGPT dataset
try:
response = requests.get(SHAREGPT_URL, timeout=60)
response.raise_for_status()
data = response.json()
print(f"Dataset size: {len(data)}")
except requests.exceptions.RequestException as e:
raise Exception(f"Failed to download ShareGPT dataset: {e}") from e
dataset_path = download_and_cache_hf_file(SHAREGPT_REPO_ID, SHAREGPT_FILENAME)
with open(dataset_path) as f:
data = json.load(f)
print(f"Dataset size: {len(data)}")
cls.texts = []
for s in data:
if "conversations" in s and len(s["conversations"]) > 0:
Expand Down
Loading