diff --git a/src/srtctl/benchmarks/sa_bench.py b/src/srtctl/benchmarks/sa_bench.py index 12298c86..5f220393 100644 --- a/src/srtctl/benchmarks/sa_bench.py +++ b/src/srtctl/benchmarks/sa_bench.py @@ -99,5 +99,7 @@ def build_command( str(b.random_range_ratio) if b.random_range_ratio is not None else "0.8", str(b.num_prompts_mult) if b.num_prompts_mult is not None else "10", str(b.num_warmup_mult) if b.num_warmup_mult is not None else "2", + b.custom_tokenizer or "", + str(b.use_chat_template).lower(), ] return cmd diff --git a/src/srtctl/benchmarks/scripts/sa-bench/backend_request_func.py b/src/srtctl/benchmarks/scripts/sa-bench/backend_request_func.py index 8f037abb..1a715ee5 100644 --- a/src/srtctl/benchmarks/scripts/sa-bench/backend_request_func.py +++ b/src/srtctl/benchmarks/scripts/sa-bench/backend_request_func.py @@ -565,10 +565,52 @@ def _fix_v5_tokenizer_components(tokenizer, model_name_or_path): backend.decoder = raw.decoder +def _load_glm_moe_dsa_tokenizer(pretrained_model_name_or_path: str) -> "PreTrainedTokenizerFast": + """Load GLM-Moe-Dsa / GLM-5 tokenizer directly from tokenizer.json. + + Works around incompatibilities when the checkpoint was saved with + transformers 5.x (TokenizersBackend / list-style extra_special_tokens). + """ + import json + from pathlib import Path + + from tokenizers import Tokenizer as RustTokenizer + from transformers import PreTrainedTokenizerFast + + _SAFE_CONFIG_KEYS = ( + "pad_token", "pad_token_id", "eos_token", "eos_token_id", + "bos_token", "bos_token_id", "unk_token", "unk_token_id", + "model_max_length", "padding_side", "truncation_side", + ) + + path = Path(pretrained_model_name_or_path) + tokenizer_json = path / "tokenizer.json" + if not tokenizer_json.exists(): + raise FileNotFoundError( + f"Expected tokenizer.json at {tokenizer_json}. " + "GlmMoeDsaTokenizer loads from tokenizer.json only." + ) + + rust_tok = RustTokenizer.from_file(str(tokenizer_json)) + init_kwargs = {} + config_path = path / "tokenizer_config.json" + if config_path.exists(): + with open(config_path, encoding="utf-8") as f: + config = json.load(f) + for key in _SAFE_CONFIG_KEYS: + if key in config: + init_kwargs[key] = config[key] + if "extra_special_tokens" in config: + init_kwargs["additional_special_tokens"] = config["extra_special_tokens"] + + return PreTrainedTokenizerFast(tokenizer_object=rust_tok, **init_kwargs) + + def get_tokenizer( pretrained_model_name_or_path: str, tokenizer_mode: str = "auto", trust_remote_code: bool = False, + custom_tokenizer: str | None = None, **kwargs, ) -> PreTrainedTokenizer | PreTrainedTokenizerFast: if pretrained_model_name_or_path is not None and not os.path.exists(pretrained_model_name_or_path): @@ -587,14 +629,31 @@ def get_tokenizer( "to use mistral tokenizer mode." ) from e return MistralTokenizer.from_pretrained(str(pretrained_model_name_or_path)) - else: - tokenizer = AutoTokenizer.from_pretrained( - pretrained_model_name_or_path, - trust_remote_code=trust_remote_code, - **kwargs, - ) - _fix_v5_tokenizer_components(tokenizer, pretrained_model_name_or_path) - return tokenizer + if custom_tokenizer: + if custom_tokenizer == "glm_moe_dsa": + return _load_glm_moe_dsa_tokenizer(pretrained_model_name_or_path) + from importlib import import_module + try: + module_path, class_name = custom_tokenizer.rsplit('.', 1) + module = import_module(module_path) + tokenizer_class = getattr(module, class_name) + return tokenizer_class.from_pretrained( + pretrained_model_name_or_path, + trust_remote_code=trust_remote_code, + **kwargs, + ) + except (ValueError, ImportError, AttributeError) as e: + raise ValueError( + f"Failed to load custom_tokenizer '{custom_tokenizer}'. " + "Expected 'glm_moe_dsa' or 'module.path.ClassName'.") from e + + tokenizer = AutoTokenizer.from_pretrained( + pretrained_model_name_or_path, + trust_remote_code=trust_remote_code, + **kwargs, + ) + _fix_v5_tokenizer_components(tokenizer, pretrained_model_name_or_path) + return tokenizer ASYNC_REQUEST_FUNCS = { diff --git a/src/srtctl/benchmarks/scripts/sa-bench/bench.sh b/src/srtctl/benchmarks/scripts/sa-bench/bench.sh index 9df299e0..518a696a 100644 --- a/src/srtctl/benchmarks/scripts/sa-bench/bench.sh +++ b/src/srtctl/benchmarks/scripts/sa-bench/bench.sh @@ -62,6 +62,20 @@ DECODE_GPUS=${11:-0} RANDOM_RANGE_RATIO=${12:-0.8} NUM_PROMPTS_MULT=${13:-10} NUM_WARMUP_MULT=${14:-2} +CUSTOM_TOKENIZER=${15:-} +USE_CHAT_TEMPLATE=${16:-true} + +# Build optional custom tokenizer args +CUSTOM_TOKENIZER_ARGS=() +if [ -n "$CUSTOM_TOKENIZER" ]; then + CUSTOM_TOKENIZER_ARGS=(--custom-tokenizer "$CUSTOM_TOKENIZER") +fi + +# Build optional chat template args +CHAT_TEMPLATE_ARGS=() +if [ "$USE_CHAT_TEMPLATE" = "true" ]; then + CHAT_TEMPLATE_ARGS=(--use-chat-template) +fi # Parse endpoint into host:port HOST=$(echo "$ENDPOINT" | sed 's|http://||' | cut -d: -f1) @@ -121,7 +135,8 @@ for concurrency in "${CONCURRENCY_LIST[@]}"; do --request-rate 250 \ --percentile-metrics ttft,tpot,itl,e2el \ --max-concurrency "$concurrency" \ - --trust-remote-code + --trust-remote-code \ + "${CUSTOM_TOKENIZER_ARGS[@]}" num_prompts=$((concurrency * NUM_PROMPTS_MULT)) @@ -151,7 +166,8 @@ for concurrency in "${CONCURRENCY_LIST[@]}"; do --percentile-metrics ttft,tpot,itl,e2el \ --max-concurrency "$concurrency" \ --trust-remote-code \ - --use-chat-template \ + "${CHAT_TEMPLATE_ARGS[@]}" \ + "${CUSTOM_TOKENIZER_ARGS[@]}" \ --save-result --result-dir "$result_dir" --result-filename "$result_filename" set +x diff --git a/src/srtctl/benchmarks/scripts/sa-bench/benchmark_serving.py b/src/srtctl/benchmarks/scripts/sa-bench/benchmark_serving.py index 4363ef6e..a5ea6490 100644 --- a/src/srtctl/benchmarks/scripts/sa-bench/benchmark_serving.py +++ b/src/srtctl/benchmarks/scripts/sa-bench/benchmark_serving.py @@ -837,6 +837,7 @@ def main(args: argparse.Namespace): tokenizer_id, tokenizer_mode=tokenizer_mode, trust_remote_code=args.trust_remote_code, + custom_tokenizer=args.custom_tokenizer, ) if args.dataset is not None: @@ -1279,6 +1280,14 @@ def main(args: argparse.Namespace): '"custom" will use --tokenizer to select the preregistered tokenizer.', ) + parser.add_argument( + "--custom-tokenizer", + type=str, + default=None, + help="Custom tokenizer to use (e.g., 'glm_moe_dsa' or 'module.path.ClassName'). " + "When set, overrides the default tokenizer loading.", + ) + parser.add_argument( "--served-model-name", type=str, diff --git a/src/srtctl/core/schema.py b/src/srtctl/core/schema.py index 159042d9..99305a04 100644 --- a/src/srtctl/core/schema.py +++ b/src/srtctl/core/schema.py @@ -544,6 +544,8 @@ class BenchmarkConfig: num_warmup_mult: int | None = None # Multiplier for warmup prompts = concurrency * mult (default: 2) # Trace replay benchmark fields (uses aiperf with mooncake_trace dataset type) trace_file: str | None = None # Path to trace JSONL file (container path, e.g., /traces/dataset.jsonl) + custom_tokenizer: str | None = None # Custom tokenizer class (e.g., "module.path.ClassName") + use_chat_template: bool = True # Pass --use-chat-template to benchmark (default: true) def get_concurrency_list(self) -> list[int]: if self.concurrencies is None: