Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/srtctl/benchmarks/sa_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,5 +99,7 @@ def build_command(
str(b.random_range_ratio) if b.random_range_ratio is not None else "0.8",
str(b.num_prompts_mult) if b.num_prompts_mult is not None else "10",
str(b.num_warmup_mult) if b.num_warmup_mult is not None else "2",
b.custom_tokenizer or "",
str(b.use_chat_template).lower(),
]
return cmd
75 changes: 67 additions & 8 deletions src/srtctl/benchmarks/scripts/sa-bench/backend_request_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,10 +565,52 @@ def _fix_v5_tokenizer_components(tokenizer, model_name_or_path):
backend.decoder = raw.decoder


def _load_glm_moe_dsa_tokenizer(pretrained_model_name_or_path: str) -> "PreTrainedTokenizerFast":
"""Load GLM-Moe-Dsa / GLM-5 tokenizer directly from tokenizer.json.

Works around incompatibilities when the checkpoint was saved with
transformers 5.x (TokenizersBackend / list-style extra_special_tokens).
"""
import json
from pathlib import Path

from tokenizers import Tokenizer as RustTokenizer
from transformers import PreTrainedTokenizerFast

_SAFE_CONFIG_KEYS = (
"pad_token", "pad_token_id", "eos_token", "eos_token_id",
"bos_token", "bos_token_id", "unk_token", "unk_token_id",
"model_max_length", "padding_side", "truncation_side",
)

path = Path(pretrained_model_name_or_path)
tokenizer_json = path / "tokenizer.json"
if not tokenizer_json.exists():
raise FileNotFoundError(
f"Expected tokenizer.json at {tokenizer_json}. "
"GlmMoeDsaTokenizer loads from tokenizer.json only."
)

rust_tok = RustTokenizer.from_file(str(tokenizer_json))
init_kwargs = {}
config_path = path / "tokenizer_config.json"
if config_path.exists():
with open(config_path, encoding="utf-8") as f:
config = json.load(f)
for key in _SAFE_CONFIG_KEYS:
if key in config:
init_kwargs[key] = config[key]
if "extra_special_tokens" in config:
init_kwargs["additional_special_tokens"] = config["extra_special_tokens"]

return PreTrainedTokenizerFast(tokenizer_object=rust_tok, **init_kwargs)


def get_tokenizer(
pretrained_model_name_or_path: str,
tokenizer_mode: str = "auto",
trust_remote_code: bool = False,
custom_tokenizer: str | None = None,
**kwargs,
) -> PreTrainedTokenizer | PreTrainedTokenizerFast:
if pretrained_model_name_or_path is not None and not os.path.exists(pretrained_model_name_or_path):
Expand All @@ -587,14 +629,31 @@ def get_tokenizer(
"to use mistral tokenizer mode."
) from e
return MistralTokenizer.from_pretrained(str(pretrained_model_name_or_path))
else:
tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path,
trust_remote_code=trust_remote_code,
**kwargs,
)
_fix_v5_tokenizer_components(tokenizer, pretrained_model_name_or_path)
return tokenizer
if custom_tokenizer:
if custom_tokenizer == "glm_moe_dsa":
return _load_glm_moe_dsa_tokenizer(pretrained_model_name_or_path)
from importlib import import_module
try:
module_path, class_name = custom_tokenizer.rsplit('.', 1)
module = import_module(module_path)
tokenizer_class = getattr(module, class_name)
return tokenizer_class.from_pretrained(
pretrained_model_name_or_path,
trust_remote_code=trust_remote_code,
**kwargs,
)
except (ValueError, ImportError, AttributeError) as e:
raise ValueError(
f"Failed to load custom_tokenizer '{custom_tokenizer}'. "
"Expected 'glm_moe_dsa' or 'module.path.ClassName'.") from e

tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path,
trust_remote_code=trust_remote_code,
**kwargs,
)
_fix_v5_tokenizer_components(tokenizer, pretrained_model_name_or_path)
return tokenizer


ASYNC_REQUEST_FUNCS = {
Expand Down
20 changes: 18 additions & 2 deletions src/srtctl/benchmarks/scripts/sa-bench/bench.sh
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,20 @@ DECODE_GPUS=${11:-0}
RANDOM_RANGE_RATIO=${12:-0.8}
NUM_PROMPTS_MULT=${13:-10}
NUM_WARMUP_MULT=${14:-2}
CUSTOM_TOKENIZER=${15:-}
USE_CHAT_TEMPLATE=${16:-true}

# Build optional custom tokenizer args
CUSTOM_TOKENIZER_ARGS=()
if [ -n "$CUSTOM_TOKENIZER" ]; then
CUSTOM_TOKENIZER_ARGS=(--custom-tokenizer "$CUSTOM_TOKENIZER")
fi

# Build optional chat template args
CHAT_TEMPLATE_ARGS=()
if [ "$USE_CHAT_TEMPLATE" = "true" ]; then
CHAT_TEMPLATE_ARGS=(--use-chat-template)
fi

# Parse endpoint into host:port
HOST=$(echo "$ENDPOINT" | sed 's|http://||' | cut -d: -f1)
Expand Down Expand Up @@ -121,7 +135,8 @@ for concurrency in "${CONCURRENCY_LIST[@]}"; do
--request-rate 250 \
--percentile-metrics ttft,tpot,itl,e2el \
--max-concurrency "$concurrency" \
--trust-remote-code
--trust-remote-code \
"${CUSTOM_TOKENIZER_ARGS[@]}"

num_prompts=$((concurrency * NUM_PROMPTS_MULT))

Expand Down Expand Up @@ -151,7 +166,8 @@ for concurrency in "${CONCURRENCY_LIST[@]}"; do
--percentile-metrics ttft,tpot,itl,e2el \
--max-concurrency "$concurrency" \
--trust-remote-code \
--use-chat-template \
"${CHAT_TEMPLATE_ARGS[@]}" \
"${CUSTOM_TOKENIZER_ARGS[@]}" \
--save-result --result-dir "$result_dir" --result-filename "$result_filename"
set +x

Expand Down
9 changes: 9 additions & 0 deletions src/srtctl/benchmarks/scripts/sa-bench/benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,6 +837,7 @@ def main(args: argparse.Namespace):
tokenizer_id,
tokenizer_mode=tokenizer_mode,
trust_remote_code=args.trust_remote_code,
custom_tokenizer=args.custom_tokenizer,
)

if args.dataset is not None:
Expand Down Expand Up @@ -1279,6 +1280,14 @@ def main(args: argparse.Namespace):
'"custom" will use --tokenizer to select the preregistered tokenizer.',
)

parser.add_argument(
"--custom-tokenizer",
type=str,
default=None,
help="Custom tokenizer to use (e.g., 'glm_moe_dsa' or 'module.path.ClassName'). "
"When set, overrides the default tokenizer loading.",
)

parser.add_argument(
"--served-model-name",
type=str,
Expand Down
2 changes: 2 additions & 0 deletions src/srtctl/core/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,8 @@ class BenchmarkConfig:
num_warmup_mult: int | None = None # Multiplier for warmup prompts = concurrency * mult (default: 2)
# Trace replay benchmark fields (uses aiperf with mooncake_trace dataset type)
trace_file: str | None = None # Path to trace JSONL file (container path, e.g., /traces/dataset.jsonl)
custom_tokenizer: str | None = None # Custom tokenizer class (e.g., "module.path.ClassName")
use_chat_template: bool = True # Pass --use-chat-template to benchmark (default: true)

def get_concurrency_list(self) -> list[int]:
if self.concurrencies is None:
Expand Down
Loading