From 2e5b079de25adb9d2f3c6e7a9ea1d3c5703f60d1 Mon Sep 17 00:00:00 2001 From: Pengyu Chen Date: Wed, 4 Mar 2026 18:35:29 +0000 Subject: [PATCH 001/112] Gemma4 init. --- python/sglang/bench_one_batch.py | 4 +- python/sglang/jit_kernel/kvcache.py | 1 + python/sglang/srt/configs/model_config.py | 10 + .../sglang/srt/grpc/sglang_scheduler_pb2.py | 134 +++ .../sglang/srt/grpc/sglang_scheduler_pb2.pyi | 632 ++++++++++++ .../srt/grpc/sglang_scheduler_pb2_grpc.py | 368 +++++++ .../srt/layers/attention/triton_backend.py | 46 +- python/sglang/srt/layers/layernorm.py | 65 +- .../srt/layers/rotary_embedding/base.py | 9 + .../layers/rotary_embedding/rope_variant.py | 63 ++ python/sglang/srt/managers/tp_worker.py | 1 + python/sglang/srt/mem_cache/common.py | 1 + python/sglang/srt/mem_cache/memory_pool.py | 17 +- .../sglang/srt/mem_cache/swa_memory_pool.py | 4 + .../sglang/srt/model_executor/model_runner.py | 2 + .../model_runner_kv_cache_mixin.py | 7 +- python/sglang/srt/models/gemma3_causal.py | 8 +- python/sglang/srt/models/gemma4_causal.py | 914 ++++++++++++++++++ scripts/playground/reference_hf.py | 10 +- 19 files changed, 2263 insertions(+), 33 deletions(-) create mode 100644 python/sglang/srt/grpc/sglang_scheduler_pb2.py create mode 100644 python/sglang/srt/grpc/sglang_scheduler_pb2.pyi create mode 100644 python/sglang/srt/grpc/sglang_scheduler_pb2_grpc.py create mode 100644 python/sglang/srt/models/gemma4_causal.py diff --git a/python/sglang/bench_one_batch.py b/python/sglang/bench_one_batch.py index 8cf0aee1ba22..4c912cfaf696 100644 --- a/python/sglang/bench_one_batch.py +++ b/python/sglang/bench_one_batch.py @@ -291,8 +291,8 @@ def prepare_inputs_for_correctness_test(bench_args, tokenizer, custom_prompts): if custom_prompts else [ "The capital of France is", - "The capital of the United Kindom is", - "Today is a sunny day and I like", + # "The capital of the United Kindom is", + # "Today is a sunny day and I like", ] ) input_ids = [tokenizer.encode(p) for p in prompts] diff --git a/python/sglang/jit_kernel/kvcache.py b/python/sglang/jit_kernel/kvcache.py index 46a14612b6ff..065ca6eb4918 100644 --- a/python/sglang/jit_kernel/kvcache.py +++ b/python/sglang/jit_kernel/kvcache.py @@ -65,6 +65,7 @@ def store_cache( v_cache (torch.Tensor): Value cache tensor of shape (num_pages, H * D). indices (torch.Tensor): Indices tensor of shape (batch_size,). """ + # print(f"store_cache called with k.shape={k.shape}, v.shape={v.shape}, k_cache.shape={k_cache.shape}, v_cache.shape={v_cache.shape}, indices.shape={indices.shape}, row_bytes={row_bytes}, num_split={num_split}") row_bytes = row_bytes or k.shape[-1] * k.element_size() module = _jit_kvcache_module(row_bytes) if num_split <= 0: diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index c5389ee699e5..ab2f61bac530 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -359,6 +359,7 @@ def _derive_hybrid_model(self): self.is_hybrid_swa_compress = self.hf_config.architectures[0] in [ "MiMoV2FlashForCausalLM", "MiMoV2MTP", + "Gemma4ForCausalLM", ] def _derive_context_length(self, context_length: int): @@ -1407,6 +1408,7 @@ def is_hybrid_swa_model(model_architectures: List[str]): "MiMoV2MTP", "Step3p5ForCausalLM", "Step3p5MTP", + "Gemma4ForCausalLM", } return any(arch in hybrid_swa_archs for arch in model_architectures) @@ -1457,6 +1459,14 @@ def get_hybrid_layer_ids( elif "Step3p5MTP" in model_architectures: swa_attention_layer_ids = [0] full_attention_layer_ids = [] + elif "Gemma4ForCausalLM" in model_architectures: + layer_types = getattr(hf_text_config, "layer_types", None) + swa_attention_layer_ids = [ + i for i, x in enumerate(layer_types) if x == "sliding_attention" + ] + full_attention_layer_ids = [ + i for i, x in enumerate(layer_types) if x == "full_attention" + ] else: swa_attention_layer_ids = None full_attention_layer_ids = None diff --git a/python/sglang/srt/grpc/sglang_scheduler_pb2.py b/python/sglang/srt/grpc/sglang_scheduler_pb2.py new file mode 100644 index 000000000000..e99981e3702b --- /dev/null +++ b/python/sglang/srt/grpc/sglang_scheduler_pb2.py @@ -0,0 +1,134 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# NO CHECKED-IN PROTOBUF GENCODE +# source: sglang_scheduler.proto +# Protobuf Python Version: 6.31.1 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import runtime_version as _runtime_version +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +_runtime_version.ValidateProtobufRuntimeVersion( + _runtime_version.Domain.PUBLIC, + 6, + 31, + 1, + '', + 'sglang_scheduler.proto' +) +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from google.protobuf import timestamp_pb2 as google_dot_protobuf_dot_timestamp__pb2 +from google.protobuf import struct_pb2 as google_dot_protobuf_dot_struct__pb2 + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16sglang_scheduler.proto\x12\x15sglang.grpc.scheduler\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1cgoogle/protobuf/struct.proto\"\xd0\x05\n\x0eSamplingParams\x12\x13\n\x0btemperature\x18\x01 \x01(\x02\x12\r\n\x05top_p\x18\x02 \x01(\x02\x12\r\n\x05top_k\x18\x03 \x01(\x05\x12\r\n\x05min_p\x18\x04 \x01(\x02\x12\x19\n\x11\x66requency_penalty\x18\x05 \x01(\x02\x12\x18\n\x10presence_penalty\x18\x06 \x01(\x02\x12\x1a\n\x12repetition_penalty\x18\x07 \x01(\x02\x12\x1b\n\x0emax_new_tokens\x18\x08 \x01(\x05H\x01\x88\x01\x01\x12\x0c\n\x04stop\x18\t \x03(\t\x12\x16\n\x0estop_token_ids\x18\n \x03(\r\x12\x1b\n\x13skip_special_tokens\x18\x0b \x01(\x08\x12%\n\x1dspaces_between_special_tokens\x18\x0c \x01(\x08\x12\x0f\n\x05regex\x18\r \x01(\tH\x00\x12\x15\n\x0bjson_schema\x18\x0e \x01(\tH\x00\x12\x16\n\x0c\x65\x62nf_grammar\x18\x0f \x01(\tH\x00\x12\x18\n\x0estructural_tag\x18\x10 \x01(\tH\x00\x12\t\n\x01n\x18\x11 \x01(\x05\x12\x16\n\x0emin_new_tokens\x18\x12 \x01(\x05\x12\x12\n\nignore_eos\x18\x13 \x01(\x08\x12\x14\n\x0cno_stop_trim\x18\x14 \x01(\x08\x12\x1c\n\x0fstream_interval\x18\x15 \x01(\x05H\x02\x88\x01\x01\x12H\n\nlogit_bias\x18\x16 \x03(\x0b\x32\x34.sglang.grpc.scheduler.SamplingParams.LogitBiasEntry\x12.\n\rcustom_params\x18\x17 \x01(\x0b\x32\x17.google.protobuf.Struct\x1a\x30\n\x0eLogitBiasEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\x42\x0c\n\nconstraintB\x11\n\x0f_max_new_tokensB\x12\n\x10_stream_interval\"]\n\x13\x44isaggregatedParams\x12\x16\n\x0e\x62ootstrap_host\x18\x01 \x01(\t\x12\x16\n\x0e\x62ootstrap_port\x18\x02 \x01(\x05\x12\x16\n\x0e\x62ootstrap_room\x18\x03 \x01(\x05\"\xe2\x04\n\x0fGenerateRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\ttokenized\x18\x02 \x01(\x0b\x32%.sglang.grpc.scheduler.TokenizedInput\x12:\n\tmm_inputs\x18\x03 \x01(\x0b\x32\'.sglang.grpc.scheduler.MultimodalInputs\x12>\n\x0fsampling_params\x18\x04 \x01(\x0b\x32%.sglang.grpc.scheduler.SamplingParams\x12\x16\n\x0ereturn_logprob\x18\x05 \x01(\x08\x12\x19\n\x11logprob_start_len\x18\x06 \x01(\x05\x12\x18\n\x10top_logprobs_num\x18\x07 \x01(\x05\x12\x19\n\x11token_ids_logprob\x18\x08 \x03(\r\x12\x1c\n\x14return_hidden_states\x18\t \x01(\x08\x12H\n\x14\x64isaggregated_params\x18\n \x01(\x0b\x32*.sglang.grpc.scheduler.DisaggregatedParams\x12\x1e\n\x16\x63ustom_logit_processor\x18\x0b \x01(\t\x12-\n\ttimestamp\x18\x0c \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x13\n\x0blog_metrics\x18\r \x01(\x08\x12\x14\n\x0cinput_embeds\x18\x0e \x03(\x02\x12\x0f\n\x07lora_id\x18\x0f \x01(\t\x12\x1a\n\x12\x64\x61ta_parallel_rank\x18\x10 \x01(\x05\x12\x0e\n\x06stream\x18\x11 \x01(\x08\":\n\x0eTokenizedInput\x12\x15\n\roriginal_text\x18\x01 \x01(\t\x12\x11\n\tinput_ids\x18\x02 \x03(\r\"\xd3\x01\n\x10MultimodalInputs\x12\x12\n\nimage_urls\x18\x01 \x03(\t\x12\x12\n\nvideo_urls\x18\x02 \x03(\t\x12\x12\n\naudio_urls\x18\x03 \x03(\t\x12\x33\n\x12processed_features\x18\x04 \x01(\x0b\x32\x17.google.protobuf.Struct\x12\x12\n\nimage_data\x18\x05 \x03(\x0c\x12\x12\n\nvideo_data\x18\x06 \x03(\x0c\x12\x12\n\naudio_data\x18\x07 \x03(\x0c\x12\x12\n\nmodalities\x18\x08 \x03(\t\"\xe3\x01\n\x10GenerateResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12;\n\x05\x63hunk\x18\x02 \x01(\x0b\x32*.sglang.grpc.scheduler.GenerateStreamChunkH\x00\x12;\n\x08\x63omplete\x18\x03 \x01(\x0b\x32\'.sglang.grpc.scheduler.GenerateCompleteH\x00\x12\x35\n\x05\x65rror\x18\x04 \x01(\x0b\x32$.sglang.grpc.scheduler.GenerateErrorH\x00\x42\n\n\x08response\"\x95\x02\n\x13GenerateStreamChunk\x12\x11\n\ttoken_ids\x18\x01 \x03(\r\x12\x15\n\rprompt_tokens\x18\x02 \x01(\x05\x12\x19\n\x11\x63ompletion_tokens\x18\x03 \x01(\x05\x12\x15\n\rcached_tokens\x18\x04 \x01(\x05\x12>\n\x0foutput_logprobs\x18\x05 \x01(\x0b\x32%.sglang.grpc.scheduler.OutputLogProbs\x12\x15\n\rhidden_states\x18\x06 \x03(\x02\x12<\n\x0einput_logprobs\x18\x07 \x01(\x0b\x32$.sglang.grpc.scheduler.InputLogProbs\x12\r\n\x05index\x18\x08 \x01(\r\"\x9b\x03\n\x10GenerateComplete\x12\x12\n\noutput_ids\x18\x01 \x03(\r\x12\x15\n\rfinish_reason\x18\x02 \x01(\t\x12\x15\n\rprompt_tokens\x18\x03 \x01(\x05\x12\x19\n\x11\x63ompletion_tokens\x18\x04 \x01(\x05\x12\x15\n\rcached_tokens\x18\x05 \x01(\x05\x12>\n\x0foutput_logprobs\x18\x06 \x01(\x0b\x32%.sglang.grpc.scheduler.OutputLogProbs\x12>\n\x11\x61ll_hidden_states\x18\x07 \x03(\x0b\x32#.sglang.grpc.scheduler.HiddenStates\x12\x1a\n\x10matched_token_id\x18\x08 \x01(\rH\x00\x12\x1a\n\x10matched_stop_str\x18\t \x01(\tH\x00\x12<\n\x0einput_logprobs\x18\n \x01(\x0b\x32$.sglang.grpc.scheduler.InputLogProbs\x12\r\n\x05index\x18\x0b \x01(\rB\x0e\n\x0cmatched_stop\"K\n\rGenerateError\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x18\n\x10http_status_code\x18\x02 \x01(\t\x12\x0f\n\x07\x64\x65tails\x18\x03 \x01(\t\"u\n\x0eOutputLogProbs\x12\x16\n\x0etoken_logprobs\x18\x01 \x03(\x02\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x12\x38\n\x0ctop_logprobs\x18\x03 \x03(\x0b\x32\".sglang.grpc.scheduler.TopLogProbs\"\x9e\x01\n\rInputLogProbs\x12@\n\x0etoken_logprobs\x18\x01 \x03(\x0b\x32(.sglang.grpc.scheduler.InputTokenLogProb\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x12\x38\n\x0ctop_logprobs\x18\x03 \x03(\x0b\x32\".sglang.grpc.scheduler.TopLogProbs\"1\n\x11InputTokenLogProb\x12\x12\n\x05value\x18\x01 \x01(\x02H\x00\x88\x01\x01\x42\x08\n\x06_value\"0\n\x0bTopLogProbs\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\"?\n\x0cHiddenStates\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\r\n\x05layer\x18\x02 \x01(\x05\x12\x10\n\x08position\x18\x03 \x01(\x05\"\xca\x02\n\x0c\x45mbedRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\ttokenized\x18\x02 \x01(\x0b\x32%.sglang.grpc.scheduler.TokenizedInput\x12:\n\tmm_inputs\x18\x04 \x01(\x0b\x32\'.sglang.grpc.scheduler.MultimodalInputs\x12>\n\x0fsampling_params\x18\x05 \x01(\x0b\x32%.sglang.grpc.scheduler.SamplingParams\x12\x13\n\x0blog_metrics\x18\x06 \x01(\x08\x12\x16\n\x0etoken_type_ids\x18\x07 \x03(\x05\x12\x1a\n\x12\x64\x61ta_parallel_rank\x18\x08 \x01(\x05\x12\x18\n\x10is_cross_encoder\x18\t \x01(\x08\x12\r\n\x05texts\x18\n \x03(\t\"\x9d\x01\n\rEmbedResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\x08\x63omplete\x18\x02 \x01(\x0b\x32$.sglang.grpc.scheduler.EmbedCompleteH\x00\x12\x32\n\x05\x65rror\x18\x03 \x01(\x0b\x32!.sglang.grpc.scheduler.EmbedErrorH\x00\x42\n\n\x08response\"\xa3\x01\n\rEmbedComplete\x12\x11\n\tembedding\x18\x01 \x03(\x02\x12\x15\n\rprompt_tokens\x18\x02 \x01(\x05\x12\x15\n\rcached_tokens\x18\x03 \x01(\x05\x12\x15\n\rembedding_dim\x18\x04 \x01(\x05\x12:\n\x10\x62\x61tch_embeddings\x18\x05 \x03(\x0b\x32 .sglang.grpc.scheduler.Embedding\"*\n\tEmbedding\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\r\n\x05index\x18\x02 \x01(\x05\"<\n\nEmbedError\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x0c\n\x04\x63ode\x18\x02 \x01(\t\x12\x0f\n\x07\x64\x65tails\x18\x03 \x01(\t\"\x14\n\x12HealthCheckRequest\"7\n\x13HealthCheckResponse\x12\x0f\n\x07healthy\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"2\n\x0c\x41\x62ortRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06reason\x18\x02 \x01(\t\"1\n\rAbortResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"I\n\x0fLoadLoRARequest\x12\x12\n\nadapter_id\x18\x01 \x01(\t\x12\x14\n\x0c\x61\x64\x61pter_path\x18\x02 \x01(\t\x12\x0c\n\x04rank\x18\x03 \x01(\x05\"H\n\x10LoadLoRAResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x12\n\nadapter_id\x18\x02 \x01(\t\x12\x0f\n\x07message\x18\x03 \x01(\t\"\'\n\x11UnloadLoRARequest\x12\x12\n\nadapter_id\x18\x01 \x01(\t\"6\n\x12UnloadLoRAResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"w\n\x14UpdateWeightsRequest\x12\x13\n\tdisk_path\x18\x01 \x01(\tH\x00\x12\x15\n\x0btensor_data\x18\x02 \x01(\x0cH\x00\x12\x14\n\nremote_url\x18\x03 \x01(\tH\x00\x12\x13\n\x0bweight_name\x18\x04 \x01(\tB\x08\n\x06source\"9\n\x15UpdateWeightsResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"-\n\x17GetInternalStateRequest\x12\x12\n\nstate_keys\x18\x01 \x03(\t\"B\n\x18GetInternalStateResponse\x12&\n\x05state\x18\x01 \x01(\x0b\x32\x17.google.protobuf.Struct\"A\n\x17SetInternalStateRequest\x12&\n\x05state\x18\x01 \x01(\x0b\x32\x17.google.protobuf.Struct\"<\n\x18SetInternalStateResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"\x15\n\x13GetModelInfoRequest\"\xac\x03\n\x14GetModelInfoResponse\x12\x12\n\nmodel_path\x18\x01 \x01(\t\x12\x16\n\x0etokenizer_path\x18\x02 \x01(\t\x12\x15\n\ris_generation\x18\x03 \x01(\x08\x12!\n\x19preferred_sampling_params\x18\x04 \x01(\t\x12\x16\n\x0eweight_version\x18\x05 \x01(\t\x12\x19\n\x11served_model_name\x18\x06 \x01(\t\x12\x1a\n\x12max_context_length\x18\x07 \x01(\x05\x12\x12\n\nvocab_size\x18\x08 \x01(\x05\x12\x17\n\x0fsupports_vision\x18\t \x01(\x08\x12\x12\n\nmodel_type\x18\n \x01(\t\x12\x15\n\reos_token_ids\x18\x0b \x03(\x05\x12\x14\n\x0cpad_token_id\x18\x0c \x01(\x05\x12\x14\n\x0c\x62os_token_id\x18\r \x01(\x05\x12\x19\n\x11max_req_input_len\x18\x0e \x01(\x05\x12\x15\n\rarchitectures\x18\x0f \x03(\t\x12\x15\n\rid2label_json\x18\x10 \x01(\t\x12\x12\n\nnum_labels\x18\x11 \x01(\x05\"\x16\n\x14GetServerInfoRequest\"\xb7\x02\n\x15GetServerInfoResponse\x12,\n\x0bserver_args\x18\x01 \x01(\x0b\x32\x17.google.protobuf.Struct\x12/\n\x0escheduler_info\x18\x02 \x01(\x0b\x32\x17.google.protobuf.Struct\x12\x17\n\x0f\x61\x63tive_requests\x18\x03 \x01(\x05\x12\x11\n\tis_paused\x18\x04 \x01(\x08\x12\x1e\n\x16last_receive_timestamp\x18\x05 \x01(\x01\x12\x16\n\x0euptime_seconds\x18\x06 \x01(\x01\x12\x16\n\x0esglang_version\x18\x07 \x01(\t\x12\x13\n\x0bserver_type\x18\x08 \x01(\t\x12.\n\nstart_time\x18\t \x01(\x0b\x32\x1a.google.protobuf.Timestamp\"D\n\x0fGetLoadsRequest\x12\x14\n\x07\x64p_rank\x18\x01 \x01(\x05H\x00\x88\x01\x01\x12\x0f\n\x07include\x18\x02 \x03(\tB\n\n\x08_dp_rank\"\xbe\x01\n\x10GetLoadsResponse\x12\x11\n\ttimestamp\x18\x01 \x01(\t\x12\x0f\n\x07version\x18\x02 \x01(\t\x12\x15\n\rdp_rank_count\x18\x03 \x01(\x05\x12\x33\n\x05loads\x18\x04 \x03(\x0b\x32$.sglang.grpc.scheduler.SchedulerLoad\x12:\n\taggregate\x18\x05 \x01(\x0b\x32\'.sglang.grpc.scheduler.AggregateMetrics\"\x99\x05\n\rSchedulerLoad\x12\x0f\n\x07\x64p_rank\x18\x01 \x01(\x05\x12\x18\n\x10num_running_reqs\x18\x02 \x01(\x05\x12\x18\n\x10num_waiting_reqs\x18\x03 \x01(\x05\x12\x16\n\x0enum_total_reqs\x18\x04 \x01(\x05\x12\x17\n\x0fnum_used_tokens\x18\x05 \x01(\x05\x12\x1c\n\x14max_total_num_tokens\x18\x06 \x01(\x05\x12\x13\n\x0btoken_usage\x18\x07 \x01(\x01\x12\x16\n\x0egen_throughput\x18\x08 \x01(\x01\x12\x16\n\x0e\x63\x61\x63he_hit_rate\x18\t \x01(\x01\x12\x13\n\x0butilization\x18\n \x01(\x01\x12\x1c\n\x14max_running_requests\x18\x0b \x01(\x05\x12\x39\n\x06memory\x18\x0c \x01(\x0b\x32$.sglang.grpc.scheduler.MemoryMetricsH\x00\x88\x01\x01\x12\x43\n\x0bspeculative\x18\r \x01(\x0b\x32).sglang.grpc.scheduler.SpeculativeMetricsH\x01\x88\x01\x01\x12\x35\n\x04lora\x18\x0e \x01(\x0b\x32\".sglang.grpc.scheduler.LoRAMetricsH\x02\x88\x01\x01\x12I\n\x0e\x64isaggregation\x18\x0f \x01(\x0b\x32,.sglang.grpc.scheduler.DisaggregationMetricsH\x03\x88\x01\x01\x12\x38\n\x06queues\x18\x10 \x01(\x0b\x32#.sglang.grpc.scheduler.QueueMetricsH\x04\x88\x01\x01\x42\t\n\x07_memoryB\x0e\n\x0c_speculativeB\x07\n\x05_loraB\x11\n\x0f_disaggregationB\t\n\x07_queues\"a\n\rMemoryMetrics\x12\x11\n\tweight_gb\x18\x01 \x01(\x01\x12\x13\n\x0bkv_cache_gb\x18\x02 \x01(\x01\x12\x10\n\x08graph_gb\x18\x03 \x01(\x01\x12\x16\n\x0etoken_capacity\x18\x04 \x01(\x05\"@\n\x12SpeculativeMetrics\x12\x15\n\raccept_length\x18\x01 \x01(\x01\x12\x13\n\x0b\x61\x63\x63\x65pt_rate\x18\x02 \x01(\x01\"K\n\x0bLoRAMetrics\x12\x12\n\nslots_used\x18\x01 \x01(\x05\x12\x13\n\x0bslots_total\x18\x02 \x01(\x05\x12\x13\n\x0butilization\x18\x03 \x01(\x01\"\x9c\x02\n\x15\x44isaggregationMetrics\x12\x0c\n\x04mode\x18\x01 \x01(\t\x12#\n\x1bprefill_prealloc_queue_reqs\x18\x02 \x01(\x05\x12#\n\x1bprefill_inflight_queue_reqs\x18\x03 \x01(\x05\x12\"\n\x1a\x64\x65\x63ode_prealloc_queue_reqs\x18\x04 \x01(\x05\x12\"\n\x1a\x64\x65\x63ode_transfer_queue_reqs\x18\x05 \x01(\x05\x12#\n\x1b\x64\x65\x63ode_retracted_queue_reqs\x18\x06 \x01(\x05\x12\x1e\n\x16kv_transfer_speed_gb_s\x18\x07 \x01(\x01\x12\x1e\n\x16kv_transfer_latency_ms\x18\x08 \x01(\x01\"S\n\x0cQueueMetrics\x12\x0f\n\x07waiting\x18\x01 \x01(\x05\x12\x0f\n\x07grammar\x18\x02 \x01(\x05\x12\x0e\n\x06paused\x18\x03 \x01(\x05\x12\x11\n\tretracted\x18\x04 \x01(\x05\"\xa8\x01\n\x10\x41ggregateMetrics\x12\x1a\n\x12total_running_reqs\x18\x01 \x01(\x05\x12\x1a\n\x12total_waiting_reqs\x18\x02 \x01(\x05\x12\x12\n\ntotal_reqs\x18\x03 \x01(\x05\x12\x17\n\x0f\x61vg_token_usage\x18\x04 \x01(\x01\x12\x16\n\x0e\x61vg_throughput\x18\x05 \x01(\x01\x12\x17\n\x0f\x61vg_utilization\x18\x06 \x01(\x01\x32\xb0\x05\n\x0fSglangScheduler\x12]\n\x08Generate\x12&.sglang.grpc.scheduler.GenerateRequest\x1a\'.sglang.grpc.scheduler.GenerateResponse0\x01\x12R\n\x05\x45mbed\x12#.sglang.grpc.scheduler.EmbedRequest\x1a$.sglang.grpc.scheduler.EmbedResponse\x12\x64\n\x0bHealthCheck\x12).sglang.grpc.scheduler.HealthCheckRequest\x1a*.sglang.grpc.scheduler.HealthCheckResponse\x12R\n\x05\x41\x62ort\x12#.sglang.grpc.scheduler.AbortRequest\x1a$.sglang.grpc.scheduler.AbortResponse\x12g\n\x0cGetModelInfo\x12*.sglang.grpc.scheduler.GetModelInfoRequest\x1a+.sglang.grpc.scheduler.GetModelInfoResponse\x12j\n\rGetServerInfo\x12+.sglang.grpc.scheduler.GetServerInfoRequest\x1a,.sglang.grpc.scheduler.GetServerInfoResponse\x12[\n\x08GetLoads\x12&.sglang.grpc.scheduler.GetLoadsRequest\x1a\'.sglang.grpc.scheduler.GetLoadsResponseb\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'sglang_scheduler_pb2', _globals) +if not _descriptor._USE_C_DESCRIPTORS: + DESCRIPTOR._loaded_options = None + _globals['_SAMPLINGPARAMS_LOGITBIASENTRY']._loaded_options = None + _globals['_SAMPLINGPARAMS_LOGITBIASENTRY']._serialized_options = b'8\001' + _globals['_SAMPLINGPARAMS']._serialized_start=113 + _globals['_SAMPLINGPARAMS']._serialized_end=833 + _globals['_SAMPLINGPARAMS_LOGITBIASENTRY']._serialized_start=732 + _globals['_SAMPLINGPARAMS_LOGITBIASENTRY']._serialized_end=780 + _globals['_DISAGGREGATEDPARAMS']._serialized_start=835 + _globals['_DISAGGREGATEDPARAMS']._serialized_end=928 + _globals['_GENERATEREQUEST']._serialized_start=931 + _globals['_GENERATEREQUEST']._serialized_end=1541 + _globals['_TOKENIZEDINPUT']._serialized_start=1543 + _globals['_TOKENIZEDINPUT']._serialized_end=1601 + _globals['_MULTIMODALINPUTS']._serialized_start=1604 + _globals['_MULTIMODALINPUTS']._serialized_end=1815 + _globals['_GENERATERESPONSE']._serialized_start=1818 + _globals['_GENERATERESPONSE']._serialized_end=2045 + _globals['_GENERATESTREAMCHUNK']._serialized_start=2048 + _globals['_GENERATESTREAMCHUNK']._serialized_end=2325 + _globals['_GENERATECOMPLETE']._serialized_start=2328 + _globals['_GENERATECOMPLETE']._serialized_end=2739 + _globals['_GENERATEERROR']._serialized_start=2741 + _globals['_GENERATEERROR']._serialized_end=2816 + _globals['_OUTPUTLOGPROBS']._serialized_start=2818 + _globals['_OUTPUTLOGPROBS']._serialized_end=2935 + _globals['_INPUTLOGPROBS']._serialized_start=2938 + _globals['_INPUTLOGPROBS']._serialized_end=3096 + _globals['_INPUTTOKENLOGPROB']._serialized_start=3098 + _globals['_INPUTTOKENLOGPROB']._serialized_end=3147 + _globals['_TOPLOGPROBS']._serialized_start=3149 + _globals['_TOPLOGPROBS']._serialized_end=3197 + _globals['_HIDDENSTATES']._serialized_start=3199 + _globals['_HIDDENSTATES']._serialized_end=3262 + _globals['_EMBEDREQUEST']._serialized_start=3265 + _globals['_EMBEDREQUEST']._serialized_end=3595 + _globals['_EMBEDRESPONSE']._serialized_start=3598 + _globals['_EMBEDRESPONSE']._serialized_end=3755 + _globals['_EMBEDCOMPLETE']._serialized_start=3758 + _globals['_EMBEDCOMPLETE']._serialized_end=3921 + _globals['_EMBEDDING']._serialized_start=3923 + _globals['_EMBEDDING']._serialized_end=3965 + _globals['_EMBEDERROR']._serialized_start=3967 + _globals['_EMBEDERROR']._serialized_end=4027 + _globals['_HEALTHCHECKREQUEST']._serialized_start=4029 + _globals['_HEALTHCHECKREQUEST']._serialized_end=4049 + _globals['_HEALTHCHECKRESPONSE']._serialized_start=4051 + _globals['_HEALTHCHECKRESPONSE']._serialized_end=4106 + _globals['_ABORTREQUEST']._serialized_start=4108 + _globals['_ABORTREQUEST']._serialized_end=4158 + _globals['_ABORTRESPONSE']._serialized_start=4160 + _globals['_ABORTRESPONSE']._serialized_end=4209 + _globals['_LOADLORAREQUEST']._serialized_start=4211 + _globals['_LOADLORAREQUEST']._serialized_end=4284 + _globals['_LOADLORARESPONSE']._serialized_start=4286 + _globals['_LOADLORARESPONSE']._serialized_end=4358 + _globals['_UNLOADLORAREQUEST']._serialized_start=4360 + _globals['_UNLOADLORAREQUEST']._serialized_end=4399 + _globals['_UNLOADLORARESPONSE']._serialized_start=4401 + _globals['_UNLOADLORARESPONSE']._serialized_end=4455 + _globals['_UPDATEWEIGHTSREQUEST']._serialized_start=4457 + _globals['_UPDATEWEIGHTSREQUEST']._serialized_end=4576 + _globals['_UPDATEWEIGHTSRESPONSE']._serialized_start=4578 + _globals['_UPDATEWEIGHTSRESPONSE']._serialized_end=4635 + _globals['_GETINTERNALSTATEREQUEST']._serialized_start=4637 + _globals['_GETINTERNALSTATEREQUEST']._serialized_end=4682 + _globals['_GETINTERNALSTATERESPONSE']._serialized_start=4684 + _globals['_GETINTERNALSTATERESPONSE']._serialized_end=4750 + _globals['_SETINTERNALSTATEREQUEST']._serialized_start=4752 + _globals['_SETINTERNALSTATEREQUEST']._serialized_end=4817 + _globals['_SETINTERNALSTATERESPONSE']._serialized_start=4819 + _globals['_SETINTERNALSTATERESPONSE']._serialized_end=4879 + _globals['_GETMODELINFOREQUEST']._serialized_start=4881 + _globals['_GETMODELINFOREQUEST']._serialized_end=4902 + _globals['_GETMODELINFORESPONSE']._serialized_start=4905 + _globals['_GETMODELINFORESPONSE']._serialized_end=5333 + _globals['_GETSERVERINFOREQUEST']._serialized_start=5335 + _globals['_GETSERVERINFOREQUEST']._serialized_end=5357 + _globals['_GETSERVERINFORESPONSE']._serialized_start=5360 + _globals['_GETSERVERINFORESPONSE']._serialized_end=5671 + _globals['_GETLOADSREQUEST']._serialized_start=5673 + _globals['_GETLOADSREQUEST']._serialized_end=5741 + _globals['_GETLOADSRESPONSE']._serialized_start=5744 + _globals['_GETLOADSRESPONSE']._serialized_end=5934 + _globals['_SCHEDULERLOAD']._serialized_start=5937 + _globals['_SCHEDULERLOAD']._serialized_end=6602 + _globals['_MEMORYMETRICS']._serialized_start=6604 + _globals['_MEMORYMETRICS']._serialized_end=6701 + _globals['_SPECULATIVEMETRICS']._serialized_start=6703 + _globals['_SPECULATIVEMETRICS']._serialized_end=6767 + _globals['_LORAMETRICS']._serialized_start=6769 + _globals['_LORAMETRICS']._serialized_end=6844 + _globals['_DISAGGREGATIONMETRICS']._serialized_start=6847 + _globals['_DISAGGREGATIONMETRICS']._serialized_end=7131 + _globals['_QUEUEMETRICS']._serialized_start=7133 + _globals['_QUEUEMETRICS']._serialized_end=7216 + _globals['_AGGREGATEMETRICS']._serialized_start=7219 + _globals['_AGGREGATEMETRICS']._serialized_end=7387 + _globals['_SGLANGSCHEDULER']._serialized_start=7390 + _globals['_SGLANGSCHEDULER']._serialized_end=8078 +# @@protoc_insertion_point(module_scope) diff --git a/python/sglang/srt/grpc/sglang_scheduler_pb2.pyi b/python/sglang/srt/grpc/sglang_scheduler_pb2.pyi new file mode 100644 index 000000000000..8d3e979aa4ad --- /dev/null +++ b/python/sglang/srt/grpc/sglang_scheduler_pb2.pyi @@ -0,0 +1,632 @@ +import datetime + +from google.protobuf import timestamp_pb2 as _timestamp_pb2 +from google.protobuf import struct_pb2 as _struct_pb2 +from google.protobuf.internal import containers as _containers +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from collections.abc import Iterable as _Iterable, Mapping as _Mapping +from typing import ClassVar as _ClassVar, Optional as _Optional, Union as _Union + +DESCRIPTOR: _descriptor.FileDescriptor + +class SamplingParams(_message.Message): + __slots__ = ("temperature", "top_p", "top_k", "min_p", "frequency_penalty", "presence_penalty", "repetition_penalty", "max_new_tokens", "stop", "stop_token_ids", "skip_special_tokens", "spaces_between_special_tokens", "regex", "json_schema", "ebnf_grammar", "structural_tag", "n", "min_new_tokens", "ignore_eos", "no_stop_trim", "stream_interval", "logit_bias", "custom_params") + class LogitBiasEntry(_message.Message): + __slots__ = ("key", "value") + KEY_FIELD_NUMBER: _ClassVar[int] + VALUE_FIELD_NUMBER: _ClassVar[int] + key: str + value: float + def __init__(self, key: _Optional[str] = ..., value: _Optional[float] = ...) -> None: ... + TEMPERATURE_FIELD_NUMBER: _ClassVar[int] + TOP_P_FIELD_NUMBER: _ClassVar[int] + TOP_K_FIELD_NUMBER: _ClassVar[int] + MIN_P_FIELD_NUMBER: _ClassVar[int] + FREQUENCY_PENALTY_FIELD_NUMBER: _ClassVar[int] + PRESENCE_PENALTY_FIELD_NUMBER: _ClassVar[int] + REPETITION_PENALTY_FIELD_NUMBER: _ClassVar[int] + MAX_NEW_TOKENS_FIELD_NUMBER: _ClassVar[int] + STOP_FIELD_NUMBER: _ClassVar[int] + STOP_TOKEN_IDS_FIELD_NUMBER: _ClassVar[int] + SKIP_SPECIAL_TOKENS_FIELD_NUMBER: _ClassVar[int] + SPACES_BETWEEN_SPECIAL_TOKENS_FIELD_NUMBER: _ClassVar[int] + REGEX_FIELD_NUMBER: _ClassVar[int] + JSON_SCHEMA_FIELD_NUMBER: _ClassVar[int] + EBNF_GRAMMAR_FIELD_NUMBER: _ClassVar[int] + STRUCTURAL_TAG_FIELD_NUMBER: _ClassVar[int] + N_FIELD_NUMBER: _ClassVar[int] + MIN_NEW_TOKENS_FIELD_NUMBER: _ClassVar[int] + IGNORE_EOS_FIELD_NUMBER: _ClassVar[int] + NO_STOP_TRIM_FIELD_NUMBER: _ClassVar[int] + STREAM_INTERVAL_FIELD_NUMBER: _ClassVar[int] + LOGIT_BIAS_FIELD_NUMBER: _ClassVar[int] + CUSTOM_PARAMS_FIELD_NUMBER: _ClassVar[int] + temperature: float + top_p: float + top_k: int + min_p: float + frequency_penalty: float + presence_penalty: float + repetition_penalty: float + max_new_tokens: int + stop: _containers.RepeatedScalarFieldContainer[str] + stop_token_ids: _containers.RepeatedScalarFieldContainer[int] + skip_special_tokens: bool + spaces_between_special_tokens: bool + regex: str + json_schema: str + ebnf_grammar: str + structural_tag: str + n: int + min_new_tokens: int + ignore_eos: bool + no_stop_trim: bool + stream_interval: int + logit_bias: _containers.ScalarMap[str, float] + custom_params: _struct_pb2.Struct + def __init__(self, temperature: _Optional[float] = ..., top_p: _Optional[float] = ..., top_k: _Optional[int] = ..., min_p: _Optional[float] = ..., frequency_penalty: _Optional[float] = ..., presence_penalty: _Optional[float] = ..., repetition_penalty: _Optional[float] = ..., max_new_tokens: _Optional[int] = ..., stop: _Optional[_Iterable[str]] = ..., stop_token_ids: _Optional[_Iterable[int]] = ..., skip_special_tokens: bool = ..., spaces_between_special_tokens: bool = ..., regex: _Optional[str] = ..., json_schema: _Optional[str] = ..., ebnf_grammar: _Optional[str] = ..., structural_tag: _Optional[str] = ..., n: _Optional[int] = ..., min_new_tokens: _Optional[int] = ..., ignore_eos: bool = ..., no_stop_trim: bool = ..., stream_interval: _Optional[int] = ..., logit_bias: _Optional[_Mapping[str, float]] = ..., custom_params: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ... + +class DisaggregatedParams(_message.Message): + __slots__ = ("bootstrap_host", "bootstrap_port", "bootstrap_room") + BOOTSTRAP_HOST_FIELD_NUMBER: _ClassVar[int] + BOOTSTRAP_PORT_FIELD_NUMBER: _ClassVar[int] + BOOTSTRAP_ROOM_FIELD_NUMBER: _ClassVar[int] + bootstrap_host: str + bootstrap_port: int + bootstrap_room: int + def __init__(self, bootstrap_host: _Optional[str] = ..., bootstrap_port: _Optional[int] = ..., bootstrap_room: _Optional[int] = ...) -> None: ... + +class GenerateRequest(_message.Message): + __slots__ = ("request_id", "tokenized", "mm_inputs", "sampling_params", "return_logprob", "logprob_start_len", "top_logprobs_num", "token_ids_logprob", "return_hidden_states", "disaggregated_params", "custom_logit_processor", "timestamp", "log_metrics", "input_embeds", "lora_id", "data_parallel_rank", "stream") + REQUEST_ID_FIELD_NUMBER: _ClassVar[int] + TOKENIZED_FIELD_NUMBER: _ClassVar[int] + MM_INPUTS_FIELD_NUMBER: _ClassVar[int] + SAMPLING_PARAMS_FIELD_NUMBER: _ClassVar[int] + RETURN_LOGPROB_FIELD_NUMBER: _ClassVar[int] + LOGPROB_START_LEN_FIELD_NUMBER: _ClassVar[int] + TOP_LOGPROBS_NUM_FIELD_NUMBER: _ClassVar[int] + TOKEN_IDS_LOGPROB_FIELD_NUMBER: _ClassVar[int] + RETURN_HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int] + DISAGGREGATED_PARAMS_FIELD_NUMBER: _ClassVar[int] + CUSTOM_LOGIT_PROCESSOR_FIELD_NUMBER: _ClassVar[int] + TIMESTAMP_FIELD_NUMBER: _ClassVar[int] + LOG_METRICS_FIELD_NUMBER: _ClassVar[int] + INPUT_EMBEDS_FIELD_NUMBER: _ClassVar[int] + LORA_ID_FIELD_NUMBER: _ClassVar[int] + DATA_PARALLEL_RANK_FIELD_NUMBER: _ClassVar[int] + STREAM_FIELD_NUMBER: _ClassVar[int] + request_id: str + tokenized: TokenizedInput + mm_inputs: MultimodalInputs + sampling_params: SamplingParams + return_logprob: bool + logprob_start_len: int + top_logprobs_num: int + token_ids_logprob: _containers.RepeatedScalarFieldContainer[int] + return_hidden_states: bool + disaggregated_params: DisaggregatedParams + custom_logit_processor: str + timestamp: _timestamp_pb2.Timestamp + log_metrics: bool + input_embeds: _containers.RepeatedScalarFieldContainer[float] + lora_id: str + data_parallel_rank: int + stream: bool + def __init__(self, request_id: _Optional[str] = ..., tokenized: _Optional[_Union[TokenizedInput, _Mapping]] = ..., mm_inputs: _Optional[_Union[MultimodalInputs, _Mapping]] = ..., sampling_params: _Optional[_Union[SamplingParams, _Mapping]] = ..., return_logprob: bool = ..., logprob_start_len: _Optional[int] = ..., top_logprobs_num: _Optional[int] = ..., token_ids_logprob: _Optional[_Iterable[int]] = ..., return_hidden_states: bool = ..., disaggregated_params: _Optional[_Union[DisaggregatedParams, _Mapping]] = ..., custom_logit_processor: _Optional[str] = ..., timestamp: _Optional[_Union[datetime.datetime, _timestamp_pb2.Timestamp, _Mapping]] = ..., log_metrics: bool = ..., input_embeds: _Optional[_Iterable[float]] = ..., lora_id: _Optional[str] = ..., data_parallel_rank: _Optional[int] = ..., stream: bool = ...) -> None: ... + +class TokenizedInput(_message.Message): + __slots__ = ("original_text", "input_ids") + ORIGINAL_TEXT_FIELD_NUMBER: _ClassVar[int] + INPUT_IDS_FIELD_NUMBER: _ClassVar[int] + original_text: str + input_ids: _containers.RepeatedScalarFieldContainer[int] + def __init__(self, original_text: _Optional[str] = ..., input_ids: _Optional[_Iterable[int]] = ...) -> None: ... + +class MultimodalInputs(_message.Message): + __slots__ = ("image_urls", "video_urls", "audio_urls", "processed_features", "image_data", "video_data", "audio_data", "modalities") + IMAGE_URLS_FIELD_NUMBER: _ClassVar[int] + VIDEO_URLS_FIELD_NUMBER: _ClassVar[int] + AUDIO_URLS_FIELD_NUMBER: _ClassVar[int] + PROCESSED_FEATURES_FIELD_NUMBER: _ClassVar[int] + IMAGE_DATA_FIELD_NUMBER: _ClassVar[int] + VIDEO_DATA_FIELD_NUMBER: _ClassVar[int] + AUDIO_DATA_FIELD_NUMBER: _ClassVar[int] + MODALITIES_FIELD_NUMBER: _ClassVar[int] + image_urls: _containers.RepeatedScalarFieldContainer[str] + video_urls: _containers.RepeatedScalarFieldContainer[str] + audio_urls: _containers.RepeatedScalarFieldContainer[str] + processed_features: _struct_pb2.Struct + image_data: _containers.RepeatedScalarFieldContainer[bytes] + video_data: _containers.RepeatedScalarFieldContainer[bytes] + audio_data: _containers.RepeatedScalarFieldContainer[bytes] + modalities: _containers.RepeatedScalarFieldContainer[str] + def __init__(self, image_urls: _Optional[_Iterable[str]] = ..., video_urls: _Optional[_Iterable[str]] = ..., audio_urls: _Optional[_Iterable[str]] = ..., processed_features: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ..., image_data: _Optional[_Iterable[bytes]] = ..., video_data: _Optional[_Iterable[bytes]] = ..., audio_data: _Optional[_Iterable[bytes]] = ..., modalities: _Optional[_Iterable[str]] = ...) -> None: ... + +class GenerateResponse(_message.Message): + __slots__ = ("request_id", "chunk", "complete", "error") + REQUEST_ID_FIELD_NUMBER: _ClassVar[int] + CHUNK_FIELD_NUMBER: _ClassVar[int] + COMPLETE_FIELD_NUMBER: _ClassVar[int] + ERROR_FIELD_NUMBER: _ClassVar[int] + request_id: str + chunk: GenerateStreamChunk + complete: GenerateComplete + error: GenerateError + def __init__(self, request_id: _Optional[str] = ..., chunk: _Optional[_Union[GenerateStreamChunk, _Mapping]] = ..., complete: _Optional[_Union[GenerateComplete, _Mapping]] = ..., error: _Optional[_Union[GenerateError, _Mapping]] = ...) -> None: ... + +class GenerateStreamChunk(_message.Message): + __slots__ = ("token_ids", "prompt_tokens", "completion_tokens", "cached_tokens", "output_logprobs", "hidden_states", "input_logprobs", "index") + TOKEN_IDS_FIELD_NUMBER: _ClassVar[int] + PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int] + COMPLETION_TOKENS_FIELD_NUMBER: _ClassVar[int] + CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int] + OUTPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int] + HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int] + INPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int] + INDEX_FIELD_NUMBER: _ClassVar[int] + token_ids: _containers.RepeatedScalarFieldContainer[int] + prompt_tokens: int + completion_tokens: int + cached_tokens: int + output_logprobs: OutputLogProbs + hidden_states: _containers.RepeatedScalarFieldContainer[float] + input_logprobs: InputLogProbs + index: int + def __init__(self, token_ids: _Optional[_Iterable[int]] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., output_logprobs: _Optional[_Union[OutputLogProbs, _Mapping]] = ..., hidden_states: _Optional[_Iterable[float]] = ..., input_logprobs: _Optional[_Union[InputLogProbs, _Mapping]] = ..., index: _Optional[int] = ...) -> None: ... + +class GenerateComplete(_message.Message): + __slots__ = ("output_ids", "finish_reason", "prompt_tokens", "completion_tokens", "cached_tokens", "output_logprobs", "all_hidden_states", "matched_token_id", "matched_stop_str", "input_logprobs", "index") + OUTPUT_IDS_FIELD_NUMBER: _ClassVar[int] + FINISH_REASON_FIELD_NUMBER: _ClassVar[int] + PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int] + COMPLETION_TOKENS_FIELD_NUMBER: _ClassVar[int] + CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int] + OUTPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int] + ALL_HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int] + MATCHED_TOKEN_ID_FIELD_NUMBER: _ClassVar[int] + MATCHED_STOP_STR_FIELD_NUMBER: _ClassVar[int] + INPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int] + INDEX_FIELD_NUMBER: _ClassVar[int] + output_ids: _containers.RepeatedScalarFieldContainer[int] + finish_reason: str + prompt_tokens: int + completion_tokens: int + cached_tokens: int + output_logprobs: OutputLogProbs + all_hidden_states: _containers.RepeatedCompositeFieldContainer[HiddenStates] + matched_token_id: int + matched_stop_str: str + input_logprobs: InputLogProbs + index: int + def __init__(self, output_ids: _Optional[_Iterable[int]] = ..., finish_reason: _Optional[str] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., output_logprobs: _Optional[_Union[OutputLogProbs, _Mapping]] = ..., all_hidden_states: _Optional[_Iterable[_Union[HiddenStates, _Mapping]]] = ..., matched_token_id: _Optional[int] = ..., matched_stop_str: _Optional[str] = ..., input_logprobs: _Optional[_Union[InputLogProbs, _Mapping]] = ..., index: _Optional[int] = ...) -> None: ... + +class GenerateError(_message.Message): + __slots__ = ("message", "http_status_code", "details") + MESSAGE_FIELD_NUMBER: _ClassVar[int] + HTTP_STATUS_CODE_FIELD_NUMBER: _ClassVar[int] + DETAILS_FIELD_NUMBER: _ClassVar[int] + message: str + http_status_code: str + details: str + def __init__(self, message: _Optional[str] = ..., http_status_code: _Optional[str] = ..., details: _Optional[str] = ...) -> None: ... + +class OutputLogProbs(_message.Message): + __slots__ = ("token_logprobs", "token_ids", "top_logprobs") + TOKEN_LOGPROBS_FIELD_NUMBER: _ClassVar[int] + TOKEN_IDS_FIELD_NUMBER: _ClassVar[int] + TOP_LOGPROBS_FIELD_NUMBER: _ClassVar[int] + token_logprobs: _containers.RepeatedScalarFieldContainer[float] + token_ids: _containers.RepeatedScalarFieldContainer[int] + top_logprobs: _containers.RepeatedCompositeFieldContainer[TopLogProbs] + def __init__(self, token_logprobs: _Optional[_Iterable[float]] = ..., token_ids: _Optional[_Iterable[int]] = ..., top_logprobs: _Optional[_Iterable[_Union[TopLogProbs, _Mapping]]] = ...) -> None: ... + +class InputLogProbs(_message.Message): + __slots__ = ("token_logprobs", "token_ids", "top_logprobs") + TOKEN_LOGPROBS_FIELD_NUMBER: _ClassVar[int] + TOKEN_IDS_FIELD_NUMBER: _ClassVar[int] + TOP_LOGPROBS_FIELD_NUMBER: _ClassVar[int] + token_logprobs: _containers.RepeatedCompositeFieldContainer[InputTokenLogProb] + token_ids: _containers.RepeatedScalarFieldContainer[int] + top_logprobs: _containers.RepeatedCompositeFieldContainer[TopLogProbs] + def __init__(self, token_logprobs: _Optional[_Iterable[_Union[InputTokenLogProb, _Mapping]]] = ..., token_ids: _Optional[_Iterable[int]] = ..., top_logprobs: _Optional[_Iterable[_Union[TopLogProbs, _Mapping]]] = ...) -> None: ... + +class InputTokenLogProb(_message.Message): + __slots__ = ("value",) + VALUE_FIELD_NUMBER: _ClassVar[int] + value: float + def __init__(self, value: _Optional[float] = ...) -> None: ... + +class TopLogProbs(_message.Message): + __slots__ = ("values", "token_ids") + VALUES_FIELD_NUMBER: _ClassVar[int] + TOKEN_IDS_FIELD_NUMBER: _ClassVar[int] + values: _containers.RepeatedScalarFieldContainer[float] + token_ids: _containers.RepeatedScalarFieldContainer[int] + def __init__(self, values: _Optional[_Iterable[float]] = ..., token_ids: _Optional[_Iterable[int]] = ...) -> None: ... + +class HiddenStates(_message.Message): + __slots__ = ("values", "layer", "position") + VALUES_FIELD_NUMBER: _ClassVar[int] + LAYER_FIELD_NUMBER: _ClassVar[int] + POSITION_FIELD_NUMBER: _ClassVar[int] + values: _containers.RepeatedScalarFieldContainer[float] + layer: int + position: int + def __init__(self, values: _Optional[_Iterable[float]] = ..., layer: _Optional[int] = ..., position: _Optional[int] = ...) -> None: ... + +class EmbedRequest(_message.Message): + __slots__ = ("request_id", "tokenized", "mm_inputs", "sampling_params", "log_metrics", "token_type_ids", "data_parallel_rank", "is_cross_encoder", "texts") + REQUEST_ID_FIELD_NUMBER: _ClassVar[int] + TOKENIZED_FIELD_NUMBER: _ClassVar[int] + MM_INPUTS_FIELD_NUMBER: _ClassVar[int] + SAMPLING_PARAMS_FIELD_NUMBER: _ClassVar[int] + LOG_METRICS_FIELD_NUMBER: _ClassVar[int] + TOKEN_TYPE_IDS_FIELD_NUMBER: _ClassVar[int] + DATA_PARALLEL_RANK_FIELD_NUMBER: _ClassVar[int] + IS_CROSS_ENCODER_FIELD_NUMBER: _ClassVar[int] + TEXTS_FIELD_NUMBER: _ClassVar[int] + request_id: str + tokenized: TokenizedInput + mm_inputs: MultimodalInputs + sampling_params: SamplingParams + log_metrics: bool + token_type_ids: _containers.RepeatedScalarFieldContainer[int] + data_parallel_rank: int + is_cross_encoder: bool + texts: _containers.RepeatedScalarFieldContainer[str] + def __init__(self, request_id: _Optional[str] = ..., tokenized: _Optional[_Union[TokenizedInput, _Mapping]] = ..., mm_inputs: _Optional[_Union[MultimodalInputs, _Mapping]] = ..., sampling_params: _Optional[_Union[SamplingParams, _Mapping]] = ..., log_metrics: bool = ..., token_type_ids: _Optional[_Iterable[int]] = ..., data_parallel_rank: _Optional[int] = ..., is_cross_encoder: bool = ..., texts: _Optional[_Iterable[str]] = ...) -> None: ... + +class EmbedResponse(_message.Message): + __slots__ = ("request_id", "complete", "error") + REQUEST_ID_FIELD_NUMBER: _ClassVar[int] + COMPLETE_FIELD_NUMBER: _ClassVar[int] + ERROR_FIELD_NUMBER: _ClassVar[int] + request_id: str + complete: EmbedComplete + error: EmbedError + def __init__(self, request_id: _Optional[str] = ..., complete: _Optional[_Union[EmbedComplete, _Mapping]] = ..., error: _Optional[_Union[EmbedError, _Mapping]] = ...) -> None: ... + +class EmbedComplete(_message.Message): + __slots__ = ("embedding", "prompt_tokens", "cached_tokens", "embedding_dim", "batch_embeddings") + EMBEDDING_FIELD_NUMBER: _ClassVar[int] + PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int] + CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int] + EMBEDDING_DIM_FIELD_NUMBER: _ClassVar[int] + BATCH_EMBEDDINGS_FIELD_NUMBER: _ClassVar[int] + embedding: _containers.RepeatedScalarFieldContainer[float] + prompt_tokens: int + cached_tokens: int + embedding_dim: int + batch_embeddings: _containers.RepeatedCompositeFieldContainer[Embedding] + def __init__(self, embedding: _Optional[_Iterable[float]] = ..., prompt_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., embedding_dim: _Optional[int] = ..., batch_embeddings: _Optional[_Iterable[_Union[Embedding, _Mapping]]] = ...) -> None: ... + +class Embedding(_message.Message): + __slots__ = ("values", "index") + VALUES_FIELD_NUMBER: _ClassVar[int] + INDEX_FIELD_NUMBER: _ClassVar[int] + values: _containers.RepeatedScalarFieldContainer[float] + index: int + def __init__(self, values: _Optional[_Iterable[float]] = ..., index: _Optional[int] = ...) -> None: ... + +class EmbedError(_message.Message): + __slots__ = ("message", "code", "details") + MESSAGE_FIELD_NUMBER: _ClassVar[int] + CODE_FIELD_NUMBER: _ClassVar[int] + DETAILS_FIELD_NUMBER: _ClassVar[int] + message: str + code: str + details: str + def __init__(self, message: _Optional[str] = ..., code: _Optional[str] = ..., details: _Optional[str] = ...) -> None: ... + +class HealthCheckRequest(_message.Message): + __slots__ = () + def __init__(self) -> None: ... + +class HealthCheckResponse(_message.Message): + __slots__ = ("healthy", "message") + HEALTHY_FIELD_NUMBER: _ClassVar[int] + MESSAGE_FIELD_NUMBER: _ClassVar[int] + healthy: bool + message: str + def __init__(self, healthy: bool = ..., message: _Optional[str] = ...) -> None: ... + +class AbortRequest(_message.Message): + __slots__ = ("request_id", "reason") + REQUEST_ID_FIELD_NUMBER: _ClassVar[int] + REASON_FIELD_NUMBER: _ClassVar[int] + request_id: str + reason: str + def __init__(self, request_id: _Optional[str] = ..., reason: _Optional[str] = ...) -> None: ... + +class AbortResponse(_message.Message): + __slots__ = ("success", "message") + SUCCESS_FIELD_NUMBER: _ClassVar[int] + MESSAGE_FIELD_NUMBER: _ClassVar[int] + success: bool + message: str + def __init__(self, success: bool = ..., message: _Optional[str] = ...) -> None: ... + +class LoadLoRARequest(_message.Message): + __slots__ = ("adapter_id", "adapter_path", "rank") + ADAPTER_ID_FIELD_NUMBER: _ClassVar[int] + ADAPTER_PATH_FIELD_NUMBER: _ClassVar[int] + RANK_FIELD_NUMBER: _ClassVar[int] + adapter_id: str + adapter_path: str + rank: int + def __init__(self, adapter_id: _Optional[str] = ..., adapter_path: _Optional[str] = ..., rank: _Optional[int] = ...) -> None: ... + +class LoadLoRAResponse(_message.Message): + __slots__ = ("success", "adapter_id", "message") + SUCCESS_FIELD_NUMBER: _ClassVar[int] + ADAPTER_ID_FIELD_NUMBER: _ClassVar[int] + MESSAGE_FIELD_NUMBER: _ClassVar[int] + success: bool + adapter_id: str + message: str + def __init__(self, success: bool = ..., adapter_id: _Optional[str] = ..., message: _Optional[str] = ...) -> None: ... + +class UnloadLoRARequest(_message.Message): + __slots__ = ("adapter_id",) + ADAPTER_ID_FIELD_NUMBER: _ClassVar[int] + adapter_id: str + def __init__(self, adapter_id: _Optional[str] = ...) -> None: ... + +class UnloadLoRAResponse(_message.Message): + __slots__ = ("success", "message") + SUCCESS_FIELD_NUMBER: _ClassVar[int] + MESSAGE_FIELD_NUMBER: _ClassVar[int] + success: bool + message: str + def __init__(self, success: bool = ..., message: _Optional[str] = ...) -> None: ... + +class UpdateWeightsRequest(_message.Message): + __slots__ = ("disk_path", "tensor_data", "remote_url", "weight_name") + DISK_PATH_FIELD_NUMBER: _ClassVar[int] + TENSOR_DATA_FIELD_NUMBER: _ClassVar[int] + REMOTE_URL_FIELD_NUMBER: _ClassVar[int] + WEIGHT_NAME_FIELD_NUMBER: _ClassVar[int] + disk_path: str + tensor_data: bytes + remote_url: str + weight_name: str + def __init__(self, disk_path: _Optional[str] = ..., tensor_data: _Optional[bytes] = ..., remote_url: _Optional[str] = ..., weight_name: _Optional[str] = ...) -> None: ... + +class UpdateWeightsResponse(_message.Message): + __slots__ = ("success", "message") + SUCCESS_FIELD_NUMBER: _ClassVar[int] + MESSAGE_FIELD_NUMBER: _ClassVar[int] + success: bool + message: str + def __init__(self, success: bool = ..., message: _Optional[str] = ...) -> None: ... + +class GetInternalStateRequest(_message.Message): + __slots__ = ("state_keys",) + STATE_KEYS_FIELD_NUMBER: _ClassVar[int] + state_keys: _containers.RepeatedScalarFieldContainer[str] + def __init__(self, state_keys: _Optional[_Iterable[str]] = ...) -> None: ... + +class GetInternalStateResponse(_message.Message): + __slots__ = ("state",) + STATE_FIELD_NUMBER: _ClassVar[int] + state: _struct_pb2.Struct + def __init__(self, state: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ... + +class SetInternalStateRequest(_message.Message): + __slots__ = ("state",) + STATE_FIELD_NUMBER: _ClassVar[int] + state: _struct_pb2.Struct + def __init__(self, state: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ... + +class SetInternalStateResponse(_message.Message): + __slots__ = ("success", "message") + SUCCESS_FIELD_NUMBER: _ClassVar[int] + MESSAGE_FIELD_NUMBER: _ClassVar[int] + success: bool + message: str + def __init__(self, success: bool = ..., message: _Optional[str] = ...) -> None: ... + +class GetModelInfoRequest(_message.Message): + __slots__ = () + def __init__(self) -> None: ... + +class GetModelInfoResponse(_message.Message): + __slots__ = ("model_path", "tokenizer_path", "is_generation", "preferred_sampling_params", "weight_version", "served_model_name", "max_context_length", "vocab_size", "supports_vision", "model_type", "eos_token_ids", "pad_token_id", "bos_token_id", "max_req_input_len", "architectures", "id2label_json", "num_labels") + MODEL_PATH_FIELD_NUMBER: _ClassVar[int] + TOKENIZER_PATH_FIELD_NUMBER: _ClassVar[int] + IS_GENERATION_FIELD_NUMBER: _ClassVar[int] + PREFERRED_SAMPLING_PARAMS_FIELD_NUMBER: _ClassVar[int] + WEIGHT_VERSION_FIELD_NUMBER: _ClassVar[int] + SERVED_MODEL_NAME_FIELD_NUMBER: _ClassVar[int] + MAX_CONTEXT_LENGTH_FIELD_NUMBER: _ClassVar[int] + VOCAB_SIZE_FIELD_NUMBER: _ClassVar[int] + SUPPORTS_VISION_FIELD_NUMBER: _ClassVar[int] + MODEL_TYPE_FIELD_NUMBER: _ClassVar[int] + EOS_TOKEN_IDS_FIELD_NUMBER: _ClassVar[int] + PAD_TOKEN_ID_FIELD_NUMBER: _ClassVar[int] + BOS_TOKEN_ID_FIELD_NUMBER: _ClassVar[int] + MAX_REQ_INPUT_LEN_FIELD_NUMBER: _ClassVar[int] + ARCHITECTURES_FIELD_NUMBER: _ClassVar[int] + ID2LABEL_JSON_FIELD_NUMBER: _ClassVar[int] + NUM_LABELS_FIELD_NUMBER: _ClassVar[int] + model_path: str + tokenizer_path: str + is_generation: bool + preferred_sampling_params: str + weight_version: str + served_model_name: str + max_context_length: int + vocab_size: int + supports_vision: bool + model_type: str + eos_token_ids: _containers.RepeatedScalarFieldContainer[int] + pad_token_id: int + bos_token_id: int + max_req_input_len: int + architectures: _containers.RepeatedScalarFieldContainer[str] + id2label_json: str + num_labels: int + def __init__(self, model_path: _Optional[str] = ..., tokenizer_path: _Optional[str] = ..., is_generation: bool = ..., preferred_sampling_params: _Optional[str] = ..., weight_version: _Optional[str] = ..., served_model_name: _Optional[str] = ..., max_context_length: _Optional[int] = ..., vocab_size: _Optional[int] = ..., supports_vision: bool = ..., model_type: _Optional[str] = ..., eos_token_ids: _Optional[_Iterable[int]] = ..., pad_token_id: _Optional[int] = ..., bos_token_id: _Optional[int] = ..., max_req_input_len: _Optional[int] = ..., architectures: _Optional[_Iterable[str]] = ..., id2label_json: _Optional[str] = ..., num_labels: _Optional[int] = ...) -> None: ... + +class GetServerInfoRequest(_message.Message): + __slots__ = () + def __init__(self) -> None: ... + +class GetServerInfoResponse(_message.Message): + __slots__ = ("server_args", "scheduler_info", "active_requests", "is_paused", "last_receive_timestamp", "uptime_seconds", "sglang_version", "server_type", "start_time") + SERVER_ARGS_FIELD_NUMBER: _ClassVar[int] + SCHEDULER_INFO_FIELD_NUMBER: _ClassVar[int] + ACTIVE_REQUESTS_FIELD_NUMBER: _ClassVar[int] + IS_PAUSED_FIELD_NUMBER: _ClassVar[int] + LAST_RECEIVE_TIMESTAMP_FIELD_NUMBER: _ClassVar[int] + UPTIME_SECONDS_FIELD_NUMBER: _ClassVar[int] + SGLANG_VERSION_FIELD_NUMBER: _ClassVar[int] + SERVER_TYPE_FIELD_NUMBER: _ClassVar[int] + START_TIME_FIELD_NUMBER: _ClassVar[int] + server_args: _struct_pb2.Struct + scheduler_info: _struct_pb2.Struct + active_requests: int + is_paused: bool + last_receive_timestamp: float + uptime_seconds: float + sglang_version: str + server_type: str + start_time: _timestamp_pb2.Timestamp + def __init__(self, server_args: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ..., scheduler_info: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ..., active_requests: _Optional[int] = ..., is_paused: bool = ..., last_receive_timestamp: _Optional[float] = ..., uptime_seconds: _Optional[float] = ..., sglang_version: _Optional[str] = ..., server_type: _Optional[str] = ..., start_time: _Optional[_Union[datetime.datetime, _timestamp_pb2.Timestamp, _Mapping]] = ...) -> None: ... + +class GetLoadsRequest(_message.Message): + __slots__ = ("dp_rank", "include") + DP_RANK_FIELD_NUMBER: _ClassVar[int] + INCLUDE_FIELD_NUMBER: _ClassVar[int] + dp_rank: int + include: _containers.RepeatedScalarFieldContainer[str] + def __init__(self, dp_rank: _Optional[int] = ..., include: _Optional[_Iterable[str]] = ...) -> None: ... + +class GetLoadsResponse(_message.Message): + __slots__ = ("timestamp", "version", "dp_rank_count", "loads", "aggregate") + TIMESTAMP_FIELD_NUMBER: _ClassVar[int] + VERSION_FIELD_NUMBER: _ClassVar[int] + DP_RANK_COUNT_FIELD_NUMBER: _ClassVar[int] + LOADS_FIELD_NUMBER: _ClassVar[int] + AGGREGATE_FIELD_NUMBER: _ClassVar[int] + timestamp: str + version: str + dp_rank_count: int + loads: _containers.RepeatedCompositeFieldContainer[SchedulerLoad] + aggregate: AggregateMetrics + def __init__(self, timestamp: _Optional[str] = ..., version: _Optional[str] = ..., dp_rank_count: _Optional[int] = ..., loads: _Optional[_Iterable[_Union[SchedulerLoad, _Mapping]]] = ..., aggregate: _Optional[_Union[AggregateMetrics, _Mapping]] = ...) -> None: ... + +class SchedulerLoad(_message.Message): + __slots__ = ("dp_rank", "num_running_reqs", "num_waiting_reqs", "num_total_reqs", "num_used_tokens", "max_total_num_tokens", "token_usage", "gen_throughput", "cache_hit_rate", "utilization", "max_running_requests", "memory", "speculative", "lora", "disaggregation", "queues") + DP_RANK_FIELD_NUMBER: _ClassVar[int] + NUM_RUNNING_REQS_FIELD_NUMBER: _ClassVar[int] + NUM_WAITING_REQS_FIELD_NUMBER: _ClassVar[int] + NUM_TOTAL_REQS_FIELD_NUMBER: _ClassVar[int] + NUM_USED_TOKENS_FIELD_NUMBER: _ClassVar[int] + MAX_TOTAL_NUM_TOKENS_FIELD_NUMBER: _ClassVar[int] + TOKEN_USAGE_FIELD_NUMBER: _ClassVar[int] + GEN_THROUGHPUT_FIELD_NUMBER: _ClassVar[int] + CACHE_HIT_RATE_FIELD_NUMBER: _ClassVar[int] + UTILIZATION_FIELD_NUMBER: _ClassVar[int] + MAX_RUNNING_REQUESTS_FIELD_NUMBER: _ClassVar[int] + MEMORY_FIELD_NUMBER: _ClassVar[int] + SPECULATIVE_FIELD_NUMBER: _ClassVar[int] + LORA_FIELD_NUMBER: _ClassVar[int] + DISAGGREGATION_FIELD_NUMBER: _ClassVar[int] + QUEUES_FIELD_NUMBER: _ClassVar[int] + dp_rank: int + num_running_reqs: int + num_waiting_reqs: int + num_total_reqs: int + num_used_tokens: int + max_total_num_tokens: int + token_usage: float + gen_throughput: float + cache_hit_rate: float + utilization: float + max_running_requests: int + memory: MemoryMetrics + speculative: SpeculativeMetrics + lora: LoRAMetrics + disaggregation: DisaggregationMetrics + queues: QueueMetrics + def __init__(self, dp_rank: _Optional[int] = ..., num_running_reqs: _Optional[int] = ..., num_waiting_reqs: _Optional[int] = ..., num_total_reqs: _Optional[int] = ..., num_used_tokens: _Optional[int] = ..., max_total_num_tokens: _Optional[int] = ..., token_usage: _Optional[float] = ..., gen_throughput: _Optional[float] = ..., cache_hit_rate: _Optional[float] = ..., utilization: _Optional[float] = ..., max_running_requests: _Optional[int] = ..., memory: _Optional[_Union[MemoryMetrics, _Mapping]] = ..., speculative: _Optional[_Union[SpeculativeMetrics, _Mapping]] = ..., lora: _Optional[_Union[LoRAMetrics, _Mapping]] = ..., disaggregation: _Optional[_Union[DisaggregationMetrics, _Mapping]] = ..., queues: _Optional[_Union[QueueMetrics, _Mapping]] = ...) -> None: ... + +class MemoryMetrics(_message.Message): + __slots__ = ("weight_gb", "kv_cache_gb", "graph_gb", "token_capacity") + WEIGHT_GB_FIELD_NUMBER: _ClassVar[int] + KV_CACHE_GB_FIELD_NUMBER: _ClassVar[int] + GRAPH_GB_FIELD_NUMBER: _ClassVar[int] + TOKEN_CAPACITY_FIELD_NUMBER: _ClassVar[int] + weight_gb: float + kv_cache_gb: float + graph_gb: float + token_capacity: int + def __init__(self, weight_gb: _Optional[float] = ..., kv_cache_gb: _Optional[float] = ..., graph_gb: _Optional[float] = ..., token_capacity: _Optional[int] = ...) -> None: ... + +class SpeculativeMetrics(_message.Message): + __slots__ = ("accept_length", "accept_rate") + ACCEPT_LENGTH_FIELD_NUMBER: _ClassVar[int] + ACCEPT_RATE_FIELD_NUMBER: _ClassVar[int] + accept_length: float + accept_rate: float + def __init__(self, accept_length: _Optional[float] = ..., accept_rate: _Optional[float] = ...) -> None: ... + +class LoRAMetrics(_message.Message): + __slots__ = ("slots_used", "slots_total", "utilization") + SLOTS_USED_FIELD_NUMBER: _ClassVar[int] + SLOTS_TOTAL_FIELD_NUMBER: _ClassVar[int] + UTILIZATION_FIELD_NUMBER: _ClassVar[int] + slots_used: int + slots_total: int + utilization: float + def __init__(self, slots_used: _Optional[int] = ..., slots_total: _Optional[int] = ..., utilization: _Optional[float] = ...) -> None: ... + +class DisaggregationMetrics(_message.Message): + __slots__ = ("mode", "prefill_prealloc_queue_reqs", "prefill_inflight_queue_reqs", "decode_prealloc_queue_reqs", "decode_transfer_queue_reqs", "decode_retracted_queue_reqs", "kv_transfer_speed_gb_s", "kv_transfer_latency_ms") + MODE_FIELD_NUMBER: _ClassVar[int] + PREFILL_PREALLOC_QUEUE_REQS_FIELD_NUMBER: _ClassVar[int] + PREFILL_INFLIGHT_QUEUE_REQS_FIELD_NUMBER: _ClassVar[int] + DECODE_PREALLOC_QUEUE_REQS_FIELD_NUMBER: _ClassVar[int] + DECODE_TRANSFER_QUEUE_REQS_FIELD_NUMBER: _ClassVar[int] + DECODE_RETRACTED_QUEUE_REQS_FIELD_NUMBER: _ClassVar[int] + KV_TRANSFER_SPEED_GB_S_FIELD_NUMBER: _ClassVar[int] + KV_TRANSFER_LATENCY_MS_FIELD_NUMBER: _ClassVar[int] + mode: str + prefill_prealloc_queue_reqs: int + prefill_inflight_queue_reqs: int + decode_prealloc_queue_reqs: int + decode_transfer_queue_reqs: int + decode_retracted_queue_reqs: int + kv_transfer_speed_gb_s: float + kv_transfer_latency_ms: float + def __init__(self, mode: _Optional[str] = ..., prefill_prealloc_queue_reqs: _Optional[int] = ..., prefill_inflight_queue_reqs: _Optional[int] = ..., decode_prealloc_queue_reqs: _Optional[int] = ..., decode_transfer_queue_reqs: _Optional[int] = ..., decode_retracted_queue_reqs: _Optional[int] = ..., kv_transfer_speed_gb_s: _Optional[float] = ..., kv_transfer_latency_ms: _Optional[float] = ...) -> None: ... + +class QueueMetrics(_message.Message): + __slots__ = ("waiting", "grammar", "paused", "retracted") + WAITING_FIELD_NUMBER: _ClassVar[int] + GRAMMAR_FIELD_NUMBER: _ClassVar[int] + PAUSED_FIELD_NUMBER: _ClassVar[int] + RETRACTED_FIELD_NUMBER: _ClassVar[int] + waiting: int + grammar: int + paused: int + retracted: int + def __init__(self, waiting: _Optional[int] = ..., grammar: _Optional[int] = ..., paused: _Optional[int] = ..., retracted: _Optional[int] = ...) -> None: ... + +class AggregateMetrics(_message.Message): + __slots__ = ("total_running_reqs", "total_waiting_reqs", "total_reqs", "avg_token_usage", "avg_throughput", "avg_utilization") + TOTAL_RUNNING_REQS_FIELD_NUMBER: _ClassVar[int] + TOTAL_WAITING_REQS_FIELD_NUMBER: _ClassVar[int] + TOTAL_REQS_FIELD_NUMBER: _ClassVar[int] + AVG_TOKEN_USAGE_FIELD_NUMBER: _ClassVar[int] + AVG_THROUGHPUT_FIELD_NUMBER: _ClassVar[int] + AVG_UTILIZATION_FIELD_NUMBER: _ClassVar[int] + total_running_reqs: int + total_waiting_reqs: int + total_reqs: int + avg_token_usage: float + avg_throughput: float + avg_utilization: float + def __init__(self, total_running_reqs: _Optional[int] = ..., total_waiting_reqs: _Optional[int] = ..., total_reqs: _Optional[int] = ..., avg_token_usage: _Optional[float] = ..., avg_throughput: _Optional[float] = ..., avg_utilization: _Optional[float] = ...) -> None: ... diff --git a/python/sglang/srt/grpc/sglang_scheduler_pb2_grpc.py b/python/sglang/srt/grpc/sglang_scheduler_pb2_grpc.py new file mode 100644 index 000000000000..99bf78bb4864 --- /dev/null +++ b/python/sglang/srt/grpc/sglang_scheduler_pb2_grpc.py @@ -0,0 +1,368 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc +import warnings + +from . import sglang_scheduler_pb2 as sglang__scheduler__pb2 + +GRPC_GENERATED_VERSION = '1.75.1' +GRPC_VERSION = grpc.__version__ +_version_not_supported = False + +try: + from grpc._utilities import first_version_is_lower + _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION) +except ImportError: + _version_not_supported = True + +if _version_not_supported: + raise RuntimeError( + f'The grpc package installed is at version {GRPC_VERSION},' + + f' but the generated code in sglang_scheduler_pb2_grpc.py depends on' + + f' grpcio>={GRPC_GENERATED_VERSION}.' + + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}' + + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.' + ) + + +class SglangSchedulerStub(object): + """Service definition for SGLang scheduler communication + This protocol bridges the Rust router and Python scheduler + """ + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.Generate = channel.unary_stream( + '/sglang.grpc.scheduler.SglangScheduler/Generate', + request_serializer=sglang__scheduler__pb2.GenerateRequest.SerializeToString, + response_deserializer=sglang__scheduler__pb2.GenerateResponse.FromString, + _registered_method=True) + self.Embed = channel.unary_unary( + '/sglang.grpc.scheduler.SglangScheduler/Embed', + request_serializer=sglang__scheduler__pb2.EmbedRequest.SerializeToString, + response_deserializer=sglang__scheduler__pb2.EmbedResponse.FromString, + _registered_method=True) + self.HealthCheck = channel.unary_unary( + '/sglang.grpc.scheduler.SglangScheduler/HealthCheck', + request_serializer=sglang__scheduler__pb2.HealthCheckRequest.SerializeToString, + response_deserializer=sglang__scheduler__pb2.HealthCheckResponse.FromString, + _registered_method=True) + self.Abort = channel.unary_unary( + '/sglang.grpc.scheduler.SglangScheduler/Abort', + request_serializer=sglang__scheduler__pb2.AbortRequest.SerializeToString, + response_deserializer=sglang__scheduler__pb2.AbortResponse.FromString, + _registered_method=True) + self.GetModelInfo = channel.unary_unary( + '/sglang.grpc.scheduler.SglangScheduler/GetModelInfo', + request_serializer=sglang__scheduler__pb2.GetModelInfoRequest.SerializeToString, + response_deserializer=sglang__scheduler__pb2.GetModelInfoResponse.FromString, + _registered_method=True) + self.GetServerInfo = channel.unary_unary( + '/sglang.grpc.scheduler.SglangScheduler/GetServerInfo', + request_serializer=sglang__scheduler__pb2.GetServerInfoRequest.SerializeToString, + response_deserializer=sglang__scheduler__pb2.GetServerInfoResponse.FromString, + _registered_method=True) + self.GetLoads = channel.unary_unary( + '/sglang.grpc.scheduler.SglangScheduler/GetLoads', + request_serializer=sglang__scheduler__pb2.GetLoadsRequest.SerializeToString, + response_deserializer=sglang__scheduler__pb2.GetLoadsResponse.FromString, + _registered_method=True) + + +class SglangSchedulerServicer(object): + """Service definition for SGLang scheduler communication + This protocol bridges the Rust router and Python scheduler + """ + + def Generate(self, request, context): + """Submit a generation request (supports streaming) + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def Embed(self, request, context): + """Submit an embedding request + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def HealthCheck(self, request, context): + """Health check and metrics + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def Abort(self, request, context): + """Abort a running request + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def GetModelInfo(self, request, context): + """Get model information + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def GetServerInfo(self, request, context): + """Get server information + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def GetLoads(self, request, context): + """Get comprehensive load metrics + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_SglangSchedulerServicer_to_server(servicer, server): + rpc_method_handlers = { + 'Generate': grpc.unary_stream_rpc_method_handler( + servicer.Generate, + request_deserializer=sglang__scheduler__pb2.GenerateRequest.FromString, + response_serializer=sglang__scheduler__pb2.GenerateResponse.SerializeToString, + ), + 'Embed': grpc.unary_unary_rpc_method_handler( + servicer.Embed, + request_deserializer=sglang__scheduler__pb2.EmbedRequest.FromString, + response_serializer=sglang__scheduler__pb2.EmbedResponse.SerializeToString, + ), + 'HealthCheck': grpc.unary_unary_rpc_method_handler( + servicer.HealthCheck, + request_deserializer=sglang__scheduler__pb2.HealthCheckRequest.FromString, + response_serializer=sglang__scheduler__pb2.HealthCheckResponse.SerializeToString, + ), + 'Abort': grpc.unary_unary_rpc_method_handler( + servicer.Abort, + request_deserializer=sglang__scheduler__pb2.AbortRequest.FromString, + response_serializer=sglang__scheduler__pb2.AbortResponse.SerializeToString, + ), + 'GetModelInfo': grpc.unary_unary_rpc_method_handler( + servicer.GetModelInfo, + request_deserializer=sglang__scheduler__pb2.GetModelInfoRequest.FromString, + response_serializer=sglang__scheduler__pb2.GetModelInfoResponse.SerializeToString, + ), + 'GetServerInfo': grpc.unary_unary_rpc_method_handler( + servicer.GetServerInfo, + request_deserializer=sglang__scheduler__pb2.GetServerInfoRequest.FromString, + response_serializer=sglang__scheduler__pb2.GetServerInfoResponse.SerializeToString, + ), + 'GetLoads': grpc.unary_unary_rpc_method_handler( + servicer.GetLoads, + request_deserializer=sglang__scheduler__pb2.GetLoadsRequest.FromString, + response_serializer=sglang__scheduler__pb2.GetLoadsResponse.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'sglang.grpc.scheduler.SglangScheduler', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + server.add_registered_method_handlers('sglang.grpc.scheduler.SglangScheduler', rpc_method_handlers) + + + # This class is part of an EXPERIMENTAL API. +class SglangScheduler(object): + """Service definition for SGLang scheduler communication + This protocol bridges the Rust router and Python scheduler + """ + + @staticmethod + def Generate(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_stream( + request, + target, + '/sglang.grpc.scheduler.SglangScheduler/Generate', + sglang__scheduler__pb2.GenerateRequest.SerializeToString, + sglang__scheduler__pb2.GenerateResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def Embed(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/sglang.grpc.scheduler.SglangScheduler/Embed', + sglang__scheduler__pb2.EmbedRequest.SerializeToString, + sglang__scheduler__pb2.EmbedResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def HealthCheck(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/sglang.grpc.scheduler.SglangScheduler/HealthCheck', + sglang__scheduler__pb2.HealthCheckRequest.SerializeToString, + sglang__scheduler__pb2.HealthCheckResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def Abort(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/sglang.grpc.scheduler.SglangScheduler/Abort', + sglang__scheduler__pb2.AbortRequest.SerializeToString, + sglang__scheduler__pb2.AbortResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def GetModelInfo(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/sglang.grpc.scheduler.SglangScheduler/GetModelInfo', + sglang__scheduler__pb2.GetModelInfoRequest.SerializeToString, + sglang__scheduler__pb2.GetModelInfoResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def GetServerInfo(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/sglang.grpc.scheduler.SglangScheduler/GetServerInfo', + sglang__scheduler__pb2.GetServerInfoRequest.SerializeToString, + sglang__scheduler__pb2.GetServerInfoResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def GetLoads(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/sglang.grpc.scheduler.SglangScheduler/GetLoads', + sglang__scheduler__pb2.GetLoadsRequest.SerializeToString, + sglang__scheduler__pb2.GetLoadsResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index d5c47d2fa67e..fdc41288357e 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -819,26 +819,32 @@ def forward_extend( else: o = torch.empty_like(q) - # Save KV cache first (must do this before unified kernel) - if save_kv_cache: - if ( - self.use_mla or layer.k_scale is None - ): # Triton MLA currently doesn't support quantized kv cache - forward_batch.token_to_kv_pool.set_kv_buffer( - layer, - forward_batch.out_cache_loc, - k, - v, - ) - else: - forward_batch.token_to_kv_pool.set_kv_buffer( - layer, - forward_batch.out_cache_loc, - k.clone(), # cloned to protect k,v from in-place mutation in set_kv_buffer - v.clone(), - layer.k_scale, - layer.v_scale, - ) + if k is None and v is None: + k, v = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) + # print(layer.layer_id, k.cpu(), v.cpu()) + elif k is None or v is None: + raise ValueError("Both k and v should be None or not None") + else: + # Save KV cache first (must do this before unified kernel) + if save_kv_cache: + if ( + self.use_mla or layer.k_scale is None + ): # Triton MLA currently doesn't support quantized kv cache + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, + forward_batch.out_cache_loc, + k, + v, + ) + else: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, + forward_batch.out_cache_loc, + k.clone(), # cloned to protect k,v from in-place mutation in set_kv_buffer + v.clone(), + layer.k_scale, + layer.v_scale, + ) logits_soft_cap = logit_capping_mod(layer.logit_capping_method, layer.logit_cap) diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index 383f58399f5c..416b5cde3114 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -121,6 +121,7 @@ def forward_cuda( residual: Optional[torch.Tensor] = None, post_residual_addition: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + return self.forward_native(x, residual, post_residual_addition) if x.numel() == 0: return x if self.variance_size_override is not None: @@ -476,7 +477,7 @@ def forward_cuda( residual: Optional[torch.Tensor] = None, post_residual_addition: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - return self._forward_impl(x, residual, post_residual_addition) + return self.forward_native(x, residual, post_residual_addition) def forward_cpu( self, @@ -554,3 +555,65 @@ def forward_npu(self, x): def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.eps}" + + +class Gemma4RMSNorm(nn.Module): + def __init__( + self, + dim: int, + eps: float = 1e-6, + scale_shift: float = 1.0, + with_scale: bool = True, + ): + super().__init__() + self.with_scale = with_scale + + if self.with_scale: + self.weight = nn.Parameter(torch.zeros(dim)) + else: + self.register_buffer("weight", torch.tensor(1.0), persistent=False) + + self.eps = eps + self.scale_shift = scale_shift + + def __repr__(self): + dim = self.weight.shape[-1] if self.weight.shape else None + return ( + f"{self.__class__.__name__}(dim={dim}, eps={self.eps}, " + f"with_scale={self.with_scale}, scale_shift={self.scale_shift})" + ) + + def _norm(self, x): + mean_squared = x.pow(2).mean(-1, keepdim=True) + self.eps + # Use torch.pow() (over torch.sqrt() or torch.rsqrt()) to addess compiler differences between Torch and JAX + return x * torch.pow(mean_squared, -0.5) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + normed_output = self._norm(x.float()) + if self.with_scale: + normed_output = normed_output * (self.weight.float() + self.scale_shift) + return normed_output.type_as(x) + + + +class RMSNormWithoutScale(MultiPlatformOp): + def __init__(self, hidden_size: int, eps=1e-6): + super().__init__() + self.hidden_size = hidden_size + self.eps = eps + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward_native(self, x): + orig_dtype = x.dtype + x = x.float() + variance = x.pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(variance + self.eps) + return x.to(orig_dtype) + + def forward_cuda(self, x): + return self.forward_native(x) + + def extra_repr(self): + return f"{self.hidden_size}, eps={self.eps}" \ No newline at end of file diff --git a/python/sglang/srt/layers/rotary_embedding/base.py b/python/sglang/srt/layers/rotary_embedding/base.py index fa3068b3dd03..03f111ee68ad 100644 --- a/python/sglang/srt/layers/rotary_embedding/base.py +++ b/python/sglang/srt/layers/rotary_embedding/base.py @@ -97,6 +97,15 @@ def __init__( ) self.position_cos, self.position_sin = None, None + def _match_cos_sin_cache_dtype(self, query: torch.Tensor) -> None: + # __setattr__ in nn.Module (called by `self.cos_sin_cache = ...`) + # is expensive, so avoid calling it if possible + if ( + self.cos_sin_cache.device != query.device + or self.cos_sin_cache.dtype != query.dtype + ): + self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype) + def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: """Compute the inverse frequency.""" # NOTE(woosuk): To exactly match the HF implementation, we need to diff --git a/python/sglang/srt/layers/rotary_embedding/rope_variant.py b/python/sglang/srt/layers/rotary_embedding/rope_variant.py index 28aaae598bc8..9dd539f40137 100644 --- a/python/sglang/srt/layers/rotary_embedding/rope_variant.py +++ b/python/sglang/srt/layers/rotary_embedding/rope_variant.py @@ -866,3 +866,66 @@ def _compute_cos_sin_cache(self) -> torch.Tensor: sin = freqs.sin() cache = torch.cat((cos, sin), dim=-1) return cache + + + +class Gemma4RotaryEmbedding(RotaryEmbedding): + """Gemma4-specific RoPE with cross-mixing. + + Instead of rotating the first `rotary_dim` dimensions contiguously, + splits the head into two halves and applies rotation across both. + + For a head_dim of D and rotary_dim of R: + - Standard RoPE rotates: [0, R) + - Gemma4 RoPE rotates: [0, R/2) cross-mixed with [D/2, D/2 + R/2) + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: float, + is_neox_style: bool, + dtype: torch.dtype, + ) -> None: + # Store angles before calling super().__init__ + # rotary_dim is already scaled by partial_rotary_factor in get_rope + # For Gemma4: head_size=512, partial_rotary_factor=0.25 -> rotary_dim=128 + self.rope_angles = rotary_dim // 2 # Number of rotation angles per half + self.nope_angles = (head_size // 2) - self.rope_angles # Non-rotated per half + + super().__init__( + head_size, + head_size, + max_position_embeddings, + base, + is_neox_style, + dtype, + ) + + def _compute_inv_freq(self, base: float) -> torch.Tensor: + """Compute frequencies only for the rotated dimensions. + + Non-rotated dims are padded with 0.0 to produce identity rotation. + """ + freq_exponents = ( + torch.arange(0, 2 * self.rope_angles, 2, dtype=torch.float) + / self.head_size + ) + inv_freq = 1.0 / (base ** freq_exponents) + + # Zero-pad for non-rotated dims (identity rotation: cos=1, sin=0) + if self.nope_angles > 0: + inv_freq = torch.cat([ + inv_freq, + torch.zeros(self.nope_angles, dtype=torch.float), + ]) + return inv_freq + + def extra_repr(self) -> str: + s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}" + s += f", rope_angles={self.rope_angles}, nope_angles={self.nope_angles}" + s += f", max_position_embeddings={self.max_position_embeddings}" + s += f", base={self.base}, is_neox_style={self.is_neox_style}" + return s diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 8f55006399e1..9c7d5e2e16f0 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -474,6 +474,7 @@ def forward_batch_generation( pp_proxy_tensors=pp_proxy_tensors, skip_attn_backend_init=skip_attn_backend_init, ) + # print(out) logits_output, can_run_cuda_graph = out.logits_output, out.can_run_graph batch_result = GenerationBatchResult( logits_output=logits_output, diff --git a/python/sglang/srt/mem_cache/common.py b/python/sglang/srt/mem_cache/common.py index 5a759ed11bcd..62dfc2f2eebc 100644 --- a/python/sglang/srt/mem_cache/common.py +++ b/python/sglang/srt/mem_cache/common.py @@ -347,6 +347,7 @@ def alloc_for_extend( prefix_lens_device = prefix_lens_cpu.to(batch.device, non_blocking=True) extend_lens_device = extend_lens_cpu.to(batch.device, non_blocking=True) + # print("req_pool_idx for each request:", [i for i, r in enumerate(batch.reqs) if r.req_pool_idx is not None]) # Allocate req slots req_pool_indices = alloc_req_slots( batch.req_to_token_pool, batch.reqs, batch.tree_cache diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 16b1410c3090..dbccf1aa38ed 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -98,6 +98,13 @@ def _set_kv_buffer_impl( alt_stream: Optional[torch.cuda.Stream] = None, same_kv_dim: bool = True, ) -> None: + # print("dtype of k_cache: ", k_cache.dtype) + # print("dtype of v_cache: ", v_cache.dtype) + # print("dtype of k: ", k.dtype) + # print("dtype of v: ", v.dtype) + # print("dtype store_dtype: ", store_dtype) + # print("row_dim: ", row_dim, "store_dtype.itemsize: ", store_dtype.itemsize) + # print("shape of k: ", k.shape, "shape of v: ", v.shape, "shape of k_cache: ", k_cache.shape, "shape of v_cache: ", v_cache.shape) row_bytes = row_dim * store_dtype.itemsize if (_is_cuda or _is_hip) and same_kv_dim and can_use_store_cache(row_bytes): return store_cache( @@ -119,6 +126,10 @@ def _set_kv_buffer_impl( v_cache[indices] = v current_stream.wait_stream(alt_stream) else: # fallback to naive implementation + # if k_cache.shape[-1] != k.shape[-1]: + # k_cache[indices, ..., :k.shape[-1]] = k + # v_cache[indices, ..., :v.shape[-1]] = v + # else: k_cache[indices] = k v_cache[indices] = v @@ -754,6 +765,7 @@ def __init__( ) self.head_num = swa_head_num if swa_head_num is not None else head_num self.head_dim = swa_head_dim if swa_head_dim is not None else head_dim + print("head_num: ", self.head_num, "head_dim: ", self.head_dim, "swa_head_num: ", swa_head_num, "swa_head_dim: ", swa_head_dim, "head_num: ", head_num, "head_dim: ", head_dim) self.v_head_dim = ( swa_v_head_dim if swa_v_head_dim is not None @@ -832,8 +844,10 @@ def _create_buffers(self): if self.enable_custom_mem_pool else nullcontext() ): + print(f"Allocating KV cache buffers with size {self.size}, page_size {self.page_size}, head_num {self.head_num}, head_dim {self.head_dim}, v_head_dim {self.v_head_dim}, dtype {self.store_dtype}, device {self.device}") # [size, head_num, head_dim] for each layer # The padded slot 0 is used for writing dummy outputs from padded tokens. + # adjust for global self.k_buffer = [ torch.zeros( (self.size + self.page_size, self.head_num, self.head_dim), @@ -977,13 +991,14 @@ def get_kv_buffer(self, layer_id: int): def set_kv_buffer( self, - layer: RadixAttention, + layer: Optional[RadixAttention], loc: torch.Tensor, cache_k: torch.Tensor, cache_v: torch.Tensor, k_scale: Optional[float] = None, v_scale: Optional[float] = None, layer_id_override: Optional[int] = None, + row_dim: Optional[int] = None, ): if layer_id_override is not None: layer_id = layer_id_override diff --git a/python/sglang/srt/mem_cache/swa_memory_pool.py b/python/sglang/srt/mem_cache/swa_memory_pool.py index 0faf201cbd48..3b1d0f1c74ed 100644 --- a/python/sglang/srt/mem_cache/swa_memory_pool.py +++ b/python/sglang/srt/mem_cache/swa_memory_pool.py @@ -40,9 +40,12 @@ def __init__( self.dtype = dtype self.head_num = head_num self.head_dim = head_dim + # self.global_head_dim = head_dim * 2 self.device = device self.swa_layer_nums = len(swa_attention_layer_ids) self.full_layer_nums = len(full_attention_layer_ids) + print(f"SWA layer nums: {self.swa_layer_nums}, Full layer nums: {self.full_layer_nums}") + self.start_layer = 0 self.page_size = page_size self.swa_loc = None @@ -157,6 +160,7 @@ def set_kv_buffer( layer_id = layer.layer_id layer_id_pool, is_swa_layer = self.layers_mapping[layer_id] + if is_swa_layer: if self.swa_loc is not None: loc = self.swa_loc diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index e42dbc556230..20c8766cd53f 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -1037,6 +1037,8 @@ def load_model(self): logger.info( f"Setting sliding_window_size to be attention_chunk_size: {self.sliding_window_size}" ) + + # Note(pyc): gemma4 has different swa def self.dtype = self.model_config.dtype diff --git a/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py b/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py index 577508b1fc23..cc4b7339e2c5 100644 --- a/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py +++ b/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py @@ -548,8 +548,8 @@ def _init_pools(self: ModelRunner): // get_attention_tp_size(), ), "swa_head_dim": self.model_config.hf_text_config.swa_head_dim, - "swa_v_head_dim": self.model_config.hf_text_config.swa_v_head_dim, - "v_head_dim": self.model_config.hf_text_config.v_head_dim, + "swa_v_head_dim": self.model_config.hf_text_config.swa_head_dim, + "v_head_dim": self.model_config.hf_text_config.head_dim, } self.token_to_kv_pool = SWAKVPool( size=self.full_max_total_num_tokens, @@ -619,6 +619,8 @@ def _init_pools(self: ModelRunner): ), ) else: + # self.max_total_num_tokens = self.max_total_num_tokens // 2 if global_head_dim is not None else self.max_total_num_tokens + # print(f"global_head_dim: {global_head_dim}, head_dim: {self.model_config.head_dim}, head_num: {self.model_config.get_total_num_kv_heads()}, max_total_num_tokens: {self.max_total_num_tokens}") self.token_to_kv_pool = MHATokenToKVPool( self.max_total_num_tokens, page_size=self.page_size, @@ -626,6 +628,7 @@ def _init_pools(self: ModelRunner): head_num=self.model_config.get_num_kv_heads( get_attention_tp_size() ), + # head_dim=self.model_config.head_dim if global_head_dim is None else global_head_dim, head_dim=self.model_config.head_dim, layer_num=self.num_effective_layers, device=self.device, diff --git a/python/sglang/srt/models/gemma3_causal.py b/python/sglang/srt/models/gemma3_causal.py index 17c535d73d3f..be3c0d6289a8 100644 --- a/python/sglang/srt/models/gemma3_causal.py +++ b/python/sglang/srt/models/gemma3_causal.py @@ -92,13 +92,19 @@ def __init__( ) if hidden_activation != "gelu_pytorch_tanh": raise ValueError( - "Gemma3 uses `gelu_pytorch_tanh` as the hidden activation " + f"{self.__class__.__name__} uses `gelu_pytorch_tanh` as the hidden activation " "function. Please set `hidden_activation` to " "`gelu_pytorch_tanh`." ) self.act_fn = GeluAndMul() + self.prefix = prefix def forward(self, x: torch.Tensor) -> torch.Tensor: + # if "layers.0.mlp" in self.prefix: + # print("---start", self.prefix) + # for p in self.gate_up_proj.parameters(): + # print(p) + # print("---end") gate_up, _ = self.gate_up_proj(x) x = self.act_fn(gate_up) x, _ = self.down_proj(x) diff --git a/python/sglang/srt/models/gemma4_causal.py b/python/sglang/srt/models/gemma4_causal.py new file mode 100644 index 000000000000..4fc2bb78e2c2 --- /dev/null +++ b/python/sglang/srt/models/gemma4_causal.py @@ -0,0 +1,914 @@ +# Copyright 2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import copy +from typing import Iterable, Optional, Set, Tuple + +import einops +import torch +import torch.nn.functional as F +from torch import nn +from transformers import ( + ROPE_INIT_FUNCTIONS, + Gemma4TextConfig, + PretrainedConfig, + PreTrainedModel, +) + +from sglang.srt.distributed import get_tensor_model_parallel_world_size +from sglang.srt.layers.activation import GeluAndMul +from sglang.srt.layers.layernorm import Gemma3RMSNorm, Gemma4RMSNorm +from sglang.srt.layers.linear import ( + MergedColumnParallelLinear, + ColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) +from sglang.srt.layers.moe.topk import TopK +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb +from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import ( + default_weight_loader, + maybe_remap_kv_scale_name, +) +from sglang.srt.utils import add_prefix, make_layers +from sglang.srt.models.gemma3_causal import Gemma3MLP +from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class +from sglang.srt.server_args import get_global_server_args +from sglang.srt.layers.rotary_embedding import get_rope +from sglang.srt.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from sglang.srt.layers.layernorm import RMSNorm, GemmaRMSNorm, RMSNormWithoutScale +from sglang.srt.models.gemma3_causal import Gemma3TextScaledWordEmbedding + + +# Aligned with HF's implementation, using sliding window inclusive with the last token +# SGLang assumes exclusive +def get_attention_sliding_window_size(config): + return config.sliding_window - 1 + + +Gemma4MLP = Gemma3MLP + + +class Gemma4PerLayerEmbedding(nn.Module): + """Per-Layer Embedding (PLE) system for Gemma 4. + + Gemma 4 uses a secondary embedding stream that provides layer-specific + token embeddings. These are combined with the main hidden states via + a gating mechanism in each decoder layer. + + The PLE embedding stores embeddings for all layers packed together: + (vocab_size, hidden_size_per_layer_input * num_hidden_layers) + """ + + def __init__( + self, + vocab_size_per_layer_input: int, + hidden_size_per_layer_input: int, + hidden_size: int, + num_hidden_layers: int, + rms_norm_eps: float, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ) -> None: + super().__init__() + self.vocab_size = vocab_size_per_layer_input + self.hidden_size_per_layer = hidden_size_per_layer_input + self.hidden_size = hidden_size + self.num_layers = num_hidden_layers + + # Packed embedding: (vocab_size, hidden_size_per_layer * num_layers) + # We store embeddings for ALL layers together + total_embed_dim = hidden_size_per_layer_input * num_hidden_layers + self.embed_tokens_per_layer = VocabParallelEmbedding( + vocab_size_per_layer_input, + total_embed_dim, + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens_per_layer", + ) + + # Projection from PLE space to hidden space + # (hidden_size_per_layer * num_layers, hidden_size) + self.per_layer_model_projection = nn.Linear( + total_embed_dim, + hidden_size, + bias=False, + ) + + # Normalization for PLE output + # JAX uses scale_plus_one=False for this norm (x * scale, not x * (1+scale)) + self.per_layer_projection_norm = RMSNorm( + self.hidden_size_per_layer, + eps=rms_norm_eps, + ) + + def forward(self, input_ids: torch.Tensor) -> torch.Tensor: + """Compute per-layer embeddings and project to hidden size. + + Args: + input_ids: Token IDs (batch_size, seq_len) + + Returns: + Per-layer input tensor (batch_size, seq_len, hidden_size) + """ + # Get packed per-layer embeddings + per_layer_embeds = self.embed_tokens_per_layer(input_ids) + + # Apply normalization (reshape to apply per-layer, then reshape back) + # Original shape: (batch, seq, hidden_size_per_layer * num_layers) + batch_size, seq_len, _ = per_layer_embeds.shape + per_layer_embeds = per_layer_embeds.view( + batch_size, seq_len, self.num_layers, self.hidden_size_per_layer + ) + per_layer_embeds = self.per_layer_projection_norm(per_layer_embeds) + per_layer_embeds = per_layer_embeds.view( + batch_size, seq_len, -1 + ) + + # Project to hidden size + per_layer_input = self.per_layer_model_projection(per_layer_embeds) + return per_layer_input + + +class Gemma4MoEBLock(nn.Module): + def __init__( + self, + layer_id: int, + config: Gemma4TextConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + self.layer_id = layer_id + self.activation = config.hidden_act + + self.topk = TopK( + top_k=config.num_experts_per_tok, + renormalize=True, + layer_id=layer_id, + ) + self.top_k = config.num_experts_per_tok + experts_type = get_moe_impl_class(quant_config) + + self.experts = experts_type( + num_experts=config.num_local_experts + + get_global_server_args().ep_num_redundant_experts, + top_k=config.num_experts_per_tok, + layer_id=layer_id, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + quant_config=quant_config, + activation=self.activation, + gemm1_alpha=self.gemm1_alpha, + gemm1_clamp_limit=self.gemm1_clamp_limit, + with_bias=True, + prefix=add_prefix("experts", prefix), + ) + + self.router = ReplicatedLinear( + config.hidden_size, + config.num_local_experts, + bias=True, + quant_config=None, + prefix=add_prefix("gate", prefix), + params_dtype=config.torch_dtype, + ) + + +class Gemma4Attention(nn.Module): + def __init__( + self, + layer_id: int, + config: Gemma4TextConfig, + head_dim: int, + max_position_embeddings: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + + self.layer_id = layer_id + self.config = config + tp_size = get_tensor_model_parallel_world_size() + + self.total_num_heads = config.num_attention_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = config.num_key_value_heads + + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + + if self.total_num_kv_heads >= tp_size: + assert self.total_num_kv_heads % tp_size == 0 + else: + assert tp_size % self.total_num_kv_heads == 0 + + hidden_size = config.hidden_size + self.head_dim = head_dim + + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=config.attention_bias, + quant_config=quant_config, + prefix=add_prefix("qkv_proj", prefix), + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=config.attention_bias, + quant_config=quant_config, + prefix=add_prefix("o_proj", prefix), + ) + + self.q_norm = Gemma4RMSNorm( + self.head_dim, + eps=config.rms_norm_eps, + ) + self.k_norm = Gemma4RMSNorm( + self.head_dim, + eps=config.rms_norm_eps, + ) + self.v_norm = Gemma4RMSNorm( + self.head_dim, + eps=config.rms_norm_eps, + scale_shift=0.0, + with_scale=False + ) + + # Determine if layer uses sliding window based on pattern + layer_type = config.layer_types[layer_id] + self.is_sliding = layer_type == "sliding_attention" + self.sliding_window = config.sliding_window if self.is_sliding else None + + # Initialize the rotary embedding based on layer type. + # Gemma 4 uses different RoPE parameters for sliding vs full attention. + if layer_type in config.rope_parameters: + rope_parameters = dict(config.rope_parameters[layer_type]) + # Fix: Use global_partial_rotary_factor for full_attention layers + # JAX reference uses global_rope_proportion=0.25 for global attention + if layer_type == "full_attention": + global_prf = getattr(config, "global_partial_rotary_factor", 0.25) + rope_parameters["partial_rotary_factor"] = global_prf + else: + # Fallback for older config format + rope_parameters = dict( + rope_type="default", + rope_theta=getattr(config, "rope_theta", 10000.0), + ) + + + # Check if this is a KV shared layer + first_kv_shared_layer_idx = ( + config.num_hidden_layers - config.num_kv_shared_layers + ) + self.is_kv_shared_layer = layer_id >= first_kv_shared_layer_idx + + # KV sharing logic for Gemma 4 + # kv_sharing_target_layer_name = None + num_kv_shared_layers = getattr(config, "num_kv_shared_layers", 0) + if num_kv_shared_layers > 0: + first_kv_shared_layer_idx = config.num_hidden_layers - num_kv_shared_layers + if self.layer_id >= first_kv_shared_layer_idx: + # Find the last non-shared layer of the same type (sliding/full) + prev_layers = config.layer_types[:first_kv_shared_layer_idx] + current_layer_type = config.layer_types[self.layer_id] + self.kv_shared_layer_index = len(prev_layers) - 1 - prev_layers[::-1].index( + current_layer_type + ) + # print(f"layer {layer_id} rope_parameters: ", rope_parameters, self.head_dim) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_parameters.get("rope_theta", 10000.0), + rope_scaling={"rope_type": rope_parameters.get("rope_type", "default")}, + partial_rotary_factor=rope_parameters.get("partial_rotary_factor", 1.0), + is_neox_style=True, + ) + + self.attn = RadixAttention( + self.num_heads, + self.head_dim, + 1, # scaling factor + num_kv_heads=self.num_kv_heads, + layer_id=( + self.kv_shared_layer_index if self.is_kv_shared_layer else self.layer_id + ), + logit_cap=getattr( + config, "attn_logit_softcapping", 0.0 + ), + sliding_window_size=self.sliding_window, + quant_config=quant_config, + prefix=add_prefix("attn", prefix), + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + **kwargs, + ): + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + q = q.unflatten(-1, (self.num_heads, self.head_dim)) + q = self.q_norm(q) + q = q.flatten(-2, -1) + + # Check if we should use shared KV cache + if self.is_kv_shared_layer and self.kv_shared_layer_index is not None: + # For KV shared layers, we skip K/V computation and normalization + # The RadixAttention will handle retrieving shared KV from cache + k = None + v = None + else: + k = k.unflatten(-1, (self.num_kv_heads, self.head_dim)) + k = self.k_norm(k) + + v = v.unflatten(-1, (self.num_kv_heads, self.head_dim)) + v = self.v_norm(v) + + # Apply rotary embedding + if k is not None: + k = k.flatten(-2, -1) + # print(f"positions: {positions.shape}, q.shape: {q.shape}, k.shape: {k.shape}, self.head_dim: {self.head_dim}") + q, k = self.rotary_emb(positions, q, k) + k = k.unflatten(-1, (self.num_kv_heads, self.head_dim)) + else: + # For shared KV layers, create a dummy key for rotary embedding and discard it + dummy_k = torch.zeros_like( + q[:, : self.kv_size] + ) # Create dummy key with same shape as needed + q, _ = self.rotary_emb(positions, q, dummy_k) + + q = q.unflatten(-1, (self.num_heads, self.head_dim)) + # print(f"attn positions: {positions.shape}, q.shape: {q.shape}, k.shape: {k.shape}, v.shape: {v.shape}, self.head_dim: {self.head_dim}") + attn_output = self.attn(q, k, v, forward_batch=forward_batch, + save_kv_cache=not self.is_kv_shared_layer) + # print(attn_output.shape) + if attn_output.dim() == 3: + attn_output = attn_output.flatten(-2, -1) + output, _ = self.o_proj(attn_output) + + return output + + +class Gemma4DecoderLayer(nn.Module): + def __init__( + self, + layer_id: int, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + self.hidden_size_per_layer_input = getattr( + config, "hidden_size_per_layer_input", 0 + ) + + self.layer_id = layer_id + + # Gemma 4 uses different head dimensions for sliding vs full attention + layer_type = config.layer_types[layer_id] + self.is_full_attention = layer_type == "full_attention" + if self.is_full_attention: + head_dim = config.head_dim # following sglang naming + else: + head_dim = getattr(config, "swa_head_dim", config.head_dim) + + self.self_attn = Gemma4Attention( + layer_id=layer_id, + config=config, + max_position_embeddings=config.max_position_embeddings, + head_dim=head_dim, + quant_config=quant_config, + prefix=add_prefix("self_attn", prefix), + ) + + # Get intermediate_size for this layer + # Gemma4 may have variable intermediate_size per layer (e.g., 6144 for layers 0-14, 12288 for layers 15+) + # if hasattr(config, 'intermediate_sizes') and config.intermediate_sizes is not None: + # layer_intermediate_size = config.intermediate_sizes[self.layer_id] + # else: + # layer_intermediate_size = config.intermediate_size + + first_kv_shared_layer_idx = config.num_hidden_layers - getattr( + config, "num_kv_shared_layers", 0 + ) + is_kv_shared_layer = self.layer_id >= first_kv_shared_layer_idx > 0 + use_double_wide_mlp = ( + getattr(config, "use_double_wide_mlp", False) and is_kv_shared_layer + ) + layer_intermediate_size = config.intermediate_size * ( + 2 if use_double_wide_mlp else 1 + ) + + self.mlp = Gemma4MLP( + hidden_size=self.hidden_size, + intermediate_size=layer_intermediate_size, + hidden_activation=config.hidden_activation, + quant_config=quant_config, + prefix=add_prefix("mlp", prefix), + ) + + self.input_layernorm = GemmaRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = GemmaRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.pre_feedforward_layernorm = GemmaRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_feedforward_layernorm = GemmaRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + # Per-Layer Embedding (PLE) components — present in each decoder layer + if self.hidden_size_per_layer_input > 0: + # Gate: projects hidden_states → per-layer dim for gating + self.per_layer_input_gate = ReplicatedLinear( + self.hidden_size, + self.hidden_size_per_layer_input, + bias=False, + quant_config=quant_config, + prefix=add_prefix("per_layer_input_gate", prefix), + ) + # Projection: projects gated per-layer input back → hidden size + self.per_layer_projection = ReplicatedLinear( + self.hidden_size_per_layer_input, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=add_prefix("per_layer_projection", prefix) + ) + self.post_per_layer_input_norm = Gemma4RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + else: + self.per_layer_input_gate = None + self.per_layer_projection = None + self.post_per_layer_input_norm = None + + # Layer scalar for full-attention layers only + if self.is_full_attention: + self.register_buffer("layer_scalar", torch.ones(1), persistent=True) + self.prefix = prefix + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + per_layer_input: torch.Tensor, + forward_batch: ForwardBatch, + **kwargs, + ) -> tuple[ + torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]] + ]: + # Gemma4 residual pattern following JAX implementation: + # 1. input_norm(x) -> attn -> post_attn_norm -> ADD residual + # 2. pre_ff_norm -> mlp -> post_ff_norm -> ADD residual + residual = hidden_states + + # Apply input layernorm + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + forward_batch=forward_batch, + ) + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = hidden_states + residual + + residual = hidden_states + hidden_states = self.pre_feedforward_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = hidden_states + residual + + if ( + per_layer_input is not None + and self.per_layer_input_gate is not None + and self.per_layer_projection is not None + and self.post_per_layer_input_norm is not None + ): + gate, _ = self.per_layer_input_gate(hidden_states) + # PLE uses gelu activation for the gate + # Note: GeluAndMul expects concatenated [gate, up] but here we + # only have a single projection. Use F.gelu directly. + gate = torch.nn.functional.gelu(gate, approximate="tanh") + gated_per_layer = gate * per_layer_input + per_layer_contribution, _ = self.per_layer_projection(gated_per_layer) + per_layer_contribution = self.post_per_layer_input_norm( + per_layer_contribution + ) + hidden_states = hidden_states + per_layer_contribution + + # Apply layer scalar for full-attention layers + if self.is_full_attention and hasattr(self, 'layer_scalar'): + hidden_states = hidden_states * self.layer_scalar + return hidden_states, None + + +Gemma4TextScaledWordEmbedding = Gemma3TextScaledWordEmbedding + + +class Gemma4TextModel(PreTrainedModel): + def __init__( + self, + config: Gemma4TextConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__(config=config) + self.config = config + self.quant_config = quant_config + self.vocab_size = config.vocab_size + self.padding_idx = getattr(config, "pad_token_id", None) + + self.embed_tokens = Gemma4TextScaledWordEmbedding( + config.vocab_size, + config.hidden_size, + self.padding_idx, + embed_scale=self.config.hidden_size**0.5, # embeded normalizer + ) + + # Per-layer input embeddings + self.hidden_size = config.hidden_size + self.hidden_size_per_layer_input = getattr( + config, "hidden_size_per_layer_input", 0 + ) + self.vocab_size_per_layer_input = getattr( + config, "vocab_size_per_layer_input", config.vocab_size + ) + + if self.hidden_size_per_layer_input > 0: + self.embed_tokens_per_layer = Gemma4TextScaledWordEmbedding( + self.vocab_size_per_layer_input, + config.num_hidden_layers * self.hidden_size_per_layer_input, + self.padding_idx, + embed_scale=self.hidden_size_per_layer_input**0.5, + ) + + # Scaled embedding factor (from config, not hardcoded) + # self.embed_scale_per_layer = torch.tensor( + # self.hidden_size_per_layer_input**0.5, + # ) + + self.per_layer_model_projection = ColumnParallelLinear( + self.hidden_size, + config.num_hidden_layers * self.hidden_size_per_layer_input, + bias=False, + quant_config=quant_config, + prefix=add_prefix("per_layer_model_projection", prefix), + ) + + self.per_layer_projection_norm = RMSNorm( + self.hidden_size_per_layer_input, + config.rms_norm_eps, + ) + self.per_layer_input_scale = torch.rsqrt(torch.tensor(2.0)) + self.per_layer_projection_scale = torch.tensor( + config.hidden_size**-0.5, + ) + else: + self.embed_tokens_per_layer = None + self.per_layer_model_projection = None + self.per_layer_projection_norm = None + self.per_layer_input_scale = None + self.per_layer_projection_scale = None + + self.layers = make_layers( + config.num_hidden_layers, + lambda idx, prefix: Gemma4DecoderLayer( + layer_id=idx, + config=config, + quant_config=quant_config, + prefix=prefix, + ), + prefix=add_prefix("layers", prefix), + ) + + self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + # self.per_layer_projection_scale = torch.tensor( + # config.hidden_size**-0.5, + # ) + # self.register_buffer( + # "per_layer_input_scale", torch.rsqrt(torch.tensor(2.0)), persistent=False + # ) + # self.register_buffer( + # "normalizer", + # torch.tensor(config.hidden_size**0.5), + # persistent=False, + # ) + self.post_init() + + def get_input_embeddings(self) -> nn.Embedding: + return self.embed_tokens + + def dtype(self) -> torch.dtype: + return next(self.parameters()).dtype + + def get_per_layer_inputs(self, input_ids: torch.LongTensor) -> torch.Tensor: + if self.embed_tokens_per_layer is None: + return None + + # Handle out-of-vocab tokens for PLE (vocab_size_per_layer_input may + # be smaller than the main vocab_size). Following Gemma3n pattern. + per_layer_inputs_mask = torch.logical_and( + input_ids >= 0, + input_ids < self.vocab_size_per_layer_input, + ) + per_layer_inputs_tokens = torch.where( + per_layer_inputs_mask, input_ids, torch.zeros_like(input_ids) + ) + + # Get packed per-layer embeddings: (num_tokens, total_ple_dim) + per_layer_embeds = self.embed_tokens_per_layer(per_layer_inputs_tokens) + + # Apply embed_scale (sqrt of per-layer hidden dim) + # Alreayd done in embedding layer + # per_layer_embeds = per_layer_embeds * self.embed_scale_per_layer + + # Reshape to (num_tokens, num_layers, hidden_size_per_layer_input) + per_layer_embeds = per_layer_embeds.reshape( + *input_ids.shape, + self.config.num_hidden_layers, + self.hidden_size_per_layer_input, + ) + return per_layer_embeds + + def project_per_layer_inputs( + self, + inputs_embeds: torch.Tensor, + per_layer_inputs: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Project inputs_embeds and combine with per_layer_inputs. + + Following HF/Gemma3n reference: + 1. Project inputs_embeds: hidden_size → total_ple_dim + 2. Scale by hidden_size^{-0.5} (Gemma4ScaledLinear w_scale) + 3. Reshape to (num_tokens, num_layers, per_layer_dim) + 4. Normalize with per_layer_projection_norm + 5. Combine: (projection + per_layer_inputs) * 1/sqrt(2) + """ + if self.per_layer_model_projection is None: + return None + + # Project from hidden_size to total_ple_dim + per_layer_projection, _ = self.per_layer_model_projection(inputs_embeds) + + # Apply w_scale (HF: Gemma4ScaledLinear with w_scale=hidden_size^{-0.5}) + per_layer_projection = ( + per_layer_projection * self.per_layer_projection_scale + ) + + # Reshape to (num_tokens, num_layers, hidden_size_per_layer_input) + per_layer_projection = per_layer_projection.reshape( + *inputs_embeds.shape[:-1], + self.config.num_hidden_layers, + self.hidden_size_per_layer_input, + ) + + # Normalize + per_layer_projection = self.per_layer_projection_norm( + per_layer_projection + ) + + if per_layer_inputs is None: + return per_layer_projection + + # Combine: (projection + per_layer_inputs) * scale + return ( + per_layer_projection + per_layer_inputs + ) * self.per_layer_input_scale + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + per_layer_inputs: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + if (input_ids is None) ^ (input_embeds is not None): + raise ValueError( + "You must specify exactly one of input_ids or inputs_embeds" + ) + + if input_ids is not None: + input_embeds = self.embed_tokens(input_ids) + per_layer_embeds = self.get_per_layer_inputs(input_ids) + per_layer_inputs = self.project_per_layer_inputs(input_embeds, per_layer_embeds) + + hidden_states = input_embeds + + for layer_idx, layer in enumerate(self.layers): + if per_layer_inputs is not None: + per_layer_input = per_layer_inputs[:, layer_idx, :] + else: + per_layer_input = None + layer_outputs = layer( + positions=positions, + hidden_states=hidden_states, + per_layer_input=per_layer_input, + forward_batch=forward_batch, + **kwargs, + ) + hidden_states = layer_outputs[0] + residual = layer_outputs[1] if len(layer_outputs) > 1 else None + + if residual is None: + hidden_states = self.norm(hidden_states) + else: + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class Gemma4ForCausalLM(PreTrainedModel): + config_class = Gemma4TextConfig + base_model_prefix = "language_model" + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} + _tp_plan = {"lm_head": "colwise_rep"} + + # BitandBytes specific attributes + default_bitsandbytes_target_modules = [ + ".gate_proj.", + ".down_proj.", + ".up_proj.", + ".q_proj.", + ".k_proj.", + ".v_proj.", + ".o_proj.", + ] + bitsandbytes_stacked_params_mapping = { + # shard_name, weight_name, index + "q_proj": ("qkv_proj", 0), + "k_proj": ("qkv_proj", 1), + "v_proj": ("qkv_proj", 2), + "gate_proj": ("gate_up_proj", 0), + "up_proj": ("gate_up_proj", 1), + } + + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + # Gemma does not apply LoRA to the embedding layer. + embedding_modules = {} + embedding_padding_modules = [] + supports_lora = False + + def __init__( + self, + config: Gemma4TextConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__(config=config) + self.config = config + self.quant_config = quant_config + self.model = Gemma4TextModel( + config=config, quant_config=quant_config, prefix=add_prefix("model", prefix) + ) + self.logits_processor = LogitsProcessor(config) + + if self.config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=add_prefix("lm_head", prefix), + ) + self.post_init() + + def get_input_embeddings(self) -> nn.Embedding: + return self.model.embed_tokens + + def get_attention_sliding_window_size(self): + return get_attention_sliding_window_size(self.config) + + def dtype(self) -> torch.dtype: + return next(self.parameters()).dtype + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + per_layer_inputs: Optional[torch.Tensor] = None, + **kwargs, + ) -> LogitsProcessor: + hidden_states = self.model( + input_ids, positions, forward_batch, input_embeds, per_layer_inputs, **kwargs + ) + + return self.logits_processor( + input_ids, hidden_states, self.model.embed_tokens, forward_batch + ) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + params_dict.update(dict(self.named_buffers())) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if "audio" in name or "vision" in name: + continue + + if ".language_model" in name: + name = name.replace(".language_model", "") + + for param_name, shard_name, shard_id in stacked_params_mapping: + if shard_name not in name: + continue + name = name.replace(shard_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # if name not in params_dict: + # continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # lm_head is not used in vllm as it is tied with embed_token. + # To prevent errors, skip loading lm_head.weight. + if "lm_head.weight" in name and self.config.tie_word_embeddings: + continue + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + if name not in params_dict: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + unloaded_params = params_dict.keys() - loaded_params + if unloaded_params: + print( + "Some weights are not initialized from checkpoints: %s", unloaded_params + ) + return loaded_params + + +EntryClass = Gemma4ForCausalLM + + + diff --git a/scripts/playground/reference_hf.py b/scripts/playground/reference_hf.py index 538c31f7713d..6887f658b165 100644 --- a/scripts/playground/reference_hf.py +++ b/scripts/playground/reference_hf.py @@ -124,8 +124,8 @@ def normal_text(args): prompts = [ "The capital of France is", - "The capital of the United Kindom is", - "Today is a sunny day and I like", + # "The capital of the United Kindom is", + # "Today is a sunny day and I like", ] max_new_tokens = args.max_new_tokens @@ -164,10 +164,8 @@ def synthetic_tokens(args): for p in prompts: input_ids = p for i in range(output_len + 1): - prefill_logits = m.forward(torch.tensor([input_ids], device="cuda")).logits[ - 0 - ][-1] - + output = m.forward(torch.tensor([input_ids], device="cuda"), output_hidden_states=True).logits[0][-1] + prefill_logits = output if i == 0: print("prefill logits", prefill_logits) else: From fe2524129bfcfd2963b5c84358f397a9ced210bb Mon Sep 17 00:00:00 2001 From: Pengyu Chen Date: Wed, 4 Mar 2026 21:24:26 +0000 Subject: [PATCH 002/112] format and cleanup --- python/sglang/jit_kernel/kvcache.py | 1 - python/sglang/srt/layers/layernorm.py | 1 + python/sglang/srt/managers/tp_worker.py | 1 - python/sglang/srt/mem_cache/common.py | 1 - python/sglang/srt/mem_cache/memory_pool.py | 13 +------------ python/sglang/srt/mem_cache/swa_memory_pool.py | 2 -- python/sglang/srt/models/gemma3_causal.py | 5 ----- python/sglang/srt/models/gemma4_causal.py | 13 ------------- 8 files changed, 2 insertions(+), 35 deletions(-) diff --git a/python/sglang/jit_kernel/kvcache.py b/python/sglang/jit_kernel/kvcache.py index 065ca6eb4918..46a14612b6ff 100644 --- a/python/sglang/jit_kernel/kvcache.py +++ b/python/sglang/jit_kernel/kvcache.py @@ -65,7 +65,6 @@ def store_cache( v_cache (torch.Tensor): Value cache tensor of shape (num_pages, H * D). indices (torch.Tensor): Indices tensor of shape (batch_size,). """ - # print(f"store_cache called with k.shape={k.shape}, v.shape={v.shape}, k_cache.shape={k_cache.shape}, v_cache.shape={v_cache.shape}, indices.shape={indices.shape}, row_bytes={row_bytes}, num_split={num_split}") row_bytes = row_bytes or k.shape[-1] * k.element_size() module = _jit_kvcache_module(row_bytes) if num_split <= 0: diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index 416b5cde3114..2f0016c53437 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -121,6 +121,7 @@ def forward_cuda( residual: Optional[torch.Tensor] = None, post_residual_addition: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + # # TODO: fix cuda: having some shape issue with sgl kernel return self.forward_native(x, residual, post_residual_addition) if x.numel() == 0: return x diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 9c7d5e2e16f0..8f55006399e1 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -474,7 +474,6 @@ def forward_batch_generation( pp_proxy_tensors=pp_proxy_tensors, skip_attn_backend_init=skip_attn_backend_init, ) - # print(out) logits_output, can_run_cuda_graph = out.logits_output, out.can_run_graph batch_result = GenerationBatchResult( logits_output=logits_output, diff --git a/python/sglang/srt/mem_cache/common.py b/python/sglang/srt/mem_cache/common.py index 62dfc2f2eebc..5a759ed11bcd 100644 --- a/python/sglang/srt/mem_cache/common.py +++ b/python/sglang/srt/mem_cache/common.py @@ -347,7 +347,6 @@ def alloc_for_extend( prefix_lens_device = prefix_lens_cpu.to(batch.device, non_blocking=True) extend_lens_device = extend_lens_cpu.to(batch.device, non_blocking=True) - # print("req_pool_idx for each request:", [i for i, r in enumerate(batch.reqs) if r.req_pool_idx is not None]) # Allocate req slots req_pool_indices = alloc_req_slots( batch.req_to_token_pool, batch.reqs, batch.tree_cache diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index dbccf1aa38ed..5856da31ea2c 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -98,15 +98,8 @@ def _set_kv_buffer_impl( alt_stream: Optional[torch.cuda.Stream] = None, same_kv_dim: bool = True, ) -> None: - # print("dtype of k_cache: ", k_cache.dtype) - # print("dtype of v_cache: ", v_cache.dtype) - # print("dtype of k: ", k.dtype) - # print("dtype of v: ", v.dtype) - # print("dtype store_dtype: ", store_dtype) - # print("row_dim: ", row_dim, "store_dtype.itemsize: ", store_dtype.itemsize) - # print("shape of k: ", k.shape, "shape of v: ", v.shape, "shape of k_cache: ", k_cache.shape, "shape of v_cache: ", v_cache.shape) row_bytes = row_dim * store_dtype.itemsize - if (_is_cuda or _is_hip) and same_kv_dim and can_use_store_cache(row_bytes): + if _is_cuda and same_kv_dim and can_use_store_cache(row_bytes): return store_cache( k.view(-1, row_dim), v.view(-1, row_dim), @@ -126,10 +119,6 @@ def _set_kv_buffer_impl( v_cache[indices] = v current_stream.wait_stream(alt_stream) else: # fallback to naive implementation - # if k_cache.shape[-1] != k.shape[-1]: - # k_cache[indices, ..., :k.shape[-1]] = k - # v_cache[indices, ..., :v.shape[-1]] = v - # else: k_cache[indices] = k v_cache[indices] = v diff --git a/python/sglang/srt/mem_cache/swa_memory_pool.py b/python/sglang/srt/mem_cache/swa_memory_pool.py index 3b1d0f1c74ed..a02738f4b9cb 100644 --- a/python/sglang/srt/mem_cache/swa_memory_pool.py +++ b/python/sglang/srt/mem_cache/swa_memory_pool.py @@ -40,11 +40,9 @@ def __init__( self.dtype = dtype self.head_num = head_num self.head_dim = head_dim - # self.global_head_dim = head_dim * 2 self.device = device self.swa_layer_nums = len(swa_attention_layer_ids) self.full_layer_nums = len(full_attention_layer_ids) - print(f"SWA layer nums: {self.swa_layer_nums}, Full layer nums: {self.full_layer_nums}") self.start_layer = 0 self.page_size = page_size diff --git a/python/sglang/srt/models/gemma3_causal.py b/python/sglang/srt/models/gemma3_causal.py index be3c0d6289a8..dc66134cb381 100644 --- a/python/sglang/srt/models/gemma3_causal.py +++ b/python/sglang/srt/models/gemma3_causal.py @@ -100,11 +100,6 @@ def __init__( self.prefix = prefix def forward(self, x: torch.Tensor) -> torch.Tensor: - # if "layers.0.mlp" in self.prefix: - # print("---start", self.prefix) - # for p in self.gate_up_proj.parameters(): - # print(p) - # print("---end") gate_up, _ = self.gate_up_proj(x) x = self.act_fn(gate_up) x, _ = self.down_proj(x) diff --git a/python/sglang/srt/models/gemma4_causal.py b/python/sglang/srt/models/gemma4_causal.py index 4fc2bb78e2c2..90f18cd33323 100644 --- a/python/sglang/srt/models/gemma4_causal.py +++ b/python/sglang/srt/models/gemma4_causal.py @@ -300,7 +300,6 @@ def __init__( self.kv_shared_layer_index = len(prev_layers) - 1 - prev_layers[::-1].index( current_layer_type ) - # print(f"layer {layer_id} rope_parameters: ", rope_parameters, self.head_dim) self.rotary_emb = get_rope( self.head_dim, rotary_dim=self.head_dim, @@ -357,7 +356,6 @@ def forward( # Apply rotary embedding if k is not None: k = k.flatten(-2, -1) - # print(f"positions: {positions.shape}, q.shape: {q.shape}, k.shape: {k.shape}, self.head_dim: {self.head_dim}") q, k = self.rotary_emb(positions, q, k) k = k.unflatten(-1, (self.num_kv_heads, self.head_dim)) else: @@ -368,10 +366,8 @@ def forward( q, _ = self.rotary_emb(positions, q, dummy_k) q = q.unflatten(-1, (self.num_heads, self.head_dim)) - # print(f"attn positions: {positions.shape}, q.shape: {q.shape}, k.shape: {k.shape}, v.shape: {v.shape}, self.head_dim: {self.head_dim}") attn_output = self.attn(q, k, v, forward_batch=forward_batch, save_kv_cache=not self.is_kv_shared_layer) - # print(attn_output.shape) if attn_output.dim() == 3: attn_output = attn_output.flatten(-2, -1) output, _ = self.o_proj(attn_output) @@ -411,13 +407,6 @@ def __init__( quant_config=quant_config, prefix=add_prefix("self_attn", prefix), ) - - # Get intermediate_size for this layer - # Gemma4 may have variable intermediate_size per layer (e.g., 6144 for layers 0-14, 12288 for layers 15+) - # if hasattr(config, 'intermediate_sizes') and config.intermediate_sizes is not None: - # layer_intermediate_size = config.intermediate_sizes[self.layer_id] - # else: - # layer_intermediate_size = config.intermediate_size first_kv_shared_layer_idx = config.num_hidden_layers - getattr( config, "num_kv_shared_layers", 0 @@ -874,8 +863,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue - # if name not in params_dict: - # continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) From cff018c3872c0de63473dcc7810c6514f4601607 Mon Sep 17 00:00:00 2001 From: Pengyu Chen Date: Thu, 5 Mar 2026 22:27:18 +0000 Subject: [PATCH 003/112] temp fix for kv sharing --- python/sglang/srt/layers/attention/triton_backend.py | 5 +++++ python/sglang/srt/models/gemma4_causal.py | 3 +-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index fdc41288357e..a905e970f9bc 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -821,6 +821,10 @@ def forward_extend( if k is None and v is None: k, v = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) + # FIXME: hacky way to make kv cache aligned + # why??? + k = k[1:q.shape[0]+1] + v = v[1:q.shape[0]+1] # print(layer.layer_id, k.cpu(), v.cpu()) elif k is None or v is None: raise ValueError("Both k and v should be None or not None") @@ -846,6 +850,7 @@ def forward_extend( layer.v_scale, ) + logits_soft_cap = logit_capping_mod(layer.logit_capping_method, layer.logit_cap) causal = True diff --git a/python/sglang/srt/models/gemma4_causal.py b/python/sglang/srt/models/gemma4_causal.py index 90f18cd33323..5b014b02824d 100644 --- a/python/sglang/srt/models/gemma4_causal.py +++ b/python/sglang/srt/models/gemma4_causal.py @@ -832,9 +832,8 @@ def forward( hidden_states = self.model( input_ids, positions, forward_batch, input_embeds, per_layer_inputs, **kwargs ) - return self.logits_processor( - input_ids, hidden_states, self.model.embed_tokens, forward_batch + input_ids, hidden_states, self.lm_head, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): From 593927029bd85465680478d87d3b5d3dc8303778 Mon Sep 17 00:00:00 2001 From: Pengyu Chen Date: Fri, 6 Mar 2026 07:34:07 +0000 Subject: [PATCH 004/112] cleanup & tp --- python/sglang/srt/models/gemma4_causal.py | 41 +++++++++-------------- 1 file changed, 16 insertions(+), 25 deletions(-) diff --git a/python/sglang/srt/models/gemma4_causal.py b/python/sglang/srt/models/gemma4_causal.py index 5b014b02824d..b8cc8bc4d448 100644 --- a/python/sglang/srt/models/gemma4_causal.py +++ b/python/sglang/srt/models/gemma4_causal.py @@ -11,25 +11,22 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -import copy + +import logging from typing import Iterable, Optional, Set, Tuple -import einops import torch import torch.nn.functional as F from torch import nn from transformers import ( - ROPE_INIT_FUNCTIONS, + AutoModel, Gemma4TextConfig, PretrainedConfig, PreTrainedModel, ) from sglang.srt.distributed import get_tensor_model_parallel_world_size -from sglang.srt.layers.activation import GeluAndMul -from sglang.srt.layers.layernorm import Gemma3RMSNorm, Gemma4RMSNorm from sglang.srt.layers.linear import ( - MergedColumnParallelLinear, ColumnParallelLinear, QKVParallelLinear, ReplicatedLinear, @@ -39,7 +36,6 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import ( @@ -47,7 +43,6 @@ maybe_remap_kv_scale_name, ) from sglang.srt.utils import add_prefix, make_layers -from sglang.srt.models.gemma3_causal import Gemma3MLP from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class from sglang.srt.server_args import get_global_server_args from sglang.srt.layers.rotary_embedding import get_rope @@ -55,9 +50,11 @@ ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.layers.layernorm import RMSNorm, GemmaRMSNorm, RMSNormWithoutScale -from sglang.srt.models.gemma3_causal import Gemma3TextScaledWordEmbedding +from sglang.srt.layers.layernorm import RMSNorm, GemmaRMSNorm, Gemma4RMSNorm +from sglang.srt.models.gemma3_causal import Gemma3TextScaledWordEmbedding, Gemma3MLP + +logger = logging.getLogger(__name__) # Aligned with HF's implementation, using sliding window inclusive with the last token # SGLang assumes exclusive @@ -66,7 +63,7 @@ def get_attention_sliding_window_size(config): Gemma4MLP = Gemma3MLP - +Gemma4TextScaledWordEmbedding = Gemma3TextScaledWordEmbedding class Gemma4PerLayerEmbedding(nn.Module): """Per-Layer Embedding (PLE) system for Gemma 4. @@ -526,9 +523,6 @@ def forward( return hidden_states, None -Gemma4TextScaledWordEmbedding = Gemma3TextScaledWordEmbedding - - class Gemma4TextModel(PreTrainedModel): def __init__( self, @@ -571,7 +565,8 @@ def __init__( # self.hidden_size_per_layer_input**0.5, # ) - self.per_layer_model_projection = ColumnParallelLinear( + # FIXME: Use replicated for now. Use ColumnParallel?. + self.per_layer_model_projection = ReplicatedLinear( self.hidden_size, config.num_hidden_layers * self.hidden_size_per_layer_input, bias=False, @@ -849,12 +844,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict.update(dict(self.named_buffers())) loaded_params: Set[str] = set() for name, loaded_weight in weights: - if "audio" in name or "vision" in name: - continue - - if ".language_model" in name: - name = name.replace(".language_model", "") - + name = name.replace("model.language_model.", "model.") for param_name, shard_name, shard_id in stacked_params_mapping: if shard_name not in name: continue @@ -862,6 +852,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if name not in params_dict: + # Skip loading weights that are not in the model + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -888,13 +881,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): loaded_params.add(name) unloaded_params = params_dict.keys() - loaded_params if unloaded_params: - print( + logger.warning( "Some weights are not initialized from checkpoints: %s", unloaded_params ) return loaded_params EntryClass = Gemma4ForCausalLM - - - +AutoModel.register(Gemma4TextConfig, Gemma4ForCausalLM, exist_ok=True) From dea02a2288cf63bf92e64fb64e6cef7578417e1b Mon Sep 17 00:00:00 2001 From: Pengyu Chen Date: Sat, 7 Mar 2026 07:33:31 +0000 Subject: [PATCH 005/112] Reasoning parser. --- .../srt/entrypoints/openai/serving_chat.py | 11 ++++++-- .../srt/managers/detokenizer_manager.py | 2 +- python/sglang/srt/parser/reasoning_parser.py | 27 +++++++++++++++++-- 3 files changed, 35 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/entrypoints/openai/serving_chat.py b/python/sglang/srt/entrypoints/openai/serving_chat.py index 7913af172c3b..428c50993db9 100644 --- a/python/sglang/srt/entrypoints/openai/serving_chat.py +++ b/python/sglang/srt/entrypoints/openai/serving_chat.py @@ -119,6 +119,11 @@ def __init__( and hasattr(self.tokenizer_manager.model_config.hf_config, "model_type") and self.tokenizer_manager.model_config.hf_config.model_type == "gpt_oss" ) + self.is_gemma4 = ( + hasattr(self.tokenizer_manager.model_config, "hf_config") + and hasattr(self.tokenizer_manager.model_config.hf_config, "model_type") + and self.tokenizer_manager.model_config.hf_config.model_type == "gemma4" + ) self.use_dpsk_v32_encoding = self._use_dpsk_v32_encoding() @@ -320,7 +325,7 @@ def _process_messages( ) -> MessageProcessingResult: """Process chat messages and apply chat template""" # GptOss model needs to keep special tokens for harmony parsing - if self.is_gpt_oss: + if self.is_gpt_oss or self.is_gemma4: request.skip_special_tokens = False tool_call_constraint = None @@ -946,6 +951,7 @@ def _build_chat_response( self.template_manager.force_reasoning or self._get_reasoning_from_request(request) ) + print(f"is_force_reasoning: {self.template_manager.force_reasoning}, self._get_reasoning_from_request(request): {self._get_reasoning_from_request(request)}") try: parser = ReasoningParser( model_type=reasoning_parser, @@ -1228,7 +1234,8 @@ def _get_reasoning_from_request(self, request: ChatCompletionRequest) -> bool: """Judge whether the request needs reasoning""" if not self.reasoning_parser: return False - if self.reasoning_parser in ["deepseek-v3"]: + # Do we want to think by default? + if self.reasoning_parser in ["deepseek-v3", "gemma4"]: # Models that require explicit enable thinking (thinking=True) return ( request.chat_template_kwargs is not None diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index 2c6db4b2f7d6..357f18e37b53 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -339,7 +339,7 @@ def _decode_batch_token_id_output(self, recv_obj: BatchTokenIDOutput): incremental_output = output_str[s.sent_offset :] s.sent_offset = len(output_str) output_strs.append(incremental_output) - + # print(output_strs) return output_strs def _extract_routed_experts( diff --git a/python/sglang/srt/parser/reasoning_parser.py b/python/sglang/srt/parser/reasoning_parser.py index 7eef5c0fbf32..afdc57a6e82a 100644 --- a/python/sglang/srt/parser/reasoning_parser.py +++ b/python/sglang/srt/parser/reasoning_parser.py @@ -37,6 +37,7 @@ def __init__( self._buffer = "" self.stripped_think_start = False + self.think_start_self_label = "" self.continue_final_message = continue_final_message if self.continue_final_message: @@ -62,7 +63,7 @@ def detect_and_parse(self, text: str) -> StreamingParseResult: return StreamingParseResult(normal_text=text) # The text is considered to be in a reasoning block. - processed_text = text.replace(self.think_start_token, "").strip() + processed_text = text.replace(self.think_start_token + self.think_start_self_label, "").strip() if ( self.think_end_token not in processed_text @@ -120,9 +121,11 @@ def parse_streaming_increment(self, new_text: str) -> StreamingParseResult: for token in tokens_to_check ): return StreamingParseResult() + + think_start_text = self.think_start_token + self.think_start_self_label # Strip `` token if present - if not self.stripped_think_start and self.think_start_token in current_text: + if not self.stripped_think_start and think_start_text in current_text: current_text = current_text.replace(self.think_start_token, "") self.stripped_think_start = True self._in_reasoning = True @@ -441,6 +444,25 @@ def __init__( previous_content=previous_content, ) +class Gemma4Detector(BaseReasoningFormatDetector): + """Gemma4 reasoning detector.""" + def __init__( + self, + stream_reasoning: bool = True, + force_reasoning: bool = False, + continue_final_message: bool = False, + previous_content: str = "", + ): + super().__init__( + "<|channel>", + "", + force_reasoning=force_reasoning, + stream_reasoning=stream_reasoning, + continue_final_message=continue_final_message, + previous_content=previous_content, + ) + self.think_start_self_label = "thought\n" + class ReasoningParser: """ @@ -468,6 +490,7 @@ class ReasoningParser: "step3p5": DeepSeekR1Detector, "nemotron_3": Nemotron3Detector, "interns1": Qwen3Detector, + "gemma4": Gemma4Detector, } def __init__( From 418ba402d6fce47fc9452105c22f6c932b0a19c1 Mon Sep 17 00:00:00 2001 From: Pengyu Chen Date: Sun, 8 Mar 2026 05:45:07 +0000 Subject: [PATCH 006/112] tool call parser --- .../srt/function_call/function_call_parser.py | 2 + .../srt/function_call/gemma4_detector.py | 405 ++++++++++++++++++ .../test_function_call_parser.py | 93 ++++ 3 files changed, 500 insertions(+) create mode 100644 python/sglang/srt/function_call/gemma4_detector.py diff --git a/python/sglang/srt/function_call/function_call_parser.py b/python/sglang/srt/function_call/function_call_parser.py index 2f562192e219..efa44e5514ff 100644 --- a/python/sglang/srt/function_call/function_call_parser.py +++ b/python/sglang/srt/function_call/function_call_parser.py @@ -14,6 +14,7 @@ from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector from sglang.srt.function_call.deepseekv31_detector import DeepSeekV31Detector from sglang.srt.function_call.deepseekv32_detector import DeepSeekV32Detector +from sglang.srt.function_call.gemma4_detector import Gemma4Detector from sglang.srt.function_call.gigachat3_detector import GigaChat3Detector from sglang.srt.function_call.glm4_moe_detector import Glm4MoeDetector from sglang.srt.function_call.glm47_moe_detector import Glm47MoeDetector @@ -69,6 +70,7 @@ class FunctionCallParser: "interns1": InternlmDetector, "hermes": HermesDetector, "gigachat3": GigaChat3Detector, + "gemma4": Gemma4Detector, } def __init__(self, tools: List[Tool], tool_call_parser: str): diff --git a/python/sglang/srt/function_call/gemma4_detector.py b/python/sglang/srt/function_call/gemma4_detector.py new file mode 100644 index 000000000000..6cd521d12e73 --- /dev/null +++ b/python/sglang/srt/function_call/gemma4_detector.py @@ -0,0 +1,405 @@ +import json +import logging +import re +from typing import Any, List, Optional + +from sglang.srt.entrypoints.openai.protocol import Tool +from sglang.srt.function_call.base_format_detector import BaseFormatDetector +from sglang.srt.function_call.core_types import ( + StreamingParseResult, + ToolCallItem, + _GetInfoFunc, +) + +logger = logging.getLogger(__name__) + +# Gemma4 special tokens for tool calls +TOOL_CALL_START = "<|tool_call>" +TOOL_CALL_END = "" +STRING_DELIM = '<|"|>' + + +def _parse_gemma4_value(value_str: str) -> object: + """Parse a single Gemma4 value (after key:) into a Python object.""" + value_str = value_str.strip() + if not value_str: + return value_str + + # Boolean + if value_str == "true": + return True + if value_str == "false": + return False + + # Number (int or float) + try: + if "." in value_str: + return float(value_str) + return int(value_str) + except ValueError: + pass + + # Bare string (no <|"|> delimiters) + return value_str + + +def _parse_gemma4_array(arr_str: str) -> list: + """Parse a Gemma4 array content string into a Python list.""" + items: list = [] + i = 0 + n = len(arr_str) + + while i < n: + while i < n and arr_str[i] in (" ", ",", "\n", "\t"): + i += 1 + if i >= n: + break + + # String element + if arr_str[i:].startswith(STRING_DELIM): + i += len(STRING_DELIM) + end_pos = arr_str.find(STRING_DELIM, i) + if end_pos == -1: + items.append(arr_str[i:]) + break + items.append(arr_str[i:end_pos]) + i = end_pos + len(STRING_DELIM) + + # Nested object + elif arr_str[i] == "{": + depth = 1 + obj_start = i + 1 + i += 1 + while i < n and depth > 0: + if arr_str[i:].startswith(STRING_DELIM): + i += len(STRING_DELIM) + nd = arr_str.find(STRING_DELIM, i) + i = nd + len(STRING_DELIM) if nd != -1 else n + continue + if arr_str[i] == "{": + depth += 1 + elif arr_str[i] == "}": + depth -= 1 + i += 1 + items.append(_parse_gemma4_args(arr_str[obj_start : i - 1])) + + # Nested array + elif arr_str[i] == "[": + depth = 1 + sub_start = i + 1 + i += 1 + while i < n and depth > 0: + if arr_str[i] == "[": + depth += 1 + elif arr_str[i] == "]": + depth -= 1 + i += 1 + items.append(_parse_gemma4_array(arr_str[sub_start : i - 1])) + + # Bare value + else: + val_start = i + while i < n and arr_str[i] not in (",", "]"): + i += 1 + items.append(_parse_gemma4_value(arr_str[val_start:i])) + + return items + + +def _parse_gemma4_args(args_str: str) -> dict: + """Parse Gemma4's custom key:value format into a Python dict.""" + if not args_str or not args_str.strip(): + return {} + + result: dict = {} + i = 0 + n = len(args_str) + + while i < n: + # Skip whitespace and commas + while i < n and args_str[i] in (" ", ",", "\n", "\t"): + i += 1 + if i >= n: + break + + # Parse key (unquoted, ends at ':') + key_start = i + while i < n and args_str[i] != ":": + i += 1 + if i >= n: + break + key = args_str[key_start:i].strip() + i += 1 # skip ':' + + # Parse value + if i >= n: + result[key] = "" + break + + # Skip whitespace after ':' + while i < n and args_str[i] in (" ", "\n", "\t"): + i += 1 + if i >= n: + result[key] = "" + break + + # String value: <|"|>...<|"|> + if args_str[i:].startswith(STRING_DELIM): + i += len(STRING_DELIM) + val_start = i + end_pos = args_str.find(STRING_DELIM, i) + if end_pos == -1: + # Unterminated string — take rest + result[key] = args_str[val_start:] + break + result[key] = args_str[val_start:end_pos] + i = end_pos + len(STRING_DELIM) + + # Nested object: {...} + elif args_str[i] == "{": + depth = 1 + obj_start = i + 1 + i += 1 + while i < n and depth > 0: + if args_str[i:].startswith(STRING_DELIM): + # Skip over string contents + i += len(STRING_DELIM) + next_delim = args_str.find(STRING_DELIM, i) + if next_delim == -1: + i = n + else: + i = next_delim + len(STRING_DELIM) + continue + if args_str[i] == "{": + depth += 1 + elif args_str[i] == "}": + depth -= 1 + i += 1 + result[key] = _parse_gemma4_args(args_str[obj_start : i - 1]) + + # Array: [...] + elif args_str[i] == "[": + depth = 1 + arr_start = i + 1 + i += 1 + while i < n and depth > 0: + if args_str[i:].startswith(STRING_DELIM): + i += len(STRING_DELIM) + next_delim = args_str.find(STRING_DELIM, i) + if next_delim == -1: + i = n + else: + i = next_delim + len(STRING_DELIM) + continue + if args_str[i] == "[": + depth += 1 + elif args_str[i] == "]": + depth -= 1 + i += 1 + arr_content = args_str[arr_start : i - 1] + result[key] = _parse_gemma4_array(arr_content) + + # Bare value (number, boolean, etc.) + else: + val_start = i + while i < n and args_str[i] not in (",", "}", "]"): + i += 1 + result[key] = _parse_gemma4_value(args_str[val_start:i]) + + return result + + +class Gemma4Detector(BaseFormatDetector): + def __init__(self): + super().__init__() + self.tool_call_start_token = TOOL_CALL_START + self.tool_call_end_token = TOOL_CALL_END + self.tool_call_regex = re.compile( + r"<\|tool_call>call:(\w+)\{(.*?)\}", + re.DOTALL, + ) + + # Streaming state + self.parsed_pos: int = 0 + self.is_inside_tool_call: bool = False + self.current_func_name: Optional[str] = None + self.json_started: bool = False + + def has_tool_call(self, text: str) -> bool: + return self.tool_call_start_token in text + + def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult: + if self.tool_call_start_token not in text: + return StreamingParseResult(normal_text=text) + + calls = [] + try: + matches = self.tool_call_regex.findall(text) + if not matches: + return StreamingParseResult(normal_text=text) + + tool_indices = self._get_tool_indices(tools) + for func_name, args_str in matches: + arguments = _parse_gemma4_args(args_str) + calls.append( + ToolCallItem( + tool_index=tool_indices.get(func_name, -1), + name=func_name, + parameters=json.dumps(arguments, ensure_ascii=False), + ) + ) + + # Content = text before first tool call + content_end = text.find(self.tool_call_start_token) + normal_text = text[:content_end] if content_end > 0 else "" + + return StreamingParseResult(normal_text=normal_text, calls=calls) + + except Exception as e: + logger.error(f"Error in detect_and_parse: {e}") + return StreamingParseResult(normal_text=text) + + def parse_streaming_increment( + self, new_text: str, tools: List[Tool] + ) -> StreamingParseResult: + self._buffer += new_text + + if not self._buffer: + return StreamingParseResult() + + calls = [] + normal_text_chunks = [] + + while True: + current_slice = self._buffer[self.parsed_pos :] + if not current_slice: + break + + if not self.is_inside_tool_call: + # Step 4: Outside tool call block + next_start = current_slice.find(self.tool_call_start_token) + if next_start == -1: + # Check for partial match at the end + partial_len = self._ends_with_partial_token(current_slice, self.tool_call_start_token) + if partial_len > 0: + text_to_append = current_slice[:-partial_len] + if text_to_append: + normal_text_chunks.append(text_to_append) + self.parsed_pos += len(text_to_append) + break + else: + normal_text_chunks.append(current_slice) + self.parsed_pos += len(current_slice) + continue + elif next_start == 0: + self.parsed_pos += len(self.tool_call_start_token) + self.is_inside_tool_call = True + continue + else: + normal_text_chunks.append(current_slice[:next_start]) + self.parsed_pos += next_start + continue + else: + # Inside tool call block + + # Check for TOOL_CALL_END first + if current_slice.startswith(self.tool_call_end_token): + self.parsed_pos += len(self.tool_call_end_token) + self.is_inside_tool_call = False + self.current_func_name = None + continue + + if not self.current_func_name: + # Skip leading whitespace + if current_slice[0] in (" ", "\n", "\t"): + self.parsed_pos += 1 + continue + + if current_slice.startswith("call:"): + brace_pos = current_slice.find("{") + if brace_pos != -1: + func_name = current_slice[5:brace_pos] + self.current_tool_id += 1 + self.current_func_name = func_name + self.current_tool_name_sent = True + + tool_indices = self._get_tool_indices(tools) + calls.append( + ToolCallItem( + tool_index=tool_indices.get(func_name, -1), + name=func_name, + parameters="", + ) + ) + self.parsed_pos += brace_pos + 1 + continue + else: + # Incomplete call:name{ + break + else: + # Check for partial matches + if "call:".startswith(current_slice) or self.tool_call_end_token.startswith(current_slice): + break + + # Unexpected content, skip + self.parsed_pos += 1 + continue + else: + # Parsing arguments (looking for balancing }) + depth = 1 + i = 0 + n = len(current_slice) + found = False + while i < n: + if current_slice[i : i + len(STRING_DELIM)] == STRING_DELIM: + i += len(STRING_DELIM) + next_delim = current_slice.find(STRING_DELIM, i) + if next_delim == -1: + i = n # Force wait + break + i = next_delim + len(STRING_DELIM) + continue + + if current_slice[i] == "{": + depth += 1 + elif current_slice[i] == "}": + depth -= 1 + if depth == 0: + args_str = current_slice[:i] + arguments = _parse_gemma4_args(args_str) + + tool_indices = self._get_tool_indices(tools) + calls.append( + ToolCallItem( + tool_index=tool_indices.get( + self.current_func_name, -1 + ), + parameters=json.dumps( + arguments, ensure_ascii=False + ), + ) + ) + self.parsed_pos += i + 1 + self.current_func_name = None # Reset for next call: + found = True + break + i += 1 + + if found: + continue + else: + # Incomplete arguments block + break + + if self.parsed_pos > 0: + self._buffer = self._buffer[self.parsed_pos :] + self.parsed_pos = 0 + + normal_text = "".join(normal_text_chunks) if normal_text_chunks else "" + return StreamingParseResult(calls=calls, normal_text=normal_text) + + def supports_structural_tag(self) -> bool: + return False + + def structure_info(self) -> _GetInfoFunc: + raise NotImplementedError diff --git a/test/registered/function_call/test_function_call_parser.py b/test/registered/function_call/test_function_call_parser.py index 5e8c1928b606..eb15e413f198 100644 --- a/test/registered/function_call/test_function_call_parser.py +++ b/test/registered/function_call/test_function_call_parser.py @@ -6,6 +6,7 @@ from sglang.srt.function_call.core_types import StreamingParseResult from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector from sglang.srt.function_call.deepseekv32_detector import DeepSeekV32Detector +from sglang.srt.function_call.gemma4_detector import Gemma4Detector from sglang.srt.function_call.gigachat3_detector import GigaChat3Detector from sglang.srt.function_call.glm4_moe_detector import Glm4MoeDetector from sglang.srt.function_call.glm47_moe_detector import Glm47MoeDetector @@ -3724,5 +3725,97 @@ def test_streaming_json_split_at_quotes(self): self.assertEqual(params["city"], "Rome") +class TestGemma4Detector(unittest.TestCase): + def setUp(self): + self.tools = [ + Tool( + type="function", + function=Function( + name="get_weather", + description="Get weather information", + parameters={ + "type": "object", + "properties": { + "location": {"type": "string"}, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + }, + }, + "required": ["location"], + }, + ), + ) + ] + self.detector = Gemma4Detector() + + def test_detect_and_parse(self): + text = 'Some text before <|tool_call>call:get_weather{location:<|"|>Tokyo<|"|>}' + result = self.detector.detect_and_parse(text, self.tools) + + self.assertEqual(result.normal_text, "Some text before ") + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].name, "get_weather") + + params = json.loads(result.calls[0].parameters) + self.assertEqual(params["location"], "Tokyo") + + def test_parse_streaming_increment(self): + chunks = [ + "Some text ", + "before <|tool", + "_call>call:get_we", + "ather{location:<|", + '"|>Tokyo<|"|>} after" + ] + + all_results = [] + for chunk in chunks: + res = self.detector.parse_streaming_increment(chunk, self.tools) + all_results.append(res) + + combined_normal_text = "".join(r.normal_text for r in all_results) + self.assertEqual(combined_normal_text, "Some text before after") + + found_name = False + found_params = False + for res in all_results: + for call in res.calls: + if call.name == "get_weather": + found_name = True + if call.parameters: + params = json.loads(call.parameters) + if params == {"location": "Tokyo"}: + found_params = True + + self.assertTrue(found_name) + self.assertTrue(found_params) + + def test_nested_array_streaming(self): + # Additional coverage for complex structure + chunks = [ + "<|tool_call>call:get_weather{location:<|\"", + "|>New York<|\"|>,nested:[1, 2, {inner:<|\"|>", + "val<|\"|>}]}" + ] + + all_results = [] + for chunk in chunks: + res = self.detector.parse_streaming_increment(chunk, self.tools) + all_results.append(res) + + found_params = False + for res in all_results: + for call in res.calls: + if call.parameters: + params = json.loads(call.parameters) + if "location" in params and params["location"] == "New York": + if "nested" in params and params["nested"] == [1, 2, {"inner": "val"}]: + found_params = True + + self.assertTrue(found_params) + + if __name__ == "__main__": unittest.main() From 3289b265b545e18e7c4fd87123dbef6199b3d67b Mon Sep 17 00:00:00 2001 From: Pengyu Chen Date: Mon, 9 Mar 2026 00:09:05 +0000 Subject: [PATCH 007/112] mm init --- python/sglang/srt/configs/model_config.py | 10 + python/sglang/srt/models/gemma4_causal.py | 8 +- python/sglang/srt/models/gemma4_mm.py | 485 ++++++++++++++++++++++ 3 files changed, 498 insertions(+), 5 deletions(-) create mode 100644 python/sglang/srt/models/gemma4_mm.py diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index ab2f61bac530..580273a2ac8c 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -360,6 +360,7 @@ def _derive_hybrid_model(self): "MiMoV2FlashForCausalLM", "MiMoV2MTP", "Gemma4ForCausalLM", + "Gemma4ForConditionalGeneration", ] def _derive_context_length(self, context_length: int): @@ -1409,6 +1410,7 @@ def is_hybrid_swa_model(model_architectures: List[str]): "Step3p5ForCausalLM", "Step3p5MTP", "Gemma4ForCausalLM", + "Gemma4ForConditionalGeneration", } return any(arch in hybrid_swa_archs for arch in model_architectures) @@ -1467,6 +1469,14 @@ def get_hybrid_layer_ids( full_attention_layer_ids = [ i for i, x in enumerate(layer_types) if x == "full_attention" ] + elif "Gemma4ForConditionalGeneration" in model_architectures: + layer_types = getattr(hf_text_config, "layer_types", None) + swa_attention_layer_ids = [ + i for i, x in enumerate(layer_types) if x == "sliding_attention" + ] + full_attention_layer_ids = [ + i for i, x in enumerate(layer_types) if x == "full_attention" + ] else: swa_attention_layer_ids = None full_attention_layer_ids = None diff --git a/python/sglang/srt/models/gemma4_causal.py b/python/sglang/srt/models/gemma4_causal.py index b8cc8bc4d448..5ca8acee4c01 100644 --- a/python/sglang/srt/models/gemma4_causal.py +++ b/python/sglang/srt/models/gemma4_causal.py @@ -315,9 +315,7 @@ def __init__( layer_id=( self.kv_shared_layer_index if self.is_kv_shared_layer else self.layer_id ), - logit_cap=getattr( - config, "attn_logit_softcapping", 0.0 - ), + logit_cap=0.0, sliding_window_size=self.sliding_window, quant_config=quant_config, prefix=add_prefix("attn", prefix), @@ -711,8 +709,8 @@ def forward( if input_ids is not None: input_embeds = self.embed_tokens(input_ids) - per_layer_embeds = self.get_per_layer_inputs(input_ids) - per_layer_inputs = self.project_per_layer_inputs(input_embeds, per_layer_embeds) + per_layer_inputs = self.get_per_layer_inputs(input_ids) + per_layer_inputs = self.project_per_layer_inputs(input_embeds, per_layer_inputs) hidden_states = input_embeds diff --git a/python/sglang/srt/models/gemma4_mm.py b/python/sglang/srt/models/gemma4_mm.py new file mode 100644 index 000000000000..f34f501edadd --- /dev/null +++ b/python/sglang/srt/models/gemma4_mm.py @@ -0,0 +1,485 @@ +# Copyright 2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + + +import logging +import re +from functools import lru_cache +from typing import Iterable, List, Optional, Set, Tuple, TypedDict, Union + +import torch +from torch import nn +from transformers import ( + Gemma4AudioConfig, + Gemma4Config, + Gemma4TextConfig, + Gemma4VisionConfig, + PreTrainedModel, +) +from transformers.models.auto.modeling_auto import AutoModel + +from sglang.srt.layers.layernorm import Gemma4RMSNorm +from sglang.srt.layers.linear import RowParallelLinear +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.managers.mm_utils import ( + MultiModalityDataPaddingPatternMultimodalTokens, + general_mm_embed_routine, +) +from sglang.srt.managers.schedule_batch import ( + Modality, + MultimodalDataItem, + MultimodalInputs, + flatten_nested_list, +) +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import ( + default_weight_loader, + maybe_remap_kv_scale_name, +) +from sglang.srt.models.gemma4_causal import Gemma4TextModel +from sglang.srt.utils import add_prefix +from sglang.srt.utils.hf_transformers_utils import get_processor + +logger = logging.getLogger(__name__) + +cached_get_processor = lru_cache(get_processor) + +class Gemma4ImagePixelInputs(TypedDict): + pixel_values: torch.Tensor + """Shape: `(batch_size * num_images, num_channels, height, width)`""" + +class Gemma4AudioInputs(TypedDict): + input_features_padded: torch.Tensor + """Shape: `(batch_size * num_audio, seq_length, num_features)`""" + input_features_mask: torch.Tensor + """Shape: `(batch_size * num_audio, seq_length)`""" + +class Gemma4MultimodalEmbedder(nn.Module): + """Projects vision/audio soft tokens into LM embedding space.""" + + def __init__( + self, + multimodal_config: Union[Gemma4AudioConfig, Gemma4VisionConfig], + text_config: Gemma4TextConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + + self.eps = multimodal_config.rms_norm_eps + self.text_hidden_size = text_config.hidden_size + + # Audio tower uses output_proj_dims (1536) rather than hidden_size + # (1024); vision uses hidden_size (768) directly. + embedding_dim = ( + getattr(multimodal_config, "output_proj_dims", None) + or multimodal_config.hidden_size + ) + + self.embedding_projection = RowParallelLinear( + embedding_dim, + self.text_hidden_size, + bias=False, + quant_config=quant_config, + prefix=add_prefix("embedding_projection", prefix), + ) + + self.embedding_post_projection_norm = Gemma4RMSNorm( + self.text_hidden_size, + eps=self.eps, + with_scale=False, + ) + + def forward( + self, + inputs_embeds: torch.Tensor, + ) -> torch.Tensor: + """Project soft tokens from a multimodal tower into LM space.""" + embs_proj, _ = self.embedding_projection(inputs_embeds) + return self.embedding_post_projection_norm(embs_proj) + +class Gemma4ForConditionalGeneration(PreTrainedModel): + config_class = Gemma4Config + """Gemma4 multimodal model for conditional generation.""" + + # BitandBytes specific attributes + default_bitsandbytes_target_modules = [ + ".gate_proj.", + ".down_proj.", + ".up_proj.", + ".q_proj.", + ".k_proj.", + ".v_proj.", + ".o_proj.", + ".out_proj.", + ] + bitsandbytes_stacked_params_mapping = { + "q_proj": ("qkv_proj", 0), + "k_proj": ("qkv_proj", 1), + "v_proj": ("qkv_proj", 2), + "gate_proj": ("gate_up_proj", 0), + "up_proj": ("gate_up_proj", 1), + "out_proj": ("proj", 0), + } + + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "o_proj", + "gate_up_proj", + "down_proj", + ] + # Gemma does not apply LoRA to the embedding layer + embedding_modules = {} + embedding_padding_modules = [] + supports_lora = True + + def __init__( + self, + config: Gemma4Config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__(config=config) + self.config = config + self.quant_config = quant_config + + prefix = add_prefix("model", prefix) + + # Vision components + self.vision_tower = AutoModel.from_config(config=config.vision_config) + + self.embed_vision = Gemma4MultimodalEmbedder( + config.vision_config, + config.text_config, + quant_config=quant_config, + prefix=add_prefix("embed_vision", prefix), + ) + + # Audio components + if getattr(config, "audio_config", None) is not None: + self.audio_tower = AutoModel.from_config(config=config.audio_config) + self.audio_tower.post_init() + self.embed_audio = Gemma4MultimodalEmbedder( + config.audio_config, + config.text_config, + quant_config=quant_config, + prefix=add_prefix("embed_audio", prefix), + ) + else: + self.audio_tower = None + self.embed_audio = None + + self.vocab_size = config.text_config.vocab_size + self.vocab_size_per_layer_input = getattr(config.text_config, "vocab_size_per_layer_input", config.text_config.vocab_size) + + # Text model + self.language_model = Gemma4TextModel( + config.text_config, + quant_config, + prefix=add_prefix("language_model", prefix), + ) + + # Create logits processor for the multimodal model + self.logits_processor = LogitsProcessor(config.text_config) + + self.post_init() + + def pad_input_ids( + self, + input_ids: List[int], + mm_inputs: MultimodalInputs, + ) -> List[int]: + """Pad input IDs with image and audio tokens.""" + pattern = MultiModalityDataPaddingPatternMultimodalTokens() + return pattern.pad_input_tokens(input_ids, mm_inputs) + + def get_input_embeddings(self) -> nn.Embedding: + return self.language_model.get_input_embeddings() + + def get_attention_sliding_window_size(self): + return getattr(self.config.text_config, "sliding_window", -1) - 1 + + def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: + all_pixel_values = flatten_nested_list([item.feature for item in items]) + vt = self.vision_tower + + all_embeds = [] + for pv in all_pixel_values: + if pv.dim() == 5: + pv = pv.squeeze(0) + if pv.dim() == 3: + pv = pv.unsqueeze(0) + elif pv.dim() != 4: + raise ValueError(f"Unexpected pixel_values shape: {pv.shape}") + + pv = pv.to(device=vt.device, dtype=self.language_model.dtype()) + + # Step 1: Patchify, pad to max_patches (2520), build positions + patch_positions, padding_positions = vt._patch_positions(pv) + inputs_embeds = vt.patch_embedder( + pv, + patch_positions[:, :vt._num_real_patches(pv)], + padding_positions[:, :vt._num_real_patches(pv)], + ) + num_real = inputs_embeds.shape[1] + num_padding = vt.max_patches - num_real + if num_padding > 0: + pad_embeds = torch.zeros( + inputs_embeds.shape[0], num_padding, inputs_embeds.shape[2], + device=inputs_embeds.device, dtype=inputs_embeds.dtype, + ) + inputs_embeds = torch.cat([inputs_embeds, pad_embeds], dim=1) + + # Step 2: Encode + model_output = vt.encoder( + inputs_embeds=inputs_embeds, + attention_mask=~padding_positions, + patch_positions=patch_positions, + ) + + # Step 3: Pool to default_output_length (280) tokens + pooler_output = vt.pooler( + hidden_states=model_output.last_hidden_state, + patch_positions=patch_positions, + padding_positions=padding_positions, + ) + hidden_states, pooler_mask = pooler_output[0] + + # Step 4: Strip padding per-image and embed + for hs, mask in zip(hidden_states, pooler_mask): + real_tokens = hs[mask] + all_embeds.append( + self.embed_vision(inputs_embeds=real_tokens.unsqueeze(0)).squeeze(0) + ) + + if all_embeds: + return torch.cat(all_embeds, dim=0) + else: + return torch.empty( + 0, self.language_model.config.hidden_size, + device=next(self.parameters()).device, + dtype=self.language_model.dtype() + ) + + def get_audio_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: + if self.audio_tower is None: + raise ValueError("Audio inputs provided but the model does not have an audio tower.") + + all_input_features = flatten_nested_list([item.feature for item in items]) + all_input_features_mask = flatten_nested_list([~item.input_features_mask for item in items]) + + all_embeds = [] + for input_features, input_features_mask in zip(all_input_features, all_input_features_mask): + if input_features.dim() == 2: + input_features = input_features.unsqueeze(0) + if input_features_mask.dim() == 1: + input_features_mask = input_features_mask.unsqueeze(0) + + input_features = input_features.to( + device=next(self.audio_tower.parameters()).device, + dtype=self.language_model.dtype(), + ) + input_features_mask = input_features_mask.to(device=input_features.device) + + # Run audio tower (mask True=padding) + audio_outputs = self.audio_tower(input_features, input_features_mask) + if isinstance(audio_outputs, tuple): + audio_encodings, audio_mask = audio_outputs + else: + audio_encodings = audio_outputs.last_hidden_state + audio_mask = audio_outputs.audio_mel_mask + + audio_features = self.embed_audio(inputs_embeds=audio_encodings) + + # Strip padding + for enc, mask in zip(audio_features, audio_mask): + all_embeds.append(enc[~mask]) + + if all_embeds: + return torch.cat(all_embeds, dim=0) + else: + return torch.empty( + 0, self.language_model.config.hidden_size, + device=next(self.parameters()).device, + dtype=self.language_model.dtype() + ) + + def get_per_layer_inputs( + self, input_ids: torch.LongTensor + ) -> Optional[torch.Tensor]: + return self.language_model.get_per_layer_inputs(input_ids) + + def project_per_layer_inputs( + self, + inputs_embeds: torch.Tensor, + per_layer_inputs: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return self.language_model.project_per_layer_inputs( + inputs_embeds, per_layer_inputs + ) + + @torch.no_grad() + def forward( + self, + input_ids: torch.LongTensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + **kwargs: object, + ) -> LogitsProcessor: + """Forward pass for multimodal Gemma4.""" + if (input_ids is None) ^ (input_embeds is not None): + raise ValueError( + "You must specify exactly one of input_ids or inputs_embeds" + ) + + positions += 1 + if input_ids is not None: + per_layer_inputs = self.get_per_layer_inputs(input_ids) + + # Use general_mm_embed_routine for handling multimodal data + hidden_states = general_mm_embed_routine( + input_ids=input_ids, + forward_batch=forward_batch, + language_model=self.language_model, + data_embedding_funcs={ + Modality.IMAGE: self.get_image_feature, + Modality.AUDIO: self.get_audio_feature, + }, + positions=positions, + per_layer_inputs=per_layer_inputs, + **kwargs, + ) + + # Process hidden states through logits processor + return self.logits_processor( + input_ids, hidden_states, self.language_model.embed_tokens, forward_batch + ) + + def tie_weights(self, recompute_mapping=False): + return self.language_model.tie_weights() + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".up_proj", 1), + (".gate_up_proj", ".gate_proj", 0), + ] + """Load weights for the model.""" + params_dict = dict(self.named_parameters()) + params_dict.update(dict(self.named_buffers())) + loaded_params: Set[str] = set() + + for name, loaded_weight in weights: + # Vestigial weights to ignore + if "embed_vision.embedding." in name or "embed_audio.embedding." in name: + continue + if self.audio_tower is None and ("audio_tower." in name or "embed_audio." in name): + continue + + name = re.sub(r"^model\.", "", name) + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models + if name.endswith(".bias") and name not in params_dict: + continue + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + if "vision_model" in name: + # adapt to VisionAttention + name = name.replace(".self_attn.out_proj", ".self_attn.proj") + # Skip loading extra bias for GPTQ models + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + unloaded_params = params_dict.keys() - loaded_params + if unloaded_params: + logger.warning( + "Some weights are not initialized from checkpoints: %s", unloaded_params + ) + return loaded_params + + lora_pattern = re.compile( + r"^language_model\.layers\.(\d+)\.(?:self_attn|mlp)\.(?:qkv_proj|o_proj|down_proj|gate_up_proj)" + ) + + def should_apply_lora(self, module_name: str) -> bool: + return bool(self.lora_pattern.match(module_name)) + + def get_hidden_dim(self, module_name, layer_idx): + # return input_dim, output_dim + if module_name == "qkv_proj": + return ( + self.config.hidden_size, + self.config.head_dim + * ( + self.config.num_attention_heads + + self.config.num_key_value_heads * 2 + ), + ) + elif module_name == "o_proj": + return ( + self.config.head_dim * self.config.num_attention_heads, + self.config.hidden_size, + ) + elif module_name == "gate_up_proj": + assert len(set(self.config.intermediate_size)) == 1, ( + "Currently SGLang requires uniform intermediate size for all layers. " + "Please file an issue if you need support for non-uniform intermediate sizes." + ) + return self.config.hidden_size, self.config.intermediate_size[0] * 2 + elif module_name == "down_proj": + assert len(set(self.config.intermediate_size)) == 1, ( + "Currently SGLang requires uniform intermediate size for all layers. " + "Please file an issue if you need support for non-uniform intermediate sizes." + ) + return self.config.intermediate_size[0], self.config.hidden_size + else: + raise NotImplementedError() + + +EntryClass = Gemma4ForConditionalGeneration From 67b7b29391cbff64bde48e521a07693a6aa89bcf Mon Sep 17 00:00:00 2001 From: Pengyu Chen Date: Mon, 9 Mar 2026 18:05:02 +0000 Subject: [PATCH 008/112] config conversion global_head_dim <-> swa_head_dim --- python/sglang/srt/utils/hf_transformers_utils.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/python/sglang/srt/utils/hf_transformers_utils.py b/python/sglang/srt/utils/hf_transformers_utils.py index 1b9a709fc7a6..5a584c5b09f4 100644 --- a/python/sglang/srt/utils/hf_transformers_utils.py +++ b/python/sglang/srt/utils/hf_transformers_utils.py @@ -397,6 +397,18 @@ def get_config( if config.model_type == "multi_modality": config.update({"architectures": ["MultiModalityCausalLM"]}) + if config.model_type == "gemma4": + global_head_dim = getattr(config.text_config, "global_head_dim", None) + num_global_key_value_heads = getattr(config.text_config, "num_global_key_value_heads", None) + + if global_head_dim is not None: + config.text_config.swa_head_dim = config.text_config.head_dim + config.text_config.head_dim = global_head_dim + + config.text_config.swa_num_key_value_heads = config.num_key_value_heads + if num_global_key_value_heads is not None: + config.text_config.num_key_value_heads = num_global_key_value_heads + if config.model_type == "longcat_flash": config.update({"architectures": ["LongcatFlashForCausalLM"]}) From 9a87e88a618e78627c2f4d84635fa4d1bea004b3 Mon Sep 17 00:00:00 2001 From: Pengyu Chen Date: Mon, 9 Mar 2026 18:23:06 +0000 Subject: [PATCH 009/112] more mm --- python/sglang/srt/configs/model_config.py | 1 + .../srt/multimodal/processors/gemma4.py | 70 +++++++++++++++++++ python/sglang/srt/parser/conversation.py | 1 + 3 files changed, 72 insertions(+) create mode 100644 python/sglang/srt/multimodal/processors/gemma4.py diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 580273a2ac8c..3c8816b03534 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -1274,6 +1274,7 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal "Ernie4_5_VLMoeForConditionalGeneration", "Gemma3ForConditionalGeneration", "Gemma3nForConditionalGeneration", + "Gemma4ForConditionalGeneration", "Glm4vForConditionalGeneration", "Glm4vMoeForConditionalGeneration", "GlmOcrForConditionalGeneration", diff --git a/python/sglang/srt/multimodal/processors/gemma4.py b/python/sglang/srt/multimodal/processors/gemma4.py new file mode 100644 index 000000000000..4618b1c0559f --- /dev/null +++ b/python/sglang/srt/multimodal/processors/gemma4.py @@ -0,0 +1,70 @@ +# Copyright 2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import Dict, List, Optional, Union + +from sglang.srt.managers.multimodal_processor import ( + BaseMultimodalProcessor as SGLangBaseProcessor, +) +from sglang.srt.models.gemma4_mm import Gemma4ForConditionalGeneration +from sglang.srt.multimodal.processors.base_processor import MultimodalSpecialTokens + + +class Gemma4SGLangProcessor(SGLangBaseProcessor): + """Multimodal processor for Gemma4 supporting image and audio inputs.""" + + models = [Gemma4ForConditionalGeneration] + + def __init__(self, hf_config, server_args, _processor, *args, **kwargs): + super().__init__(hf_config, server_args, _processor, *args, **kwargs) + + self.IM_START_TOKEN_ID = hf_config.boi_token_id + self.IM_END_TOKEN_ID = hf_config.eoi_token_id + + self.AUDIO_START_TOKEN_ID = hf_config.boa_token_id + self.AUDIO_END_TOKEN_ID = hf_config.eoa_token_id + self.mm_tokens = MultimodalSpecialTokens( + image_token="", + image_token_id=hf_config.image_token_id, + audio_token="", + audio_token_id=hf_config.audio_token_id, + ).build(_processor) + + async def process_mm_data_async( + self, + image_data: Optional[List[Union[str, bytes, Dict]]] = None, + audio_data: Optional[List[Union[str, bytes, Dict]]] = None, + input_text: str = "", + request_obj=None, + *args, + **kwargs, + ): + """Process multimodal data including images and audio.""" + base_output = self.load_mm_data( + prompt=input_text, + image_data=image_data, + audio_data=audio_data, + multimodal_tokens=self.mm_tokens, + ) + + mm_items, input_ids, _ = self.process_and_combine_mm_data( + base_output, self.mm_tokens + ) + + return { + "input_ids": input_ids.tolist(), + "mm_items": mm_items, + "im_token_id": self.mm_tokens.image_token_id, + "audio_token_id": self.mm_tokens.audio_token_id, + } diff --git a/python/sglang/srt/parser/conversation.py b/python/sglang/srt/parser/conversation.py index 954cb168ba34..092b1bbd93cd 100644 --- a/python/sglang/srt/parser/conversation.py +++ b/python/sglang/srt/parser/conversation.py @@ -65,6 +65,7 @@ class SeparatorStyle(IntEnum): QWEN2_VL_EMBED = auto() QWEN2_AUDIO = auto() GEMMA3 = auto() + GEMMA4 = auto() MPT = auto() PADDLE_OCR = auto() From 416eccbb75b2e5e37455bbab9e0f9ece05a10bd4 Mon Sep 17 00:00:00 2001 From: Pengyu Chen Date: Tue, 10 Mar 2026 03:36:16 +0000 Subject: [PATCH 010/112] re-add gemma4 rope. (was removed as part of rebase) --- python/sglang/srt/entrypoints/openai/serving_chat.py | 1 - python/sglang/srt/layers/rotary_embedding/factory.py | 10 ++++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/entrypoints/openai/serving_chat.py b/python/sglang/srt/entrypoints/openai/serving_chat.py index 428c50993db9..ee8eae9cd0cf 100644 --- a/python/sglang/srt/entrypoints/openai/serving_chat.py +++ b/python/sglang/srt/entrypoints/openai/serving_chat.py @@ -951,7 +951,6 @@ def _build_chat_response( self.template_manager.force_reasoning or self._get_reasoning_from_request(request) ) - print(f"is_force_reasoning: {self.template_manager.force_reasoning}, self._get_reasoning_from_request(request): {self._get_reasoning_from_request(request)}") try: parser = ReasoningParser( model_type=reasoning_parser, diff --git a/python/sglang/srt/layers/rotary_embedding/factory.py b/python/sglang/srt/layers/rotary_embedding/factory.py index 9e24a2e5ed60..53ca927b3594 100644 --- a/python/sglang/srt/layers/rotary_embedding/factory.py +++ b/python/sglang/srt/layers/rotary_embedding/factory.py @@ -22,6 +22,7 @@ FourierRotaryEmbedding, Llama3RotaryEmbedding, Phi3LongRoPEScaledRotaryEmbedding, + Gemma4RotaryEmbedding, ) from sglang.srt.layers.rotary_embedding.yarn import YaRNScalingRotaryEmbedding from sglang.srt.utils import get_bool_env_var, is_hip @@ -275,6 +276,15 @@ def get_rope( long_factor, **extra_kwargs, ) + elif scaling_type == "proportional": + rotary_emb = Gemma4RotaryEmbedding( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + dtype, + ) else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") _ROPE_DICT[key] = rotary_emb From 1614769cf56b3e8bef51c39f1273fc8dfa6020f7 Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Wed, 18 Mar 2026 12:57:12 -0700 Subject: [PATCH 011/112] lint on main --- .../srt/function_call/gemma4_detector.py | 18 ++- .../srt/layers/attention/triton_backend.py | 11 +- python/sglang/srt/layers/layernorm.py | 3 +- .../srt/layers/rotary_embedding/factory.py | 2 +- .../layers/rotary_embedding/rope_variant.py | 22 ++-- python/sglang/srt/mem_cache/memory_pool.py | 19 ++- .../sglang/srt/mem_cache/swa_memory_pool.py | 2 +- .../sglang/srt/model_executor/model_runner.py | 2 +- python/sglang/srt/models/gemma4_causal.py | 122 +++++++++--------- python/sglang/srt/models/gemma4_mm.py | 69 ++++++---- python/sglang/srt/parser/reasoning_parser.py | 8 +- .../sglang/srt/utils/hf_transformers_utils.py | 8 +- scripts/playground/reference_hf.py | 4 +- .../test_function_call_parser.py | 30 +++-- 14 files changed, 182 insertions(+), 138 deletions(-) diff --git a/python/sglang/srt/function_call/gemma4_detector.py b/python/sglang/srt/function_call/gemma4_detector.py index 6cd521d12e73..abf890e7987e 100644 --- a/python/sglang/srt/function_call/gemma4_detector.py +++ b/python/sglang/srt/function_call/gemma4_detector.py @@ -1,7 +1,7 @@ import json import logging import re -from typing import Any, List, Optional +from typing import List, Optional from sglang.srt.entrypoints.openai.protocol import Tool from sglang.srt.function_call.base_format_detector import BaseFormatDetector @@ -280,7 +280,9 @@ def parse_streaming_increment( next_start = current_slice.find(self.tool_call_start_token) if next_start == -1: # Check for partial match at the end - partial_len = self._ends_with_partial_token(current_slice, self.tool_call_start_token) + partial_len = self._ends_with_partial_token( + current_slice, self.tool_call_start_token + ) if partial_len > 0: text_to_append = current_slice[:-partial_len] if text_to_append: @@ -301,14 +303,14 @@ def parse_streaming_increment( continue else: # Inside tool call block - + # Check for TOOL_CALL_END first if current_slice.startswith(self.tool_call_end_token): self.parsed_pos += len(self.tool_call_end_token) self.is_inside_tool_call = False self.current_func_name = None continue - + if not self.current_func_name: # Skip leading whitespace if current_slice[0] in (" ", "\n", "\t"): @@ -338,9 +340,11 @@ def parse_streaming_increment( break else: # Check for partial matches - if "call:".startswith(current_slice) or self.tool_call_end_token.startswith(current_slice): + if "call:".startswith( + current_slice + ) or self.tool_call_end_token.startswith(current_slice): break - + # Unexpected content, skip self.parsed_pos += 1 continue @@ -355,7 +359,7 @@ def parse_streaming_increment( i += len(STRING_DELIM) next_delim = current_slice.find(STRING_DELIM, i) if next_delim == -1: - i = n # Force wait + i = n # Force wait break i = next_delim + len(STRING_DELIM) continue diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index a905e970f9bc..30614f2867cf 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -823,8 +823,8 @@ def forward_extend( k, v = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) # FIXME: hacky way to make kv cache aligned # why??? - k = k[1:q.shape[0]+1] - v = v[1:q.shape[0]+1] + k = k[1 : q.shape[0] + 1] + v = v[1 : q.shape[0] + 1] # print(layer.layer_id, k.cpu(), v.cpu()) elif k is None or v is None: raise ValueError("Both k and v should be None or not None") @@ -834,8 +834,8 @@ def forward_extend( if ( self.use_mla or layer.k_scale is None ): # Triton MLA currently doesn't support quantized kv cache - forward_batch.token_to_kv_pool.set_kv_buffer( - layer, + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, forward_batch.out_cache_loc, k, v, @@ -848,8 +848,7 @@ def forward_extend( v.clone(), layer.k_scale, layer.v_scale, - ) - + ) logits_soft_cap = logit_capping_mod(layer.logit_capping_method, layer.logit_cap) diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index 2f0016c53437..b7ca8f7b334f 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -596,7 +596,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return normed_output.type_as(x) - class RMSNormWithoutScale(MultiPlatformOp): def __init__(self, hidden_size: int, eps=1e-6): super().__init__() @@ -617,4 +616,4 @@ def forward_cuda(self, x): return self.forward_native(x) def extra_repr(self): - return f"{self.hidden_size}, eps={self.eps}" \ No newline at end of file + return f"{self.hidden_size}, eps={self.eps}" diff --git a/python/sglang/srt/layers/rotary_embedding/factory.py b/python/sglang/srt/layers/rotary_embedding/factory.py index 53ca927b3594..2bbbac6f1882 100644 --- a/python/sglang/srt/layers/rotary_embedding/factory.py +++ b/python/sglang/srt/layers/rotary_embedding/factory.py @@ -20,9 +20,9 @@ DynamicNTKAlphaRotaryEmbedding, DynamicNTKScalingRotaryEmbedding, FourierRotaryEmbedding, + Gemma4RotaryEmbedding, Llama3RotaryEmbedding, Phi3LongRoPEScaledRotaryEmbedding, - Gemma4RotaryEmbedding, ) from sglang.srt.layers.rotary_embedding.yarn import YaRNScalingRotaryEmbedding from sglang.srt.utils import get_bool_env_var, is_hip diff --git a/python/sglang/srt/layers/rotary_embedding/rope_variant.py b/python/sglang/srt/layers/rotary_embedding/rope_variant.py index 9dd539f40137..2fe9d5da280d 100644 --- a/python/sglang/srt/layers/rotary_embedding/rope_variant.py +++ b/python/sglang/srt/layers/rotary_embedding/rope_variant.py @@ -868,13 +868,12 @@ def _compute_cos_sin_cache(self) -> torch.Tensor: return cache - class Gemma4RotaryEmbedding(RotaryEmbedding): """Gemma4-specific RoPE with cross-mixing. - + Instead of rotating the first `rotary_dim` dimensions contiguously, splits the head into two halves and applies rotation across both. - + For a head_dim of D and rotary_dim of R: - Standard RoPE rotates: [0, R) - Gemma4 RoPE rotates: [0, R/2) cross-mixed with [D/2, D/2 + R/2) @@ -906,21 +905,22 @@ def __init__( def _compute_inv_freq(self, base: float) -> torch.Tensor: """Compute frequencies only for the rotated dimensions. - + Non-rotated dims are padded with 0.0 to produce identity rotation. """ freq_exponents = ( - torch.arange(0, 2 * self.rope_angles, 2, dtype=torch.float) - / self.head_size + torch.arange(0, 2 * self.rope_angles, 2, dtype=torch.float) / self.head_size ) - inv_freq = 1.0 / (base ** freq_exponents) + inv_freq = 1.0 / (base**freq_exponents) # Zero-pad for non-rotated dims (identity rotation: cos=1, sin=0) if self.nope_angles > 0: - inv_freq = torch.cat([ - inv_freq, - torch.zeros(self.nope_angles, dtype=torch.float), - ]) + inv_freq = torch.cat( + [ + inv_freq, + torch.zeros(self.nope_angles, dtype=torch.float), + ] + ) return inv_freq def extra_repr(self) -> str: diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 5856da31ea2c..9943e4715b05 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -754,7 +754,20 @@ def __init__( ) self.head_num = swa_head_num if swa_head_num is not None else head_num self.head_dim = swa_head_dim if swa_head_dim is not None else head_dim - print("head_num: ", self.head_num, "head_dim: ", self.head_dim, "swa_head_num: ", swa_head_num, "swa_head_dim: ", swa_head_dim, "head_num: ", head_num, "head_dim: ", head_dim) + print( + "head_num: ", + self.head_num, + "head_dim: ", + self.head_dim, + "swa_head_num: ", + swa_head_num, + "swa_head_dim: ", + swa_head_dim, + "head_num: ", + head_num, + "head_dim: ", + head_dim, + ) self.v_head_dim = ( swa_v_head_dim if swa_v_head_dim is not None @@ -833,7 +846,9 @@ def _create_buffers(self): if self.enable_custom_mem_pool else nullcontext() ): - print(f"Allocating KV cache buffers with size {self.size}, page_size {self.page_size}, head_num {self.head_num}, head_dim {self.head_dim}, v_head_dim {self.v_head_dim}, dtype {self.store_dtype}, device {self.device}") + print( + f"Allocating KV cache buffers with size {self.size}, page_size {self.page_size}, head_num {self.head_num}, head_dim {self.head_dim}, v_head_dim {self.v_head_dim}, dtype {self.store_dtype}, device {self.device}" + ) # [size, head_num, head_dim] for each layer # The padded slot 0 is used for writing dummy outputs from padded tokens. # adjust for global diff --git a/python/sglang/srt/mem_cache/swa_memory_pool.py b/python/sglang/srt/mem_cache/swa_memory_pool.py index a02738f4b9cb..06b0cb01fc97 100644 --- a/python/sglang/srt/mem_cache/swa_memory_pool.py +++ b/python/sglang/srt/mem_cache/swa_memory_pool.py @@ -158,7 +158,7 @@ def set_kv_buffer( layer_id = layer.layer_id layer_id_pool, is_swa_layer = self.layers_mapping[layer_id] - + if is_swa_layer: if self.swa_loc is not None: loc = self.swa_loc diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 20c8766cd53f..e35af50a83fd 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -1037,7 +1037,7 @@ def load_model(self): logger.info( f"Setting sliding_window_size to be attention_chunk_size: {self.sliding_window_size}" ) - + # Note(pyc): gemma4 has different swa def self.dtype = self.model_config.dtype diff --git a/python/sglang/srt/models/gemma4_causal.py b/python/sglang/srt/models/gemma4_causal.py index 5ca8acee4c01..1219e76bb0d9 100644 --- a/python/sglang/srt/models/gemma4_causal.py +++ b/python/sglang/srt/models/gemma4_causal.py @@ -16,7 +16,6 @@ from typing import Iterable, Optional, Set, Tuple import torch -import torch.nn.functional as F from torch import nn from transformers import ( AutoModel, @@ -26,36 +25,34 @@ ) from sglang.srt.distributed import get_tensor_model_parallel_world_size +from sglang.srt.layers.layernorm import Gemma4RMSNorm, GemmaRMSNorm, RMSNorm from sglang.srt.layers.linear import ( - ColumnParallelLinear, QKVParallelLinear, ReplicatedLinear, RowParallelLinear, ) -from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class +from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead +from sglang.srt.layers.rotary_embedding import get_rope +from sglang.srt.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name, ) -from sglang.srt.utils import add_prefix, make_layers -from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class +from sglang.srt.models.gemma3_causal import Gemma3MLP, Gemma3TextScaledWordEmbedding from sglang.srt.server_args import get_global_server_args -from sglang.srt.layers.rotary_embedding import get_rope -from sglang.srt.layers.vocab_parallel_embedding import ( - ParallelLMHead, - VocabParallelEmbedding, -) -from sglang.srt.layers.layernorm import RMSNorm, GemmaRMSNorm, Gemma4RMSNorm -from sglang.srt.models.gemma3_causal import Gemma3TextScaledWordEmbedding, Gemma3MLP - +from sglang.srt.utils import add_prefix, make_layers logger = logging.getLogger(__name__) + # Aligned with HF's implementation, using sliding window inclusive with the last token # SGLang assumes exclusive def get_attention_sliding_window_size(config): @@ -65,13 +62,14 @@ def get_attention_sliding_window_size(config): Gemma4MLP = Gemma3MLP Gemma4TextScaledWordEmbedding = Gemma3TextScaledWordEmbedding + class Gemma4PerLayerEmbedding(nn.Module): """Per-Layer Embedding (PLE) system for Gemma 4. - + Gemma 4 uses a secondary embedding stream that provides layer-specific token embeddings. These are combined with the main hidden states via a gating mechanism in each decoder layer. - + The PLE embedding stores embeddings for all layers packed together: (vocab_size, hidden_size_per_layer_input * num_hidden_layers) """ @@ -91,7 +89,7 @@ def __init__( self.hidden_size_per_layer = hidden_size_per_layer_input self.hidden_size = hidden_size self.num_layers = num_hidden_layers - + # Packed embedding: (vocab_size, hidden_size_per_layer * num_layers) # We store embeddings for ALL layers together total_embed_dim = hidden_size_per_layer_input * num_hidden_layers @@ -101,7 +99,7 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.embed_tokens_per_layer", ) - + # Projection from PLE space to hidden space # (hidden_size_per_layer * num_layers, hidden_size) self.per_layer_model_projection = nn.Linear( @@ -109,26 +107,26 @@ def __init__( hidden_size, bias=False, ) - + # Normalization for PLE output # JAX uses scale_plus_one=False for this norm (x * scale, not x * (1+scale)) self.per_layer_projection_norm = RMSNorm( self.hidden_size_per_layer, eps=rms_norm_eps, ) - + def forward(self, input_ids: torch.Tensor) -> torch.Tensor: """Compute per-layer embeddings and project to hidden size. - + Args: input_ids: Token IDs (batch_size, seq_len) - + Returns: Per-layer input tensor (batch_size, seq_len, hidden_size) """ # Get packed per-layer embeddings per_layer_embeds = self.embed_tokens_per_layer(input_ids) - + # Apply normalization (reshape to apply per-layer, then reshape back) # Original shape: (batch, seq, hidden_size_per_layer * num_layers) batch_size, seq_len, _ = per_layer_embeds.shape @@ -136,10 +134,8 @@ def forward(self, input_ids: torch.Tensor) -> torch.Tensor: batch_size, seq_len, self.num_layers, self.hidden_size_per_layer ) per_layer_embeds = self.per_layer_projection_norm(per_layer_embeds) - per_layer_embeds = per_layer_embeds.view( - batch_size, seq_len, -1 - ) - + per_layer_embeds = per_layer_embeds.view(batch_size, seq_len, -1) + # Project to hidden size per_layer_input = self.per_layer_model_projection(per_layer_embeds) return per_layer_input @@ -202,7 +198,7 @@ def __init__( prefix: str = "", ) -> None: super().__init__() - + self.layer_id = layer_id self.config = config tp_size = get_tensor_model_parallel_world_size() @@ -251,10 +247,7 @@ def __init__( eps=config.rms_norm_eps, ) self.v_norm = Gemma4RMSNorm( - self.head_dim, - eps=config.rms_norm_eps, - scale_shift=0.0, - with_scale=False + self.head_dim, eps=config.rms_norm_eps, scale_shift=0.0, with_scale=False ) # Determine if layer uses sliding window based on pattern @@ -270,7 +263,7 @@ def __init__( # JAX reference uses global_rope_proportion=0.25 for global attention if layer_type == "full_attention": global_prf = getattr(config, "global_partial_rotary_factor", 0.25) - rope_parameters["partial_rotary_factor"] = global_prf + rope_parameters["partial_rotary_factor"] = global_prf else: # Fallback for older config format rope_parameters = dict( @@ -278,7 +271,6 @@ def __init__( rope_theta=getattr(config, "rope_theta", 10000.0), ) - # Check if this is a KV shared layer first_kv_shared_layer_idx = ( config.num_hidden_layers - config.num_kv_shared_layers @@ -294,8 +286,8 @@ def __init__( # Find the last non-shared layer of the same type (sliding/full) prev_layers = config.layer_types[:first_kv_shared_layer_idx] current_layer_type = config.layer_types[self.layer_id] - self.kv_shared_layer_index = len(prev_layers) - 1 - prev_layers[::-1].index( - current_layer_type + self.kv_shared_layer_index = ( + len(prev_layers) - 1 - prev_layers[::-1].index(current_layer_type) ) self.rotary_emb = get_rope( self.head_dim, @@ -347,7 +339,7 @@ def forward( v = v.unflatten(-1, (self.num_kv_heads, self.head_dim)) v = self.v_norm(v) - + # Apply rotary embedding if k is not None: k = k.flatten(-2, -1) @@ -361,12 +353,17 @@ def forward( q, _ = self.rotary_emb(positions, q, dummy_k) q = q.unflatten(-1, (self.num_heads, self.head_dim)) - attn_output = self.attn(q, k, v, forward_batch=forward_batch, - save_kv_cache=not self.is_kv_shared_layer) + attn_output = self.attn( + q, + k, + v, + forward_batch=forward_batch, + save_kv_cache=not self.is_kv_shared_layer, + ) if attn_output.dim() == 3: attn_output = attn_output.flatten(-2, -1) output, _ = self.o_proj(attn_output) - + return output @@ -390,7 +387,7 @@ def __init__( layer_type = config.layer_types[layer_id] self.is_full_attention = layer_type == "full_attention" if self.is_full_attention: - head_dim = config.head_dim # following sglang naming + head_dim = config.head_dim # following sglang naming else: head_dim = getattr(config, "swa_head_dim", config.head_dim) @@ -402,7 +399,7 @@ def __init__( quant_config=quant_config, prefix=add_prefix("self_attn", prefix), ) - + first_kv_shared_layer_idx = config.num_hidden_layers - getattr( config, "num_kv_shared_layers", 0 ) @@ -422,9 +419,7 @@ def __init__( prefix=add_prefix("mlp", prefix), ) - self.input_layernorm = GemmaRMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ) + self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = GemmaRMSNorm( config.hidden_size, eps=config.rms_norm_eps ) @@ -451,7 +446,7 @@ def __init__( self.hidden_size, bias=False, quant_config=quant_config, - prefix=add_prefix("per_layer_projection", prefix) + prefix=add_prefix("per_layer_projection", prefix), ) self.post_per_layer_input_norm = Gemma4RMSNorm( config.hidden_size, eps=config.rms_norm_eps @@ -480,7 +475,7 @@ def forward( # 1. input_norm(x) -> attn -> post_attn_norm -> ADD residual # 2. pre_ff_norm -> mlp -> post_ff_norm -> ADD residual residual = hidden_states - + # Apply input layernorm hidden_states = self.input_layernorm(hidden_states) hidden_states = self.self_attn( @@ -514,9 +509,9 @@ def forward( per_layer_contribution ) hidden_states = hidden_states + per_layer_contribution - + # Apply layer scalar for full-attention layers - if self.is_full_attention and hasattr(self, 'layer_scalar'): + if self.is_full_attention and hasattr(self, "layer_scalar"): hidden_states = hidden_states * self.layer_scalar return hidden_states, None @@ -538,9 +533,9 @@ def __init__( config.vocab_size, config.hidden_size, self.padding_idx, - embed_scale=self.config.hidden_size**0.5, # embeded normalizer + embed_scale=self.config.hidden_size**0.5, # embeded normalizer ) - + # Per-layer input embeddings self.hidden_size = config.hidden_size self.hidden_size_per_layer_input = getattr( @@ -562,7 +557,7 @@ def __init__( # self.embed_scale_per_layer = torch.tensor( # self.hidden_size_per_layer_input**0.5, # ) - + # FIXME: Use replicated for now. Use ColumnParallel?. self.per_layer_model_projection = ReplicatedLinear( self.hidden_size, @@ -599,7 +594,7 @@ def __init__( ) self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - + # self.per_layer_projection_scale = torch.tensor( # config.hidden_size**-0.5, # ) @@ -669,9 +664,7 @@ def project_per_layer_inputs( per_layer_projection, _ = self.per_layer_model_projection(inputs_embeds) # Apply w_scale (HF: Gemma4ScaledLinear with w_scale=hidden_size^{-0.5}) - per_layer_projection = ( - per_layer_projection * self.per_layer_projection_scale - ) + per_layer_projection = per_layer_projection * self.per_layer_projection_scale # Reshape to (num_tokens, num_layers, hidden_size_per_layer_input) per_layer_projection = per_layer_projection.reshape( @@ -681,17 +674,13 @@ def project_per_layer_inputs( ) # Normalize - per_layer_projection = self.per_layer_projection_norm( - per_layer_projection - ) + per_layer_projection = self.per_layer_projection_norm(per_layer_projection) if per_layer_inputs is None: return per_layer_projection # Combine: (projection + per_layer_inputs) * scale - return ( - per_layer_projection + per_layer_inputs - ) * self.per_layer_input_scale + return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale def forward( self, @@ -823,7 +812,12 @@ def forward( **kwargs, ) -> LogitsProcessor: hidden_states = self.model( - input_ids, positions, forward_batch, input_embeds, per_layer_inputs, **kwargs + input_ids, + positions, + forward_batch, + input_embeds, + per_layer_inputs, + **kwargs, ) return self.logits_processor( input_ids, hidden_states, self.lm_head, forward_batch @@ -881,7 +875,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if unloaded_params: logger.warning( "Some weights are not initialized from checkpoints: %s", unloaded_params - ) + ) return loaded_params diff --git a/python/sglang/srt/models/gemma4_mm.py b/python/sglang/srt/models/gemma4_mm.py index f34f501edadd..f6abf04902be 100644 --- a/python/sglang/srt/models/gemma4_mm.py +++ b/python/sglang/srt/models/gemma4_mm.py @@ -56,16 +56,19 @@ cached_get_processor = lru_cache(get_processor) + class Gemma4ImagePixelInputs(TypedDict): pixel_values: torch.Tensor """Shape: `(batch_size * num_images, num_channels, height, width)`""" + class Gemma4AudioInputs(TypedDict): input_features_padded: torch.Tensor """Shape: `(batch_size * num_audio, seq_length, num_features)`""" input_features_mask: torch.Tensor """Shape: `(batch_size * num_audio, seq_length)`""" + class Gemma4MultimodalEmbedder(nn.Module): """Projects vision/audio soft tokens into LM embedding space.""" @@ -110,6 +113,7 @@ def forward( embs_proj, _ = self.embedding_projection(inputs_embeds) return self.embedding_post_projection_norm(embs_proj) + class Gemma4ForConditionalGeneration(PreTrainedModel): config_class = Gemma4Config """Gemma4 multimodal model for conditional generation.""" @@ -195,7 +199,11 @@ def __init__( self.embed_audio = None self.vocab_size = config.text_config.vocab_size - self.vocab_size_per_layer_input = getattr(config.text_config, "vocab_size_per_layer_input", config.text_config.vocab_size) + self.vocab_size_per_layer_input = getattr( + config.text_config, + "vocab_size_per_layer_input", + config.text_config.vocab_size, + ) # Text model self.language_model = Gemma4TextModel( @@ -227,7 +235,7 @@ def get_attention_sliding_window_size(self): def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: all_pixel_values = flatten_nested_list([item.feature for item in items]) vt = self.vision_tower - + all_embeds = [] for pv in all_pixel_values: if pv.dim() == 5: @@ -236,22 +244,25 @@ def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: pv = pv.unsqueeze(0) elif pv.dim() != 4: raise ValueError(f"Unexpected pixel_values shape: {pv.shape}") - + pv = pv.to(device=vt.device, dtype=self.language_model.dtype()) - + # Step 1: Patchify, pad to max_patches (2520), build positions patch_positions, padding_positions = vt._patch_positions(pv) inputs_embeds = vt.patch_embedder( pv, - patch_positions[:, :vt._num_real_patches(pv)], - padding_positions[:, :vt._num_real_patches(pv)], + patch_positions[:, : vt._num_real_patches(pv)], + padding_positions[:, : vt._num_real_patches(pv)], ) num_real = inputs_embeds.shape[1] num_padding = vt.max_patches - num_real if num_padding > 0: pad_embeds = torch.zeros( - inputs_embeds.shape[0], num_padding, inputs_embeds.shape[2], - device=inputs_embeds.device, dtype=inputs_embeds.dtype, + inputs_embeds.shape[0], + num_padding, + inputs_embeds.shape[2], + device=inputs_embeds.device, + dtype=inputs_embeds.dtype, ) inputs_embeds = torch.cat([inputs_embeds, pad_embeds], dim=1) @@ -281,31 +292,38 @@ def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: return torch.cat(all_embeds, dim=0) else: return torch.empty( - 0, self.language_model.config.hidden_size, + 0, + self.language_model.config.hidden_size, device=next(self.parameters()).device, - dtype=self.language_model.dtype() + dtype=self.language_model.dtype(), ) def get_audio_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: if self.audio_tower is None: - raise ValueError("Audio inputs provided but the model does not have an audio tower.") - + raise ValueError( + "Audio inputs provided but the model does not have an audio tower." + ) + all_input_features = flatten_nested_list([item.feature for item in items]) - all_input_features_mask = flatten_nested_list([~item.input_features_mask for item in items]) - + all_input_features_mask = flatten_nested_list( + [~item.input_features_mask for item in items] + ) + all_embeds = [] - for input_features, input_features_mask in zip(all_input_features, all_input_features_mask): + for input_features, input_features_mask in zip( + all_input_features, all_input_features_mask + ): if input_features.dim() == 2: input_features = input_features.unsqueeze(0) if input_features_mask.dim() == 1: input_features_mask = input_features_mask.unsqueeze(0) - + input_features = input_features.to( device=next(self.audio_tower.parameters()).device, dtype=self.language_model.dtype(), ) input_features_mask = input_features_mask.to(device=input_features.device) - + # Run audio tower (mask True=padding) audio_outputs = self.audio_tower(input_features, input_features_mask) if isinstance(audio_outputs, tuple): @@ -313,20 +331,21 @@ def get_audio_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: else: audio_encodings = audio_outputs.last_hidden_state audio_mask = audio_outputs.audio_mel_mask - + audio_features = self.embed_audio(inputs_embeds=audio_encodings) - + # Strip padding for enc, mask in zip(audio_features, audio_mask): all_embeds.append(enc[~mask]) - + if all_embeds: return torch.cat(all_embeds, dim=0) else: return torch.empty( - 0, self.language_model.config.hidden_size, + 0, + self.language_model.config.hidden_size, device=next(self.parameters()).device, - dtype=self.language_model.dtype() + dtype=self.language_model.dtype(), ) def get_per_layer_inputs( @@ -402,7 +421,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Vestigial weights to ignore if "embed_vision.embedding." in name or "embed_audio.embedding." in name: continue - if self.audio_tower is None and ("audio_tower." in name or "embed_audio." in name): + if self.audio_tower is None and ( + "audio_tower." in name or "embed_audio." in name + ): continue name = re.sub(r"^model\.", "", name) @@ -440,7 +461,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if unloaded_params: logger.warning( "Some weights are not initialized from checkpoints: %s", unloaded_params - ) + ) return loaded_params lora_pattern = re.compile( diff --git a/python/sglang/srt/parser/reasoning_parser.py b/python/sglang/srt/parser/reasoning_parser.py index afdc57a6e82a..98f11b69167a 100644 --- a/python/sglang/srt/parser/reasoning_parser.py +++ b/python/sglang/srt/parser/reasoning_parser.py @@ -63,7 +63,9 @@ def detect_and_parse(self, text: str) -> StreamingParseResult: return StreamingParseResult(normal_text=text) # The text is considered to be in a reasoning block. - processed_text = text.replace(self.think_start_token + self.think_start_self_label, "").strip() + processed_text = text.replace( + self.think_start_token + self.think_start_self_label, "" + ).strip() if ( self.think_end_token not in processed_text @@ -121,7 +123,7 @@ def parse_streaming_increment(self, new_text: str) -> StreamingParseResult: for token in tokens_to_check ): return StreamingParseResult() - + think_start_text = self.think_start_token + self.think_start_self_label # Strip `` token if present @@ -444,8 +446,10 @@ def __init__( previous_content=previous_content, ) + class Gemma4Detector(BaseReasoningFormatDetector): """Gemma4 reasoning detector.""" + def __init__( self, stream_reasoning: bool = True, diff --git a/python/sglang/srt/utils/hf_transformers_utils.py b/python/sglang/srt/utils/hf_transformers_utils.py index 5a584c5b09f4..a950f703259f 100644 --- a/python/sglang/srt/utils/hf_transformers_utils.py +++ b/python/sglang/srt/utils/hf_transformers_utils.py @@ -399,12 +399,14 @@ def get_config( if config.model_type == "gemma4": global_head_dim = getattr(config.text_config, "global_head_dim", None) - num_global_key_value_heads = getattr(config.text_config, "num_global_key_value_heads", None) - + num_global_key_value_heads = getattr( + config.text_config, "num_global_key_value_heads", None + ) + if global_head_dim is not None: config.text_config.swa_head_dim = config.text_config.head_dim config.text_config.head_dim = global_head_dim - + config.text_config.swa_num_key_value_heads = config.num_key_value_heads if num_global_key_value_heads is not None: config.text_config.num_key_value_heads = num_global_key_value_heads diff --git a/scripts/playground/reference_hf.py b/scripts/playground/reference_hf.py index 6887f658b165..ab6b31677b18 100644 --- a/scripts/playground/reference_hf.py +++ b/scripts/playground/reference_hf.py @@ -164,7 +164,9 @@ def synthetic_tokens(args): for p in prompts: input_ids = p for i in range(output_len + 1): - output = m.forward(torch.tensor([input_ids], device="cuda"), output_hidden_states=True).logits[0][-1] + output = m.forward( + torch.tensor([input_ids], device="cuda"), output_hidden_states=True + ).logits[0][-1] prefill_logits = output if i == 0: print("prefill logits", prefill_logits) diff --git a/test/registered/function_call/test_function_call_parser.py b/test/registered/function_call/test_function_call_parser.py index eb15e413f198..2c2be7e1ca46 100644 --- a/test/registered/function_call/test_function_call_parser.py +++ b/test/registered/function_call/test_function_call_parser.py @@ -3752,11 +3752,11 @@ def setUp(self): def test_detect_and_parse(self): text = 'Some text before <|tool_call>call:get_weather{location:<|"|>Tokyo<|"|>}' result = self.detector.detect_and_parse(text, self.tools) - + self.assertEqual(result.normal_text, "Some text before ") self.assertEqual(len(result.calls), 1) self.assertEqual(result.calls[0].name, "get_weather") - + params = json.loads(result.calls[0].parameters) self.assertEqual(params["location"], "Tokyo") @@ -3767,17 +3767,17 @@ def test_parse_streaming_increment(self): "_call>call:get_we", "ather{location:<|", '"|>Tokyo<|"|>} after" + "call|> after", ] - + all_results = [] for chunk in chunks: res = self.detector.parse_streaming_increment(chunk, self.tools) all_results.append(res) - + combined_normal_text = "".join(r.normal_text for r in all_results) self.assertEqual(combined_normal_text, "Some text before after") - + found_name = False found_params = False for res in all_results: @@ -3788,18 +3788,18 @@ def test_parse_streaming_increment(self): params = json.loads(call.parameters) if params == {"location": "Tokyo"}: found_params = True - + self.assertTrue(found_name) self.assertTrue(found_params) def test_nested_array_streaming(self): # Additional coverage for complex structure chunks = [ - "<|tool_call>call:get_weather{location:<|\"", - "|>New York<|\"|>,nested:[1, 2, {inner:<|\"|>", - "val<|\"|>}]}" + '<|tool_call>call:get_weather{location:<|"', + '|>New York<|"|>,nested:[1, 2, {inner:<|"|>', + 'val<|"|>}]}', ] - + all_results = [] for chunk in chunks: res = self.detector.parse_streaming_increment(chunk, self.tools) @@ -3811,9 +3811,13 @@ def test_nested_array_streaming(self): if call.parameters: params = json.loads(call.parameters) if "location" in params and params["location"] == "New York": - if "nested" in params and params["nested"] == [1, 2, {"inner": "val"}]: + if "nested" in params and params["nested"] == [ + 1, + 2, + {"inner": "val"}, + ]: found_params = True - + self.assertTrue(found_params) From 2af1b413da17245aa784103bf7cb3fe8327f0b2b Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Mon, 9 Mar 2026 21:35:39 +0000 Subject: [PATCH 012/112] gemma4 mm init and kvcache fix --- python/sglang/srt/configs/model_config.py | 1 + python/sglang/srt/layers/attention/triton_backend.py | 9 +++------ python/sglang/srt/utils/hf_transformers_utils.py | 6 +++++- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 3c8816b03534..e0e8483b820c 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -143,6 +143,7 @@ def __init__( if enable_multimodal is None: mm_disabled_models = [ "Gemma3ForConditionalGeneration", + # "Gemma4ForConditionalGeneration", "Llama4ForConditionalGeneration", "Step3VLForConditionalGeneration", ] diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index 30614f2867cf..359882e72026 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -820,12 +820,9 @@ def forward_extend( o = torch.empty_like(q) if k is None and v is None: - k, v = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) - # FIXME: hacky way to make kv cache aligned - # why??? - k = k[1 : q.shape[0] + 1] - v = v[1 : q.shape[0] + 1] - # print(layer.layer_id, k.cpu(), v.cpu()) + k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) + k = k_buffer[forward_batch.out_cache_loc] + v = v_buffer[forward_batch.out_cache_loc] elif k is None or v is None: raise ValueError("Both k and v should be None or not None") else: diff --git a/python/sglang/srt/utils/hf_transformers_utils.py b/python/sglang/srt/utils/hf_transformers_utils.py index a950f703259f..7627b73cc310 100644 --- a/python/sglang/srt/utils/hf_transformers_utils.py +++ b/python/sglang/srt/utils/hf_transformers_utils.py @@ -639,7 +639,11 @@ def get_processor( kwargs["size"] = {"shortest_edge": 3136, "longest_edge": 1003520} if config.model_type not in {"llava", "clip"}: - kwargs["use_fast"] = use_fast + if config.model_type == "gemma4": + # TODO(kpham-sgl): revert this once we have a fast tokenizer for gemma4 + kwargs["use_fast"] = False + else: + kwargs["use_fast"] = use_fast try: if "InternVL3_5" in tokenizer_name: processor = AutoTokenizer.from_pretrained( From 89bf65c2a6050a135f65c011f447df82ae83f3b0 Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Mon, 9 Mar 2026 22:47:07 +0000 Subject: [PATCH 013/112] add vision tower support, pending some refactor --- python/sglang/srt/models/gemma4_mm.py | 61 +- python/sglang/srt/models/gemma4_vision.py | 593 ++++++++++++++++++ .../srt/multimodal/processors/gemma4.py | 2 - 3 files changed, 610 insertions(+), 46 deletions(-) create mode 100644 python/sglang/srt/models/gemma4_vision.py diff --git a/python/sglang/srt/models/gemma4_mm.py b/python/sglang/srt/models/gemma4_mm.py index f6abf04902be..ceb3ef704ce2 100644 --- a/python/sglang/srt/models/gemma4_mm.py +++ b/python/sglang/srt/models/gemma4_mm.py @@ -29,6 +29,8 @@ ) from transformers.models.auto.modeling_auto import AutoModel +from sglang.srt.models.gemma4_vision import Gemma4VisionEncoder + from sglang.srt.layers.layernorm import Gemma4RMSNorm from sglang.srt.layers.linear import RowParallelLinear from sglang.srt.layers.logits_processor import LogitsProcessor @@ -127,7 +129,6 @@ class Gemma4ForConditionalGeneration(PreTrainedModel): ".k_proj.", ".v_proj.", ".o_proj.", - ".out_proj.", ] bitsandbytes_stacked_params_mapping = { "q_proj": ("qkv_proj", 0), @@ -135,7 +136,6 @@ class Gemma4ForConditionalGeneration(PreTrainedModel): "v_proj": ("qkv_proj", 2), "gate_proj": ("gate_up_proj", 0), "up_proj": ("gate_up_proj", 1), - "out_proj": ("proj", 0), } packed_modules_mapping = { @@ -174,8 +174,11 @@ def __init__( prefix = add_prefix("model", prefix) - # Vision components - self.vision_tower = AutoModel.from_config(config=config.vision_config) + self.vision_tower = Gemma4VisionEncoder( + config=config.vision_config, + quant_config=quant_config, + prefix=add_prefix("vision_tower", prefix), + ) self.embed_vision = Gemma4MultimodalEmbedder( config.vision_config, @@ -247,42 +250,9 @@ def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: pv = pv.to(device=vt.device, dtype=self.language_model.dtype()) - # Step 1: Patchify, pad to max_patches (2520), build positions - patch_positions, padding_positions = vt._patch_positions(pv) - inputs_embeds = vt.patch_embedder( - pv, - patch_positions[:, : vt._num_real_patches(pv)], - padding_positions[:, : vt._num_real_patches(pv)], - ) - num_real = inputs_embeds.shape[1] - num_padding = vt.max_patches - num_real - if num_padding > 0: - pad_embeds = torch.zeros( - inputs_embeds.shape[0], - num_padding, - inputs_embeds.shape[2], - device=inputs_embeds.device, - dtype=inputs_embeds.dtype, - ) - inputs_embeds = torch.cat([inputs_embeds, pad_embeds], dim=1) - - # Step 2: Encode - model_output = vt.encoder( - inputs_embeds=inputs_embeds, - attention_mask=~padding_positions, - patch_positions=patch_positions, - ) + pooled, pooler_mask = vt(pv) - # Step 3: Pool to default_output_length (280) tokens - pooler_output = vt.pooler( - hidden_states=model_output.last_hidden_state, - patch_positions=patch_positions, - padding_positions=padding_positions, - ) - hidden_states, pooler_mask = pooler_output[0] - - # Step 4: Strip padding per-image and embed - for hs, mask in zip(hidden_states, pooler_mask): + for hs, mask in zip(pooled, pooler_mask): real_tokens = hs[mask] all_embeds.append( self.embed_vision(inputs_embeds=real_tokens.unsqueeze(0)).squeeze(0) @@ -378,8 +348,12 @@ def forward( ) positions += 1 + per_layer_inputs = None if input_ids is not None: - per_layer_inputs = self.get_per_layer_inputs(input_ids) + ple_ids = input_ids.clone() + ple_ids[input_ids == self.config.image_token_id] = 0 + ple_ids[input_ids == self.config.audio_token_id] = 0 + per_layer_inputs = self.get_per_layer_inputs(ple_ids) # Use general_mm_embed_routine for handling multimodal data hidden_states = general_mm_embed_routine( @@ -427,23 +401,22 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): continue name = re.sub(r"^model\.", "", name) + orig_name = name for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models if name.endswith(".bias") and name not in params_dict: + name = orig_name continue if name not in params_dict: + name = orig_name continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: - if "vision_model" in name: - # adapt to VisionAttention - name = name.replace(".self_attn.out_proj", ".self_attn.proj") # Skip loading extra bias for GPTQ models if name.endswith(".bias") and name not in params_dict: continue diff --git a/python/sglang/srt/models/gemma4_vision.py b/python/sglang/srt/models/gemma4_vision.py new file mode 100644 index 000000000000..9650cd6a2f7c --- /dev/null +++ b/python/sglang/srt/models/gemma4_vision.py @@ -0,0 +1,593 @@ +# Copyright 2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import annotations + +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from transformers import Gemma4VisionConfig + +from sglang.srt.layers.attention.vision import QKV_BACKEND_IMPL, VisionAttention +from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size +from sglang.srt.layers.layernorm import Gemma4RMSNorm +from sglang.srt.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.utils import add_prefix, get_device_capability, is_cuda + +# --------------------------------------------------------------------------- +# 2-D Multidimensional RoPE (matches HF Gemma4RotaryEmbedding for vision) +# --------------------------------------------------------------------------- + + +def _rotate_half(x: torch.Tensor) -> torch.Tensor: + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def _apply_rotary(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: + return (x * cos) + (_rotate_half(x) * sin) + + +class Gemma4VisionRotaryEmbedding(nn.Module): + """Compute 2-D multidimensional RoPE cos/sin for patch positions.""" + + def __init__(self, config: Gemma4VisionConfig): + super().__init__() + self.head_dim = config.head_dim + rope_params = config.rope_parameters.get("full_attention", {}) + self.rope_theta: float = rope_params.get("rope_theta", 100.0) + + @torch.no_grad() + def forward( + self, x: torch.Tensor, patch_positions: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: [batch, seq, hidden] – only used for device/dtype. + patch_positions: [batch, num_patches, 2] – (x, y) coordinates. + Returns: + (cos, sin) each of shape [batch, num_patches, head_dim]. + """ + ndim = patch_positions.shape[-1] # 2 + head_dim_per_dim = self.head_dim // ndim + + all_embs = [] + for d in range(ndim): + dim_inv_freq = 1.0 / ( + self.rope_theta + ** ( + torch.arange(0, head_dim_per_dim, 2, device=x.device, dtype=torch.float) + / head_dim_per_dim + ) + ) + dim_inv_freq_expanded = dim_inv_freq[None, :, None].expand( + patch_positions.shape[0], -1, 1 + ) + dim_positions = patch_positions[:, :, d].float() + dim_positions_expanded = dim_positions[:, None, :] + + dim_freqs = (dim_inv_freq_expanded @ dim_positions_expanded).transpose(1, 2) + dim_emb = torch.cat((dim_freqs, dim_freqs), dim=-1) + all_embs.append(dim_emb) + + emb = torch.cat(all_embs, dim=-1) + cos = emb.cos().to(dtype=x.dtype) + sin = emb.sin().to(dtype=x.dtype) + return cos, sin + + +def _apply_multidimensional_rope( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, +) -> torch.Tensor: + """Apply 2-D RoPE to x of shape [batch*seq, heads, head_dim]. + + cos/sin have shape [batch, seq, head_dim]. We split along head_dim into + ndim=2 parts and apply standard rotary to each independently. + """ + ndim = 2 + chunk_size = x.shape[-1] // ndim + x_parts = x.split(chunk_size, dim=-1) + cos_parts = cos.split(chunk_size, dim=-1) + sin_parts = sin.split(chunk_size, dim=-1) + y_parts = [_apply_rotary(x_parts[k], cos_parts[k], sin_parts[k]) for k in range(ndim)] + return torch.cat(y_parts, dim=-1) + + +# --------------------------------------------------------------------------- +# Vision Attention (TP-sharded via QKVParallelLinear + RowParallelLinear) +# --------------------------------------------------------------------------- + + +class Gemma4VisionAttention(nn.Module): + """Multi-head attention for the Gemma 4 vision encoder. + + Uses SGLang's QKVParallelLinear and RowParallelLinear for tensor-parallel + sharding, Gemma4RMSNorm for per-head QK/V normalization, and the same + multimodal attention backends as VisionAttention. + """ + + def __init__( + self, + config: Gemma4VisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.num_kv_heads = config.num_key_value_heads + self.head_dim = config.head_dim + + tp_size = get_attention_tp_size() + self.num_heads_per_partition = self.num_heads // tp_size + self.num_kv_heads_per_partition = self.num_kv_heads // tp_size + + self.q_size = self.num_heads_per_partition * self.head_dim + self.kv_size = self.num_kv_heads_per_partition * self.head_dim + + self.qkv_proj = QKVParallelLinear( + hidden_size=self.hidden_size, + head_size=self.head_dim, + total_num_heads=self.num_heads, + total_num_kv_heads=self.num_kv_heads, + bias=config.attention_bias, + quant_config=quant_config, + prefix=add_prefix("qkv_proj", prefix), + ) + self.o_proj = RowParallelLinear( + input_size=self.num_heads * self.head_dim, + output_size=self.hidden_size, + bias=config.attention_bias, + quant_config=quant_config, + prefix=add_prefix("o_proj", prefix), + ) + + self.q_norm = Gemma4RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = Gemma4RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.v_norm = Gemma4RMSNorm( + self.head_dim, eps=config.rms_norm_eps, scale_shift=0.0, with_scale=False + ) + + backend = self._select_backend() + self.qkv_backend = QKV_BACKEND_IMPL[backend]( + head_dim=self.head_dim, + num_heads=self.num_heads_per_partition, + num_kv_heads=self.num_kv_heads_per_partition, + dropout=0.0, + flatten_batch=True, + softmax_in_single_precision=False, + ) + + @staticmethod + def _select_backend() -> str: + """Mirror VisionAttention._determine_attention_backend for consistency.""" + from sglang.srt.server_args import get_global_server_args + + override = get_global_server_args().mm_attention_backend + if override is not None: + return override + if is_cuda(): + major, _ = get_device_capability() + if major == 9: + from sglang.srt.utils import is_blackwell_supported + + if is_blackwell_supported(): + return "triton_attn" + return "fa3" + return "triton_attn" + return "sdpa" + + def forward( + self, + hidden_states: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Args: + hidden_states: [batch, seq, hidden_size] + cos, sin: [batch, seq, head_dim] from Gemma4VisionRotaryEmbedding + attention_mask: [batch, seq] — True = valid, False = padding + """ + bsz, seq_len, _ = hidden_states.shape + + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + q = q.reshape(bsz * seq_len, self.num_heads_per_partition, self.head_dim) + k = k.reshape(bsz * seq_len, self.num_kv_heads_per_partition, self.head_dim) + v = v.reshape(bsz * seq_len, self.num_kv_heads_per_partition, self.head_dim) + + # Per-head QK norm + q = self.q_norm(q.reshape(-1, self.head_dim)).reshape(q.shape) + k = self.k_norm(k.reshape(-1, self.head_dim)).reshape(k.shape) + v = self.v_norm(v.reshape(-1, self.head_dim)).reshape(v.shape) + + # 2-D RoPE: cos/sin are [batch, seq, head_dim]; broadcast to [batch*seq, 1, head_dim] + cos_flat = cos.reshape(bsz * seq_len, 1, self.head_dim) + sin_flat = sin.reshape(bsz * seq_len, 1, self.head_dim) + q = _apply_multidimensional_rope(q, cos_flat, sin_flat) + k = _apply_multidimensional_rope(k, cos_flat, sin_flat) + + # Build 4-D attention mask for backends that expect it + if attention_mask is not None: + attn_mask_4d = ( + attention_mask.unsqueeze(-1) * attention_mask.unsqueeze(1) + ).unsqueeze(1) + else: + attn_mask_4d = None + + output = self.qkv_backend.forward( + q=q, k=k, v=v, + cu_seqlens=None, + bsz=bsz, seq_len=seq_len, + attention_mask=attn_mask_4d, + ) + + output = rearrange(output, "(b s) h d -> b s (h d)", b=bsz) + output, _ = self.o_proj(output) + return output + + +# --------------------------------------------------------------------------- +# Vision MLP (GeGLU, TP-sharded) +# --------------------------------------------------------------------------- + + +class Gemma4VisionMLP(nn.Module): + def __init__( + self, + config: Gemma4VisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_up_proj = MergedColumnParallelLinear( + input_size=self.hidden_size, + output_sizes=[self.intermediate_size, self.intermediate_size], + bias=False, + quant_config=quant_config, + prefix=add_prefix("gate_up_proj", prefix), + ) + self.down_proj = RowParallelLinear( + input_size=self.intermediate_size, + output_size=self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=add_prefix("down_proj", prefix), + ) + from sglang.srt.layers.activation import SiluAndMul + + self.act_fn = SiluAndMul() # GeGLU: GELU variant handled by weight init + + def forward(self, x: torch.Tensor) -> torch.Tensor: + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +# --------------------------------------------------------------------------- +# Encoder Layer +# --------------------------------------------------------------------------- + + +class Gemma4VisionEncoderLayer(nn.Module): + def __init__( + self, + config: Gemma4VisionConfig, + layer_idx: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.self_attn = Gemma4VisionAttention( + config, quant_config=quant_config, + prefix=add_prefix("self_attn", prefix), + ) + self.mlp = Gemma4VisionMLP( + config, quant_config=quant_config, + prefix=add_prefix("mlp", prefix), + ) + eps = config.rms_norm_eps + hs = config.hidden_size + self.input_layernorm = Gemma4RMSNorm(hs, eps=eps) + self.post_attention_layernorm = Gemma4RMSNorm(hs, eps=eps) + self.pre_feedforward_layernorm = Gemma4RMSNorm(hs, eps=eps) + self.post_feedforward_layernorm = Gemma4RMSNorm(hs, eps=eps) + + self.register_buffer("layer_scalar", torch.ones(())) + + def forward( + self, + hidden_states: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn(hidden_states, cos, sin, attention_mask) + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.pre_feedforward_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = residual + hidden_states + + hidden_states = hidden_states * self.layer_scalar + return hidden_states + + +# --------------------------------------------------------------------------- +# Vision Transformer (stack of encoder layers + RoPE) +# --------------------------------------------------------------------------- + + +class Gemma4VisionTransformer(nn.Module): + def __init__( + self, + config: Gemma4VisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + self.rotary_emb = Gemma4VisionRotaryEmbedding(config) + self.layers = nn.ModuleList([ + Gemma4VisionEncoderLayer( + config, layer_idx=i, quant_config=quant_config, + prefix=add_prefix(f"layers.{i}", prefix), + ) + for i in range(config.num_hidden_layers) + ]) + + def forward( + self, + inputs_embeds: torch.Tensor, + attention_mask: torch.Tensor, + patch_positions: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + inputs_embeds: [batch, seq, hidden_size] + attention_mask: [batch, seq] — True = valid token + patch_positions: [batch, seq, 2] + Returns: + last_hidden_state: [batch, seq, hidden_size] + """ + cos, sin = self.rotary_emb(inputs_embeds, patch_positions) + hidden_states = inputs_embeds + for layer in self.layers: + hidden_states = layer(hidden_states, cos, sin, attention_mask) + return hidden_states + + +# --------------------------------------------------------------------------- +# Patch Embedder +# --------------------------------------------------------------------------- + + +class Gemma4VisionPatchEmbedder(nn.Module): + def __init__(self, config: Gemma4VisionConfig): + super().__init__() + self.patch_size = config.patch_size + self.hidden_size = config.hidden_size + self.position_embedding_size = config.position_embedding_size + + self.input_proj = nn.Linear(3 * self.patch_size ** 2, self.hidden_size, bias=False) + self.position_embedding_table = nn.Parameter( + torch.ones(2, self.position_embedding_size, self.hidden_size) + ) + + def _position_embeddings( + self, patch_positions: torch.Tensor, padding_positions: torch.Tensor + ) -> torch.Tensor: + one_hot = F.one_hot(patch_positions, num_classes=self.position_embedding_size) + one_hot = one_hot.permute(0, 2, 1, 3).to(self.position_embedding_table) + position_embeddings = one_hot @ self.position_embedding_table + position_embeddings = position_embeddings.sum(dim=1) + position_embeddings = torch.where(padding_positions.unsqueeze(-1), 0.0, position_embeddings) + return position_embeddings + + def _patchify(self, pixel_values: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, height, width = pixel_values.shape + patch_height = height // self.patch_size + patch_width = width // self.patch_size + patchified_shape = (batch_size, num_channels, patch_height, self.patch_size, patch_width, self.patch_size) + consolidated_shape = (batch_size, patch_height * patch_width, num_channels * self.patch_size ** 2) + patches = pixel_values.reshape(patchified_shape).permute(0, 2, 4, 3, 5, 1).reshape(consolidated_shape) + patches = 2 * (patches - 0.5) + return self.input_proj(patches.to(self.input_proj.weight.dtype)) + + def forward( + self, pixel_values: torch.Tensor, patch_positions: torch.Tensor, padding_positions: torch.Tensor + ) -> torch.Tensor: + hidden_states = self._patchify(pixel_values) + position_embeddings = self._position_embeddings(patch_positions, padding_positions) + return hidden_states + position_embeddings + + +# --------------------------------------------------------------------------- +# Pooler +# --------------------------------------------------------------------------- + + +class Gemma4VisionPooler(nn.Module): + def __init__(self, config: Gemma4VisionConfig): + super().__init__() + self.hidden_size = config.hidden_size + self.default_output_length = config.default_output_length + self.root_hidden_size = self.hidden_size ** 0.5 + + def _avg_pool_by_positions( + self, x: torch.Tensor, patch_positions: torch.Tensor, length: int + ) -> Tuple[torch.Tensor, torch.Tensor]: + input_seq_len = x.shape[1] + k = int((input_seq_len // length) ** 0.5) + k_squared = k ** 2 + if k_squared * length != input_seq_len: + raise ValueError( + f"Cannot pool {x.shape} to {length}: {k=}^2 times {length=} must be {input_seq_len}." + ) + clamped_positions = patch_positions.clamp(min=0) + max_x = clamped_positions[..., 0].max(dim=-1, keepdim=True)[0] + 1 + kernel_idxs = torch.div(clamped_positions, k, rounding_mode="floor") + kernel_idxs = kernel_idxs[..., 0] + (max_x // k) * kernel_idxs[..., 1] + weights = F.one_hot(kernel_idxs.long(), length).float() / k_squared + output = weights.transpose(1, 2).to(x.dtype) @ x + mask = torch.logical_not((weights == 0).all(dim=1)) + return output, mask + + def forward( + self, + hidden_states: torch.Tensor, + patch_positions: torch.Tensor, + padding_positions: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Returns: + (pooled_hidden_states, mask) where mask is True for valid tokens. + """ + length = self.default_output_length + if isinstance(length, (list, tuple)): + length = length[0] + if hidden_states.shape[1] == length: + mask = padding_positions + else: + hidden_states, mask = self._avg_pool_by_positions( + hidden_states, patch_positions, length + ) + hidden_states = hidden_states * self.root_hidden_size + return hidden_states, mask + + +# --------------------------------------------------------------------------- +# Top-level Vision Encoder (patch_embedder → transformer → pooler) +# --------------------------------------------------------------------------- + + +class Gemma4VisionEncoder(nn.Module): + """Drop-in replacement for HF ``Gemma4VisionEncoder`` with TP support.""" + + def __init__( + self, + config: Gemma4VisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + self.patch_size = config.patch_size + self.pooling_kernel_size = config.pooling_kernel_size + self.default_output_length = config.default_output_length + self.max_patches = self.default_output_length * self.pooling_kernel_size ** 2 + + self.patch_embedder = Gemma4VisionPatchEmbedder(config) + self.encoder = Gemma4VisionTransformer( + config, quant_config=quant_config, + prefix=add_prefix("encoder", prefix), + ) + self.pooler = Gemma4VisionPooler(config) + + @property + def device(self) -> torch.device: + return self.patch_embedder.input_proj.weight.device + + def _num_real_patches(self, pixel_values: torch.Tensor) -> int: + _, _, height, width = pixel_values.shape + return (height // self.patch_size) * (width // self.patch_size) + + def _patch_positions( + self, pixel_values: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + batch_size, _, height, width = pixel_values.shape + device = pixel_values.device + patch_height = height // self.patch_size + patch_width = width // self.patch_size + num_patches = patch_height * patch_width + num_padding = self.max_patches - num_patches + + patch_grid = torch.meshgrid( + torch.arange(patch_width, device=device), + torch.arange(patch_height, device=device), + indexing="xy", + ) + stacked_grid = torch.stack(patch_grid, dim=-1) + real_positions = stacked_grid.reshape(num_patches, 2).unsqueeze(0).repeat(batch_size, 1, 1) + + if num_padding > 0: + pad_positions = torch.full( + (batch_size, num_padding, 2), -1, device=device, dtype=torch.long + ) + patch_positions = torch.cat([real_positions, pad_positions], dim=1) + else: + patch_positions = real_positions + + padding_positions = torch.zeros(batch_size, self.max_patches, device=device, dtype=torch.bool) + if num_padding > 0: + padding_positions[:, num_patches:] = True + + return patch_positions.long(), padding_positions + + def forward(self, pixel_values: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Encode pixel_values into soft tokens. + + Args: + pixel_values: [batch, channels, height, width] + + Returns: + (hidden_states, pooler_mask) — hidden_states [batch, output_len, hidden], + pooler_mask [batch, output_len] True = valid. + """ + patch_positions, padding_positions = self._patch_positions(pixel_values) + + inputs_embeds = self.patch_embedder( + pixel_values, + patch_positions[:, : self._num_real_patches(pixel_values)], + padding_positions[:, : self._num_real_patches(pixel_values)], + ) + + num_real = inputs_embeds.shape[1] + num_padding = self.max_patches - num_real + if num_padding > 0: + pad_embeds = torch.zeros( + inputs_embeds.shape[0], num_padding, inputs_embeds.shape[2], + device=inputs_embeds.device, dtype=inputs_embeds.dtype, + ) + inputs_embeds = torch.cat([inputs_embeds, pad_embeds], dim=1) + + last_hidden = self.encoder( + inputs_embeds=inputs_embeds, + attention_mask=~padding_positions, + patch_positions=patch_positions, + ) + + pooled, pooler_mask = self.pooler(last_hidden, patch_positions, padding_positions) + return pooled, pooler_mask diff --git a/python/sglang/srt/multimodal/processors/gemma4.py b/python/sglang/srt/multimodal/processors/gemma4.py index 4618b1c0559f..5c115da37bb6 100644 --- a/python/sglang/srt/multimodal/processors/gemma4.py +++ b/python/sglang/srt/multimodal/processors/gemma4.py @@ -35,9 +35,7 @@ def __init__(self, hf_config, server_args, _processor, *args, **kwargs): self.AUDIO_START_TOKEN_ID = hf_config.boa_token_id self.AUDIO_END_TOKEN_ID = hf_config.eoa_token_id self.mm_tokens = MultimodalSpecialTokens( - image_token="", image_token_id=hf_config.image_token_id, - audio_token="", audio_token_id=hf_config.audio_token_id, ).build(_processor) From c9959ab90db1541e6aaf317ddedf9a9fd75cbf5c Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Tue, 10 Mar 2026 00:06:07 +0000 Subject: [PATCH 014/112] don't partite the embedding projection because vision tower already does AR --- python/sglang/srt/models/gemma4_mm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/models/gemma4_mm.py b/python/sglang/srt/models/gemma4_mm.py index ceb3ef704ce2..c4083eebd35e 100644 --- a/python/sglang/srt/models/gemma4_mm.py +++ b/python/sglang/srt/models/gemma4_mm.py @@ -32,7 +32,7 @@ from sglang.srt.models.gemma4_vision import Gemma4VisionEncoder from sglang.srt.layers.layernorm import Gemma4RMSNorm -from sglang.srt.layers.linear import RowParallelLinear +from sglang.srt.layers.linear import ReplicatedLinear from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.managers.mm_utils import ( @@ -93,7 +93,7 @@ def __init__( or multimodal_config.hidden_size ) - self.embedding_projection = RowParallelLinear( + self.embedding_projection = ReplicatedLinear( embedding_dim, self.text_hidden_size, bias=False, From 18115f99ea97974a4a130210f90ec04536858873 Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Tue, 10 Mar 2026 03:21:58 +0000 Subject: [PATCH 015/112] so many changes to make vision encoder work --- .../attention/triton_ops/prefill_attention.py | 6 +- python/sglang/srt/layers/attention/vision.py | 28 ++- python/sglang/srt/models/gemma4_mm.py | 44 ++++ python/sglang/srt/models/gemma4_vision.py | 209 +++++++++++++----- 4 files changed, 230 insertions(+), 57 deletions(-) diff --git a/python/sglang/srt/layers/attention/triton_ops/prefill_attention.py b/python/sglang/srt/layers/attention/triton_ops/prefill_attention.py index ac0fc72af140..a50b89787f2a 100644 --- a/python/sglang/srt/layers/attention/triton_ops/prefill_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/prefill_attention.py @@ -168,13 +168,14 @@ def _fwd_kernel( def context_attention_fwd( - q, k, v, o, b_start_loc, b_seq_len, max_input_len, is_causal=True + q, k, v, o, b_start_loc, b_seq_len, max_input_len, is_causal=True, sm_scale=None ): """ q, k, v: [b * s, head, head_dim] b_start_loc: [b] b_seq_len: [b] out: [b * s, head, head_dim] + sm_scale: softmax scale, defaults to 1/sqrt(head_dim) """ if (_is_cuda or _is_hip) and CUDA_CAPABILITY[0] > 8: BLOCK = 128 @@ -183,7 +184,8 @@ def context_attention_fwd( Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] - sm_scale = 1.0 / (Lq**0.5) + if sm_scale is None: + sm_scale = 1.0 / (Lq**0.5) batch, head = b_seq_len.shape[0], q.shape[1] kv_group_num = q.shape[1] // k.shape[1] diff --git a/python/sglang/srt/layers/attention/vision.py b/python/sglang/srt/layers/attention/vision.py index 77a8cde46410..84d115e1bb18 100644 --- a/python/sglang/srt/layers/attention/vision.py +++ b/python/sglang/srt/layers/attention/vision.py @@ -165,6 +165,7 @@ def __init__( dropout: float = 0.0, flatten_batch: bool = False, softmax_in_single_precision: bool = False, + softmax_scale: float | None = None, **kwargs, ): super().__init__() @@ -174,7 +175,7 @@ def __init__( self.flatten_batch = flatten_batch self.softmax_in_single_precision = softmax_in_single_precision self.dropout = dropout - self.scale = 1.0 / math.sqrt(self.head_size) + self.scale = softmax_scale if softmax_scale is not None else 1.0 / math.sqrt(self.head_size) @staticmethod @lru_cache(maxsize=128) @@ -296,6 +297,7 @@ def forward( attn_mask=attention_mask, dropout_p=self.dropout, is_causal=False, + scale=self.scale, ) # [b, h, s, head_size] --> [b * s, h, head_size] @@ -332,9 +334,12 @@ def forward( r""" Args: cu_seqlens: [b] + softmax_scale: override softmax scale (default 1/sqrt(head_dim)) Returns: [b * s, h, head_size] """ + softmax_scale = kwargs.get("softmax_scale", None) + if envs.SGLANG_VIT_ENABLE_CUDA_GRAPH.get(): if "output_ws" not in kwargs: raise RuntimeError("output_ws should be prepared for cuda-graph mode") @@ -352,6 +357,7 @@ def forward( cu_seqlens[1], cu_seqlens[2], is_causal=False, + sm_scale=softmax_scale, ) else: cu_seqlens = resolve_seqlens(cu_seqlens, bsz, seq_len, device=q.device) @@ -370,6 +376,7 @@ def forward( seq_lens.cuda(), max_seqlen, is_causal=False, + sm_scale=softmax_scale, ) return output @@ -404,6 +411,8 @@ def forward( Returns: [b * s, h, head_size] """ + softmax_scale = kwargs.get("softmax_scale", None) + if envs.SGLANG_VIT_ENABLE_CUDA_GRAPH.get(): max_seqlen = cu_seqlens[1] output = flash_attn_func( @@ -414,6 +423,7 @@ def forward( cu_seqlens_k=cu_seqlens[0], max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen, + softmax_scale=softmax_scale, ) else: cu_seqlens = resolve_seqlens(cu_seqlens, bsz, seq_len, device=q.device) @@ -429,6 +439,7 @@ def forward( cu_seqlens_k=cu_seqlens, max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen, + softmax_scale=softmax_scale, ) return output @@ -468,6 +479,8 @@ def forward( ) cu_seqlens = cu_seqlens.get_data() + softmax_scale = kwargs.get("softmax_scale", None) + cu_seqlens = cu_seqlens.to(dtype=torch.int32).to(q.device) seq_lens = cu_seqlens[1:] - cu_seqlens[:-1] max_seqlen = seq_lens.max().item() @@ -480,6 +493,7 @@ def forward( cu_seqlens_k=cu_seqlens, max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen, + softmax_scale=softmax_scale, ver=4, ) @@ -635,6 +649,8 @@ def forward( seq_len: int, **kwargs, ) -> torch.Tensor: + softmax_scale = kwargs.get("softmax_scale", None) + cu_seqlens = resolve_seqlens(cu_seqlens, bsz, seq_len, device=q.device) cu_seqlens = cu_seqlens.to(dtype=torch.int32).to(q.device) @@ -649,6 +665,7 @@ def forward( cu_seqlens_k=cu_seqlens, max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen, + softmax_scale=softmax_scale, ) @@ -692,15 +709,18 @@ def forward( output = torch.empty_like(q) seq_len_arg = seq_lens.to(torch.int32) + _, num_heads, head_size = q.shape num_kv_heads = k.shape[1] + scale_value = kwargs.get("softmax_scale") or head_size**-0.5 + torch_npu._npu_flash_attention_unpad( query=q, key=k, value=v, seq_len=seq_len_arg, - scale_value=head_size**-0.5, + scale_value=scale_value, num_heads=num_heads, num_kv_heads=num_kv_heads, out=output, @@ -742,6 +762,7 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, dropout: float = 0.0, softmax_in_single_precision: bool = False, + softmax_scale: Optional[float] = None, flatten_batch: bool = False, prefix: str = "", proj_bias: bool = True, @@ -806,6 +827,7 @@ def __init__( self.customized_position_embedding_applier = ( customized_position_embedding_applier ) + self.softmax_scale = softmax_scale self.qkv_backend = QKV_BACKEND_IMPL[qkv_backend]( head_dim=self.head_size, num_heads=self.num_attention_heads_per_partition, @@ -813,6 +835,7 @@ def __init__( dropout=dropout, flatten_batch=flatten_batch, softmax_in_single_precision=softmax_in_single_precision, + softmax_scale=softmax_scale, use_data_parallel=use_data_parallel, workspace_buffer=workspace_buffer, ) @@ -1108,6 +1131,7 @@ def forward( sequence_lengths=sequence_lengths, max_seqlen=max_seqlen, output_ws=attn_output_ws, + softmax_scale=self.softmax_scale, ) assert output.dim() == 3, output.shape diff --git a/python/sglang/srt/models/gemma4_mm.py b/python/sglang/srt/models/gemma4_mm.py index c4083eebd35e..68b48478e60e 100644 --- a/python/sglang/srt/models/gemma4_mm.py +++ b/python/sglang/srt/models/gemma4_mm.py @@ -347,6 +347,9 @@ def forward( "You must specify exactly one of input_ids or inputs_embeds" ) + # DEBUG: check mm_inputs in forward + has_mm = forward_batch.contains_mm_inputs() if hasattr(forward_batch, 'contains_mm_inputs') else False + is_decode = forward_batch.forward_mode.is_decode() positions += 1 per_layer_inputs = None if input_ids is not None: @@ -401,6 +404,47 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): continue name = re.sub(r"^model\.", "", name) + + # Vision encoder non-fused linears are wrapped in ClippableLinear: + # checkpoint "proj.weight" → our "proj.linear.weight" + if "vision_tower." in name and (name.endswith(".weight") or name.endswith(".bias")): + base, attr = name.rsplit(".", 1) + alt = f"{base}.linear.{attr}" + if alt in params_dict: + name = alt + + # Vision encoder fused projections (ClippableQKV / ClippableGateUp): + # weight/bias: *.q_proj.weight → *.qkv.q_proj.weight (stacked params then fuses) + # output bound: *.q_proj.output_min → *.qkv.q_output_min + # input bound: *.{q,k,v}_proj.input_min → *.qkv.input_min + # (all are identical in the checkpoint -- same hidden_states input -- + # so they collapse to a single shared buffer; last write wins) + if "vision_tower." in name: + m = re.match( + r"(.+\.self_attn)\.(q_proj|k_proj|v_proj)\.(.*)", name + ) + if m: + pfx, proj, attr = m.groups() + if attr in ("weight", "bias"): + name = f"{pfx}.qkv.{proj}.{attr}" + elif attr.startswith("output_"): + name = f"{pfx}.qkv.{proj[0]}_{attr}" + elif attr.startswith("input_"): + name = f"{pfx}.qkv.{attr}" + + m = re.match( + r"(.+\.mlp)\.(gate_proj|up_proj)\.(.*)", name + ) + if m: + pfx, proj, attr = m.groups() + short = proj.split("_")[0] # "gate" or "up" + if attr in ("weight", "bias"): + name = f"{pfx}.gate_up.{proj}.{attr}" + elif attr.startswith("output_"): + name = f"{pfx}.gate_up.{short}_{attr}" + elif attr.startswith("input_"): + name = f"{pfx}.gate_up.{attr}" + orig_name = name for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: diff --git a/python/sglang/srt/models/gemma4_vision.py b/python/sglang/srt/models/gemma4_vision.py index 9650cd6a2f7c..889bfa1add24 100644 --- a/python/sglang/srt/models/gemma4_vision.py +++ b/python/sglang/srt/models/gemma4_vision.py @@ -32,6 +32,139 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.utils import add_prefix, get_device_capability, is_cuda + +# --------------------------------------------------------------------------- +# Activation clamping (matches HF Gemma4ClippableLinear) +# --------------------------------------------------------------------------- + +_INF = float("inf") + + +class ClippableLinear(nn.Module): + """``RowParallelLinear`` with input/output activation clamping. + + Mirrors HF's ``Gemma4ClippableLinear``: owns the linear layer and applies + ``torch.clamp`` before and after the linear forward pass. Clip bounds + default to ±inf (no-op) and are populated from the checkpoint. + """ + + def __init__( + self, + input_size: int, + output_size: int, + *, + bias: bool = True, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.linear = RowParallelLinear( + input_size=input_size, + output_size=output_size, + bias=bias, + quant_config=quant_config, + prefix=add_prefix("linear", prefix), + ) + self.input_min = nn.parameter.Buffer(torch.tensor(-_INF), persistent=False) + self.input_max = nn.parameter.Buffer(torch.tensor(_INF), persistent=False) + self.output_min = nn.parameter.Buffer(torch.tensor(-_INF), persistent=False) + self.output_max = nn.parameter.Buffer(torch.tensor(_INF), persistent=False) + + def forward(self, x: torch.Tensor): + x = torch.clamp(x, self.input_min, self.input_max) + x, _ = self.linear(x) + x = torch.clamp(x, self.output_min, self.output_max) + return x + + +class ClippableQKVParallelLinear(nn.Module): + """Fused QKV projection with per-projection activation clamping. + + Owns a single ``QKVParallelLinear`` for the fused matmul. Clip bounds + are stored as flat buffers: shared ``input_min/max`` (applied before the + matmul) and per-projection ``q/k/v_output_min/max`` (applied after split). + """ + + def __init__( + self, + config: Gemma4VisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + tp_size = get_attention_tp_size() + self.q_size = (config.num_attention_heads // tp_size) * config.head_dim + self.kv_size = (config.num_key_value_heads // tp_size) * config.head_dim + + self.qkv_proj = QKVParallelLinear( + hidden_size=config.hidden_size, + head_size=config.head_dim, + total_num_heads=config.num_attention_heads, + total_num_kv_heads=config.num_key_value_heads, + bias=config.attention_bias, + quant_config=quant_config, + prefix=add_prefix("qkv_proj", prefix), + ) + self.input_min = nn.parameter.Buffer(torch.tensor(-_INF), persistent=False) + self.input_max = nn.parameter.Buffer(torch.tensor(_INF), persistent=False) + self.q_output_min = nn.parameter.Buffer(torch.tensor(-_INF), persistent=False) + self.q_output_max = nn.parameter.Buffer(torch.tensor(_INF), persistent=False) + self.k_output_min = nn.parameter.Buffer(torch.tensor(-_INF), persistent=False) + self.k_output_max = nn.parameter.Buffer(torch.tensor(_INF), persistent=False) + self.v_output_min = nn.parameter.Buffer(torch.tensor(-_INF), persistent=False) + self.v_output_max = nn.parameter.Buffer(torch.tensor(_INF), persistent=False) + + def forward(self, hidden_states: torch.Tensor): + x = torch.clamp(hidden_states, self.input_min, self.input_max) + qkv, _ = self.qkv_proj(x) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q = torch.clamp(q, self.q_output_min, self.q_output_max) + k = torch.clamp(k, self.k_output_min, self.k_output_max) + v = torch.clamp(v, self.v_output_min, self.v_output_max) + return q, k, v + + +class ClippableGateUpParallelLinear(nn.Module): + """Fused gate/up projection with per-projection activation clamping. + + Same pattern as ``ClippableQKVParallelLinear``: owns a single + ``MergedColumnParallelLinear`` for the fused matmul, with shared input + bounds and per-projection output bounds as flat buffers. + """ + + def __init__( + self, + config: Gemma4VisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + tp_size = get_attention_tp_size() + self.proj_size = config.intermediate_size // tp_size + + self.gate_up_proj = MergedColumnParallelLinear( + input_size=config.hidden_size, + output_sizes=[config.intermediate_size, config.intermediate_size], + bias=False, + quant_config=quant_config, + prefix=add_prefix("gate_up_proj", prefix), + ) + self.input_min = nn.parameter.Buffer(torch.tensor(-_INF), persistent=False) + self.input_max = nn.parameter.Buffer(torch.tensor(_INF), persistent=False) + self.gate_output_min = nn.parameter.Buffer(torch.tensor(-_INF), persistent=False) + self.gate_output_max = nn.parameter.Buffer(torch.tensor(_INF), persistent=False) + self.up_output_min = nn.parameter.Buffer(torch.tensor(-_INF), persistent=False) + self.up_output_max = nn.parameter.Buffer(torch.tensor(_INF), persistent=False) + + def forward(self, x: torch.Tensor): + x = torch.clamp(x, self.input_min, self.input_max) + gate_up, _ = self.gate_up_proj(x) + gate, up = gate_up.split([self.proj_size, self.proj_size], dim=-1) + gate = torch.clamp(gate, self.gate_output_min, self.gate_output_max) + up = torch.clamp(up, self.up_output_min, self.up_output_max) + return gate, up + + # --------------------------------------------------------------------------- # 2-D Multidimensional RoPE (matches HF Gemma4RotaryEmbedding for vision) # --------------------------------------------------------------------------- @@ -115,16 +248,15 @@ def _apply_multidimensional_rope( # --------------------------------------------------------------------------- -# Vision Attention (TP-sharded via QKVParallelLinear + RowParallelLinear) +# Vision Attention (TP-sharded, fused QKV) # --------------------------------------------------------------------------- class Gemma4VisionAttention(nn.Module): """Multi-head attention for the Gemma 4 vision encoder. - Uses SGLang's QKVParallelLinear and RowParallelLinear for tensor-parallel - sharding, Gemma4RMSNorm for per-head QK/V normalization, and the same - multimodal attention backends as VisionAttention. + QKV uses a fused ``ClippableQKVParallelLinear`` for efficient matmul with + per-projection clip bounds. Output projection uses ``ClippableLinear``. """ def __init__( @@ -134,30 +266,18 @@ def __init__( prefix: str = "", ): super().__init__() - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.num_kv_heads = config.num_key_value_heads self.head_dim = config.head_dim tp_size = get_attention_tp_size() - self.num_heads_per_partition = self.num_heads // tp_size - self.num_kv_heads_per_partition = self.num_kv_heads // tp_size + self.num_heads_per_partition = config.num_attention_heads // tp_size + self.num_kv_heads_per_partition = config.num_key_value_heads // tp_size - self.q_size = self.num_heads_per_partition * self.head_dim - self.kv_size = self.num_kv_heads_per_partition * self.head_dim - - self.qkv_proj = QKVParallelLinear( - hidden_size=self.hidden_size, - head_size=self.head_dim, - total_num_heads=self.num_heads, - total_num_kv_heads=self.num_kv_heads, - bias=config.attention_bias, - quant_config=quant_config, - prefix=add_prefix("qkv_proj", prefix), + self.qkv = ClippableQKVParallelLinear( + config, quant_config=quant_config, prefix=prefix, ) - self.o_proj = RowParallelLinear( - input_size=self.num_heads * self.head_dim, - output_size=self.hidden_size, + self.o_proj = ClippableLinear( + input_size=config.num_attention_heads * config.head_dim, + output_size=config.hidden_size, bias=config.attention_bias, quant_config=quant_config, prefix=add_prefix("o_proj", prefix), @@ -171,12 +291,13 @@ def __init__( backend = self._select_backend() self.qkv_backend = QKV_BACKEND_IMPL[backend]( - head_dim=self.head_dim, + head_dim=config.head_dim, num_heads=self.num_heads_per_partition, num_kv_heads=self.num_kv_heads_per_partition, dropout=0.0, flatten_batch=True, softmax_in_single_precision=False, + softmax_scale=1.0, ) @staticmethod @@ -205,33 +326,23 @@ def forward( sin: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: - """ - Args: - hidden_states: [batch, seq, hidden_size] - cos, sin: [batch, seq, head_dim] from Gemma4VisionRotaryEmbedding - attention_mask: [batch, seq] — True = valid, False = padding - """ bsz, seq_len, _ = hidden_states.shape - qkv, _ = self.qkv_proj(hidden_states) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k, v = self.qkv(hidden_states) q = q.reshape(bsz * seq_len, self.num_heads_per_partition, self.head_dim) k = k.reshape(bsz * seq_len, self.num_kv_heads_per_partition, self.head_dim) v = v.reshape(bsz * seq_len, self.num_kv_heads_per_partition, self.head_dim) - # Per-head QK norm q = self.q_norm(q.reshape(-1, self.head_dim)).reshape(q.shape) k = self.k_norm(k.reshape(-1, self.head_dim)).reshape(k.shape) v = self.v_norm(v.reshape(-1, self.head_dim)).reshape(v.shape) - # 2-D RoPE: cos/sin are [batch, seq, head_dim]; broadcast to [batch*seq, 1, head_dim] cos_flat = cos.reshape(bsz * seq_len, 1, self.head_dim) sin_flat = sin.reshape(bsz * seq_len, 1, self.head_dim) q = _apply_multidimensional_rope(q, cos_flat, sin_flat) k = _apply_multidimensional_rope(k, cos_flat, sin_flat) - # Build 4-D attention mask for backends that expect it if attention_mask is not None: attn_mask_4d = ( attention_mask.unsqueeze(-1) * attention_mask.unsqueeze(1) @@ -244,10 +355,11 @@ def forward( cu_seqlens=None, bsz=bsz, seq_len=seq_len, attention_mask=attn_mask_4d, + softmax_scale=1.0, ) output = rearrange(output, "(b s) h d -> b s (h d)", b=bsz) - output, _ = self.o_proj(output) + output = self.o_proj(output) return output @@ -264,30 +376,21 @@ def __init__( prefix: str = "", ): super().__init__() - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_up_proj = MergedColumnParallelLinear( - input_size=self.hidden_size, - output_sizes=[self.intermediate_size, self.intermediate_size], - bias=False, - quant_config=quant_config, - prefix=add_prefix("gate_up_proj", prefix), + self.gate_up = ClippableGateUpParallelLinear( + config, quant_config=quant_config, prefix=prefix, ) - self.down_proj = RowParallelLinear( - input_size=self.intermediate_size, - output_size=self.hidden_size, + self.down_proj = ClippableLinear( + input_size=config.intermediate_size, + output_size=config.hidden_size, bias=False, quant_config=quant_config, prefix=add_prefix("down_proj", prefix), ) - from sglang.srt.layers.activation import SiluAndMul - - self.act_fn = SiluAndMul() # GeGLU: GELU variant handled by weight init def forward(self, x: torch.Tensor) -> torch.Tensor: - gate_up, _ = self.gate_up_proj(x) - x = self.act_fn(gate_up) - x, _ = self.down_proj(x) + gate, up = self.gate_up(x) + x = F.silu(gate) * up + x = self.down_proj(x) return x From 898b52de76b86e40e277e54474010520d342240b Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Tue, 10 Mar 2026 18:27:58 +0000 Subject: [PATCH 016/112] clean up --- python/sglang/srt/configs/model_config.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index e0e8483b820c..3c8816b03534 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -143,7 +143,6 @@ def __init__( if enable_multimodal is None: mm_disabled_models = [ "Gemma3ForConditionalGeneration", - # "Gemma4ForConditionalGeneration", "Llama4ForConditionalGeneration", "Step3VLForConditionalGeneration", ] From d6652f7fee62c9ac66405f8f0481aee1a8876831 Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Tue, 10 Mar 2026 18:34:58 +0000 Subject: [PATCH 017/112] add more comments --- python/sglang/srt/models/gemma4_mm.py | 8 ++++++++ python/sglang/srt/models/gemma4_vision.py | 9 +++++++++ 2 files changed, 17 insertions(+) diff --git a/python/sglang/srt/models/gemma4_mm.py b/python/sglang/srt/models/gemma4_mm.py index 68b48478e60e..fb21cc9b6c2d 100644 --- a/python/sglang/srt/models/gemma4_mm.py +++ b/python/sglang/srt/models/gemma4_mm.py @@ -414,11 +414,19 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): name = alt # Vision encoder fused projections (ClippableQKV / ClippableGateUp): + # + # QKV (self_attn.{q,k,v}_proj → self_attn.qkv): # weight/bias: *.q_proj.weight → *.qkv.q_proj.weight (stacked params then fuses) # output bound: *.q_proj.output_min → *.qkv.q_output_min # input bound: *.{q,k,v}_proj.input_min → *.qkv.input_min # (all are identical in the checkpoint -- same hidden_states input -- # so they collapse to a single shared buffer; last write wins) + # + # GateUp (mlp.{gate,up}_proj → mlp.gate_up): + # weight/bias: *.gate_proj.weight → *.gate_up.gate_proj.weight + # output bound: *.gate_proj.output_min → *.gate_up.gate_output_min + # input bound: *.{gate,up}_proj.input_min → *.gate_up.input_min + # (same collapse as QKV -- both see the same MLP input) if "vision_tower." in name: m = re.match( r"(.+\.self_attn)\.(q_proj|k_proj|v_proj)\.(.*)", name diff --git a/python/sglang/srt/models/gemma4_vision.py b/python/sglang/srt/models/gemma4_vision.py index 889bfa1add24..d4c9afde4924 100644 --- a/python/sglang/srt/models/gemma4_vision.py +++ b/python/sglang/srt/models/gemma4_vision.py @@ -83,6 +83,11 @@ class ClippableQKVParallelLinear(nn.Module): Owns a single ``QKVParallelLinear`` for the fused matmul. Clip bounds are stored as flat buffers: shared ``input_min/max`` (applied before the matmul) and per-projection ``q/k/v_output_min/max`` (applied after split). + + The checkpoint stores separate ``input_min/max`` for each of q, k, v but + they are identical (all three projections receive the same hidden_states), + so we collapse them into a single shared buffer (last write wins during + weight loading). """ def __init__( @@ -130,6 +135,10 @@ class ClippableGateUpParallelLinear(nn.Module): Same pattern as ``ClippableQKVParallelLinear``: owns a single ``MergedColumnParallelLinear`` for the fused matmul, with shared input bounds and per-projection output bounds as flat buffers. + + The checkpoint stores separate ``input_min/max`` for gate and up but they + are identical (both projections receive the same MLP input), so we collapse + them into a single shared buffer (last write wins during weight loading). """ def __init__( From 8b4c06f271f9f8684f21f6bfc73330d3dc4cd3ad Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Tue, 10 Mar 2026 20:37:02 +0000 Subject: [PATCH 018/112] init audio support --- python/sglang/srt/layers/clippable_linear.py | 207 +++++ python/sglang/srt/models/gemma4_audio.py | 860 +++++++++++++++++++ python/sglang/srt/models/gemma4_mm.py | 179 ++-- python/sglang/srt/models/gemma4_vision.py | 169 +--- 4 files changed, 1180 insertions(+), 235 deletions(-) create mode 100644 python/sglang/srt/layers/clippable_linear.py create mode 100644 python/sglang/srt/models/gemma4_audio.py diff --git a/python/sglang/srt/layers/clippable_linear.py b/python/sglang/srt/layers/clippable_linear.py new file mode 100644 index 000000000000..b34895b9c910 --- /dev/null +++ b/python/sglang/srt/layers/clippable_linear.py @@ -0,0 +1,207 @@ +# Copyright 2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TP-sharded linear wrappers with per-tensor activation clamping. + +Used by the Gemma 4 vision and audio encoders. Each wrapper owns a parallel +linear and four scalar clip buffers (``input_min/max``, ``output_min/max``) +that default to ±inf (no-op) and are populated from the checkpoint. + +For fused projections (QKV, GateUp), input bounds are shared (the checkpoint +stores identical copies per projection — last write wins during loading) and +output bounds are per-projection. +""" + +from typing import Optional, Tuple + +import torch +import torch.nn as nn + +from sglang.srt.layers.dp_attention import get_attention_tp_size +from sglang.srt.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.utils import add_prefix + + +_INF = float("inf") + + +class ClippableRowParallelLinear(nn.Module): + """``RowParallelLinear`` with input/output activation clamping. + + Checkpoint weight at ``.weight`` is remapped to ``.linear.weight`` + by the model's ``load_weights``. + """ + + def __init__( + self, + input_size: int, + output_size: int, + *, + bias: bool = True, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.linear = RowParallelLinear( + input_size=input_size, + output_size=output_size, + bias=bias, + quant_config=quant_config, + prefix=add_prefix("linear", prefix), + ) + self.input_min = nn.parameter.Buffer(torch.tensor(-_INF), persistent=False) + self.input_max = nn.parameter.Buffer(torch.tensor(_INF), persistent=False) + self.output_min = nn.parameter.Buffer(torch.tensor(-_INF), persistent=False) + self.output_max = nn.parameter.Buffer(torch.tensor(_INF), persistent=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = torch.clamp(x, self.input_min, self.input_max) + x, _ = self.linear(x) + x = torch.clamp(x, self.output_min, self.output_max) + return x + + +class ClippableColumnParallelLinear(nn.Module): + """``ColumnParallelLinear`` with input/output activation clamping.""" + + def __init__( + self, + input_size: int, + output_size: int, + *, + bias: bool = False, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.linear = ColumnParallelLinear( + input_size=input_size, + output_size=output_size, + bias=bias, + quant_config=quant_config, + prefix=add_prefix("linear", prefix), + ) + self.input_min = nn.parameter.Buffer(torch.tensor(-_INF), persistent=False) + self.input_max = nn.parameter.Buffer(torch.tensor(_INF), persistent=False) + self.output_min = nn.parameter.Buffer(torch.tensor(-_INF), persistent=False) + self.output_max = nn.parameter.Buffer(torch.tensor(_INF), persistent=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = torch.clamp(x, self.input_min, self.input_max) + x, _ = self.linear(x) + x = torch.clamp(x, self.output_min, self.output_max) + return x + + +class ClippableQKVParallelLinear(nn.Module): + """Fused QKV projection with per-projection activation clamping. + + Owns a single ``QKVParallelLinear`` for the fused matmul. Clip bounds + are stored as flat buffers: shared ``input_min/max`` (applied before the + matmul) and per-projection ``q/k/v_output_min/max`` (applied after split). + """ + + def __init__( + self, + hidden_size: int, + head_size: int, + total_num_heads: int, + total_num_kv_heads: int, + *, + bias: bool = False, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + tp_size = get_attention_tp_size() + self.q_size = (total_num_heads // tp_size) * head_size + self.kv_size = (total_num_kv_heads // tp_size) * head_size + + self.qkv_proj = QKVParallelLinear( + hidden_size=hidden_size, + head_size=head_size, + total_num_heads=total_num_heads, + total_num_kv_heads=total_num_kv_heads, + bias=bias, + quant_config=quant_config, + prefix=add_prefix("qkv_proj", prefix), + ) + self.input_min = nn.parameter.Buffer(torch.tensor(-_INF), persistent=False) + self.input_max = nn.parameter.Buffer(torch.tensor(_INF), persistent=False) + self.q_output_min = nn.parameter.Buffer(torch.tensor(-_INF), persistent=False) + self.q_output_max = nn.parameter.Buffer(torch.tensor(_INF), persistent=False) + self.k_output_min = nn.parameter.Buffer(torch.tensor(-_INF), persistent=False) + self.k_output_max = nn.parameter.Buffer(torch.tensor(_INF), persistent=False) + self.v_output_min = nn.parameter.Buffer(torch.tensor(-_INF), persistent=False) + self.v_output_max = nn.parameter.Buffer(torch.tensor(_INF), persistent=False) + + def forward( + self, hidden_states: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + x = torch.clamp(hidden_states, self.input_min, self.input_max) + qkv, _ = self.qkv_proj(x) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q = torch.clamp(q, self.q_output_min, self.q_output_max) + k = torch.clamp(k, self.k_output_min, self.k_output_max) + v = torch.clamp(v, self.v_output_min, self.v_output_max) + return q, k, v + + +class ClippableGateUpParallelLinear(nn.Module): + """Fused gate/up projection with per-projection activation clamping. + + Same pattern as ``ClippableQKVParallelLinear``: owns a single + ``MergedColumnParallelLinear`` for the fused matmul, with shared input + bounds and per-projection output bounds as flat buffers. + """ + + def __init__( + self, + input_size: int, + intermediate_size: int, + *, + bias: bool = False, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + tp_size = get_attention_tp_size() + self.proj_size = intermediate_size // tp_size + + self.gate_up_proj = MergedColumnParallelLinear( + input_size=input_size, + output_sizes=[intermediate_size, intermediate_size], + bias=bias, + quant_config=quant_config, + prefix=add_prefix("gate_up_proj", prefix), + ) + self.input_min = nn.parameter.Buffer(torch.tensor(-_INF), persistent=False) + self.input_max = nn.parameter.Buffer(torch.tensor(_INF), persistent=False) + self.gate_output_min = nn.parameter.Buffer(torch.tensor(-_INF), persistent=False) + self.gate_output_max = nn.parameter.Buffer(torch.tensor(_INF), persistent=False) + self.up_output_min = nn.parameter.Buffer(torch.tensor(-_INF), persistent=False) + self.up_output_max = nn.parameter.Buffer(torch.tensor(_INF), persistent=False) + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + x = torch.clamp(x, self.input_min, self.input_max) + gate_up, _ = self.gate_up_proj(x) + gate, up = gate_up.split([self.proj_size, self.proj_size], dim=-1) + gate = torch.clamp(gate, self.gate_output_min, self.gate_output_max) + up = torch.clamp(up, self.up_output_min, self.up_output_max) + return gate, up diff --git a/python/sglang/srt/models/gemma4_audio.py b/python/sglang/srt/models/gemma4_audio.py new file mode 100644 index 000000000000..89d15e1bb396 --- /dev/null +++ b/python/sglang/srt/models/gemma4_audio.py @@ -0,0 +1,860 @@ +# Copyright 2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""SGLang-native TP-sharded audio encoder for Gemma 4. + +Architecture: Conformer-based USM (Universal Speech Model) with SSCP convolution +projection. Adapted from gemma3n_audio.py with Gemma 4 specific changes: + - Activation clamping (clippable linears) on all conformer linears + - per_dim_key_scale in attention + - LayerNorm (not CumulativeGroupNorm) in SSCP convolution blocks + - Semicausal SSCP padding + - Mask propagation through SSCP + - Output projection (hidden_size -> output_proj_dims) +""" + +import math +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import Gemma4AudioConfig + +from sglang.srt.layers.clippable_linear import ( + ClippableColumnParallelLinear, + ClippableQKVParallelLinear, + ClippableRowParallelLinear, +) +from sglang.srt.layers.layernorm import Gemma4RMSNorm +from sglang.srt.layers.linear import ( + ColumnParallelLinear, + RowParallelLinear, +) +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.utils import add_prefix, make_layers + + +# --------------------------------------------------------------------------- +# Relative Position Embedding +# --------------------------------------------------------------------------- + + +class Gemma4AudioRelativePositionEmbedding(nn.Module): + def __init__( + self, + config: Gemma4AudioConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + + self.num_heads = config.conf_num_attention_heads + self.channels = config.hidden_size + self.head_dim = self.channels // self.num_heads + self.max_backward = max(0, config.conf_attention_context_left - 1) + self.max_forward = config.conf_attention_context_right + + self.pos_proj = ColumnParallelLinear( + self.channels, + self.num_heads * self.head_dim, + bias=False, + quant_config=quant_config, + prefix=add_prefix("pos_proj", prefix), + ) + + min_timescale = 1.0 + max_timescale = 1.0e4 + num_timescales = self.channels // 2 + log_timescale_increment = math.log( + float(max_timescale) / float(min_timescale) + ) / max(num_timescales - 1, 1) + inv_timescales = min_timescale * torch.exp( + torch.arange(num_timescales) * -log_timescale_increment + ) + self.register_buffer( + "inv_timescales", + inv_timescales.float().unsqueeze(0).unsqueeze(0), + persistent=False, + ) + + def _get_timing_signal_1d_pos( + self, position: torch.Tensor, dtype: torch.dtype + ) -> torch.Tensor: + assert position.ndim == 2 + position = position.float().unsqueeze(-1) + scaled_time = position * self.inv_timescales.to( + device=position.device, dtype=torch.float32 + ) + timing_signal = torch.cat( + [torch.sin(scaled_time), torch.cos(scaled_time)], dim=-1 + ) + return timing_signal.type(dtype) + + def _relative_shift( + self, + term_bd_before_shift: torch.Tensor, + batch_size: int, + num_heads: int, + num_query_blocks: int, + query_block_size: int, + key_context_size: int, + max_span_plus_1: int, + ) -> torch.Tensor: + pad_amount_last_dim = (key_context_size + 1) - max_span_plus_1 + padding_tuple = (0, pad_amount_last_dim) + + term_bd_padded = F.pad(term_bd_before_shift, padding_tuple) + term_bd_reshaped = term_bd_padded.reshape( + ( + batch_size, + num_heads, + num_query_blocks, + query_block_size * (key_context_size + 1), + ) + ) + term_bd_sliced = term_bd_reshaped[ + :, :, :, : query_block_size * key_context_size + ] + term_bd_shifted = term_bd_sliced.reshape( + ( + batch_size, + num_heads, + num_query_blocks, + query_block_size, + key_context_size, + ) + ) + return term_bd_shifted + + def forward(self, queries: torch.Tensor, keys: torch.Tensor) -> torch.Tensor: + batch_size, num_query_blocks, query_block_size, num_heads, head_dim = ( + queries.shape + ) + _, _, key_context_size, _, _ = keys.shape + + pos_indices = torch.arange( + self.max_backward, -self.max_forward - 1, -1, device=queries.device + ).unsqueeze(0) + max_span_plus_1 = pos_indices.shape[1] + + sin_emb_timing_signal = self._get_timing_signal_1d_pos( + pos_indices, dtype=queries.dtype + ) + # pos_proj is a ColumnParallelLinear (no implicit dtype promotion); + # project in weight dtype, then cast back to queries' dtype for the matmuls. + projected_sin_emb, _ = self.pos_proj( + sin_emb_timing_signal.to(self.pos_proj.weight.dtype) + ) + projected_sin_emb = projected_sin_emb.to(queries.dtype) + sin_emb = projected_sin_emb.reshape( + 1, max_span_plus_1, self.num_heads, self.head_dim + ).squeeze(0) + + queries_p = queries.permute(0, 3, 1, 2, 4) + keys_p_t = keys.permute(0, 3, 1, 4, 2) + term_ac = torch.matmul(queries_p, keys_p_t) + + q_permuted = queries.permute(0, 3, 1, 2, 4) + s_permuted = sin_emb.permute(1, 2, 0) + q_reshaped = q_permuted.reshape( + batch_size, num_heads, num_query_blocks * query_block_size, head_dim + ) + term_bd_unshifed_matmul = torch.matmul(q_reshaped, s_permuted) + term_bd_unshifed = term_bd_unshifed_matmul.reshape( + batch_size, + num_heads, + num_query_blocks, + query_block_size, + max_span_plus_1, + ) + + term_bd_shifted = self._relative_shift( + term_bd_unshifed, + batch_size, + num_heads, + num_query_blocks, + query_block_size, + key_context_size, + max_span_plus_1, + ) + + return term_ac + term_bd_shifted + + +# --------------------------------------------------------------------------- +# Local Dot-Product Attention (with per_dim_key_scale) +# --------------------------------------------------------------------------- + + +class Gemma4AudioAttention(nn.Module): + def __init__( + self, + config: Gemma4AudioConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + + self.num_heads = config.conf_num_attention_heads + self.hidden_size = config.hidden_size + self.head_dim = self.hidden_size // self.num_heads + + self.chunk_size = config.conf_attention_chunk_size + self.max_future_horizon = config.conf_attention_context_right + self.max_past_horizon = max(0, config.conf_attention_context_left - 1) + self.attention_logits_soft_cap = config.conf_attention_logit_cap + self.context_size = ( + self.chunk_size + self.max_past_horizon + self.max_future_horizon + ) + + self.relative_position_embedding = Gemma4AudioRelativePositionEmbedding( + config, + quant_config, + prefix=add_prefix("relative_position_embedding", prefix), + ) + self.per_dim_scale = nn.Parameter(torch.zeros((self.head_dim,))) + self.per_dim_key_scale = nn.Parameter(torch.ones((self.head_dim,))) + + self.qkv = ClippableQKVParallelLinear( + hidden_size=self.hidden_size, + head_size=self.head_dim, + total_num_heads=self.num_heads, + total_num_kv_heads=self.num_heads, + bias=False, + quant_config=quant_config, + prefix=prefix, + ) + + # softplus(0) = log(2); pre-fold into scale factors + r_softplus_0 = 1.0 / math.log(2) + self.q_scale = (self.head_dim ** -0.5) * r_softplus_0 + self.k_scale = r_softplus_0 + + self.register_buffer( + "softcap", + torch.tensor(self.attention_logits_soft_cap).float(), + persistent=False, + ) + + # ------ block / context helpers (identical to Gemma3n) ------------------ + + def _pad_dim1( + self, x: torch.Tensor, dim10_val: int, dim11_val: int + ) -> torch.Tensor: + padding_tuple = [0] * x.ndim * 2 + dim_idx_from_end = x.ndim - 2 + start_idx_for_dim = 2 * dim_idx_from_end + padding_tuple[start_idx_for_dim] = dim10_val + padding_tuple[start_idx_for_dim + 1] = dim11_val + return F.pad(x, tuple(padding_tuple)) + + def _convert_to_block(self, x: torch.Tensor) -> torch.Tensor: + shape = x.shape + b, t = shape[:2] + num_blocks = (t + self.chunk_size - 1) // self.chunk_size + if (padding_len := num_blocks * self.chunk_size - t) > 0: + x = self._pad_dim1(x, 0, padding_len) + permute_dims = (b, num_blocks, self.chunk_size) + shape[2:] + return x.reshape(permute_dims).contiguous() + + def _extract_block_context(self, x: torch.Tensor) -> torch.Tensor: + pad_left = self.max_past_horizon + pad_right = self.max_future_horizon + self.chunk_size - 1 + x = self._pad_dim1(x, pad_left, pad_right) + frame_len = self.context_size + frame_step = self.chunk_size + x_unfolded = x.unfold(dimension=1, size=frame_len, step=frame_step) + if x.ndim > 2 and x_unfolded.ndim > 3: + x_unfolded = torch.movedim(x_unfolded, source=-1, destination=2) + return x_unfolded.contiguous() + + # ------ forward --------------------------------------------------------- + + def forward( + self, + x: torch.Tensor, + mask: torch.BoolTensor, + causal_valid_mask: torch.BoolTensor, + ) -> torch.Tensor: + q, k, v = self.qkv(x) + qkv_shape = (*x.shape[:-1], self.num_heads, self.head_dim) + query_states = q.float().reshape(qkv_shape).contiguous() + key_states = k.float().reshape(qkv_shape).contiguous() + value_states = v.float().reshape(qkv_shape).contiguous() + + per_dim_scale_sp = F.softplus(self.per_dim_scale) + broadcast_shape = (1, 1, 1, self.head_dim) + query_states = query_states * self.q_scale * per_dim_scale_sp.view(broadcast_shape) + + per_dim_key_scale_sp = F.softplus(self.per_dim_key_scale) + key_states = key_states * self.k_scale * per_dim_key_scale_sp.view(broadcast_shape) + + batch_size, q_time = query_states.shape[:2] + + query_blocks = self._convert_to_block(query_states) + key_blocks = self._extract_block_context(key_states) + value_blocks = self._extract_block_context(value_states) + num_query_blocks = query_blocks.shape[1] + + original_valid_mask = ~mask + extracted_valid_mask_blocks = self._extract_block_context(original_valid_mask) + + if ( + extracted_valid_mask_blocks.ndim == 4 + and extracted_valid_mask_blocks.shape[0] == batch_size + and extracted_valid_mask_blocks.shape[1] == num_query_blocks + and extracted_valid_mask_blocks.shape[2] + * extracted_valid_mask_blocks.shape[3] + == self.context_size + ): + extracted_valid_mask_blocks = extracted_valid_mask_blocks.reshape( + batch_size, num_query_blocks, self.context_size + ) + + condition_from_input_validity = extracted_valid_mask_blocks.unsqueeze( + 1 + ).unsqueeze(-2) + condition_from_causality = ( + causal_valid_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0) + ) + + final_condition_for_where = torch.logical_and( + condition_from_input_validity, + condition_from_causality.to(condition_from_input_validity.device), + ) + + logits = self.relative_position_embedding(query_blocks, key_blocks) + + softcap_val = self.softcap.to(logits.device) + logits = logits / softcap_val + logits = torch.tanh(logits) + logits = logits * softcap_val + + logits = torch.where( + final_condition_for_where, + logits, + self.config.conf_attention_invalid_logits_value, + ) + + probabilities = F.softmax(logits, dim=-1, dtype=torch.float32).to( + dtype=value_blocks.dtype + ) + + b_dim, n_dim, u_dim, w_dim, c_dim = probabilities.shape + h_dim = value_blocks.shape[-1] + prob_bun = probabilities.permute(0, 2, 1, 3, 4).reshape(-1, w_dim, c_dim) + v_bun = value_blocks.permute(0, 1, 3, 2, 4).reshape(-1, c_dim, h_dim) + result_bmm = torch.bmm(prob_bun, v_bun) + context_vectors = result_bmm.reshape( + b_dim, u_dim, n_dim, w_dim, h_dim + ).permute(0, 1, 3, 2, 4) + context_vectors = context_vectors.reshape( + batch_size, + num_query_blocks * self.chunk_size, + self.num_heads, + self.head_dim, + ) + context_vectors = context_vectors[:, :q_time] + return context_vectors + + +# --------------------------------------------------------------------------- +# SSCP (Sub-Sample Convolution Projection) +# --------------------------------------------------------------------------- + + +class Gemma4AudioSSCPConvBlock(nn.Module): + """Single 2D conv block with LayerNorm and semicausal padding.""" + + def __init__( + self, + config: Gemma4AudioConfig, + idx: int, + input_freq_dim: int, + manual_padding: Tuple[int, int, int, int] = (0, 0, 0, 0), + ): + super().__init__() + self.config = config + + in_channels = 1 if idx == 0 else config.sscp_conv_channel_size[idx - 1] + out_channels = config.sscp_conv_channel_size[idx] + kernel_t, kernel_f = config.sscp_conv_kernel_size[idx] + stride_t, stride_f = config.sscp_conv_stride_size[idx] + self.time_stride = stride_t + + if config.sscp_conv_time_pad_top is not None and config.sscp_conv_time_pad_bottom is not None: + pad_t_top = config.sscp_conv_time_pad_top + pad_t_bottom = config.sscp_conv_time_pad_bottom + elif config.sscp_conv_padding_type == "semicausal": + pad_t_top = kernel_t // 2 + pad_t_bottom = 0 if config.streaming else kernel_t // 2 + else: + pad_t_top = 0 + pad_t_bottom = 0 if config.streaming else kernel_t - 1 + + pad_f_left = 1 + pad_f_right = 1 + + self.manual_padding = (pad_f_left, pad_f_right, pad_t_top, pad_t_bottom) + + self.conv = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(kernel_t, kernel_f), + stride=(stride_t, stride_f), + padding=(0, 0), + bias=False, + ) + + f_in_padded = input_freq_dim + pad_f_left + pad_f_right + self.f_out_conv = (f_in_padded - kernel_f) // stride_f + 1 + + self.norm = nn.LayerNorm( + [out_channels], + eps=config.sscp_conv_group_norm_eps, + elementwise_affine=True, + bias=False, + ) + self.activation = nn.ReLU() + + def forward( + self, audio_encodings: torch.Tensor, audio_mel_mask: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + mask_for_fill = audio_mel_mask.unsqueeze(1).unsqueeze(-1) + audio_encodings = audio_encodings.masked_fill(mask_for_fill, 0.0) + + audio_encodings_padded = F.pad( + audio_encodings, self.manual_padding, mode="constant", value=0.0 + ).to(self.conv.weight.dtype) + audio_encodings_conv = self.conv(audio_encodings_padded) + + output_mask = audio_mel_mask[:, :: self.time_stride][ + :, : audio_encodings_conv.shape[2] + ] + + x = audio_encodings_conv.permute(0, 2, 3, 1) + x_normed = self.norm(x) + audio_encodings_normed = x_normed.permute(0, 3, 1, 2).contiguous() + return self.activation(audio_encodings_normed), output_mask + + +class Gemma4AudioSubSampleConvProjection(nn.Module): + def __init__( + self, + config: Gemma4AudioConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + + current_f = config.input_feat_size + calculated_f_out_dims = [] + + for i in range(2): + kernel_h, kernel_w = config.sscp_conv_kernel_size[i] + stride_h, stride_w = config.sscp_conv_stride_size[i] + + pad_f_left = 1 + pad_f_right = 1 + f_in_padded = current_f + pad_f_left + pad_f_right + f_out = (f_in_padded - kernel_w) // stride_w + 1 + calculated_f_out_dims.append(f_out) + current_f = f_out + + self.conv_0 = Gemma4AudioSSCPConvBlock( + idx=0, + input_freq_dim=config.input_feat_size, + config=config, + ) + self.conv_1 = Gemma4AudioSSCPConvBlock( + idx=1, + input_freq_dim=calculated_f_out_dims[0], + config=config, + ) + + final_c_out = config.sscp_conv_channel_size[-1] + final_f_out = calculated_f_out_dims[-1] + self.input_proj_in_features = final_c_out * final_f_out + + self.input_proj_linear = RowParallelLinear( + self.input_proj_in_features, + config.hidden_size, + bias=False, + quant_config=quant_config, + prefix=add_prefix("input_proj_linear", prefix), + ) + + def forward( + self, audio_encodings: torch.Tensor, audio_mel_mask: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + audio_encodings_reshaped = audio_encodings.unsqueeze(1) + x, mask = self.conv_0(audio_encodings_reshaped, audio_mel_mask) + x, mask = self.conv_1(x, mask) + b, c_out, t_out, f_out = x.shape + x_permuted = x.permute(0, 2, 3, 1).contiguous() + output_flattened = x_permuted.reshape(b, t_out, f_out * c_out) + output, _ = self.input_proj_linear(output_flattened) + return output, mask + + +# --------------------------------------------------------------------------- +# Conformer Blocks +# --------------------------------------------------------------------------- + + +class Gemma4AudioConformerAttention(nn.Module): + def __init__( + self, + config: Gemma4AudioConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + self.post_in_features = config.hidden_size + + self.register_buffer( + "gradient_clipping", + torch.tensor(config.gradient_clipping), + persistent=False, + ) + + self.pre_attn_norm = Gemma4RMSNorm( + config.hidden_size, scale_shift=0.0 + ) + self.attn = Gemma4AudioAttention( + config, quant_config, prefix=add_prefix("attn", prefix) + ) + self.post = ClippableRowParallelLinear( + self.post_in_features, + config.hidden_size, + bias=False, + quant_config=quant_config, + prefix=add_prefix("post", prefix), + ) + self.post_norm = Gemma4RMSNorm( + config.hidden_size, scale_shift=0.0 + ) + + def forward( + self, + audio_encodings: torch.Tensor, + audio_mel_mask: torch.BoolTensor, + causal_valid_mask: torch.BoolTensor, + ) -> torch.Tensor: + audio_encodings_input_to_attn = audio_encodings + audio_encodings = torch.clamp( + audio_encodings, -self.gradient_clipping, self.gradient_clipping + ) + audio_encodings_norm = self.pre_attn_norm(audio_encodings) + audio_encodings_attn_out = self.attn( + audio_encodings_norm, audio_mel_mask, causal_valid_mask + ) + + b, t, num_heads, head_dim = audio_encodings_attn_out.shape + audio_encodings_reshaped = audio_encodings_attn_out.reshape( + b, t, num_heads * head_dim + ).to(dtype=audio_encodings_input_to_attn.dtype) + + audio_encodings = self.post(audio_encodings_reshaped) + audio_encodings = torch.clamp( + audio_encodings, -self.gradient_clipping, self.gradient_clipping + ) + return audio_encodings_input_to_attn + self.post_norm(audio_encodings) + + +class Gemma4AudioConformerFeedForward(nn.Module): + def __init__( + self, + config: Gemma4AudioConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + + self.register_buffer( + "gradient_clipping", + torch.tensor(config.gradient_clipping), + persistent=False, + ) + + self.pre_layer_norm = Gemma4RMSNorm( + config.hidden_size, scale_shift=0.0 + ) + self.ffw_layer_1 = ClippableColumnParallelLinear( + config.hidden_size, + config.hidden_size * 4, + bias=False, + quant_config=quant_config, + prefix=add_prefix("ffw_layer_1", prefix), + ) + self.ffw_layer_2 = ClippableRowParallelLinear( + config.hidden_size * 4, + config.hidden_size, + bias=False, + quant_config=quant_config, + prefix=add_prefix("ffw_layer_2", prefix), + ) + self.post_layer_norm = Gemma4RMSNorm( + config.hidden_size, scale_shift=0.0 + ) + self.post_layer_scale = config.conf_residual_weight + + def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor: + residual = audio_encodings + audio_encodings = torch.clamp( + audio_encodings, -self.gradient_clipping, self.gradient_clipping + ) + audio_encodings = self.pre_layer_norm(audio_encodings) + audio_encodings = self.ffw_layer_1(audio_encodings) + audio_encodings = F.silu(audio_encodings) + audio_encodings = self.ffw_layer_2(audio_encodings) + audio_encodings = torch.clamp( + audio_encodings, -self.gradient_clipping, self.gradient_clipping + ) + audio_encodings = self.post_layer_norm(audio_encodings) + return residual + (audio_encodings * self.post_layer_scale) + + +class Gemma4AudioConformerLightConv1d(nn.Module): + def __init__( + self, + config: Gemma4AudioConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + self.causal_padding = config.conf_conv_kernel_size - 1 + + self.register_buffer( + "gradient_clipping", + torch.tensor(config.gradient_clipping), + persistent=False, + ) + + self.pre_layer_norm = Gemma4RMSNorm( + config.hidden_size, eps=config.rms_norm_eps, scale_shift=0.0 + ) + self.linear_start = ClippableColumnParallelLinear( + config.hidden_size, + config.hidden_size * 2, + bias=False, + quant_config=quant_config, + prefix=add_prefix("linear_start", prefix), + ) + self.depthwise_conv1d = nn.Conv1d( + in_channels=config.hidden_size, + out_channels=config.hidden_size, + kernel_size=config.conf_conv_kernel_size, + stride=1, + padding=0, + groups=config.hidden_size, + bias=False, + ) + self.conv_norm = Gemma4RMSNorm( + config.hidden_size, eps=config.rms_norm_eps, scale_shift=0.0 + ) + self.linear_end = ClippableRowParallelLinear( + config.hidden_size, + config.hidden_size, + bias=False, + quant_config=quant_config, + prefix=add_prefix("linear_end", prefix), + ) + + def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor: + audio_encodings_residual = audio_encodings + + audio_encodings = self.pre_layer_norm(audio_encodings) + audio_encodings = self.linear_start(audio_encodings) + audio_encodings = F.glu(audio_encodings, dim=-1) + + audio_encodings_permuted = audio_encodings.permute(0, 2, 1) + audio_encodings_permuted_padded = F.pad( + audio_encodings_permuted, (self.causal_padding, 0) + ) + audio_encodings = self.depthwise_conv1d(audio_encodings_permuted_padded) + audio_encodings = audio_encodings.permute(0, 2, 1) + audio_encodings = torch.clamp( + audio_encodings, -self.gradient_clipping, self.gradient_clipping + ) + audio_encodings = self.conv_norm(audio_encodings) + audio_encodings = F.silu(audio_encodings) + audio_encodings = self.linear_end(audio_encodings) + return audio_encodings + audio_encodings_residual + + +class Gemma4AudioConformerBlock(nn.Module): + def __init__( + self, + config: Gemma4AudioConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + + self.ffw_layer_start = Gemma4AudioConformerFeedForward( + config, quant_config, prefix=add_prefix("ffw_layer_start", prefix) + ) + self.attention = Gemma4AudioConformerAttention( + config, quant_config, prefix=add_prefix("attention", prefix) + ) + self.lconv1d = Gemma4AudioConformerLightConv1d( + config, quant_config, prefix=add_prefix("lconv1d", prefix) + ) + self.ffw_layer_end = Gemma4AudioConformerFeedForward( + config, quant_config, prefix=add_prefix("ffw_layer_end", prefix) + ) + self.register_buffer( + "gradient_clipping", + torch.tensor(config.gradient_clipping), + persistent=False, + ) + self.norm = Gemma4RMSNorm(config.hidden_size, scale_shift=0.0) + + def forward( + self, + audio_encodings: torch.Tensor, + audio_mel_mask: torch.BoolTensor, + causal_valid_mask: torch.BoolTensor, + ) -> torch.Tensor: + audio_encodings = self.ffw_layer_start(audio_encodings) + audio_encodings = self.attention( + audio_encodings, audio_mel_mask, causal_valid_mask + ) + validity_mask_for_lconv = ~audio_mel_mask + audio_encodings_for_lconv_input = ( + audio_encodings + * validity_mask_for_lconv.unsqueeze(-1).to(audio_encodings.dtype) + ) + audio_encodings = self.lconv1d(audio_encodings_for_lconv_input) + + audio_encodings = self.ffw_layer_end(audio_encodings) + audio_encodings = torch.clamp( + audio_encodings, -self.gradient_clipping, self.gradient_clipping + ) + return self.norm(audio_encodings) + + +# --------------------------------------------------------------------------- +# Top-level Encoder +# --------------------------------------------------------------------------- + + +class Gemma4AudioEncoder(nn.Module): + """SGLang-native TP-sharded Gemma 4 audio encoder (USM Conformer + SSCP).""" + + def __init__( + self, + config: Gemma4AudioConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + + self.subsample_conv_projection = Gemma4AudioSubSampleConvProjection( + config, quant_config, prefix=add_prefix("subsample_conv_projection", prefix) + ) + self.conformer = make_layers( + config.conf_num_hidden_layers, + lambda idx, prefix: Gemma4AudioConformerBlock( + config=config, + quant_config=quant_config, + prefix=prefix, + ), + prefix=add_prefix("conformer", prefix), + ) + + if config.output_proj_dims is not None: + self.output_proj = RowParallelLinear( + config.hidden_size, + config.output_proj_dims, + bias=True, + quant_config=quant_config, + prefix=add_prefix("output_proj", prefix), + ) + else: + self.output_proj = None + + @property + def device(self): + return next(self.parameters()).device + + def forward( + self, audio_mel: torch.Tensor, audio_mel_mask: torch.BoolTensor + ) -> Tuple[torch.Tensor, torch.BoolTensor]: + """Encode a batch of mel spectrograms. + + Args: + audio_mel: [batch, num_frames, mel_bins] + audio_mel_mask: [batch, num_frames], True = padding + + Returns: + audio_encodings: [batch, reduced_frames, hidden_size/output_proj_dims] + audio_mel_mask: [batch, reduced_frames], True = padding + """ + audio_encodings, current_mask = self.subsample_conv_projection( + audio_mel, audio_mel_mask + ) + + with torch.no_grad(): + chunk_size = self.config.conf_attention_chunk_size + max_future_horizon = self.config.conf_attention_context_right + max_past_horizon = max(0, self.config.conf_attention_context_left - 1) + upper_diagonal = max_past_horizon + max_future_horizon + context_size = chunk_size + max_past_horizon + max_future_horizon + + lower_causal_mask = torch.tril( + torch.ones((context_size, chunk_size), dtype=torch.bool), + diagonal=0, + ).T + upper_causal_mask = torch.tril( + torch.ones((chunk_size, context_size), dtype=torch.bool), + diagonal=upper_diagonal, + ) + local_causal_valid_mask = torch.ones( + (chunk_size, context_size), dtype=torch.bool + ) + causal_valid_mask = ( + local_causal_valid_mask * lower_causal_mask * upper_causal_mask + ) + + for block in self.conformer: + audio_encodings = block(audio_encodings, current_mask, causal_valid_mask) + + if self.config.conf_reduction_factor > 1: + audio_encodings = audio_encodings[:, :: self.config.conf_reduction_factor] + current_mask = current_mask[:, :: self.config.conf_reduction_factor] + + if self.output_proj is not None: + audio_encodings, _ = self.output_proj(audio_encodings) + + if current_mask.shape[1] != audio_encodings.shape[1]: + target_len = audio_encodings.shape[1] + if target_len > current_mask.shape[1]: + current_mask = F.pad( + current_mask, (0, target_len - current_mask.shape[1]), value=True + ) + else: + current_mask = current_mask[:, :target_len] + + audio_encodings = audio_encodings.masked_fill(current_mask.unsqueeze(-1), 0.0) + return audio_encodings, current_mask diff --git a/python/sglang/srt/models/gemma4_mm.py b/python/sglang/srt/models/gemma4_mm.py index fb21cc9b6c2d..b97bd1393b5b 100644 --- a/python/sglang/srt/models/gemma4_mm.py +++ b/python/sglang/srt/models/gemma4_mm.py @@ -27,8 +27,8 @@ Gemma4VisionConfig, PreTrainedModel, ) -from transformers.models.auto.modeling_auto import AutoModel +from sglang.srt.models.gemma4_audio import Gemma4AudioEncoder from sglang.srt.models.gemma4_vision import Gemma4VisionEncoder from sglang.srt.layers.layernorm import Gemma4RMSNorm @@ -189,8 +189,11 @@ def __init__( # Audio components if getattr(config, "audio_config", None) is not None: - self.audio_tower = AutoModel.from_config(config=config.audio_config) - self.audio_tower.post_init() + self.audio_tower = Gemma4AudioEncoder( + config=config.audio_config, + quant_config=quant_config, + prefix=add_prefix("audio_tower", prefix), + ) self.embed_audio = Gemma4MultimodalEmbedder( config.audio_config, config.text_config, @@ -270,14 +273,10 @@ def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: def get_audio_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: if self.audio_tower is None: - raise ValueError( - "Audio inputs provided but the model does not have an audio tower." - ) + raise ValueError("Audio inputs provided but the model does not have an audio tower.") all_input_features = flatten_nested_list([item.feature for item in items]) - all_input_features_mask = flatten_nested_list( - [~item.input_features_mask for item in items] - ) + all_input_features_mask = flatten_nested_list([~item.input_features_mask for item in items]) all_embeds = [] for input_features, input_features_mask in zip( @@ -289,22 +288,18 @@ def get_audio_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: input_features_mask = input_features_mask.unsqueeze(0) input_features = input_features.to( - device=next(self.audio_tower.parameters()).device, + device=self.audio_tower.device, dtype=self.language_model.dtype(), ) input_features_mask = input_features_mask.to(device=input_features.device) - # Run audio tower (mask True=padding) - audio_outputs = self.audio_tower(input_features, input_features_mask) - if isinstance(audio_outputs, tuple): - audio_encodings, audio_mask = audio_outputs - else: - audio_encodings = audio_outputs.last_hidden_state - audio_mask = audio_outputs.audio_mel_mask + # audio_mel_mask convention: True = padding + audio_encodings, audio_mask = self.audio_tower( + input_features, input_features_mask + ) audio_features = self.embed_audio(inputs_embeds=audio_encodings) - # Strip padding for enc, mask in zip(audio_features, audio_mask): all_embeds.append(enc[~mask]) @@ -347,9 +342,6 @@ def forward( "You must specify exactly one of input_ids or inputs_embeds" ) - # DEBUG: check mm_inputs in forward - has_mm = forward_batch.contains_mm_inputs() if hasattr(forward_batch, 'contains_mm_inputs') else False - is_decode = forward_batch.forward_mode.is_decode() positions += 1 per_layer_inputs = None if input_ids is not None: @@ -380,22 +372,86 @@ def forward( def tie_weights(self, recompute_mapping=False): return self.language_model.tie_weights() + # Standard stacked-params mapping for fused QKV / GateUp linears + # in the text decoder. Also consumed by the tower QKV remap (step 2). + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".up_proj", 1), + (".gate_up_proj", ".gate_proj", 0), + ] + + # Regex for fused QKV in vision/audio towers. + # Vision: *.self_attn.{q,k,v}_proj.* Audio: *.attn.{q,k,v}_proj.* + _RE_TOWER_QKV = re.compile( + r"(.+\.(?:self_attn|attn))\.(q_proj|k_proj|v_proj)\.(.*)" + ) + # Regex for fused GateUp in the vision tower MLP. + _RE_TOWER_GATE_UP = re.compile( + r"(.+\.mlp)\.(gate_proj|up_proj)\.(.*)" + ) + + @staticmethod + def _remap_tower_name(name: str, params_dict: dict) -> str: + """Remap a vision/audio tower checkpoint name to our module tree. + + Three transformations, applied in order: + + 1. **Fused QKV** — ``{q,k,v}_proj.*`` → ``qkv.*`` + Weight/bias are redirected into the fused ``qkv.{proj}.{attr}`` + namespace (stacked-params then merges them into ``qkv_proj``). + Clip buffers are split: ``input_*`` → shared ``qkv.input_*``, + ``output_*`` → per-projection ``qkv.{q,k,v}_output_*``. + + 2. **Fused GateUp** — ``{gate,up}_proj.*`` → ``gate_up.*`` + Same pattern as QKV. + + 3. **Clippable wrapper** — ``*.weight``/``*.bias`` → ``*.linear.weight`` + Catches the remaining (non-fused) clippable linears whose inner + ``RowParallelLinear``/``ColumnParallelLinear`` lives at ``.linear``. + Falls back to the original name when ``.linear.`` does not exist + in ``params_dict`` (plain linears, norms, conv weights, etc.). + """ + # Step 1: fused QKV + m = Gemma4ForConditionalGeneration._RE_TOWER_QKV.match(name) + if m: + pfx, proj, attr = m.groups() + if attr in ("weight", "bias"): + return f"{pfx}.qkv.{proj}.{attr}" + if attr.startswith("output_"): + return f"{pfx}.qkv.{proj[0]}_{attr}" + if attr.startswith("input_"): + return f"{pfx}.qkv.{attr}" + + # Step 2: fused GateUp + m = Gemma4ForConditionalGeneration._RE_TOWER_GATE_UP.match(name) + if m: + pfx, proj, attr = m.groups() + short = proj.split("_")[0] # "gate" or "up" + if attr in ("weight", "bias"): + return f"{pfx}.gate_up.{proj}.{attr}" + if attr.startswith("output_"): + return f"{pfx}.gate_up.{short}_{attr}" + if attr.startswith("input_"): + return f"{pfx}.gate_up.{attr}" + + # Step 3: clippable wrapper (.weight → .linear.weight) + if name.endswith(".weight") or name.endswith(".bias"): + base, attr = name.rsplit(".", 1) + alt = f"{base}.linear.{attr}" + if alt in params_dict: + return alt + + return name + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - (".qkv_proj", ".q_proj", "q"), - (".qkv_proj", ".k_proj", "k"), - (".qkv_proj", ".v_proj", "v"), - (".gate_up_proj", ".up_proj", 1), - (".gate_up_proj", ".gate_proj", 0), - ] - """Load weights for the model.""" params_dict = dict(self.named_parameters()) params_dict.update(dict(self.named_buffers())) loaded_params: Set[str] = set() for name, loaded_weight in weights: - # Vestigial weights to ignore if "embed_vision.embedding." in name or "embed_audio.embedding." in name: continue if self.audio_tower is None and ( @@ -405,62 +461,16 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): name = re.sub(r"^model\.", "", name) - # Vision encoder non-fused linears are wrapped in ClippableLinear: - # checkpoint "proj.weight" → our "proj.linear.weight" - if "vision_tower." in name and (name.endswith(".weight") or name.endswith(".bias")): - base, attr = name.rsplit(".", 1) - alt = f"{base}.linear.{attr}" - if alt in params_dict: - name = alt - - # Vision encoder fused projections (ClippableQKV / ClippableGateUp): - # - # QKV (self_attn.{q,k,v}_proj → self_attn.qkv): - # weight/bias: *.q_proj.weight → *.qkv.q_proj.weight (stacked params then fuses) - # output bound: *.q_proj.output_min → *.qkv.q_output_min - # input bound: *.{q,k,v}_proj.input_min → *.qkv.input_min - # (all are identical in the checkpoint -- same hidden_states input -- - # so they collapse to a single shared buffer; last write wins) - # - # GateUp (mlp.{gate,up}_proj → mlp.gate_up): - # weight/bias: *.gate_proj.weight → *.gate_up.gate_proj.weight - # output bound: *.gate_proj.output_min → *.gate_up.gate_output_min - # input bound: *.{gate,up}_proj.input_min → *.gate_up.input_min - # (same collapse as QKV -- both see the same MLP input) - if "vision_tower." in name: - m = re.match( - r"(.+\.self_attn)\.(q_proj|k_proj|v_proj)\.(.*)", name - ) - if m: - pfx, proj, attr = m.groups() - if attr in ("weight", "bias"): - name = f"{pfx}.qkv.{proj}.{attr}" - elif attr.startswith("output_"): - name = f"{pfx}.qkv.{proj[0]}_{attr}" - elif attr.startswith("input_"): - name = f"{pfx}.qkv.{attr}" - - m = re.match( - r"(.+\.mlp)\.(gate_proj|up_proj)\.(.*)", name - ) - if m: - pfx, proj, attr = m.groups() - short = proj.split("_")[0] # "gate" or "up" - if attr in ("weight", "bias"): - name = f"{pfx}.gate_up.{proj}.{attr}" - elif attr.startswith("output_"): - name = f"{pfx}.gate_up.{short}_{attr}" - elif attr.startswith("input_"): - name = f"{pfx}.gate_up.{attr}" + # Remap vision / audio tower names (fused QKV/GateUp, clippable wrappers) + if "vision_tower." in name or "audio_tower." in name: + name = self._remap_tower_name(name, params_dict) + # Try stacked (fused) params first orig_name = name - for param_name, weight_name, shard_id in stacked_params_mapping: + for param_name, weight_name, shard_id in self.stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) - if name.endswith(".bias") and name not in params_dict: - name = orig_name - continue if name not in params_dict: name = orig_name continue @@ -469,10 +479,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader(param, loaded_weight, shard_id) break else: - # Skip loading extra bias for GPTQ models if name.endswith(".bias") and name not in params_dict: continue - # Remapping the name of FP8 kv-scale name = maybe_remap_kv_scale_name(name, params_dict) if name is None: continue @@ -485,7 +493,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): unloaded_params = params_dict.keys() - loaded_params if unloaded_params: logger.warning( - "Some weights are not initialized from checkpoints: %s", unloaded_params + "Some weights are not initialized from checkpoints: %s", + unloaded_params, ) return loaded_params diff --git a/python/sglang/srt/models/gemma4_vision.py b/python/sglang/srt/models/gemma4_vision.py index d4c9afde4924..daea6875d22e 100644 --- a/python/sglang/srt/models/gemma4_vision.py +++ b/python/sglang/srt/models/gemma4_vision.py @@ -22,158 +22,17 @@ from transformers import Gemma4VisionConfig from sglang.srt.layers.attention.vision import QKV_BACKEND_IMPL, VisionAttention +from sglang.srt.layers.clippable_linear import ( + ClippableGateUpParallelLinear, + ClippableQKVParallelLinear, + ClippableRowParallelLinear, +) from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size from sglang.srt.layers.layernorm import Gemma4RMSNorm -from sglang.srt.layers.linear import ( - MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear, -) from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.utils import add_prefix, get_device_capability, is_cuda -# --------------------------------------------------------------------------- -# Activation clamping (matches HF Gemma4ClippableLinear) -# --------------------------------------------------------------------------- - -_INF = float("inf") - - -class ClippableLinear(nn.Module): - """``RowParallelLinear`` with input/output activation clamping. - - Mirrors HF's ``Gemma4ClippableLinear``: owns the linear layer and applies - ``torch.clamp`` before and after the linear forward pass. Clip bounds - default to ±inf (no-op) and are populated from the checkpoint. - """ - - def __init__( - self, - input_size: int, - output_size: int, - *, - bias: bool = True, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ): - super().__init__() - self.linear = RowParallelLinear( - input_size=input_size, - output_size=output_size, - bias=bias, - quant_config=quant_config, - prefix=add_prefix("linear", prefix), - ) - self.input_min = nn.parameter.Buffer(torch.tensor(-_INF), persistent=False) - self.input_max = nn.parameter.Buffer(torch.tensor(_INF), persistent=False) - self.output_min = nn.parameter.Buffer(torch.tensor(-_INF), persistent=False) - self.output_max = nn.parameter.Buffer(torch.tensor(_INF), persistent=False) - - def forward(self, x: torch.Tensor): - x = torch.clamp(x, self.input_min, self.input_max) - x, _ = self.linear(x) - x = torch.clamp(x, self.output_min, self.output_max) - return x - - -class ClippableQKVParallelLinear(nn.Module): - """Fused QKV projection with per-projection activation clamping. - - Owns a single ``QKVParallelLinear`` for the fused matmul. Clip bounds - are stored as flat buffers: shared ``input_min/max`` (applied before the - matmul) and per-projection ``q/k/v_output_min/max`` (applied after split). - - The checkpoint stores separate ``input_min/max`` for each of q, k, v but - they are identical (all three projections receive the same hidden_states), - so we collapse them into a single shared buffer (last write wins during - weight loading). - """ - - def __init__( - self, - config: Gemma4VisionConfig, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ): - super().__init__() - tp_size = get_attention_tp_size() - self.q_size = (config.num_attention_heads // tp_size) * config.head_dim - self.kv_size = (config.num_key_value_heads // tp_size) * config.head_dim - - self.qkv_proj = QKVParallelLinear( - hidden_size=config.hidden_size, - head_size=config.head_dim, - total_num_heads=config.num_attention_heads, - total_num_kv_heads=config.num_key_value_heads, - bias=config.attention_bias, - quant_config=quant_config, - prefix=add_prefix("qkv_proj", prefix), - ) - self.input_min = nn.parameter.Buffer(torch.tensor(-_INF), persistent=False) - self.input_max = nn.parameter.Buffer(torch.tensor(_INF), persistent=False) - self.q_output_min = nn.parameter.Buffer(torch.tensor(-_INF), persistent=False) - self.q_output_max = nn.parameter.Buffer(torch.tensor(_INF), persistent=False) - self.k_output_min = nn.parameter.Buffer(torch.tensor(-_INF), persistent=False) - self.k_output_max = nn.parameter.Buffer(torch.tensor(_INF), persistent=False) - self.v_output_min = nn.parameter.Buffer(torch.tensor(-_INF), persistent=False) - self.v_output_max = nn.parameter.Buffer(torch.tensor(_INF), persistent=False) - - def forward(self, hidden_states: torch.Tensor): - x = torch.clamp(hidden_states, self.input_min, self.input_max) - qkv, _ = self.qkv_proj(x) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q = torch.clamp(q, self.q_output_min, self.q_output_max) - k = torch.clamp(k, self.k_output_min, self.k_output_max) - v = torch.clamp(v, self.v_output_min, self.v_output_max) - return q, k, v - - -class ClippableGateUpParallelLinear(nn.Module): - """Fused gate/up projection with per-projection activation clamping. - - Same pattern as ``ClippableQKVParallelLinear``: owns a single - ``MergedColumnParallelLinear`` for the fused matmul, with shared input - bounds and per-projection output bounds as flat buffers. - - The checkpoint stores separate ``input_min/max`` for gate and up but they - are identical (both projections receive the same MLP input), so we collapse - them into a single shared buffer (last write wins during weight loading). - """ - - def __init__( - self, - config: Gemma4VisionConfig, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ): - super().__init__() - tp_size = get_attention_tp_size() - self.proj_size = config.intermediate_size // tp_size - - self.gate_up_proj = MergedColumnParallelLinear( - input_size=config.hidden_size, - output_sizes=[config.intermediate_size, config.intermediate_size], - bias=False, - quant_config=quant_config, - prefix=add_prefix("gate_up_proj", prefix), - ) - self.input_min = nn.parameter.Buffer(torch.tensor(-_INF), persistent=False) - self.input_max = nn.parameter.Buffer(torch.tensor(_INF), persistent=False) - self.gate_output_min = nn.parameter.Buffer(torch.tensor(-_INF), persistent=False) - self.gate_output_max = nn.parameter.Buffer(torch.tensor(_INF), persistent=False) - self.up_output_min = nn.parameter.Buffer(torch.tensor(-_INF), persistent=False) - self.up_output_max = nn.parameter.Buffer(torch.tensor(_INF), persistent=False) - - def forward(self, x: torch.Tensor): - x = torch.clamp(x, self.input_min, self.input_max) - gate_up, _ = self.gate_up_proj(x) - gate, up = gate_up.split([self.proj_size, self.proj_size], dim=-1) - gate = torch.clamp(gate, self.gate_output_min, self.gate_output_max) - up = torch.clamp(up, self.up_output_min, self.up_output_max) - return gate, up - - # --------------------------------------------------------------------------- # 2-D Multidimensional RoPE (matches HF Gemma4RotaryEmbedding for vision) # --------------------------------------------------------------------------- @@ -282,9 +141,15 @@ def __init__( self.num_kv_heads_per_partition = config.num_key_value_heads // tp_size self.qkv = ClippableQKVParallelLinear( - config, quant_config=quant_config, prefix=prefix, + hidden_size=config.hidden_size, + head_size=config.head_dim, + total_num_heads=config.num_attention_heads, + total_num_kv_heads=config.num_key_value_heads, + bias=config.attention_bias, + quant_config=quant_config, + prefix=prefix, ) - self.o_proj = ClippableLinear( + self.o_proj = ClippableRowParallelLinear( input_size=config.num_attention_heads * config.head_dim, output_size=config.hidden_size, bias=config.attention_bias, @@ -386,9 +251,13 @@ def __init__( ): super().__init__() self.gate_up = ClippableGateUpParallelLinear( - config, quant_config=quant_config, prefix=prefix, + input_size=config.hidden_size, + intermediate_size=config.intermediate_size, + bias=False, + quant_config=quant_config, + prefix=prefix, ) - self.down_proj = ClippableLinear( + self.down_proj = ClippableRowParallelLinear( input_size=config.intermediate_size, output_size=config.hidden_size, bias=False, From f85490eb9c0b638bb2bf7031dfbd7e162545f5d2 Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Tue, 10 Mar 2026 22:37:53 +0000 Subject: [PATCH 019/112] TP fix for audio encoder, change act_fn for vision_encoder, and update test --- python/sglang/srt/models/gemma4_audio.py | 43 +++-- python/sglang/srt/models/gemma4_vision.py | 9 +- .../multimodal/processors/base_processor.py | 1 + python/sglang/test/runners.py | 7 +- test/manual/test_vlm_accuracy.py | 163 +++++++++++++++++- 5 files changed, 208 insertions(+), 15 deletions(-) diff --git a/python/sglang/srt/models/gemma4_audio.py b/python/sglang/srt/models/gemma4_audio.py index 89d15e1bb396..c9b81a953c77 100644 --- a/python/sglang/srt/models/gemma4_audio.py +++ b/python/sglang/srt/models/gemma4_audio.py @@ -36,13 +36,17 @@ ClippableQKVParallelLinear, ClippableRowParallelLinear, ) +from sglang.srt.layers.dp_attention import ( + get_attention_tp_rank, + get_attention_tp_size, +) from sglang.srt.layers.layernorm import Gemma4RMSNorm from sglang.srt.layers.linear import ( ColumnParallelLinear, RowParallelLinear, ) from sglang.srt.layers.quantization.base_config import QuantizationConfig -from sglang.srt.utils import add_prefix, make_layers +from sglang.srt.utils import add_prefix, make_layers, set_weight_attrs # --------------------------------------------------------------------------- @@ -60,15 +64,17 @@ def __init__( super().__init__() self.config = config - self.num_heads = config.conf_num_attention_heads + tp_size = get_attention_tp_size() + total_num_heads = config.conf_num_attention_heads self.channels = config.hidden_size - self.head_dim = self.channels // self.num_heads + self.head_dim = self.channels // total_num_heads + self.num_heads = total_num_heads // tp_size self.max_backward = max(0, config.conf_attention_context_left - 1) self.max_forward = config.conf_attention_context_right self.pos_proj = ColumnParallelLinear( self.channels, - self.num_heads * self.head_dim, + config.hidden_size, bias=False, quant_config=quant_config, prefix=add_prefix("pos_proj", prefix), @@ -208,9 +214,11 @@ def __init__( super().__init__() self.config = config - self.num_heads = config.conf_num_attention_heads + tp_size = get_attention_tp_size() + total_num_heads = config.conf_num_attention_heads self.hidden_size = config.hidden_size - self.head_dim = self.hidden_size // self.num_heads + self.head_dim = self.hidden_size // total_num_heads + self.num_heads = total_num_heads // tp_size self.chunk_size = config.conf_attention_chunk_size self.max_future_horizon = config.conf_attention_context_right @@ -494,6 +502,7 @@ def __init__( self.input_proj_in_features, config.hidden_size, bias=False, + input_is_parallel=False, quant_config=quant_config, prefix=add_prefix("input_proj_linear", prefix), ) @@ -641,6 +650,8 @@ def __init__( super().__init__() self.config = config self.causal_padding = config.conf_conv_kernel_size - 1 + tp_size = get_attention_tp_size() + hidden_per_tp = config.hidden_size // tp_size self.register_buffer( "gradient_clipping", @@ -659,17 +670,28 @@ def __init__( prefix=add_prefix("linear_start", prefix), ) self.depthwise_conv1d = nn.Conv1d( - in_channels=config.hidden_size, - out_channels=config.hidden_size, + in_channels=hidden_per_tp, + out_channels=hidden_per_tp, kernel_size=config.conf_conv_kernel_size, stride=1, padding=0, - groups=config.hidden_size, + groups=hidden_per_tp, bias=False, ) self.conv_norm = Gemma4RMSNorm( - config.hidden_size, eps=config.rms_norm_eps, scale_shift=0.0 + hidden_per_tp, eps=config.rms_norm_eps, scale_shift=0.0 ) + + tp_rank = get_attention_tp_rank() + + def _shard_dim0(param, loaded_weight, _rank=tp_rank, _tp=tp_size): + shard = param.shape[0] + loaded_weight = loaded_weight.narrow(0, _rank * shard, shard) + param.data.copy_(loaded_weight) + + set_weight_attrs(self.depthwise_conv1d.weight, {"weight_loader": _shard_dim0}) + set_weight_attrs(self.conv_norm.weight, {"weight_loader": _shard_dim0}) + self.linear_end = ClippableRowParallelLinear( config.hidden_size, config.hidden_size, @@ -788,6 +810,7 @@ def __init__( config.hidden_size, config.output_proj_dims, bias=True, + input_is_parallel=False, quant_config=quant_config, prefix=add_prefix("output_proj", prefix), ) diff --git a/python/sglang/srt/models/gemma4_vision.py b/python/sglang/srt/models/gemma4_vision.py index daea6875d22e..cb0a82d2938b 100644 --- a/python/sglang/srt/models/gemma4_vision.py +++ b/python/sglang/srt/models/gemma4_vision.py @@ -238,7 +238,7 @@ def forward( # --------------------------------------------------------------------------- -# Vision MLP (GeGLU, TP-sharded) +# Vision MLP (GatedGELU, TP-sharded) # --------------------------------------------------------------------------- @@ -250,6 +250,11 @@ def __init__( prefix: str = "", ): super().__init__() + if config.hidden_activation != "gelu_pytorch_tanh": + raise ValueError( + f"Gemma4VisionMLP expects hidden_activation='gelu_pytorch_tanh', " + f"got {config.hidden_activation!r}" + ) self.gate_up = ClippableGateUpParallelLinear( input_size=config.hidden_size, intermediate_size=config.intermediate_size, @@ -267,7 +272,7 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: gate, up = self.gate_up(x) - x = F.silu(gate) * up + x = F.gelu(gate, approximate="tanh") * up x = self.down_proj(x) return x diff --git a/python/sglang/srt/multimodal/processors/base_processor.py b/python/sglang/srt/multimodal/processors/base_processor.py index 4cc4cfb500e7..0f769bf7d990 100644 --- a/python/sglang/srt/multimodal/processors/base_processor.py +++ b/python/sglang/srt/multimodal/processors/base_processor.py @@ -321,6 +321,7 @@ def process_mm_data( if audios: if self._processor.__class__.__name__ in { "Gemma3nProcessor", + "Gemma4Processor", "GlmAsrProcessor", "Qwen2AudioProcessor", "Qwen3OmniMoeProcessor", diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index 1f10c0cd06bf..8e95d85569e1 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -25,11 +25,16 @@ AutoConfig, AutoModel, AutoModelForCausalLM, - AutoModelForVision2Seq, AutoProcessor, GenerationConfig, ) +try: + # TODO(kpham-sgl): For whatever reason the provided transformers package does not have this module + from transformers import AutoModelForVision2Seq +except ImportError: + AutoModelForVision2Seq = None + from sglang.srt.entrypoints.engine import Engine from sglang.srt.model_loader.ci_weight_validation import ci_validate_and_clean_hf_cache from sglang.srt.utils import get_device, is_npu, load_image diff --git a/test/manual/test_vlm_accuracy.py b/test/manual/test_vlm_accuracy.py index 6e26c012a7eb..0f480293ffb0 100644 --- a/test/manual/test_vlm_accuracy.py +++ b/test/manual/test_vlm_accuracy.py @@ -1,4 +1,8 @@ -""" """ +"""Multimodal encoder accuracy tests: compare HF vs SGLang encoder outputs. + +# TODO(kpham-sgl): Rename this file to test_mm_accuracy.py — it now covers both +# vision and audio encoder comparisons, not just VLM embeddings. +""" import unittest from typing import List, Optional @@ -6,7 +10,7 @@ import numpy as np import torch import torch.nn.functional as F -from transformers import AutoModel, AutoProcessor, AutoTokenizer +from transformers import AutoConfig, AutoModel, AutoProcessor, AutoTokenizer from sglang.srt.configs.model_config import ModelConfig from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest @@ -318,3 +322,158 @@ async def test_vlm_embedding_output(self): ) self.compare_outputs(sglang_output, hf_output) + + +# --------------------------------------------------------------------------- +# Gemma 4 encoder accuracy: vision tower + audio tower vs HF reference +# --------------------------------------------------------------------------- + + +class TestGemma4EncoderAccuracy(unittest.TestCase): + """Compare Gemma 4 vision and audio encoder outputs between HF and SGLang. + + For each encoder we compare: + 1. Raw tower output (before the multimodal embedder projection). + 2. Projected output (tower + ``embed_vision`` / ``embed_audio``). + + Inputs are random tensors so that the test is self-contained and does not + depend on image / audio files. + """ + + MODEL_PATH = "gg-hf-gg/gemma-4-e4b-it" + COSINE_THRESHOLD = 0.99 + + @classmethod + def setUpClass(cls): + cls.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # -- HF model: extract encoder components, discard the rest ----------- + from transformers import ( + Gemma4ForConditionalGeneration as HFGemma4ForConditionalGeneration, + ) + + hf_full = HFGemma4ForConditionalGeneration.from_pretrained( + cls.MODEL_PATH, torch_dtype=torch.bfloat16 + ) + + cls.hf_vision_tower = hf_full.model.vision_tower.eval().to(cls.device) + cls.hf_embed_vision = hf_full.model.embed_vision.eval().to(cls.device) + + cls.hf_audio_tower = None + cls.hf_embed_audio = None + cls.mel_bins = None + if hf_full.model.audio_tower is not None: + cls.hf_audio_tower = hf_full.model.audio_tower.eval().to(cls.device) + cls.hf_embed_audio = hf_full.model.embed_audio.eval().to(cls.device) + config = AutoConfig.from_pretrained(cls.MODEL_PATH) + cls.mel_bins = config.audio_config.input_feat_size + + del hf_full + torch.cuda.empty_cache() + + # -- SGLang model via ModelRunner ------------------------------------- + cls.model_runner = ModelRunner( + model_config=ModelConfig(cls.MODEL_PATH, model_override_args="{}"), + mem_fraction_static=0.8, + gpu_id=0, + tp_rank=0, + tp_size=1, + moe_ep_rank=0, + moe_ep_size=1, + pp_rank=0, + pp_size=1, + nccl_port=12435, + server_args=ServerArgs( + model_path=cls.MODEL_PATH, + disable_cuda_graph=True, + mm_attention_backend="sdpa", + ), + ) + cls.sg_model = cls.model_runner.model + + # -- helpers -------------------------------------------------------------- + + @staticmethod + def _cosine_stats(a: torch.Tensor, b: torch.Tensor): + cos = F.cosine_similarity(a.float(), b.float()) + return cos.mean().item(), cos.min().item() + + def _assert_cosine_close(self, hf: torch.Tensor, sg: torch.Tensor, label: str): + mean_cos, min_cos = self._cosine_stats(hf, sg) + print(f" {label}: mean_cos={mean_cos:.6f} min_cos={min_cos:.6f}") + self.assertGreater( + min_cos, + self.COSINE_THRESHOLD, + f"{label} min cosine {min_cos:.6f} < {self.COSINE_THRESHOLD}", + ) + + # -- vision --------------------------------------------------------------- + + def test_vision_encoder(self): + """Vision tower + embed_vision should match HF on random pixels.""" + pixel_values = torch.randn( + 1, 3, 768, 768, device=self.device, dtype=torch.bfloat16 + ) + + with torch.no_grad(): + # HF: last_hidden_state is [1, num_real_tokens, hidden] (padding stripped) + hf_out = self.hf_vision_tower(pixel_values) + hf_tokens = hf_out.last_hidden_state.squeeze(0) + hf_projected = self.hf_embed_vision( + hf_tokens.unsqueeze(0) + ).squeeze(0) + + # SGLang: returns (pooled, pooler_mask) with mask True = valid + sg_pooled, sg_mask = self.sg_model.vision_tower(pixel_values) + sg_tokens = torch.cat( + [hs[m] for hs, m in zip(sg_pooled, sg_mask)] + ) + sg_projected = self.sg_model.embed_vision( + sg_tokens.unsqueeze(0) + ).squeeze(0) + + self.assertEqual(hf_tokens.shape, sg_tokens.shape) + print() + self._assert_cosine_close(hf_tokens, sg_tokens, "vision tower") + self._assert_cosine_close(hf_projected, sg_projected, "vision projected") + + # -- audio ---------------------------------------------------------------- + + def test_audio_encoder(self): + """Audio tower + embed_audio should match HF on random mel input.""" + if self.hf_audio_tower is None: + self.skipTest("Model does not have an audio tower") + + num_frames = 200 + audio_mel = torch.randn( + 1, num_frames, self.mel_bins, device=self.device, dtype=torch.bfloat16 + ) + audio_mel_mask = torch.zeros( + 1, num_frames, device=self.device, dtype=torch.bool + ) + + with torch.no_grad(): + # HF: returns (encodings, mask) — does NOT zero-fill padding + hf_enc, hf_mask = self.hf_audio_tower(audio_mel, audio_mel_mask) + hf_valid_mask = ~hf_mask + hf_valid = hf_enc[hf_valid_mask.unsqueeze(-1).expand_as(hf_enc)].reshape( + -1, hf_enc.shape[-1] + ) + hf_projected = self.hf_embed_audio( + hf_valid.unsqueeze(0) + ).squeeze(0) + + # SGLang: returns (encodings, mask) — zero-fills padding positions + sg_enc, sg_mask = self.sg_model.audio_tower(audio_mel, audio_mel_mask) + sg_valid_mask = ~sg_mask + sg_valid = sg_enc[sg_valid_mask.unsqueeze(-1).expand_as(sg_enc)].reshape( + -1, sg_enc.shape[-1] + ) + sg_projected = self.sg_model.embed_audio( + sg_valid.unsqueeze(0) + ).squeeze(0) + + self.assertEqual(hf_valid.shape, sg_valid.shape) + print() + self._assert_cosine_close(hf_valid, sg_valid, "audio tower") + self._assert_cosine_close(hf_projected, sg_projected, "audio projected") From 5b77767aa78f424a889170c72ba655dfc969bd3f Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Wed, 11 Mar 2026 02:06:10 +0000 Subject: [PATCH 020/112] audio, vision, and text all work correctly now --- python/sglang/srt/layers/clippable_linear.py | 81 ++++++- python/sglang/srt/models/gemma4_audio.py | 10 +- .../srt/multimodal/processors/gemma4.py | 33 ++- test/manual/test_vlm_accuracy.py | 214 ++++++++++++++++++ 4 files changed, 329 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/layers/clippable_linear.py b/python/sglang/srt/layers/clippable_linear.py index b34895b9c910..a010e8939474 100644 --- a/python/sglang/srt/layers/clippable_linear.py +++ b/python/sglang/srt/layers/clippable_linear.py @@ -163,12 +163,87 @@ def forward( return q, k, v +class ClippableGLUParallelLinear(nn.Module): + """Fused linear + GLU gating with correct TP sharding. + + Used by the audio encoder's ``LightConv1d``, where a single linear + projects to ``[hidden * 2]`` and GLU splits into value/gate halves. + A plain ``ColumnParallelLinear`` is *incorrect* here under TP because it + shards the output contiguously, mixing value and gate across ranks. + This wrapper uses ``MergedColumnParallelLinear`` to shard each half + independently, then applies GLU (``value * sigmoid(gate)``) on each + rank's correctly-paired shard. + + Output clamping is applied once *after* the GLU gate, using a single + ``output_min/max`` pair (matching the checkpoint layout). + + The checkpoint stores a single fused ``[hidden * 2, input]`` weight. + A custom ``weight_loader`` on the inner param automatically splits it + into value (first half) and gate (second half) shards, so no special + handling is needed in the model's ``load_weights``. + """ + + def __init__( + self, + input_size: int, + hidden_size: int, + *, + bias: bool = False, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + tp_size = get_attention_tp_size() + self.proj_size = hidden_size // tp_size + + self.linear = MergedColumnParallelLinear( + input_size=input_size, + output_sizes=[hidden_size, hidden_size], + bias=bias, + quant_config=quant_config, + prefix=add_prefix("linear", prefix), + ) + + # The checkpoint has a single fused weight; MergedColumnParallelLinear + # expects per-shard loading. Wrap the original weight_loader so that + # a call *without* shard_id (the generic load_weights path) splits + # automatically. + orig_loader = self.linear.weight.weight_loader + + def _fused_weight_loader(param, loaded_weight, loaded_shard_id=None): + if loaded_shard_id is not None: + return orig_loader(param, loaded_weight, loaded_shard_id) + half = loaded_weight.shape[0] // 2 + orig_loader(param, loaded_weight[:half], 0) + orig_loader(param, loaded_weight[half:], 1) + + self.linear.weight.weight_loader = _fused_weight_loader + + self.input_min = nn.parameter.Buffer(torch.tensor(-_INF), persistent=False) + self.input_max = nn.parameter.Buffer(torch.tensor(_INF), persistent=False) + self.output_min = nn.parameter.Buffer(torch.tensor(-_INF), persistent=False) + self.output_max = nn.parameter.Buffer(torch.tensor(_INF), persistent=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = torch.clamp(x, self.input_min, self.input_max) + merged, _ = self.linear(x) + value, gate = merged.split([self.proj_size, self.proj_size], dim=-1) + x = value * torch.sigmoid(gate) + x = torch.clamp(x, self.output_min, self.output_max) + return x + + class ClippableGateUpParallelLinear(nn.Module): """Fused gate/up projection with per-projection activation clamping. - Same pattern as ``ClippableQKVParallelLinear``: owns a single - ``MergedColumnParallelLinear`` for the fused matmul, with shared input - bounds and per-projection output bounds as flat buffers. + Used by the MLP layers in the vision/audio encoders. Owns a single + ``MergedColumnParallelLinear`` for the fused matmul and returns the + two projections separately so the caller can apply its own activation + (e.g. ``SiLU(gate) * up``). + + Output clamping is applied *per-projection before* the caller's + activation, using separate ``gate_output_min/max`` and + ``up_output_min/max`` bounds. """ def __init__( diff --git a/python/sglang/srt/models/gemma4_audio.py b/python/sglang/srt/models/gemma4_audio.py index c9b81a953c77..e19ea9ac9fc5 100644 --- a/python/sglang/srt/models/gemma4_audio.py +++ b/python/sglang/srt/models/gemma4_audio.py @@ -33,6 +33,7 @@ from sglang.srt.layers.clippable_linear import ( ClippableColumnParallelLinear, + ClippableGLUParallelLinear, ClippableQKVParallelLinear, ClippableRowParallelLinear, ) @@ -239,8 +240,8 @@ def __init__( self.qkv = ClippableQKVParallelLinear( hidden_size=self.hidden_size, head_size=self.head_dim, - total_num_heads=self.num_heads, - total_num_kv_heads=self.num_heads, + total_num_heads=total_num_heads, + total_num_kv_heads=total_num_heads, bias=False, quant_config=quant_config, prefix=prefix, @@ -662,9 +663,9 @@ def __init__( self.pre_layer_norm = Gemma4RMSNorm( config.hidden_size, eps=config.rms_norm_eps, scale_shift=0.0 ) - self.linear_start = ClippableColumnParallelLinear( + self.linear_start = ClippableGLUParallelLinear( + config.hidden_size, config.hidden_size, - config.hidden_size * 2, bias=False, quant_config=quant_config, prefix=add_prefix("linear_start", prefix), @@ -705,7 +706,6 @@ def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor: audio_encodings = self.pre_layer_norm(audio_encodings) audio_encodings = self.linear_start(audio_encodings) - audio_encodings = F.glu(audio_encodings, dim=-1) audio_encodings_permuted = audio_encodings.permute(0, 2, 1) audio_encodings_permuted_padded = F.pad( diff --git a/python/sglang/srt/multimodal/processors/gemma4.py b/python/sglang/srt/multimodal/processors/gemma4.py index 5c115da37bb6..861197b8347c 100644 --- a/python/sglang/srt/multimodal/processors/gemma4.py +++ b/python/sglang/srt/multimodal/processors/gemma4.py @@ -14,13 +14,14 @@ from typing import Dict, List, Optional, Union +import numpy as np + from sglang.srt.managers.multimodal_processor import ( BaseMultimodalProcessor as SGLangBaseProcessor, ) from sglang.srt.models.gemma4_mm import Gemma4ForConditionalGeneration from sglang.srt.multimodal.processors.base_processor import MultimodalSpecialTokens - class Gemma4SGLangProcessor(SGLangBaseProcessor): """Multimodal processor for Gemma4 supporting image and audio inputs.""" @@ -39,6 +40,36 @@ def __init__(self, hf_config, server_args, _processor, *args, **kwargs): audio_token_id=hf_config.audio_token_id, ).build(_processor) + def _get_audio_pad_multiple(self) -> int: + """Derive the waveform padding alignment from processor config. + + The HF processor's ceil(duration_ms / audio_ms_per_token) formula can + overshoot by 1 token relative to what the SSCP convolutions produce. + Padding waveforms to a multiple of (hop_length * first_conv_stride) + aligns the two calculations. + See: gemma-4-eap-extras/examples/gemma-4-audio-examples.ipynb + """ + fe = getattr(self._processor, "feature_extractor", None) + hop = getattr(fe, "hop_length", 160) + ac = getattr(self.hf_config, "audio_config", None) + first_stride = ac.sscp_conv_stride_size[0][0] if ac is not None else 2 + return hop * first_stride + + def process_mm_data(self, input_text, images=None, videos=None, audios=None, **kwargs): + if audios: + pad_multiple = self._get_audio_pad_multiple() + padded = [] + for a in audios: + a = np.asarray(a) + remainder = len(a) % pad_multiple + if remainder != 0: + a = np.pad(a, (0, pad_multiple - remainder), mode="constant") + padded.append(a) + audios = padded + return super().process_mm_data( + input_text, images=images, videos=videos, audios=audios, **kwargs + ) + async def process_mm_data_async( self, image_data: Optional[List[Union[str, bytes, Dict]]] = None, diff --git a/test/manual/test_vlm_accuracy.py b/test/manual/test_vlm_accuracy.py index 0f480293ffb0..cb759115cd53 100644 --- a/test/manual/test_vlm_accuracy.py +++ b/test/manual/test_vlm_accuracy.py @@ -4,11 +4,15 @@ # vision and audio encoder comparisons, not just VLM embeddings. """ +import os +import socket +import tempfile import unittest from typing import List, Optional import numpy as np import torch +import torch.multiprocessing as mp import torch.nn.functional as F from transformers import AutoConfig, AutoModel, AutoProcessor, AutoTokenizer @@ -477,3 +481,213 @@ def test_audio_encoder(self): print() self._assert_cosine_close(hf_valid, sg_valid, "audio tower") self._assert_cosine_close(hf_projected, sg_projected, "audio projected") + + +# --------------------------------------------------------------------------- +# Gemma 4 encoder accuracy at TP=2: compare SGLang (TP=2) vs HF reference +# --------------------------------------------------------------------------- + + +def _find_free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("localhost", 0)) + s.listen(1) + return s.getsockname()[1] + + +def _tp2_encoder_worker( + local_rank: int, + world_size: int, + nccl_port: int, + model_path: str, + mel_bins: int, + num_frames: int, + result_file: str, +): + """Worker spawned by mp.spawn — loads SGLang model with TP and runs encoders.""" + os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(i) for i in range(world_size)) + device = torch.device(f"cuda:{local_rank}") + torch.cuda.set_device(device) + + model_runner = ModelRunner( + model_config=ModelConfig(model_path, model_override_args="{}"), + mem_fraction_static=0.5, + gpu_id=local_rank, + tp_rank=local_rank, + tp_size=world_size, + moe_ep_rank=0, + moe_ep_size=1, + pp_rank=0, + pp_size=1, + nccl_port=nccl_port, + server_args=ServerArgs( + model_path=model_path, + disable_cuda_graph=True, + mm_attention_backend="sdpa", + mem_fraction_static=0.5, + ), + ) + sg_model = model_runner.model + + # Deterministic input — identical on every rank. + torch.manual_seed(42) + audio_mel = torch.randn( + 1, num_frames, mel_bins, device=device, dtype=torch.bfloat16 + ) + audio_mel_mask = torch.zeros(1, num_frames, device=device, dtype=torch.bool) + pixel_values = torch.randn(1, 3, 768, 768, device=device, dtype=torch.bfloat16) + + with torch.no_grad(): + # Audio + sg_audio_enc, sg_audio_mask = sg_model.audio_tower(audio_mel, audio_mel_mask) + sg_audio_valid_mask = ~sg_audio_mask + sg_audio_valid = sg_audio_enc[ + sg_audio_valid_mask.unsqueeze(-1).expand_as(sg_audio_enc) + ].reshape(-1, sg_audio_enc.shape[-1]) + sg_audio_proj = sg_model.embed_audio(sg_audio_valid.unsqueeze(0)).squeeze(0) + + # Vision + sg_vis_pooled, sg_vis_mask = sg_model.vision_tower(pixel_values) + sg_vis_tokens = torch.cat([hs[m] for hs, m in zip(sg_vis_pooled, sg_vis_mask)]) + sg_vis_proj = sg_model.embed_vision(sg_vis_tokens.unsqueeze(0)).squeeze(0) + + if local_rank == 0: + torch.save( + { + "audio_valid": sg_audio_valid.cpu(), + "audio_projected": sg_audio_proj.cpu(), + "vision_tokens": sg_vis_tokens.cpu(), + "vision_projected": sg_vis_proj.cpu(), + }, + result_file, + ) + + +class TestGemma4EncoderAccuracyTP2(unittest.TestCase): + """Compare Gemma 4 vision + audio encoder outputs at TP=2 against HF. + + Uses ``mp.spawn`` to create 2 workers that jointly load the SGLang model + with tensor parallelism, then compares rank-0 output with the HF reference + computed in the parent process. + """ + + MODEL_PATH = "gg-hf-gg/gemma-4-e4b-it" + # TP=2 all-reduce introduces small bf16 rounding that compounds across + # 12 conformer blocks; 0.98 is the practical floor. + COSINE_THRESHOLD = 0.98 + NUM_FRAMES = 200 + + @classmethod + def setUpClass(cls): + if not torch.cuda.is_available() or torch.cuda.device_count() < 2: + raise unittest.SkipTest("Need >= 2 GPUs for TP=2 test") + + cls.device = torch.device("cuda:0") + config = AutoConfig.from_pretrained(cls.MODEL_PATH) + cls.mel_bins = config.audio_config.input_feat_size + + # -- HF reference (run on GPU 0, then free) ---------------------------- + from transformers import ( + Gemma4ForConditionalGeneration as HFGemma4ForConditionalGeneration, + ) + + hf_full = HFGemma4ForConditionalGeneration.from_pretrained( + cls.MODEL_PATH, torch_dtype=torch.bfloat16 + ) + hf_audio_tower = hf_full.model.audio_tower.eval().to(cls.device) + hf_embed_audio = hf_full.model.embed_audio.eval().to(cls.device) + hf_vision_tower = hf_full.model.vision_tower.eval().to(cls.device) + hf_embed_vision = hf_full.model.embed_vision.eval().to(cls.device) + del hf_full + torch.cuda.empty_cache() + + torch.manual_seed(42) + audio_mel = torch.randn( + 1, cls.NUM_FRAMES, cls.mel_bins, device=cls.device, dtype=torch.bfloat16 + ) + audio_mel_mask = torch.zeros( + 1, cls.NUM_FRAMES, device=cls.device, dtype=torch.bool + ) + pixel_values = torch.randn( + 1, 3, 768, 768, device=cls.device, dtype=torch.bfloat16 + ) + + with torch.no_grad(): + hf_enc, hf_mask = hf_audio_tower(audio_mel, audio_mel_mask) + hf_valid_mask = ~hf_mask + cls.hf_audio_valid = hf_enc[ + hf_valid_mask.unsqueeze(-1).expand_as(hf_enc) + ].reshape(-1, hf_enc.shape[-1]).cpu() + cls.hf_audio_proj = hf_embed_audio( + cls.hf_audio_valid.unsqueeze(0).to(cls.device) + ).squeeze(0).cpu() + + hf_vis_out = hf_vision_tower(pixel_values) + cls.hf_vis_tokens = hf_vis_out.last_hidden_state.squeeze(0).cpu() + cls.hf_vis_proj = hf_embed_vision( + cls.hf_vis_tokens.unsqueeze(0).to(cls.device) + ).squeeze(0).cpu() + + del hf_audio_tower, hf_embed_audio, hf_vision_tower, hf_embed_vision + import gc + + gc.collect() + torch.cuda.empty_cache() + + # -- Run SGLang at TP=2 via mp.spawn ----------------------------------- + nccl_port = _find_free_port() + cls._result_file = tempfile.NamedTemporaryFile(suffix=".pt", delete=False) + cls._result_file.close() + + mp.spawn( + _tp2_encoder_worker, + args=( + 2, + nccl_port, + cls.MODEL_PATH, + cls.mel_bins, + cls.NUM_FRAMES, + cls._result_file.name, + ), + nprocs=2, + join=True, + ) + + cls.sg_results = torch.load(cls._result_file.name, weights_only=True) + + @classmethod + def tearDownClass(cls): + if hasattr(cls, "_result_file"): + os.unlink(cls._result_file.name) + + @staticmethod + def _cosine_stats(a: torch.Tensor, b: torch.Tensor): + cos = F.cosine_similarity(a.float(), b.float()) + return cos.mean().item(), cos.min().item() + + def _assert_cosine_close(self, hf: torch.Tensor, sg: torch.Tensor, label: str): + mean_cos, min_cos = self._cosine_stats(hf, sg) + print(f" {label}: mean_cos={mean_cos:.6f} min_cos={min_cos:.6f}") + self.assertGreater( + min_cos, + self.COSINE_THRESHOLD, + f"{label} min cosine {min_cos:.6f} < {self.COSINE_THRESHOLD}", + ) + + def test_audio_encoder_tp2(self): + """Audio tower + embed_audio at TP=2 should match HF reference.""" + sg_valid = self.sg_results["audio_valid"] + sg_proj = self.sg_results["audio_projected"] + self.assertEqual(self.hf_audio_valid.shape, sg_valid.shape) + print() + self._assert_cosine_close(self.hf_audio_valid, sg_valid, "audio tower TP=2") + self._assert_cosine_close(self.hf_audio_proj, sg_proj, "audio projected TP=2") + + def test_vision_encoder_tp2(self): + """Vision tower + embed_vision at TP=2 should match HF reference.""" + sg_tokens = self.sg_results["vision_tokens"] + sg_proj = self.sg_results["vision_projected"] + self.assertEqual(self.hf_vis_tokens.shape, sg_tokens.shape) + print() + self._assert_cosine_close(self.hf_vis_tokens, sg_tokens, "vision tower TP=2") + self._assert_cosine_close(self.hf_vis_proj, sg_proj, "vision projected TP=2") From 785b99ece20423069c5b0ff966ffe891161df32f Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Wed, 11 Mar 2026 22:24:00 +0000 Subject: [PATCH 021/112] fix swa memory pool indices to retrieve --- python/sglang/srt/layers/attention/triton_backend.py | 7 +++++-- python/sglang/srt/mem_cache/memory_pool.py | 8 ++++++++ python/sglang/srt/mem_cache/swa_memory_pool.py | 9 +++++++++ python/sglang/srt/models/gemma4_mm.py | 7 +++++++ 4 files changed, 29 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index 359882e72026..1c842fe446b8 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -820,9 +820,12 @@ def forward_extend( o = torch.empty_like(q) if k is None and v is None: + cache_loc = forward_batch.token_to_kv_pool.translate_loc( + layer.layer_id, forward_batch.out_cache_loc + ) k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) - k = k_buffer[forward_batch.out_cache_loc] - v = v_buffer[forward_batch.out_cache_loc] + k = k_buffer[cache_loc] + v = v_buffer[cache_loc] elif k is None or v is None: raise ValueError("Both k and v should be None or not None") else: diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 9943e4715b05..895fc61627ec 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -708,6 +708,14 @@ def set_kv_buffer( ) -> None: raise NotImplementedError() + def translate_loc(self, layer_id: int, loc: torch.Tensor) -> torch.Tensor: + """Translate full-pool cache locations to the correct locations for a given layer. + + For pools with separate sub-pools (e.g. SWAKVPool), the indices used to + write/read may differ per layer type. Override this to perform the mapping. + """ + return loc + def register_layer_transfer_counter(self, layer_transfer_counter: LayerDoneCounter): self.layer_transfer_counter = layer_transfer_counter diff --git a/python/sglang/srt/mem_cache/swa_memory_pool.py b/python/sglang/srt/mem_cache/swa_memory_pool.py index 06b0cb01fc97..a133ad760ff1 100644 --- a/python/sglang/srt/mem_cache/swa_memory_pool.py +++ b/python/sglang/srt/mem_cache/swa_memory_pool.py @@ -136,6 +136,15 @@ def get_kv_buffer(self, layer_id: int): else: return self.full_kv_pool.get_kv_buffer(layer_id_pool) + def translate_loc(self, layer_id: int, loc: torch.Tensor) -> torch.Tensor: + _, is_swa_layer = self.layers_mapping[layer_id] + if is_swa_layer: + if self.swa_loc is not None: + return self.swa_loc + elif self.full_to_swa_index_mapping is not None: + return self.translate_loc_from_full_to_swa(loc) + return loc + def set_swa_loc(self, loc: torch.Tensor): self.swa_loc = loc diff --git a/python/sglang/srt/models/gemma4_mm.py b/python/sglang/srt/models/gemma4_mm.py index b97bd1393b5b..a62012bad943 100644 --- a/python/sglang/srt/models/gemma4_mm.py +++ b/python/sglang/srt/models/gemma4_mm.py @@ -244,6 +244,13 @@ def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: all_embeds = [] for pv in all_pixel_values: + if ( + pv.dim() in (2, 3) + and pv.shape[-1] == self.config.text_config.hidden_size + ): + all_embeds.append(pv.to(self.language_model.device)) + continue + if pv.dim() == 5: pv = pv.squeeze(0) if pv.dim() == 3: From 27687c965b25407c6f589072ff47bb27f3f05eec Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Wed, 11 Mar 2026 22:47:21 +0000 Subject: [PATCH 022/112] softmax_scale should not be kwargs --- python/sglang/srt/layers/attention/vision.py | 21 +++++++++----------- 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/python/sglang/srt/layers/attention/vision.py b/python/sglang/srt/layers/attention/vision.py index 84d115e1bb18..38525c900a53 100644 --- a/python/sglang/srt/layers/attention/vision.py +++ b/python/sglang/srt/layers/attention/vision.py @@ -241,6 +241,7 @@ def forward( bsz: int, cu_seqlens: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, + softmax_scale: Optional[float] = None, **kwargs, ) -> torch.Tensor: r""" @@ -329,6 +330,7 @@ def forward( cu_seqlens: torch.Tensor | SingletonCache | None, bsz: int, seq_len: int, + softmax_scale: Optional[float] = None, **kwargs, ) -> torch.Tensor: r""" @@ -338,8 +340,6 @@ def forward( Returns: [b * s, h, head_size] """ - softmax_scale = kwargs.get("softmax_scale", None) - if envs.SGLANG_VIT_ENABLE_CUDA_GRAPH.get(): if "output_ws" not in kwargs: raise RuntimeError("output_ws should be prepared for cuda-graph mode") @@ -403,6 +403,7 @@ def forward( cu_seqlens: torch.Tensor | SingletonCache | None, bsz: int, seq_len: int, + softmax_scale: Optional[float] = None, **kwargs, ) -> torch.Tensor: r""" @@ -411,8 +412,6 @@ def forward( Returns: [b * s, h, head_size] """ - softmax_scale = kwargs.get("softmax_scale", None) - if envs.SGLANG_VIT_ENABLE_CUDA_GRAPH.get(): max_seqlen = cu_seqlens[1] output = flash_attn_func( @@ -462,6 +461,7 @@ def forward( cu_seqlens: torch.Tensor | SingletonCache | None, bsz: int, seq_len: int, + softmax_scale: Optional[float] = None, **kwargs, ) -> torch.Tensor: r""" @@ -479,8 +479,6 @@ def forward( ) cu_seqlens = cu_seqlens.get_data() - softmax_scale = kwargs.get("softmax_scale", None) - cu_seqlens = cu_seqlens.to(dtype=torch.int32).to(q.device) seq_lens = cu_seqlens[1:] - cu_seqlens[:-1] max_seqlen = seq_lens.max().item() @@ -520,6 +518,7 @@ def forward( cu_seqlens: torch.Tensor | SingletonCache | None, bsz: int, seq_len: int, + softmax_scale: Optional[float] = None, **kwargs, ) -> torch.Tensor: r""" @@ -595,7 +594,7 @@ def forward( raise RuntimeError("offset + len out of bounds; packed indptr is wrong") _, _, head_size = q.shape - scale = head_size**-0.5 + scale = softmax_scale if softmax_scale is not None else head_size**-0.5 output, _ = cudnn_batch_prefill_with_kv_cache( q, @@ -647,10 +646,9 @@ def forward( cu_seqlens: torch.Tensor | SingletonCache | None, bsz: int, seq_len: int, + softmax_scale: Optional[float] = None, **kwargs, ) -> torch.Tensor: - softmax_scale = kwargs.get("softmax_scale", None) - cu_seqlens = resolve_seqlens(cu_seqlens, bsz, seq_len, device=q.device) cu_seqlens = cu_seqlens.to(dtype=torch.int32).to(q.device) @@ -687,6 +685,7 @@ def forward( cu_seqlens: torch.Tensor | SingletonCache | None, bsz: int, seq_len: int, + softmax_scale: Optional[float] = None, **kwargs, ) -> torch.Tensor: r""" @@ -699,7 +698,6 @@ def forward( if "output_ws" not in kwargs: raise RuntimeError("output_ws should be prepared for npu-graph mode") output = kwargs["output_ws"] - # graph mode: runner already passes seq_lens (int32 on CPU) seq_len_arg = cu_seqlens else: cu_seqlens = resolve_seqlens(cu_seqlens, bsz, seq_len, device="cpu") @@ -709,11 +707,10 @@ def forward( output = torch.empty_like(q) seq_len_arg = seq_lens.to(torch.int32) - _, num_heads, head_size = q.shape num_kv_heads = k.shape[1] - scale_value = kwargs.get("softmax_scale") or head_size**-0.5 + scale_value = softmax_scale if softmax_scale is not None else head_size**-0.5 torch_npu._npu_flash_attention_unpad( query=q, From 278d9fcf3f14c804e5faf0a699be7a9d7028a45e Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Sun, 15 Mar 2026 20:51:28 +0000 Subject: [PATCH 023/112] fix misc bugs with SWA kv cache --- python/sglang/srt/configs/model_config.py | 2 +- .../srt/layers/attention/triton_backend.py | 56 +++++++++++++++++-- .../sglang/srt/mem_cache/swa_memory_pool.py | 5 ++ .../sglang/srt/utils/hf_transformers_utils.py | 1 + 4 files changed, 57 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 3c8816b03534..4dae87e5d739 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -418,7 +418,7 @@ def _derive_model_shapes(self): self.swa_v_head_dim = getattr( self.hf_text_config, "swa_v_head_dim", - self.v_head_dim, + self.swa_head_dim, ) # FIXME: temporary special judge for MLA architecture if ( diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index 1c842fe446b8..37a2fd6ad61b 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -51,6 +51,8 @@ class ForwardMetadata: window_kv_indices: torch.Tensor window_num_kv_splits: torch.Tensor window_kv_offsets: torch.Tensor + # Separate attn_logits for SWA layers when v_head_dim differs + swa_attn_logits: Optional[torch.Tensor] = None class TritonAttnBackend(AttentionBackend): @@ -94,16 +96,27 @@ def __init__( self.num_kv_head = model_runner.model_config.get_num_kv_heads( get_attention_tp_size() ) - if ( + # The decode kernel's intermediate attn_logits buffer must match the + # exact v_head_dim of the layer being processed (the triton kernel uses + # a // Lv stride trick to derive attn_lse indices from attn_logits strides). + # When SWA and full attention layers have different v_head_dim (e.g. Gemma 4 + # with swa=256, full=512), we need two separate attn_logits buffers. + full_v_head_dim = model_runner.model_config.v_head_dim + swa_v_head_dim = model_runner.model_config.swa_v_head_dim + if swa_v_head_dim != full_v_head_dim: + self.v_head_dim = full_v_head_dim + self.swa_v_head_dim = swa_v_head_dim + elif ( model_runner.hybrid_gdn_config is not None or model_runner.kimi_linear_config is not None ): - # For hybrid linear models, layer_id = 0 may not be full attention self.v_head_dim = model_runner.token_to_kv_pool.get_v_head_dim() + self.swa_v_head_dim = None else: - self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[ - -1 - ] + self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer( + 0 + ).shape[-1] + self.swa_v_head_dim = None self.max_context_len = model_runner.model_config.context_len self.device = model_runner.device self.device_core_count = get_device_core_count(model_runner.gpu_id) @@ -242,6 +255,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): window_kv_indices = None window_num_kv_splits = None window_kv_offsets = None + swa_attn_logits = None spec_info = forward_batch.spec_info if forward_batch.forward_mode.is_decode_or_idle(): @@ -290,6 +304,14 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): dtype=torch.float32, device=self.device, ) + if self.swa_v_head_dim is not None: + swa_attn_logits = torch.empty( + (bs, self.num_head, self.max_kv_splits, self.swa_v_head_dim), + dtype=torch.float32, + device=self.device, + ) + else: + swa_attn_logits = None attn_lse = torch.empty( (bs, self.num_head, self.max_kv_splits), dtype=torch.float32, @@ -436,6 +458,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): window_kv_indices, window_num_kv_splits, window_kv_offsets, + swa_attn_logits=swa_attn_logits, ) def init_cuda_graph_state( @@ -450,6 +473,14 @@ def init_cuda_graph_state( dtype=torch.float32, device=self.device, ) + if self.swa_v_head_dim is not None: + self.cuda_graph_swa_attn_logits = torch.zeros( + (max_num_tokens, self.num_head, self.max_kv_splits, self.swa_v_head_dim), + dtype=torch.float32, + device=self.device, + ) + else: + self.cuda_graph_swa_attn_logits = None self.cuda_graph_attn_lse = torch.zeros( (max_num_tokens, self.num_head, self.max_kv_splits), dtype=torch.float32, @@ -520,6 +551,7 @@ def init_forward_metadata_capture_cuda_graph( window_kv_indices = None window_num_kv_splits = None window_kv_offsets = None + swa_attn_logits = None if forward_mode.is_decode_or_idle(): if spec_info is None: @@ -558,6 +590,7 @@ def init_forward_metadata_capture_cuda_graph( kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices attn_logits = self.cuda_graph_attn_logits + swa_attn_logits = self.cuda_graph_swa_attn_logits attn_lse = self.cuda_graph_attn_lse max_extend_len = None num_kv_splits = self.cuda_graph_num_kv_splits @@ -659,6 +692,7 @@ def init_forward_metadata_capture_cuda_graph( window_kv_indices, window_num_kv_splits, window_kv_offsets, + swa_attn_logits=swa_attn_logits, ) def init_forward_metadata_replay_cuda_graph( @@ -1099,6 +1133,16 @@ def forward_decode( k_descale = 1.0 v_descale = 1.0 + # Select the correctly-sized attn_logits buffer for this layer. + # The triton kernel's // Lv stride trick requires attn_logits.shape[-1] + # to exactly match the layer's v_head_dim. + attn_logits = self.forward_metadata.attn_logits + if ( + self.forward_metadata.swa_attn_logits is not None + and layer.v_head_dim == self.swa_v_head_dim + ): + attn_logits = self.forward_metadata.swa_attn_logits + self.decode_attention_fwd( q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id), @@ -1106,7 +1150,7 @@ def forward_decode( o.view(-1, layer.tp_q_head_num, layer.v_head_dim), kv_indptr, kv_indices, - self.forward_metadata.attn_logits, + attn_logits, self.forward_metadata.attn_lse, self.forward_metadata.num_kv_splits, self.max_kv_splits, diff --git a/python/sglang/srt/mem_cache/swa_memory_pool.py b/python/sglang/srt/mem_cache/swa_memory_pool.py index a133ad760ff1..eaff0f10e2dc 100644 --- a/python/sglang/srt/mem_cache/swa_memory_pool.py +++ b/python/sglang/srt/mem_cache/swa_memory_pool.py @@ -98,6 +98,11 @@ def get_kv_size_bytes(self): k_size_swa, v_size_swa = self.swa_kv_pool.get_kv_size_bytes() return k_size + k_size_swa, v_size + v_size_swa + def get_v_head_dim(self): + swa_v_dim = self.swa_kv_pool.get_value_buffer(0).shape[-1] + full_v_dim = self.full_kv_pool.get_value_buffer(0).shape[-1] + return max(swa_v_dim, full_v_dim) + def get_contiguous_buf_infos(self): full_kv_data_ptrs, full_kv_data_lens, full_kv_item_lens = ( self.full_kv_pool.get_contiguous_buf_infos() diff --git a/python/sglang/srt/utils/hf_transformers_utils.py b/python/sglang/srt/utils/hf_transformers_utils.py index 7627b73cc310..367b13b3314a 100644 --- a/python/sglang/srt/utils/hf_transformers_utils.py +++ b/python/sglang/srt/utils/hf_transformers_utils.py @@ -405,6 +405,7 @@ def get_config( if global_head_dim is not None: config.text_config.swa_head_dim = config.text_config.head_dim + config.text_config.swa_v_head_dim = config.text_config.head_dim config.text_config.head_dim = global_head_dim config.text_config.swa_num_key_value_heads = config.num_key_value_heads From 8b2dbe4a9da8c940df85abca0c199739b51750f4 Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Tue, 17 Mar 2026 20:33:44 +0000 Subject: [PATCH 024/112] addressing comments --- python/sglang/srt/layers/attention/triton_backend.py | 11 +++++++---- python/sglang/srt/mem_cache/memory_pool.py | 8 -------- python/sglang/srt/mem_cache/swa_memory_pool.py | 9 --------- 3 files changed, 7 insertions(+), 21 deletions(-) diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index 37a2fd6ad61b..93d81994384c 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -13,6 +13,7 @@ from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.layers.radix_attention import AttentionType from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode +from sglang.srt.mem_cache.swa_memory_pool import SWAKVPool from sglang.srt.speculative.spec_utils import generate_draft_decode_kv_indices from sglang.srt.utils import ( get_bool_env_var, @@ -110,6 +111,7 @@ def __init__( model_runner.hybrid_gdn_config is not None or model_runner.kimi_linear_config is not None ): + # For hybrid linear models, layer_id = 0 may not be full attention self.v_head_dim = model_runner.token_to_kv_pool.get_v_head_dim() self.swa_v_head_dim = None else: @@ -854,10 +856,11 @@ def forward_extend( o = torch.empty_like(q) if k is None and v is None: - cache_loc = forward_batch.token_to_kv_pool.translate_loc( - layer.layer_id, forward_batch.out_cache_loc - ) - k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) + pool = forward_batch.token_to_kv_pool + cache_loc = forward_batch.out_cache_loc + if isinstance(pool, SWAKVPool) and pool.layers_mapping[layer.layer_id][1]: + cache_loc = pool.translate_loc_from_full_to_swa(cache_loc) + k_buffer, v_buffer = pool.get_kv_buffer(layer.layer_id) k = k_buffer[cache_loc] v = v_buffer[cache_loc] elif k is None or v is None: diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 895fc61627ec..9943e4715b05 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -708,14 +708,6 @@ def set_kv_buffer( ) -> None: raise NotImplementedError() - def translate_loc(self, layer_id: int, loc: torch.Tensor) -> torch.Tensor: - """Translate full-pool cache locations to the correct locations for a given layer. - - For pools with separate sub-pools (e.g. SWAKVPool), the indices used to - write/read may differ per layer type. Override this to perform the mapping. - """ - return loc - def register_layer_transfer_counter(self, layer_transfer_counter: LayerDoneCounter): self.layer_transfer_counter = layer_transfer_counter diff --git a/python/sglang/srt/mem_cache/swa_memory_pool.py b/python/sglang/srt/mem_cache/swa_memory_pool.py index eaff0f10e2dc..d1ccb3ceace9 100644 --- a/python/sglang/srt/mem_cache/swa_memory_pool.py +++ b/python/sglang/srt/mem_cache/swa_memory_pool.py @@ -141,15 +141,6 @@ def get_kv_buffer(self, layer_id: int): else: return self.full_kv_pool.get_kv_buffer(layer_id_pool) - def translate_loc(self, layer_id: int, loc: torch.Tensor) -> torch.Tensor: - _, is_swa_layer = self.layers_mapping[layer_id] - if is_swa_layer: - if self.swa_loc is not None: - return self.swa_loc - elif self.full_to_swa_index_mapping is not None: - return self.translate_loc_from_full_to_swa(loc) - return loc - def set_swa_loc(self, loc: torch.Tensor): self.swa_loc = loc From 95950dcd0477a3e25c7a111ca6f4beb6f3c6e558 Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Tue, 17 Mar 2026 20:36:28 +0000 Subject: [PATCH 025/112] lint --- .../srt/layers/attention/triton_backend.py | 15 ++- python/sglang/srt/layers/attention/vision.py | 6 +- python/sglang/srt/layers/clippable_linear.py | 5 +- python/sglang/srt/models/gemma4_audio.py | 38 +++--- python/sglang/srt/models/gemma4_mm.py | 17 +-- python/sglang/srt/models/gemma4_vision.py | 114 +++++++++++++----- .../srt/multimodal/processors/gemma4.py | 5 +- test/manual/test_vlm_accuracy.py | 44 +++---- 8 files changed, 150 insertions(+), 94 deletions(-) diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index 93d81994384c..3cba295b1f59 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -12,8 +12,8 @@ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.layers.radix_attention import AttentionType -from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.mem_cache.swa_memory_pool import SWAKVPool +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.speculative.spec_utils import generate_draft_decode_kv_indices from sglang.srt.utils import ( get_bool_env_var, @@ -115,9 +115,9 @@ def __init__( self.v_head_dim = model_runner.token_to_kv_pool.get_v_head_dim() self.swa_v_head_dim = None else: - self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer( - 0 - ).shape[-1] + self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[ + -1 + ] self.swa_v_head_dim = None self.max_context_len = model_runner.model_config.context_len self.device = model_runner.device @@ -477,7 +477,12 @@ def init_cuda_graph_state( ) if self.swa_v_head_dim is not None: self.cuda_graph_swa_attn_logits = torch.zeros( - (max_num_tokens, self.num_head, self.max_kv_splits, self.swa_v_head_dim), + ( + max_num_tokens, + self.num_head, + self.max_kv_splits, + self.swa_v_head_dim, + ), dtype=torch.float32, device=self.device, ) diff --git a/python/sglang/srt/layers/attention/vision.py b/python/sglang/srt/layers/attention/vision.py index 38525c900a53..bbff58f65187 100644 --- a/python/sglang/srt/layers/attention/vision.py +++ b/python/sglang/srt/layers/attention/vision.py @@ -175,7 +175,11 @@ def __init__( self.flatten_batch = flatten_batch self.softmax_in_single_precision = softmax_in_single_precision self.dropout = dropout - self.scale = softmax_scale if softmax_scale is not None else 1.0 / math.sqrt(self.head_size) + self.scale = ( + softmax_scale + if softmax_scale is not None + else 1.0 / math.sqrt(self.head_size) + ) @staticmethod @lru_cache(maxsize=128) diff --git a/python/sglang/srt/layers/clippable_linear.py b/python/sglang/srt/layers/clippable_linear.py index a010e8939474..a253bb42197a 100644 --- a/python/sglang/srt/layers/clippable_linear.py +++ b/python/sglang/srt/layers/clippable_linear.py @@ -37,7 +37,6 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.utils import add_prefix - _INF = float("inf") @@ -268,7 +267,9 @@ def __init__( ) self.input_min = nn.parameter.Buffer(torch.tensor(-_INF), persistent=False) self.input_max = nn.parameter.Buffer(torch.tensor(_INF), persistent=False) - self.gate_output_min = nn.parameter.Buffer(torch.tensor(-_INF), persistent=False) + self.gate_output_min = nn.parameter.Buffer( + torch.tensor(-_INF), persistent=False + ) self.gate_output_max = nn.parameter.Buffer(torch.tensor(_INF), persistent=False) self.up_output_min = nn.parameter.Buffer(torch.tensor(-_INF), persistent=False) self.up_output_max = nn.parameter.Buffer(torch.tensor(_INF), persistent=False) diff --git a/python/sglang/srt/models/gemma4_audio.py b/python/sglang/srt/models/gemma4_audio.py index e19ea9ac9fc5..3926a99f2cef 100644 --- a/python/sglang/srt/models/gemma4_audio.py +++ b/python/sglang/srt/models/gemma4_audio.py @@ -49,7 +49,6 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.utils import add_prefix, make_layers, set_weight_attrs - # --------------------------------------------------------------------------- # Relative Position Embedding # --------------------------------------------------------------------------- @@ -249,7 +248,7 @@ def __init__( # softplus(0) = log(2); pre-fold into scale factors r_softplus_0 = 1.0 / math.log(2) - self.q_scale = (self.head_dim ** -0.5) * r_softplus_0 + self.q_scale = (self.head_dim**-0.5) * r_softplus_0 self.k_scale = r_softplus_0 self.register_buffer( @@ -306,10 +305,14 @@ def forward( per_dim_scale_sp = F.softplus(self.per_dim_scale) broadcast_shape = (1, 1, 1, self.head_dim) - query_states = query_states * self.q_scale * per_dim_scale_sp.view(broadcast_shape) + query_states = ( + query_states * self.q_scale * per_dim_scale_sp.view(broadcast_shape) + ) per_dim_key_scale_sp = F.softplus(self.per_dim_key_scale) - key_states = key_states * self.k_scale * per_dim_key_scale_sp.view(broadcast_shape) + key_states = ( + key_states * self.k_scale * per_dim_key_scale_sp.view(broadcast_shape) + ) batch_size, q_time = query_states.shape[:2] @@ -367,9 +370,9 @@ def forward( prob_bun = probabilities.permute(0, 2, 1, 3, 4).reshape(-1, w_dim, c_dim) v_bun = value_blocks.permute(0, 1, 3, 2, 4).reshape(-1, c_dim, h_dim) result_bmm = torch.bmm(prob_bun, v_bun) - context_vectors = result_bmm.reshape( - b_dim, u_dim, n_dim, w_dim, h_dim - ).permute(0, 1, 3, 2, 4) + context_vectors = result_bmm.reshape(b_dim, u_dim, n_dim, w_dim, h_dim).permute( + 0, 1, 3, 2, 4 + ) context_vectors = context_vectors.reshape( batch_size, num_query_blocks * self.chunk_size, @@ -404,7 +407,10 @@ def __init__( stride_t, stride_f = config.sscp_conv_stride_size[idx] self.time_stride = stride_t - if config.sscp_conv_time_pad_top is not None and config.sscp_conv_time_pad_bottom is not None: + if ( + config.sscp_conv_time_pad_top is not None + and config.sscp_conv_time_pad_bottom is not None + ): pad_t_top = config.sscp_conv_time_pad_top pad_t_bottom = config.sscp_conv_time_pad_bottom elif config.sscp_conv_padding_type == "semicausal": @@ -543,9 +549,7 @@ def __init__( persistent=False, ) - self.pre_attn_norm = Gemma4RMSNorm( - config.hidden_size, scale_shift=0.0 - ) + self.pre_attn_norm = Gemma4RMSNorm(config.hidden_size, scale_shift=0.0) self.attn = Gemma4AudioAttention( config, quant_config, prefix=add_prefix("attn", prefix) ) @@ -556,9 +560,7 @@ def __init__( quant_config=quant_config, prefix=add_prefix("post", prefix), ) - self.post_norm = Gemma4RMSNorm( - config.hidden_size, scale_shift=0.0 - ) + self.post_norm = Gemma4RMSNorm(config.hidden_size, scale_shift=0.0) def forward( self, @@ -603,9 +605,7 @@ def __init__( persistent=False, ) - self.pre_layer_norm = Gemma4RMSNorm( - config.hidden_size, scale_shift=0.0 - ) + self.pre_layer_norm = Gemma4RMSNorm(config.hidden_size, scale_shift=0.0) self.ffw_layer_1 = ClippableColumnParallelLinear( config.hidden_size, config.hidden_size * 4, @@ -620,9 +620,7 @@ def __init__( quant_config=quant_config, prefix=add_prefix("ffw_layer_2", prefix), ) - self.post_layer_norm = Gemma4RMSNorm( - config.hidden_size, scale_shift=0.0 - ) + self.post_layer_norm = Gemma4RMSNorm(config.hidden_size, scale_shift=0.0) self.post_layer_scale = config.conf_residual_weight def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor: diff --git a/python/sglang/srt/models/gemma4_mm.py b/python/sglang/srt/models/gemma4_mm.py index a62012bad943..74ace8c403f1 100644 --- a/python/sglang/srt/models/gemma4_mm.py +++ b/python/sglang/srt/models/gemma4_mm.py @@ -28,9 +28,6 @@ PreTrainedModel, ) -from sglang.srt.models.gemma4_audio import Gemma4AudioEncoder -from sglang.srt.models.gemma4_vision import Gemma4VisionEncoder - from sglang.srt.layers.layernorm import Gemma4RMSNorm from sglang.srt.layers.linear import ReplicatedLinear from sglang.srt.layers.logits_processor import LogitsProcessor @@ -50,7 +47,9 @@ default_weight_loader, maybe_remap_kv_scale_name, ) +from sglang.srt.models.gemma4_audio import Gemma4AudioEncoder from sglang.srt.models.gemma4_causal import Gemma4TextModel +from sglang.srt.models.gemma4_vision import Gemma4VisionEncoder from sglang.srt.utils import add_prefix from sglang.srt.utils.hf_transformers_utils import get_processor @@ -280,10 +279,14 @@ def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: def get_audio_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: if self.audio_tower is None: - raise ValueError("Audio inputs provided but the model does not have an audio tower.") + raise ValueError( + "Audio inputs provided but the model does not have an audio tower." + ) all_input_features = flatten_nested_list([item.feature for item in items]) - all_input_features_mask = flatten_nested_list([~item.input_features_mask for item in items]) + all_input_features_mask = flatten_nested_list( + [~item.input_features_mask for item in items] + ) all_embeds = [] for input_features, input_features_mask in zip( @@ -396,9 +399,7 @@ def tie_weights(self, recompute_mapping=False): r"(.+\.(?:self_attn|attn))\.(q_proj|k_proj|v_proj)\.(.*)" ) # Regex for fused GateUp in the vision tower MLP. - _RE_TOWER_GATE_UP = re.compile( - r"(.+\.mlp)\.(gate_proj|up_proj)\.(.*)" - ) + _RE_TOWER_GATE_UP = re.compile(r"(.+\.mlp)\.(gate_proj|up_proj)\.(.*)") @staticmethod def _remap_tower_name(name: str, params_dict: dict) -> str: diff --git a/python/sglang/srt/models/gemma4_vision.py b/python/sglang/srt/models/gemma4_vision.py index cb0a82d2938b..24988ca16dc3 100644 --- a/python/sglang/srt/models/gemma4_vision.py +++ b/python/sglang/srt/models/gemma4_vision.py @@ -21,18 +21,17 @@ from einops import rearrange from transformers import Gemma4VisionConfig -from sglang.srt.layers.attention.vision import QKV_BACKEND_IMPL, VisionAttention +from sglang.srt.layers.attention.vision import QKV_BACKEND_IMPL from sglang.srt.layers.clippable_linear import ( ClippableGateUpParallelLinear, ClippableQKVParallelLinear, ClippableRowParallelLinear, ) -from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size +from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.layers.layernorm import Gemma4RMSNorm from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.utils import add_prefix, get_device_capability, is_cuda - # --------------------------------------------------------------------------- # 2-D Multidimensional RoPE (matches HF Gemma4RotaryEmbedding for vision) # --------------------------------------------------------------------------- @@ -44,7 +43,9 @@ def _rotate_half(x: torch.Tensor) -> torch.Tensor: return torch.cat((-x2, x1), dim=-1) -def _apply_rotary(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: +def _apply_rotary( + x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor +) -> torch.Tensor: return (x * cos) + (_rotate_half(x) * sin) @@ -76,7 +77,9 @@ def forward( dim_inv_freq = 1.0 / ( self.rope_theta ** ( - torch.arange(0, head_dim_per_dim, 2, device=x.device, dtype=torch.float) + torch.arange( + 0, head_dim_per_dim, 2, device=x.device, dtype=torch.float + ) / head_dim_per_dim ) ) @@ -111,7 +114,9 @@ def _apply_multidimensional_rope( x_parts = x.split(chunk_size, dim=-1) cos_parts = cos.split(chunk_size, dim=-1) sin_parts = sin.split(chunk_size, dim=-1) - y_parts = [_apply_rotary(x_parts[k], cos_parts[k], sin_parts[k]) for k in range(ndim)] + y_parts = [ + _apply_rotary(x_parts[k], cos_parts[k], sin_parts[k]) for k in range(ndim) + ] return torch.cat(y_parts, dim=-1) @@ -225,9 +230,12 @@ def forward( attn_mask_4d = None output = self.qkv_backend.forward( - q=q, k=k, v=v, + q=q, + k=k, + v=v, cu_seqlens=None, - bsz=bsz, seq_len=seq_len, + bsz=bsz, + seq_len=seq_len, attention_mask=attn_mask_4d, softmax_scale=1.0, ) @@ -292,11 +300,13 @@ def __init__( ): super().__init__() self.self_attn = Gemma4VisionAttention( - config, quant_config=quant_config, + config, + quant_config=quant_config, prefix=add_prefix("self_attn", prefix), ) self.mlp = Gemma4VisionMLP( - config, quant_config=quant_config, + config, + quant_config=quant_config, prefix=add_prefix("mlp", prefix), ) eps = config.rms_norm_eps @@ -346,13 +356,17 @@ def __init__( super().__init__() self.config = config self.rotary_emb = Gemma4VisionRotaryEmbedding(config) - self.layers = nn.ModuleList([ - Gemma4VisionEncoderLayer( - config, layer_idx=i, quant_config=quant_config, - prefix=add_prefix(f"layers.{i}", prefix), - ) - for i in range(config.num_hidden_layers) - ]) + self.layers = nn.ModuleList( + [ + Gemma4VisionEncoderLayer( + config, + layer_idx=i, + quant_config=quant_config, + prefix=add_prefix(f"layers.{i}", prefix), + ) + for i in range(config.num_hidden_layers) + ] + ) def forward( self, @@ -387,7 +401,9 @@ def __init__(self, config: Gemma4VisionConfig): self.hidden_size = config.hidden_size self.position_embedding_size = config.position_embedding_size - self.input_proj = nn.Linear(3 * self.patch_size ** 2, self.hidden_size, bias=False) + self.input_proj = nn.Linear( + 3 * self.patch_size**2, self.hidden_size, bias=False + ) self.position_embedding_table = nn.Parameter( torch.ones(2, self.position_embedding_size, self.hidden_size) ) @@ -399,24 +415,46 @@ def _position_embeddings( one_hot = one_hot.permute(0, 2, 1, 3).to(self.position_embedding_table) position_embeddings = one_hot @ self.position_embedding_table position_embeddings = position_embeddings.sum(dim=1) - position_embeddings = torch.where(padding_positions.unsqueeze(-1), 0.0, position_embeddings) + position_embeddings = torch.where( + padding_positions.unsqueeze(-1), 0.0, position_embeddings + ) return position_embeddings def _patchify(self, pixel_values: torch.Tensor) -> torch.Tensor: batch_size, num_channels, height, width = pixel_values.shape patch_height = height // self.patch_size patch_width = width // self.patch_size - patchified_shape = (batch_size, num_channels, patch_height, self.patch_size, patch_width, self.patch_size) - consolidated_shape = (batch_size, patch_height * patch_width, num_channels * self.patch_size ** 2) - patches = pixel_values.reshape(patchified_shape).permute(0, 2, 4, 3, 5, 1).reshape(consolidated_shape) + patchified_shape = ( + batch_size, + num_channels, + patch_height, + self.patch_size, + patch_width, + self.patch_size, + ) + consolidated_shape = ( + batch_size, + patch_height * patch_width, + num_channels * self.patch_size**2, + ) + patches = ( + pixel_values.reshape(patchified_shape) + .permute(0, 2, 4, 3, 5, 1) + .reshape(consolidated_shape) + ) patches = 2 * (patches - 0.5) return self.input_proj(patches.to(self.input_proj.weight.dtype)) def forward( - self, pixel_values: torch.Tensor, patch_positions: torch.Tensor, padding_positions: torch.Tensor + self, + pixel_values: torch.Tensor, + patch_positions: torch.Tensor, + padding_positions: torch.Tensor, ) -> torch.Tensor: hidden_states = self._patchify(pixel_values) - position_embeddings = self._position_embeddings(patch_positions, padding_positions) + position_embeddings = self._position_embeddings( + patch_positions, padding_positions + ) return hidden_states + position_embeddings @@ -430,14 +468,14 @@ def __init__(self, config: Gemma4VisionConfig): super().__init__() self.hidden_size = config.hidden_size self.default_output_length = config.default_output_length - self.root_hidden_size = self.hidden_size ** 0.5 + self.root_hidden_size = self.hidden_size**0.5 def _avg_pool_by_positions( self, x: torch.Tensor, patch_positions: torch.Tensor, length: int ) -> Tuple[torch.Tensor, torch.Tensor]: input_seq_len = x.shape[1] k = int((input_seq_len // length) ** 0.5) - k_squared = k ** 2 + k_squared = k**2 if k_squared * length != input_seq_len: raise ValueError( f"Cannot pool {x.shape} to {length}: {k=}^2 times {length=} must be {input_seq_len}." @@ -493,11 +531,12 @@ def __init__( self.patch_size = config.patch_size self.pooling_kernel_size = config.pooling_kernel_size self.default_output_length = config.default_output_length - self.max_patches = self.default_output_length * self.pooling_kernel_size ** 2 + self.max_patches = self.default_output_length * self.pooling_kernel_size**2 self.patch_embedder = Gemma4VisionPatchEmbedder(config) self.encoder = Gemma4VisionTransformer( - config, quant_config=quant_config, + config, + quant_config=quant_config, prefix=add_prefix("encoder", prefix), ) self.pooler = Gemma4VisionPooler(config) @@ -526,7 +565,9 @@ def _patch_positions( indexing="xy", ) stacked_grid = torch.stack(patch_grid, dim=-1) - real_positions = stacked_grid.reshape(num_patches, 2).unsqueeze(0).repeat(batch_size, 1, 1) + real_positions = ( + stacked_grid.reshape(num_patches, 2).unsqueeze(0).repeat(batch_size, 1, 1) + ) if num_padding > 0: pad_positions = torch.full( @@ -536,7 +577,9 @@ def _patch_positions( else: patch_positions = real_positions - padding_positions = torch.zeros(batch_size, self.max_patches, device=device, dtype=torch.bool) + padding_positions = torch.zeros( + batch_size, self.max_patches, device=device, dtype=torch.bool + ) if num_padding > 0: padding_positions[:, num_patches:] = True @@ -564,8 +607,11 @@ def forward(self, pixel_values: torch.Tensor) -> Tuple[torch.Tensor, torch.Tenso num_padding = self.max_patches - num_real if num_padding > 0: pad_embeds = torch.zeros( - inputs_embeds.shape[0], num_padding, inputs_embeds.shape[2], - device=inputs_embeds.device, dtype=inputs_embeds.dtype, + inputs_embeds.shape[0], + num_padding, + inputs_embeds.shape[2], + device=inputs_embeds.device, + dtype=inputs_embeds.dtype, ) inputs_embeds = torch.cat([inputs_embeds, pad_embeds], dim=1) @@ -575,5 +621,7 @@ def forward(self, pixel_values: torch.Tensor) -> Tuple[torch.Tensor, torch.Tenso patch_positions=patch_positions, ) - pooled, pooler_mask = self.pooler(last_hidden, patch_positions, padding_positions) + pooled, pooler_mask = self.pooler( + last_hidden, patch_positions, padding_positions + ) return pooled, pooler_mask diff --git a/python/sglang/srt/multimodal/processors/gemma4.py b/python/sglang/srt/multimodal/processors/gemma4.py index 861197b8347c..efea548680ba 100644 --- a/python/sglang/srt/multimodal/processors/gemma4.py +++ b/python/sglang/srt/multimodal/processors/gemma4.py @@ -22,6 +22,7 @@ from sglang.srt.models.gemma4_mm import Gemma4ForConditionalGeneration from sglang.srt.multimodal.processors.base_processor import MultimodalSpecialTokens + class Gemma4SGLangProcessor(SGLangBaseProcessor): """Multimodal processor for Gemma4 supporting image and audio inputs.""" @@ -55,7 +56,9 @@ def _get_audio_pad_multiple(self) -> int: first_stride = ac.sscp_conv_stride_size[0][0] if ac is not None else 2 return hop * first_stride - def process_mm_data(self, input_text, images=None, videos=None, audios=None, **kwargs): + def process_mm_data( + self, input_text, images=None, videos=None, audios=None, **kwargs + ): if audios: pad_multiple = self._get_audio_pad_multiple() padded = [] diff --git a/test/manual/test_vlm_accuracy.py b/test/manual/test_vlm_accuracy.py index cb759115cd53..fde2579a1493 100644 --- a/test/manual/test_vlm_accuracy.py +++ b/test/manual/test_vlm_accuracy.py @@ -423,18 +423,12 @@ def test_vision_encoder(self): # HF: last_hidden_state is [1, num_real_tokens, hidden] (padding stripped) hf_out = self.hf_vision_tower(pixel_values) hf_tokens = hf_out.last_hidden_state.squeeze(0) - hf_projected = self.hf_embed_vision( - hf_tokens.unsqueeze(0) - ).squeeze(0) + hf_projected = self.hf_embed_vision(hf_tokens.unsqueeze(0)).squeeze(0) # SGLang: returns (pooled, pooler_mask) with mask True = valid sg_pooled, sg_mask = self.sg_model.vision_tower(pixel_values) - sg_tokens = torch.cat( - [hs[m] for hs, m in zip(sg_pooled, sg_mask)] - ) - sg_projected = self.sg_model.embed_vision( - sg_tokens.unsqueeze(0) - ).squeeze(0) + sg_tokens = torch.cat([hs[m] for hs, m in zip(sg_pooled, sg_mask)]) + sg_projected = self.sg_model.embed_vision(sg_tokens.unsqueeze(0)).squeeze(0) self.assertEqual(hf_tokens.shape, sg_tokens.shape) print() @@ -463,9 +457,7 @@ def test_audio_encoder(self): hf_valid = hf_enc[hf_valid_mask.unsqueeze(-1).expand_as(hf_enc)].reshape( -1, hf_enc.shape[-1] ) - hf_projected = self.hf_embed_audio( - hf_valid.unsqueeze(0) - ).squeeze(0) + hf_projected = self.hf_embed_audio(hf_valid.unsqueeze(0)).squeeze(0) # SGLang: returns (encodings, mask) — zero-fills padding positions sg_enc, sg_mask = self.sg_model.audio_tower(audio_mel, audio_mel_mask) @@ -473,9 +465,7 @@ def test_audio_encoder(self): sg_valid = sg_enc[sg_valid_mask.unsqueeze(-1).expand_as(sg_enc)].reshape( -1, sg_enc.shape[-1] ) - sg_projected = self.sg_model.embed_audio( - sg_valid.unsqueeze(0) - ).squeeze(0) + sg_projected = self.sg_model.embed_audio(sg_valid.unsqueeze(0)).squeeze(0) self.assertEqual(hf_valid.shape, sg_valid.shape) print() @@ -615,18 +605,24 @@ def setUpClass(cls): with torch.no_grad(): hf_enc, hf_mask = hf_audio_tower(audio_mel, audio_mel_mask) hf_valid_mask = ~hf_mask - cls.hf_audio_valid = hf_enc[ - hf_valid_mask.unsqueeze(-1).expand_as(hf_enc) - ].reshape(-1, hf_enc.shape[-1]).cpu() - cls.hf_audio_proj = hf_embed_audio( - cls.hf_audio_valid.unsqueeze(0).to(cls.device) - ).squeeze(0).cpu() + cls.hf_audio_valid = ( + hf_enc[hf_valid_mask.unsqueeze(-1).expand_as(hf_enc)] + .reshape(-1, hf_enc.shape[-1]) + .cpu() + ) + cls.hf_audio_proj = ( + hf_embed_audio(cls.hf_audio_valid.unsqueeze(0).to(cls.device)) + .squeeze(0) + .cpu() + ) hf_vis_out = hf_vision_tower(pixel_values) cls.hf_vis_tokens = hf_vis_out.last_hidden_state.squeeze(0).cpu() - cls.hf_vis_proj = hf_embed_vision( - cls.hf_vis_tokens.unsqueeze(0).to(cls.device) - ).squeeze(0).cpu() + cls.hf_vis_proj = ( + hf_embed_vision(cls.hf_vis_tokens.unsqueeze(0).to(cls.device)) + .squeeze(0) + .cpu() + ) del hf_audio_tower, hf_embed_audio, hf_vision_tower, hf_embed_vision import gc From f6b9759cc2828770e1447ac61e2ea3b2be3b9630 Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Tue, 17 Mar 2026 21:32:37 +0000 Subject: [PATCH 026/112] nit --- python/sglang/srt/mem_cache/memory_pool.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 9943e4715b05..2a6bd9e4c81e 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -754,20 +754,6 @@ def __init__( ) self.head_num = swa_head_num if swa_head_num is not None else head_num self.head_dim = swa_head_dim if swa_head_dim is not None else head_dim - print( - "head_num: ", - self.head_num, - "head_dim: ", - self.head_dim, - "swa_head_num: ", - swa_head_num, - "swa_head_dim: ", - swa_head_dim, - "head_num: ", - head_num, - "head_dim: ", - head_dim, - ) self.v_head_dim = ( swa_v_head_dim if swa_v_head_dim is not None @@ -846,9 +832,6 @@ def _create_buffers(self): if self.enable_custom_mem_pool else nullcontext() ): - print( - f"Allocating KV cache buffers with size {self.size}, page_size {self.page_size}, head_num {self.head_num}, head_dim {self.head_dim}, v_head_dim {self.v_head_dim}, dtype {self.store_dtype}, device {self.device}" - ) # [size, head_num, head_dim] for each layer # The padded slot 0 is used for writing dummy outputs from padded tokens. # adjust for global From 879aaefd29cf7f9658a98ba1e361d6a42a04570e Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Tue, 17 Mar 2026 22:07:39 +0000 Subject: [PATCH 027/112] Fix layer_scalar to apply unconditionally on all decoder layers The HF reference applies layer_scalar to every Gemma4DecoderLayer, not just full-attention layers. New checkpoints have non-trivial scalar values on SWA layers that were being silently ignored. Made-with: Cursor --- python/sglang/srt/models/gemma4_causal.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/models/gemma4_causal.py b/python/sglang/srt/models/gemma4_causal.py index 1219e76bb0d9..eb258758089c 100644 --- a/python/sglang/srt/models/gemma4_causal.py +++ b/python/sglang/srt/models/gemma4_causal.py @@ -456,9 +456,7 @@ def __init__( self.per_layer_projection = None self.post_per_layer_input_norm = None - # Layer scalar for full-attention layers only - if self.is_full_attention: - self.register_buffer("layer_scalar", torch.ones(1), persistent=True) + self.register_buffer("layer_scalar", torch.ones(1), persistent=True) self.prefix = prefix def forward( @@ -510,9 +508,7 @@ def forward( ) hidden_states = hidden_states + per_layer_contribution - # Apply layer scalar for full-attention layers - if self.is_full_attention and hasattr(self, "layer_scalar"): - hidden_states = hidden_states * self.layer_scalar + hidden_states = hidden_states * self.layer_scalar return hidden_states, None From e1f8f6100141d5c512b7194132447bca3e732b2f Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Tue, 17 Mar 2026 16:54:24 -0700 Subject: [PATCH 028/112] Clarify SWA attn_logits buffer condition in triton backend Gate the two-buffer path on sliding_window_size to make intent explicit, and rewrite comment to explain the kernel's // Lv stride constraint. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../sglang/srt/layers/attention/triton_backend.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index 3cba295b1f59..198ebed4f6ec 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -97,14 +97,15 @@ def __init__( self.num_kv_head = model_runner.model_config.get_num_kv_heads( get_attention_tp_size() ) - # The decode kernel's intermediate attn_logits buffer must match the - # exact v_head_dim of the layer being processed (the triton kernel uses - # a // Lv stride trick to derive attn_lse indices from attn_logits strides). - # When SWA and full attention layers have different v_head_dim (e.g. Gemma 4 - # with swa=256, full=512), we need two separate attn_logits buffers. + # The decode triton kernel derives attn_lse offsets from attn_logits + # strides via integer division by v_head_dim (the "// Lv" trick in + # _fwd_kernel_stage1/stage2), so attn_logits.shape[-1] must exactly + # match the layer's v_head_dim. For hybrid SWA models where SWA and + # full-attention layers use different v_head_dim (e.g. Gemma 4: + # swa=256, full=512), we allocate a second buffer for SWA layers. full_v_head_dim = model_runner.model_config.v_head_dim swa_v_head_dim = model_runner.model_config.swa_v_head_dim - if swa_v_head_dim != full_v_head_dim: + if self.sliding_window_size is not None and swa_v_head_dim != full_v_head_dim: self.v_head_dim = full_v_head_dim self.swa_v_head_dim = swa_v_head_dim elif ( From 702a55e3b5af84dddfc070204652bfba7951dad5 Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Mon, 16 Mar 2026 22:36:36 +0000 Subject: [PATCH 029/112] initial dense 31b support --- python/sglang/srt/models/gemma4_causal.py | 93 +++++++++++-------- python/sglang/srt/models/gemma4_mm.py | 27 ++++++ .../sglang/srt/utils/hf_transformers_utils.py | 28 +++--- 3 files changed, 99 insertions(+), 49 deletions(-) diff --git a/python/sglang/srt/models/gemma4_causal.py b/python/sglang/srt/models/gemma4_causal.py index eb258758089c..13bf81944a8c 100644 --- a/python/sglang/srt/models/gemma4_causal.py +++ b/python/sglang/srt/models/gemma4_causal.py @@ -13,6 +13,7 @@ # ============================================================================== import logging +import re from typing import Iterable, Optional, Set, Tuple import torch @@ -203,10 +204,19 @@ def __init__( self.config = config tp_size = get_tensor_model_parallel_world_size() + layer_type = config.layer_types[layer_id] + self.sliding_window = config.sliding_window if layer_type == "sliding_attention" else None + self.total_num_heads = config.num_attention_heads assert self.total_num_heads % tp_size == 0 self.num_heads = self.total_num_heads // tp_size - self.total_num_kv_heads = config.num_key_value_heads + + if layer_type == "sliding_attention": + self.total_num_kv_heads = getattr( + config, "swa_num_key_value_heads", config.num_key_value_heads + ) + else: + self.total_num_kv_heads = config.num_key_value_heads self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) @@ -250,45 +260,32 @@ def __init__( self.head_dim, eps=config.rms_norm_eps, scale_shift=0.0, with_scale=False ) - # Determine if layer uses sliding window based on pattern - layer_type = config.layer_types[layer_id] - self.is_sliding = layer_type == "sliding_attention" - self.sliding_window = config.sliding_window if self.is_sliding else None - - # Initialize the rotary embedding based on layer type. - # Gemma 4 uses different RoPE parameters for sliding vs full attention. if layer_type in config.rope_parameters: rope_parameters = dict(config.rope_parameters[layer_type]) - # Fix: Use global_partial_rotary_factor for full_attention layers - # JAX reference uses global_rope_proportion=0.25 for global attention if layer_type == "full_attention": global_prf = getattr(config, "global_partial_rotary_factor", 0.25) rope_parameters["partial_rotary_factor"] = global_prf else: - # Fallback for older config format rope_parameters = dict( rope_type="default", rope_theta=getattr(config, "rope_theta", 10000.0), ) - # Check if this is a KV shared layer + # KV sharing logic + num_kv_shared_layers = getattr(config, "num_kv_shared_layers", 0) first_kv_shared_layer_idx = ( - config.num_hidden_layers - config.num_kv_shared_layers + config.num_hidden_layers - num_kv_shared_layers ) - self.is_kv_shared_layer = layer_id >= first_kv_shared_layer_idx + self.is_kv_shared_layer = layer_id >= first_kv_shared_layer_idx and num_kv_shared_layers > 0 + + self.kv_shared_layer_index = None + if num_kv_shared_layers > 0 and self.layer_id >= first_kv_shared_layer_idx: + prev_layers = config.layer_types[:first_kv_shared_layer_idx] + current_layer_type = config.layer_types[self.layer_id] + self.kv_shared_layer_index = len(prev_layers) - 1 - prev_layers[::-1].index( + current_layer_type + ) - # KV sharing logic for Gemma 4 - # kv_sharing_target_layer_name = None - num_kv_shared_layers = getattr(config, "num_kv_shared_layers", 0) - if num_kv_shared_layers > 0: - first_kv_shared_layer_idx = config.num_hidden_layers - num_kv_shared_layers - if self.layer_id >= first_kv_shared_layer_idx: - # Find the last non-shared layer of the same type (sliding/full) - prev_layers = config.layer_types[:first_kv_shared_layer_idx] - current_layer_type = config.layer_types[self.layer_id] - self.kv_shared_layer_index = ( - len(prev_layers) - 1 - prev_layers[::-1].index(current_layer_type) - ) self.rotary_emb = get_rope( self.head_dim, rotary_dim=self.head_dim, @@ -378,8 +375,8 @@ def __init__( super().__init__() self.hidden_size = config.hidden_size self.hidden_size_per_layer_input = getattr( - config, "hidden_size_per_layer_input", 0 - ) + config, "hidden_size_per_layer_input", None + ) or 0 self.layer_id = layer_id @@ -535,11 +532,11 @@ def __init__( # Per-layer input embeddings self.hidden_size = config.hidden_size self.hidden_size_per_layer_input = getattr( - config, "hidden_size_per_layer_input", 0 - ) + config, "hidden_size_per_layer_input", None + ) or 0 self.vocab_size_per_layer_input = getattr( - config, "vocab_size_per_layer_input", config.vocab_size - ) + config, "vocab_size_per_layer_input", None + ) or config.vocab_size if self.hidden_size_per_layer_input > 0: self.embed_tokens_per_layer = Gemma4TextScaledWordEmbedding( @@ -819,6 +816,16 @@ def forward( input_ids, hidden_states, self.lm_head, forward_batch ) + def _get_k_eq_v_layers(self) -> set: + """Return set of layer indices where attention_k_eq_v applies (full-attention layers).""" + if not getattr(self.config, "attention_k_eq_v", False): + return set() + return { + i + for i, lt in enumerate(self.config.layer_types) + if lt != "sliding_attention" + } + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) @@ -828,34 +835,44 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), ] + k_eq_v_layers = self._get_k_eq_v_layers() + params_dict = dict(self.named_parameters()) params_dict.update(dict(self.named_buffers())) loaded_params: Set[str] = set() for name, loaded_weight in weights: name = name.replace("model.language_model.", "model.") + + # attention_k_eq_v: full-attention layers have no v_proj in the + # checkpoint (K and V share weights). When we see a k_proj weight + # for one of these layers, load it into both the "k" and "v" shards + # of the fused QKV so the forward produces v_raw == k_raw. + should_dup_k_to_v = ( + ".k_proj." in name + and k_eq_v_layers + and (m := re.search(r"layers\.(\d+)\.", name)) is not None + and int(m.group(1)) in k_eq_v_layers + ) + for param_name, shard_name, shard_id in stacked_params_mapping: if shard_name not in name: continue name = name.replace(shard_name, param_name) - # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue if name not in params_dict: - # Skip loading weights that are not in the model continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) + if should_dup_k_to_v: + weight_loader(param, loaded_weight, "v") break else: - # lm_head is not used in vllm as it is tied with embed_token. - # To prevent errors, skip loading lm_head.weight. if "lm_head.weight" in name and self.config.tie_word_embeddings: continue - # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue - # Remapping the name of FP8 kv-scale. name = maybe_remap_kv_scale_name(name, params_dict) if name is None: continue diff --git a/python/sglang/srt/models/gemma4_mm.py b/python/sglang/srt/models/gemma4_mm.py index 74ace8c403f1..1c506a83acce 100644 --- a/python/sglang/srt/models/gemma4_mm.py +++ b/python/sglang/srt/models/gemma4_mm.py @@ -454,7 +454,20 @@ def _remap_tower_name(name: str, params_dict: dict) -> str: return name + def _get_k_eq_v_layers(self) -> set: + """Return set of layer indices where attention_k_eq_v applies (full-attention layers).""" + text_config = self.config.text_config + if not getattr(text_config, "attention_k_eq_v", False): + return set() + return { + i + for i, lt in enumerate(text_config.layer_types) + if lt != "sliding_attention" + } + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + k_eq_v_layers = self._get_k_eq_v_layers() + params_dict = dict(self.named_parameters()) params_dict.update(dict(self.named_buffers())) loaded_params: Set[str] = set() @@ -473,6 +486,18 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if "vision_tower." in name or "audio_tower." in name: name = self._remap_tower_name(name, params_dict) + # attention_k_eq_v: full-attention layers have no v_proj in the + # checkpoint (K and V share weights). When we see a k_proj weight + # for one of these layers, load it into both the "k" and "v" shards + # of the fused QKV so the forward produces v_raw == k_raw. + should_dup_k_to_v = ( + ".k_proj." in name + and k_eq_v_layers + and "language_model." in name + and (m := re.search(r"layers\.(\d+)\.", name)) is not None + and int(m.group(1)) in k_eq_v_layers + ) + # Try stacked (fused) params first orig_name = name for param_name, weight_name, shard_id in self.stacked_params_mapping: @@ -485,6 +510,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) + if should_dup_k_to_v: + weight_loader(param, loaded_weight, "v") break else: if name.endswith(".bias") and name not in params_dict: diff --git a/python/sglang/srt/utils/hf_transformers_utils.py b/python/sglang/srt/utils/hf_transformers_utils.py index 367b13b3314a..3d3ad13ea072 100644 --- a/python/sglang/srt/utils/hf_transformers_utils.py +++ b/python/sglang/srt/utils/hf_transformers_utils.py @@ -398,19 +398,25 @@ def get_config( config.update({"architectures": ["MultiModalityCausalLM"]}) if config.model_type == "gemma4": - global_head_dim = getattr(config.text_config, "global_head_dim", None) - num_global_key_value_heads = getattr( - config.text_config, "num_global_key_value_heads", None - ) + # Gemma4 configs use base attributes for SWA layers and `global_*` + # variants for full-attention layers. SGLang expects the opposite: + # base = full-attention, `swa_*` = sliding-window overrides. + # Remap here so the rest of the stack sees a uniform convention. + text_config = config.text_config + global_head_dim = getattr(text_config, "global_head_dim", None) + global_kv_heads = getattr(text_config, "num_global_key_value_heads", None) - if global_head_dim is not None: - config.text_config.swa_head_dim = config.text_config.head_dim - config.text_config.swa_v_head_dim = config.text_config.head_dim - config.text_config.head_dim = global_head_dim + swa_head_dim = text_config.head_dim + swa_kv_heads = text_config.num_key_value_heads - config.text_config.swa_num_key_value_heads = config.num_key_value_heads - if num_global_key_value_heads is not None: - config.text_config.num_key_value_heads = num_global_key_value_heads + text_config.swa_head_dim = swa_head_dim + text_config.swa_v_head_dim = swa_head_dim + text_config.swa_num_key_value_heads = swa_kv_heads + + if global_head_dim is not None: + text_config.head_dim = global_head_dim + if global_kv_heads is not None: + text_config.num_key_value_heads = global_kv_heads if config.model_type == "longcat_flash": config.update({"architectures": ["LongcatFlashForCausalLM"]}) From 808860ebefe3519f75f1f5b2e1893d8daa51beae Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Wed, 18 Mar 2026 23:45:42 +0000 Subject: [PATCH 030/112] custom bidirectional mask for image tokens --- python/sglang/srt/models/gemma4_causal.py | 2 +- python/sglang/srt/models/gemma4_mm.py | 98 ++++++++++++++++++++++- 2 files changed, 97 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/models/gemma4_causal.py b/python/sglang/srt/models/gemma4_causal.py index 13bf81944a8c..a546a2ba4de1 100644 --- a/python/sglang/srt/models/gemma4_causal.py +++ b/python/sglang/srt/models/gemma4_causal.py @@ -823,7 +823,7 @@ def _get_k_eq_v_layers(self) -> set: return { i for i, lt in enumerate(self.config.layer_types) - if lt != "sliding_attention" + if lt == "full_attention" } def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/gemma4_mm.py b/python/sglang/srt/models/gemma4_mm.py index 1c506a83acce..c2f200c927a9 100644 --- a/python/sglang/srt/models/gemma4_mm.py +++ b/python/sglang/srt/models/gemma4_mm.py @@ -28,6 +28,7 @@ PreTrainedModel, ) +from sglang.srt.layers.attention.triton_backend import TritonAttnBackend from sglang.srt.layers.layernorm import Gemma4RMSNorm from sglang.srt.layers.linear import ReplicatedLinear from sglang.srt.layers.logits_processor import LogitsProcessor @@ -42,7 +43,8 @@ MultimodalInputs, flatten_nested_list, ) -from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode +from sglang.srt.server_args import get_global_server_args from sglang.srt.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name, @@ -237,6 +239,84 @@ def get_input_embeddings(self) -> nn.Embedding: def get_attention_sliding_window_size(self): return getattr(self.config.text_config, "sliding_window", -1) - 1 + def prepare_attn_masks( + self, + forward_batch: ForwardBatch, + input_ids: torch.Tensor, + mask_dtype: torch.dtype, + ): + """Prepare bidirectional attention masks for image tokens. + + Gemma 4 uses bidirectional attention for image soft tokens during prefill. + Following the HF implementation, bidirectional attention is only enabled + within each individual image group (same-image tokens), not across images. + Currently only the TritonAttnBackend supports this. + + TODO(kpham-sgl): Guard appropriately for gemma3_mm.py:prepare_attn_masks() + """ + if not isinstance(forward_batch.attn_backend, TritonAttnBackend): + logger.warning_once( + "Bidirectional attention for image tokens requires TritonAttnBackend. " + "Falling back to causal attention, which may degrade image quality." + ) + return + if get_global_server_args().chunked_prefill_size != -1: + logger.warning_once( + "Bidirectional attention for image tokens is not supported with chunked prefill. " + "Image token spans split across chunk boundaries will receive causal attention. " + "Disable chunked prefill (--chunked-prefill-size=-1) to fix this." + ) + return + assert forward_batch.forward_mode == ForwardMode.EXTEND + + bidirectional_attn_masks_list = [] + bidirectional_attn_mask_indptr = torch.zeros( + forward_batch.batch_size + 1, dtype=torch.int32, device=input_ids.device + ) + + for i in range(forward_batch.batch_size): + extend_seq_len = forward_batch.extend_seq_lens[i] + prefix_len = forward_batch.extend_prefix_lens[i] + bidirectional_attn_mask = torch.zeros( + extend_seq_len, + extend_seq_len + prefix_len, + dtype=mask_dtype, + device=input_ids.device, + ) + # Start with causal mask + bidirectional_attn_mask.fill_(1) + bidirectional_attn_mask = bidirectional_attn_mask.tril( + diagonal=prefix_len + ) + + # Enable bidirectional attention within each image group + mm_inputs = forward_batch.mm_inputs[i] + if mm_inputs is not None: + for mm_item in mm_inputs.mm_items: + if mm_item.is_image(): + for im_begin, im_end in mm_item.offsets: + # Only handle image tokens in the extend portion + # (compatible with radix cache) + if im_begin >= prefix_len: + bidirectional_attn_mask[ + im_begin - prefix_len : im_end + 1 - prefix_len, + im_begin : im_end + 1, + ] = 1 + + bidirectional_attn_masks_list.append(bidirectional_attn_mask.flatten()) + bidirectional_attn_mask_indptr[i + 1] = ( + bidirectional_attn_mask_indptr[i] + bidirectional_attn_mask.nelement() + ) + + if bidirectional_attn_masks_list: + bidirectional_attn_masks = torch.cat(bidirectional_attn_masks_list, dim=0) + forward_batch.attn_backend.forward_metadata.mask_indptr = ( + bidirectional_attn_mask_indptr + ) + forward_batch.attn_backend.forward_metadata.custom_mask = ( + bidirectional_attn_masks + ) + def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: all_pixel_values = flatten_nested_list([item.feature for item in items]) vt = self.vision_tower @@ -360,6 +440,20 @@ def forward( ple_ids[input_ids == self.config.audio_token_id] = 0 per_layer_inputs = self.get_per_layer_inputs(ple_ids) + # Prepare bidirectional attention masks for image tokens during prefill. + # Gemma 4 uses bidirectional attention for image soft tokens. + # Only TritonAttnBackend supports this; incompatible with CUDA Graph and + # chunked prefill. + if ( + forward_batch.forward_mode == ForwardMode.EXTEND + and forward_batch.contains_image_inputs() + ): + self.prepare_attn_masks( + forward_batch, + input_ids, + mask_dtype=torch.bool, + ) + # Use general_mm_embed_routine for handling multimodal data hidden_states = general_mm_embed_routine( input_ids=input_ids, @@ -462,7 +556,7 @@ def _get_k_eq_v_layers(self) -> set: return { i for i, lt in enumerate(text_config.layer_types) - if lt != "sliding_attention" + if lt == "full_attention" } def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): From b8a8323fec3fe21985c13dae72ade52acde390de Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Thu, 19 Mar 2026 21:23:19 +0000 Subject: [PATCH 031/112] canonical warning for chunked prefill + bidirectional mask --- python/sglang/srt/models/gemma4_mm.py | 35 ++++++++++++++++++--------- 1 file changed, 24 insertions(+), 11 deletions(-) diff --git a/python/sglang/srt/models/gemma4_mm.py b/python/sglang/srt/models/gemma4_mm.py index c2f200c927a9..86dcf9e6ec73 100644 --- a/python/sglang/srt/models/gemma4_mm.py +++ b/python/sglang/srt/models/gemma4_mm.py @@ -260,13 +260,6 @@ def prepare_attn_masks( "Falling back to causal attention, which may degrade image quality." ) return - if get_global_server_args().chunked_prefill_size != -1: - logger.warning_once( - "Bidirectional attention for image tokens is not supported with chunked prefill. " - "Image token spans split across chunk boundaries will receive causal attention. " - "Disable chunked prefill (--chunked-prefill-size=-1) to fix this." - ) - return assert forward_batch.forward_mode == ForwardMode.EXTEND bidirectional_attn_masks_list = [] @@ -274,6 +267,8 @@ def prepare_attn_masks( forward_batch.batch_size + 1, dtype=torch.int32, device=input_ids.device ) + split_images = [] + for i in range(forward_batch.batch_size): extend_seq_len = forward_batch.extend_seq_lens[i] prefix_len = forward_batch.extend_prefix_lens[i] @@ -295,19 +290,37 @@ def prepare_attn_masks( for mm_item in mm_inputs.mm_items: if mm_item.is_image(): for im_begin, im_end in mm_item.offsets: - # Only handle image tokens in the extend portion - # (compatible with radix cache) - if im_begin >= prefix_len: + # Note(kpham-sgl): We only apply bidirectional attention when the image token span + # is fully contained in the extend window. Otherwise, we silently fall back to + # causal attention. + # FIXME(kpham-sgl): This is a hack to work around the fact that the image token span + # might not be fully contained in the extend window during chunked prefill. + # We should fix this by properly making chunked prefill mask aware. + if im_begin >= prefix_len and im_end < prefix_len + extend_seq_len: bidirectional_attn_mask[ im_begin - prefix_len : im_end + 1 - prefix_len, im_begin : im_end + 1, ] = 1 + elif im_end >= prefix_len and im_begin < prefix_len + extend_seq_len: + split_images.append((i, im_begin, im_end)) bidirectional_attn_masks_list.append(bidirectional_attn_mask.flatten()) bidirectional_attn_mask_indptr[i + 1] = ( bidirectional_attn_mask_indptr[i] + bidirectional_attn_mask.nelement() ) - + if split_images: + num_split_images = len(split_images) + logger.warning_once( + f"{num_split_images} images are split across chunk boundaries. " + "Below are the first 5 images that are split across chunk boundaries: " + ) + for i, im_begin, im_end in split_images[:5]: + logger.warning_once( + f"Image {i}:{im_begin}-{im_end} is split across chunk boundaries.\n", + ) + logger.warning_once( + "Those images will receive causal attention. Disable chunked prefill (--chunked-prefill-size=-1) for full bidirectional attention.", + ) if bidirectional_attn_masks_list: bidirectional_attn_masks = torch.cat(bidirectional_attn_masks_list, dim=0) forward_batch.attn_backend.forward_metadata.mask_indptr = ( From 1d5117d38e86aff7bcbe76968c186f2e9f3798dc Mon Sep 17 00:00:00 2001 From: Pengyu Chen Date: Tue, 17 Mar 2026 04:30:08 +0000 Subject: [PATCH 032/112] gemma4 moe --- .../srt/layers/moe/fused_moe_triton/layer.py | 25 ++ python/sglang/srt/mem_cache/memory_pool.py | 3 +- .../model_runner_kv_cache_mixin.py | 7 +- python/sglang/srt/models/gemma4_causal.py | 290 +++++++++++++++--- python/sglang/srt/models/gemma4_mm.py | 62 +++- .../sglang/srt/utils/hf_transformers_utils.py | 5 + 6 files changed, 331 insertions(+), 61 deletions(-) diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 8da7d8eef330..071621872e97 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -1059,6 +1059,31 @@ def make_expert_params_mapping_fused( ("experts.w2_weight", f"experts.{ckpt_down_proj_name}", "w2"), ("experts.w2_weight_bias", f"experts.{ckpt_down_proj_bias_name}", "w2"), ] + + @classmethod + def make_expert_params_mapping_gemma4( + cls, + ckpt_gate_proj_name: str, + ckpt_down_proj_name: str, + ckpt_up_proj_name: str, + ): + return [ + # (param_name, weight_name, shard_id) + ( + ( + "experts.w13_weight" + if weight_name in [ckpt_gate_proj_name, ckpt_up_proj_name] + else "experts.w2_weight" + ), + f"experts.{weight_name}", + shard_id, + ) + for shard_id, weight_name in [ + ("w1", ckpt_gate_proj_name), + ("w2", ckpt_down_proj_name), + ("w3", ckpt_up_proj_name), + ] + ] @classmethod def make_expert_params_mapping_fused_mxfp4( diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 2a6bd9e4c81e..a7587e8b42ae 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -978,14 +978,13 @@ def get_kv_buffer(self, layer_id: int): def set_kv_buffer( self, - layer: Optional[RadixAttention], + layer: RadixAttention, loc: torch.Tensor, cache_k: torch.Tensor, cache_v: torch.Tensor, k_scale: Optional[float] = None, v_scale: Optional[float] = None, layer_id_override: Optional[int] = None, - row_dim: Optional[int] = None, ): if layer_id_override is not None: layer_id = layer_id_override diff --git a/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py b/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py index cc4b7339e2c5..577508b1fc23 100644 --- a/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py +++ b/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py @@ -548,8 +548,8 @@ def _init_pools(self: ModelRunner): // get_attention_tp_size(), ), "swa_head_dim": self.model_config.hf_text_config.swa_head_dim, - "swa_v_head_dim": self.model_config.hf_text_config.swa_head_dim, - "v_head_dim": self.model_config.hf_text_config.head_dim, + "swa_v_head_dim": self.model_config.hf_text_config.swa_v_head_dim, + "v_head_dim": self.model_config.hf_text_config.v_head_dim, } self.token_to_kv_pool = SWAKVPool( size=self.full_max_total_num_tokens, @@ -619,8 +619,6 @@ def _init_pools(self: ModelRunner): ), ) else: - # self.max_total_num_tokens = self.max_total_num_tokens // 2 if global_head_dim is not None else self.max_total_num_tokens - # print(f"global_head_dim: {global_head_dim}, head_dim: {self.model_config.head_dim}, head_num: {self.model_config.get_total_num_kv_heads()}, max_total_num_tokens: {self.max_total_num_tokens}") self.token_to_kv_pool = MHATokenToKVPool( self.max_total_num_tokens, page_size=self.page_size, @@ -628,7 +626,6 @@ def _init_pools(self: ModelRunner): head_num=self.model_config.get_num_kv_heads( get_attention_tp_size() ), - # head_dim=self.model_config.head_dim if global_head_dim is None else global_head_dim, head_dim=self.model_config.head_dim, layer_num=self.num_effective_layers, device=self.device, diff --git a/python/sglang/srt/models/gemma4_causal.py b/python/sglang/srt/models/gemma4_causal.py index a546a2ba4de1..d837e5afdd90 100644 --- a/python/sglang/srt/models/gemma4_causal.py +++ b/python/sglang/srt/models/gemma4_causal.py @@ -25,6 +25,10 @@ PreTrainedModel, ) +from sglang.srt.distributed import ( + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce, +) from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.layers.layernorm import Gemma4RMSNorm, GemmaRMSNorm, RMSNorm from sglang.srt.layers.linear import ( @@ -37,6 +41,19 @@ from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +<<<<<<< HEAD +======= +from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import ( + default_weight_loader, + maybe_remap_kv_scale_name, +) +from sglang.srt.utils import add_prefix, make_layers +from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class +from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE +from sglang.srt.server_args import get_global_server_args +>>>>>>> 8aaf187f9 (gemma4 moe) from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, @@ -142,51 +159,138 @@ def forward(self, input_ids: torch.Tensor) -> torch.Tensor: return per_layer_input -class Gemma4MoEBLock(nn.Module): +class Gemma4Router(nn.Module): + """Router for Gemma4 MoE that preprocesses input before projection. + + Applies RMSNorm (no learned weight), root_size scaling + (hidden_size^{-0.5}), then a learned per-dimension scale before + projecting to expert logits. + + This preprocessing is applied ONLY to the router's input, not to + the expert MLPs' input. + """ + def __init__( self, + config, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + + # RMSNorm without learned weight — pure normalization only + self.norm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps, with_scale=False) + # Per-dimension learned scale, applied after norm + root_size + self.scale = nn.Parameter(torch.ones(self.hidden_size)) + # Constant 1/sqrt(hidden_size) scaling factor + self.register_buffer( + "root_size", + torch.tensor(self.hidden_size**-0.5), + persistent=False, + ) + # Project to expert logits; replicated across TP for consistent routing + self.proj = ReplicatedLinear( + self.hidden_size, + config.num_experts, + bias=False, + quant_config=None, + prefix=add_prefix("proj", prefix), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Returns raw router logits [T, E].""" + x = self.norm(x) + x = x * self.root_size.to(x.dtype) + x = x * self.scale.to(x.dtype) + router_logits, _ = self.proj(x) + return router_logits + + +class Gemma4MoE(nn.Module): + """Mixture of Experts for Gemma4. + + Wraps MoE implementation with custom routing. The router projection is + external (Gemma4Router) — this class only handles expert dispatch. + + Gemma4 routing: softmax over ALL experts → top-k → renormalize. + per_expert_scale is folded into routing weights for mathematical + correctness with MoE's fused kernel. + """ + + def __init__( + self, + hidden_size: int, layer_id: int, config: Gemma4TextConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", - ): + ) -> None: super().__init__() - self.tp_size = get_tensor_model_parallel_world_size() self.layer_id = layer_id - self.activation = config.hidden_act + self.hidden_size = hidden_size + self.num_experts = config.num_experts + self.tp_size = get_tensor_model_parallel_world_size() + + # Per-expert output scale folded into routing weights so that + # MoE's fused kernel computes: Σ_e (expert_e * w_e * scale_e) + self.per_expert_scale = nn.Parameter(torch.ones(config.num_experts)) + + # Gemma4 routing: softmax over ALL experts → top-k → renormalize. + per_expert_scale = self.per_expert_scale + + def routing_function( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + ) -> tuple[torch.Tensor, torch.Tensor]: + _, topk_ids = torch.topk(gating_output, k=topk, dim=-1) + router_probabilities = torch.nn.functional.softmax(gating_output, dim=-1) + indicator = torch.nn.functional.one_hot( + topk_ids, num_classes=gating_output.size(-1) + ).sum(dim=-2) + gate_weights = indicator * router_probabilities + renorm_factor = torch.sum(gate_weights, dim=-1, keepdim=True) + renorm_factor = torch.where(renorm_factor > 0.0, renorm_factor, 1.0) + dispatch_weights = gate_weights / renorm_factor + + topk_weights = dispatch_weights.gather(1, topk_ids) + + # Fold per_expert_scale into routing weights + expert_scales = per_expert_scale[topk_ids].to(topk_weights.dtype) + topk_weights = topk_weights * expert_scales + + return topk_weights.to(torch.float32), topk_ids.to(torch.int32) self.topk = TopK( - top_k=config.num_experts_per_tok, - renormalize=True, + top_k=config.top_k_experts, layer_id=layer_id, + custom_routing_function=routing_function, ) - self.top_k = config.num_experts_per_tok + experts_type = get_moe_impl_class(quant_config) self.experts = experts_type( - num_experts=config.num_local_experts - + get_global_server_args().ep_num_redundant_experts, - top_k=config.num_experts_per_tok, - layer_id=layer_id, + num_experts=config.num_experts + get_global_server_args().ep_num_redundant_experts, hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, + intermediate_size=config.expert_intermediate_size, + layer_id=layer_id, + top_k=config.top_k_experts, quant_config=quant_config, - activation=self.activation, - gemm1_alpha=self.gemm1_alpha, - gemm1_clamp_limit=self.gemm1_clamp_limit, - with_bias=True, prefix=add_prefix("experts", prefix), + activation="gelu", + reduce_results=True, ) - self.router = ReplicatedLinear( - config.hidden_size, - config.num_local_experts, - bias=True, - quant_config=None, - prefix=add_prefix("gate", prefix), - params_dtype=config.torch_dtype, - ) + def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor, should_allreduce_fusion: bool = False) -> torch.Tensor: + num_tokens, hidden_dim = hidden_states.shape + topk_output = self.topk(hidden_states, router_logits) + hidden_states = self.experts(hidden_states, topk_output) + if self.tp_size > 1 and not should_allreduce_fusion: + hidden_states = tensor_model_parallel_all_reduce(hidden_states) + return hidden_states.view(num_tokens, hidden_dim) class Gemma4Attention(nn.Module): def __init__( @@ -194,6 +298,7 @@ def __init__( layer_id: int, config: Gemma4TextConfig, head_dim: int, + total_num_kv_heads: int, max_position_embeddings: int, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", @@ -388,11 +493,21 @@ def __init__( else: head_dim = getattr(config, "swa_head_dim", config.head_dim) + if self.is_full_attention and getattr(config, "attention_k_eq_v", False): + total_num_kv_heads = getattr( + config, "num_global_key_value_heads", + config.num_key_value_heads + ) + else: + total_num_kv_heads = getattr( + config, "swa_num_key_value_heads", config.num_key_value_heads + ) self.self_attn = Gemma4Attention( layer_id=layer_id, config=config, max_position_embeddings=config.max_position_embeddings, head_dim=head_dim, + total_num_kv_heads=total_num_kv_heads, quant_config=quant_config, prefix=add_prefix("self_attn", prefix), ) @@ -428,7 +543,7 @@ def __init__( ) # Per-Layer Embedding (PLE) components — present in each decoder layer - if self.hidden_size_per_layer_input > 0: + if self.hidden_size_per_layer_input and self.hidden_size_per_layer_input > 0: # Gate: projects hidden_states → per-layer dim for gating self.per_layer_input_gate = ReplicatedLinear( self.hidden_size, @@ -454,6 +569,40 @@ def __init__( self.post_per_layer_input_norm = None self.register_buffer("layer_scalar", torch.ones(1), persistent=True) + # Parallel MoE + self.enable_moe_block = getattr(config, "enable_moe_block", False) or getattr( + config, "use_second_mlp_block", False + ) + if self.enable_moe_block: + self.router = Gemma4Router( + config, + quant_config=quant_config, + prefix=add_prefix("router", prefix), + ) + self.moe = Gemma4MoE( + hidden_size=self.hidden_size, + layer_id=layer_id, + config=config, + quant_config=quant_config, + prefix=add_prefix("moe", prefix), + ) + + self.post_feedforward_layernorm_1 = GemmaRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_feedforward_layernorm_2 = GemmaRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.pre_feedforward_layernorm_2 = GemmaRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + else: + self.router = None + self.moe = None + self.post_feedforward_layernorm_1 = None + self.post_feedforward_layernorm_2 = None + self.pre_feedforward_layernorm_2 = None + self.prefix = prefix def forward( @@ -482,8 +631,27 @@ def forward( hidden_states = hidden_states + residual residual = hidden_states - hidden_states = self.pre_feedforward_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) + + if self.enable_moe_block: + # Dense MLP branch + hidden_states_1 = self.pre_feedforward_layernorm(hidden_states) + hidden_states_1 = self.mlp(hidden_states_1) + hidden_states_1 = self.post_feedforward_layernorm_1(hidden_states_1) + + # MoE branch: router sees raw hidden_states (applies its own + # norm + scale internally); experts see separately normed input + router_logits = self.router(hidden_states) + hidden_states_2 = self.pre_feedforward_layernorm_2(hidden_states) + hidden_states_2 = self.moe(hidden_states_2, router_logits) + hidden_states_2 = self.moe_out(hidden_states_2) + hidden_states_2 = self.post_feedforward_layernorm_2(hidden_states_2) + + # Combine branches + hidden_states = hidden_states_1 + hidden_states_2 + else: + hidden_states = self.pre_feedforward_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_feedforward_layernorm(hidden_states) hidden_states = hidden_states + residual @@ -538,7 +706,7 @@ def __init__( config, "vocab_size_per_layer_input", None ) or config.vocab_size - if self.hidden_size_per_layer_input > 0: + if self.hidden_size_per_layer_input and self.hidden_size_per_layer_input > 0: self.embed_tokens_per_layer = Gemma4TextScaledWordEmbedding( self.vocab_size_per_layer_input, config.num_hidden_layers * self.hidden_size_per_layer_input, @@ -835,6 +1003,14 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), ] + + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=getattr(self.config, "num_experts", 0), + ) + k_eq_v_layers = self._get_k_eq_v_layers() params_dict = dict(self.named_parameters()) @@ -854,6 +1030,14 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): and int(m.group(1)) in k_eq_v_layers ) + + if ( + ".moe." in name + and "experts" not in name + and "per_expert_scale" not in name + ): + name = name.replace(".moe.", ".moe.experts.") + for param_name, shard_name, shard_id in stacked_params_mapping: if shard_name not in name: continue @@ -869,20 +1053,41 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader(param, loaded_weight, "v") break else: - if "lm_head.weight" in name and self.config.tie_word_embeddings: - continue - if name.endswith(".bias") and name not in params_dict: - continue - name = maybe_remap_kv_scale_name(name, params_dict) - if name is None: - continue - - if name not in params_dict: - continue - - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + if name in params_dict.keys(): + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader( + param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, + ) + else: + raise KeyError( + f"Parameter '{name}' not found in model." + ) + break + else: + if "lm_head.weight" in name and self.config.tie_word_embeddings: + continue + if name.endswith(".bias") and name not in params_dict: + continue + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + if name not in params_dict: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) loaded_params.add(name) unloaded_params = params_dict.keys() - loaded_params if unloaded_params: @@ -893,4 +1098,3 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): EntryClass = Gemma4ForCausalLM -AutoModel.register(Gemma4TextConfig, Gemma4ForCausalLM, exist_ok=True) diff --git a/python/sglang/srt/models/gemma4_mm.py b/python/sglang/srt/models/gemma4_mm.py index 86dcf9e6ec73..1b468e8db4a6 100644 --- a/python/sglang/srt/models/gemma4_mm.py +++ b/python/sglang/srt/models/gemma4_mm.py @@ -55,6 +55,9 @@ from sglang.srt.utils import add_prefix from sglang.srt.utils.hf_transformers_utils import get_processor +from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE +from sglang.srt.server_args import get_global_server_args + logger = logging.getLogger(__name__) cached_get_processor = lru_cache(get_processor) @@ -575,6 +578,13 @@ def _get_k_eq_v_layers(self) -> set: def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): k_eq_v_layers = self._get_k_eq_v_layers() + expert_params_mapping = FusedMoE.make_expert_params_mapping_gemma4( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + ) + num_experts = self.config.num_experts + params_dict = dict(self.named_parameters()) params_dict.update(dict(self.named_buffers())) loaded_params: Set[str] = set() @@ -589,6 +599,14 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): name = re.sub(r"^model\.", "", name) + # moe experts + if ( + ".moe." in name + and "experts" not in name + and "per_expert_scale" not in name + ): + name = name.replace(".moe.", ".moe.experts.") + # Remap vision / audio tower names (fused QKV/GateUp, clippable wrappers) if "vision_tower." in name or "audio_tower." in name: name = self._remap_tower_name(name, params_dict) @@ -608,6 +626,15 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Try stacked (fused) params first orig_name = name for param_name, weight_name, shard_id in self.stacked_params_mapping: + name = orig_name + m = re.search(r"language_model.layers\.(\d+)\.", name) + if "k_proj" in name and "k_proj" in weight_name and k_eq_v_layer_indices and m and int(m.group(1)) in k_eq_v_layer_indices: + n = name.replace(weight_name, param_name) + param = params_dict[n] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight, "v") + loaded_params.add(n) + if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -621,16 +648,29 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader(param, loaded_weight, "v") break else: - if name.endswith(".bias") and name not in params_dict: - continue - name = maybe_remap_kv_scale_name(name, params_dict) - if name is None: - continue - if name not in params_dict: - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) + for param_name, weight_name, shard_id in expert_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + for i in range(num_experts): + weight_loader(param, loaded_weight[i].T, name, shard_id, i) + break + else: + if name.endswith(".bias") and name not in params_dict: + continue + + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) loaded_params.add(name) unloaded_params = params_dict.keys() - loaded_params if unloaded_params: @@ -660,7 +700,7 @@ def get_hidden_dim(self, module_name, layer_idx): ) elif module_name == "o_proj": return ( - self.config.head_dim * self.config.num_attention_heads, + self.head_dim * self.num_attention_heads, self.config.hidden_size, ) elif module_name == "gate_up_proj": diff --git a/python/sglang/srt/utils/hf_transformers_utils.py b/python/sglang/srt/utils/hf_transformers_utils.py index 3d3ad13ea072..38c046665381 100644 --- a/python/sglang/srt/utils/hf_transformers_utils.py +++ b/python/sglang/srt/utils/hf_transformers_utils.py @@ -418,6 +418,11 @@ def get_config( if global_kv_heads is not None: text_config.num_key_value_heads = global_kv_heads + if not hasattr(config.text_config, "v_head_dim"): + config.text_config.v_head_dim = config.text_config.head_dim + if not hasattr(config.text_config, "swa_v_head_dim"): + config.text_config.swa_v_head_dim = config.text_config.swa_head_dim + if config.model_type == "longcat_flash": config.update({"architectures": ["LongcatFlashForCausalLM"]}) From 3bc956e512d98ee36eee038338ba86e24c87b2b1 Mon Sep 17 00:00:00 2001 From: Pengyu Chen Date: Tue, 17 Mar 2026 06:50:44 +0000 Subject: [PATCH 033/112] bench_hf + fixes for moe --- benchmark/mmlu/bench_hf.py | 151 ++++++++++++++++++++++ python/sglang/srt/models/gemma4_causal.py | 92 ++++--------- python/sglang/srt/models/gemma4_mm.py | 7 - 3 files changed, 177 insertions(+), 73 deletions(-) create mode 100644 benchmark/mmlu/bench_hf.py diff --git a/benchmark/mmlu/bench_hf.py b/benchmark/mmlu/bench_hf.py new file mode 100644 index 000000000000..ce7bc6075706 --- /dev/null +++ b/benchmark/mmlu/bench_hf.py @@ -0,0 +1,151 @@ +""" +Usage: +python3 bench_hf.py --model-path meta-llama/Llama-2-7b-hf --data-dir data --ntrain 5 +""" + +import argparse +import json +import os +import time + +import numpy as np +import pandas as pd +import torch +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer + +choices = ["A", "B", "C", "D"] + + +def format_subject(subject): + l = subject.split("_") + s = "" + for entry in l: + s += " " + entry + return s + + +def format_example(df, idx, include_answer=True): + prompt = df.iloc[idx, 0] + k = df.shape[1] - 2 + for j in range(k): + prompt += "\n{}. {}".format(choices[j], df.iloc[idx, j + 1]) + prompt += "\nAnswer:" + if include_answer: + prompt += " {}\n\n".format(df.iloc[idx, k + 1]) + return prompt + + +def gen_prompt(train_df, subject, k=-1): + prompt = "The following are multiple choice questions (with answers) about{}.\n\n".format( + format_subject(subject) + ) + if k == -1: + k = train_df.shape[0] + for i in range(k): + prompt += format_example(train_df, i) + return prompt + + +@torch.no_grad() +def main(args): + print(f"Loading model: {args.model_path}") + tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True) + model = AutoModelForCausalLM.from_pretrained( + args.model_path, + torch_dtype=torch.bfloat16, + trust_remote_code=True, + device_map="auto", + ).eval() + + subjects = sorted( + [ + f.split("_test.csv")[0] + for f in os.listdir(os.path.join(args.data_dir, "test")) + if "_test.csv" in f + ] + ) + + all_cors = [] + num_requests = 0 + total_latency = 0 + + for subject in tqdm(subjects[: args.nsub]): + dev_df = pd.read_csv( + os.path.join(args.data_dir, "dev", subject + "_dev.csv"), header=None + )[: args.ntrain] + test_df = pd.read_csv( + os.path.join(args.data_dir, "test", subject + "_test.csv"), header=None + ) + + k = args.ntrain + few_shot_examples = gen_prompt(dev_df, subject, k) + while len(tokenizer.encode(few_shot_examples)) > 1536: + k -= 1 + if k < 0: + break + few_shot_examples = gen_prompt(dev_df, subject, k) + + preds = [] + labels = [] + tic = time.perf_counter() + + for i in range(test_df.shape[0]): + prompt_end = format_example(test_df, i, include_answer=False) + prompt = few_shot_examples + prompt_end + + input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device) + output_ids = model.generate( + input_ids, + max_new_tokens=1, + do_sample=False, + pad_token_id=tokenizer.eos_token_id, + ) + + output_str = tokenizer.decode( + output_ids[0][input_ids.shape[-1] :], skip_special_tokens=True + ) + preds.append(output_str.strip()[0] if len(output_str.strip()) > 0 else "") + labels.append(test_df.iloc[i, test_df.shape[1] - 1]) + + latency = time.perf_counter() - tic + total_latency += latency + + cors = [pred == label for pred, label in zip(preds, labels)] + all_cors.append(cors) + num_requests += len(test_df) + + print( + f"Subject: {subject}, Accuracy: {np.mean(cors):.3f}, Latency: {latency:.3f}s" + ) + + weighted_acc = np.mean(np.concatenate(all_cors)) + print(f"Total Latency: {total_latency:.3f}s") + print(f"Average Accuracy: {weighted_acc:.3f}") + + if args.output: + with open(args.output, "a") as fout: + value = { + "task": "mmlu", + "backend": "hf", + "model": args.model_path, + "latency": round(total_latency, 3), + "accuracy": round(weighted_acc, 3), + "num_requests": num_requests, + "other": { + "nsub": args.nsub, + "ntrain": args.ntrain, + }, + } + fout.write(json.dumps(value) + "\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model-path", type=str, required=True) + parser.add_argument("--ntrain", type=int, default=5) + parser.add_argument("--data-dir", type=str, default="data") + parser.add_argument("--nsub", type=int, default=60) + parser.add_argument("--output", type=str, help="Output file path") + args = parser.parse_args() + main(args) \ No newline at end of file diff --git a/python/sglang/srt/models/gemma4_causal.py b/python/sglang/srt/models/gemma4_causal.py index d837e5afdd90..70be12de6ef1 100644 --- a/python/sglang/srt/models/gemma4_causal.py +++ b/python/sglang/srt/models/gemma4_causal.py @@ -298,7 +298,6 @@ def __init__( layer_id: int, config: Gemma4TextConfig, head_dim: int, - total_num_kv_heads: int, max_position_embeddings: int, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", @@ -493,21 +492,11 @@ def __init__( else: head_dim = getattr(config, "swa_head_dim", config.head_dim) - if self.is_full_attention and getattr(config, "attention_k_eq_v", False): - total_num_kv_heads = getattr( - config, "num_global_key_value_heads", - config.num_key_value_heads - ) - else: - total_num_kv_heads = getattr( - config, "swa_num_key_value_heads", config.num_key_value_heads - ) self.self_attn = Gemma4Attention( layer_id=layer_id, config=config, max_position_embeddings=config.max_position_embeddings, head_dim=head_dim, - total_num_kv_heads=total_num_kv_heads, quant_config=quant_config, prefix=add_prefix("self_attn", prefix), ) @@ -643,7 +632,6 @@ def forward( router_logits = self.router(hidden_states) hidden_states_2 = self.pre_feedforward_layernorm_2(hidden_states) hidden_states_2 = self.moe(hidden_states_2, router_logits) - hidden_states_2 = self.moe_out(hidden_states_2) hidden_states_2 = self.post_feedforward_layernorm_2(hidden_states_2) # Combine branches @@ -713,13 +701,7 @@ def __init__( self.padding_idx, embed_scale=self.hidden_size_per_layer_input**0.5, ) - - # Scaled embedding factor (from config, not hardcoded) - # self.embed_scale_per_layer = torch.tensor( - # self.hidden_size_per_layer_input**0.5, - # ) - - # FIXME: Use replicated for now. Use ColumnParallel?. + self.per_layer_model_projection = ReplicatedLinear( self.hidden_size, config.num_hidden_layers * self.hidden_size_per_layer_input, @@ -755,18 +737,6 @@ def __init__( ) self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - # self.per_layer_projection_scale = torch.tensor( - # config.hidden_size**-0.5, - # ) - # self.register_buffer( - # "per_layer_input_scale", torch.rsqrt(torch.tensor(2.0)), persistent=False - # ) - # self.register_buffer( - # "normalizer", - # torch.tensor(config.hidden_size**0.5), - # persistent=False, - # ) self.post_init() def get_input_embeddings(self) -> nn.Embedding: @@ -1004,12 +974,12 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("gate_up_proj", "up_proj", 1), ] - expert_params_mapping = FusedMoE.make_expert_params_mapping( + expert_params_mapping = FusedMoE.make_expert_params_mapping_gemma4( ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", - num_experts=getattr(self.config, "num_experts", 0), ) + num_experts = self.config.num_experts k_eq_v_layers = self._get_k_eq_v_layers() @@ -1019,6 +989,13 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): for name, loaded_weight in weights: name = name.replace("model.language_model.", "model.") + if ( + ".moe." in name + and "experts" not in name + and "per_expert_scale" not in name + ): + name = name.replace(".moe.", ".moe.experts.") + # attention_k_eq_v: full-attention layers have no v_proj in the # checkpoint (K and V share weights). When we see a k_proj weight # for one of these layers, load it into both the "k" and "v" shards @@ -1030,21 +1007,16 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): and int(m.group(1)) in k_eq_v_layers ) - - if ( - ".moe." in name - and "experts" not in name - and "per_expert_scale" not in name - ): - name = name.replace(".moe.", ".moe.experts.") - - for param_name, shard_name, shard_id in stacked_params_mapping: - if shard_name not in name: - continue - name = name.replace(shard_name, param_name) - if name.endswith(".bias") and name not in params_dict: + # Try stacked (fused) params first + orig_name = name + for param_name, weight_name, shard_id in self.stacked_params_mapping: + name = orig_name + m = re.search(r".layers\.(\d+)\.", name) + if weight_name not in name: continue + name = name.replace(weight_name, param_name) if name not in params_dict: + name = orig_name continue param = params_dict[name] weight_loader = param.weight_loader @@ -1053,38 +1025,26 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader(param, loaded_weight, "v") break else: - for mapping in expert_params_mapping: - param_name, weight_name, expert_id, shard_id = mapping + for param_name, weight_name, shard_id in expert_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) - if name in params_dict.keys(): - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader( - param, - loaded_weight, - name, - shard_id=shard_id, - expert_id=expert_id, - ) - else: - raise KeyError( - f"Parameter '{name}' not found in model." - ) + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + for i in range(num_experts): + weight_loader(param, loaded_weight[i].T, name, shard_id, i) break else: - if "lm_head.weight" in name and self.config.tie_word_embeddings: - continue if name.endswith(".bias") and name not in params_dict: continue + name = maybe_remap_kv_scale_name(name, params_dict) if name is None: continue - if name not in params_dict: continue - param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) diff --git a/python/sglang/srt/models/gemma4_mm.py b/python/sglang/srt/models/gemma4_mm.py index 1b468e8db4a6..690632051f15 100644 --- a/python/sglang/srt/models/gemma4_mm.py +++ b/python/sglang/srt/models/gemma4_mm.py @@ -628,13 +628,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): for param_name, weight_name, shard_id in self.stacked_params_mapping: name = orig_name m = re.search(r"language_model.layers\.(\d+)\.", name) - if "k_proj" in name and "k_proj" in weight_name and k_eq_v_layer_indices and m and int(m.group(1)) in k_eq_v_layer_indices: - n = name.replace(weight_name, param_name) - param = params_dict[n] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight, "v") - loaded_params.add(n) - if weight_name not in name: continue name = name.replace(weight_name, param_name) From 0ee8f82a51b376cfb7c1e0d7d332bc99982ae0cb Mon Sep 17 00:00:00 2001 From: Pengyu Chen Date: Tue, 17 Mar 2026 07:12:18 +0000 Subject: [PATCH 034/112] format --- python/sglang/srt/models/gemma4_causal.py | 5 ----- python/sglang/srt/models/gemma4_mm.py | 6 +----- 2 files changed, 1 insertion(+), 10 deletions(-) diff --git a/python/sglang/srt/models/gemma4_causal.py b/python/sglang/srt/models/gemma4_causal.py index 70be12de6ef1..1a9080e2c7e2 100644 --- a/python/sglang/srt/models/gemma4_causal.py +++ b/python/sglang/srt/models/gemma4_causal.py @@ -1008,15 +1008,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ) # Try stacked (fused) params first - orig_name = name for param_name, weight_name, shard_id in self.stacked_params_mapping: - name = orig_name - m = re.search(r".layers\.(\d+)\.", name) if weight_name not in name: continue name = name.replace(weight_name, param_name) if name not in params_dict: - name = orig_name continue param = params_dict[name] weight_loader = param.weight_loader @@ -1039,7 +1035,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): else: if name.endswith(".bias") and name not in params_dict: continue - name = maybe_remap_kv_scale_name(name, params_dict) if name is None: continue diff --git a/python/sglang/srt/models/gemma4_mm.py b/python/sglang/srt/models/gemma4_mm.py index 690632051f15..78e5b4883f6f 100644 --- a/python/sglang/srt/models/gemma4_mm.py +++ b/python/sglang/srt/models/gemma4_mm.py @@ -578,6 +578,7 @@ def _get_k_eq_v_layers(self) -> set: def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): k_eq_v_layers = self._get_k_eq_v_layers() + # TODO(pyc96): revisit and simplify. expert_params_mapping = FusedMoE.make_expert_params_mapping_gemma4( ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", @@ -624,15 +625,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ) # Try stacked (fused) params first - orig_name = name for param_name, weight_name, shard_id in self.stacked_params_mapping: - name = orig_name - m = re.search(r"language_model.layers\.(\d+)\.", name) if weight_name not in name: continue name = name.replace(weight_name, param_name) if name not in params_dict: - name = orig_name continue param = params_dict[name] weight_loader = param.weight_loader @@ -655,7 +652,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): else: if name.endswith(".bias") and name not in params_dict: continue - name = maybe_remap_kv_scale_name(name, params_dict) if name is None: continue From 897fc8ba481dc1f09e9f1bbd94ceac9393a7e0ce Mon Sep 17 00:00:00 2001 From: Pengyu Chen Date: Wed, 18 Mar 2026 07:15:13 +0000 Subject: [PATCH 035/112] address comments --- python/sglang/srt/models/gemma4_causal.py | 13 +++++++------ python/sglang/srt/models/gemma4_mm.py | 6 +++++- python/sglang/srt/utils/hf_transformers_utils.py | 8 ++++---- 3 files changed, 16 insertions(+), 11 deletions(-) diff --git a/python/sglang/srt/models/gemma4_causal.py b/python/sglang/srt/models/gemma4_causal.py index 1a9080e2c7e2..229a9b9f3a7c 100644 --- a/python/sglang/srt/models/gemma4_causal.py +++ b/python/sglang/srt/models/gemma4_causal.py @@ -283,13 +283,10 @@ def routing_function( reduce_results=True, ) - def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor, should_allreduce_fusion: bool = False) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor) -> torch.Tensor: num_tokens, hidden_dim = hidden_states.shape topk_output = self.topk(hidden_states, router_logits) hidden_states = self.experts(hidden_states, topk_output) - - if self.tp_size > 1 and not should_allreduce_fusion: - hidden_states = tensor_model_parallel_all_reduce(hidden_states) return hidden_states.view(num_tokens, hidden_dim) class Gemma4Attention(nn.Module): @@ -532,7 +529,7 @@ def __init__( ) # Per-Layer Embedding (PLE) components — present in each decoder layer - if self.hidden_size_per_layer_input and self.hidden_size_per_layer_input > 0: + if self.hidden_size_per_layer_input > 0: # Gate: projects hidden_states → per-layer dim for gating self.per_layer_input_gate = ReplicatedLinear( self.hidden_size, @@ -1008,7 +1005,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ) # Try stacked (fused) params first - for param_name, weight_name, shard_id in self.stacked_params_mapping: + orig_name = name + for param_name, weight_name, shard_id in stacked_params_mapping: + name = orig_name if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -1022,6 +1021,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): break else: for param_name, weight_name, shard_id in expert_params_mapping: + name = orig_name if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -1033,6 +1033,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader(param, loaded_weight[i].T, name, shard_id, i) break else: + name = orig_name if name.endswith(".bias") and name not in params_dict: continue name = maybe_remap_kv_scale_name(name, params_dict) diff --git a/python/sglang/srt/models/gemma4_mm.py b/python/sglang/srt/models/gemma4_mm.py index 78e5b4883f6f..dc1b77173255 100644 --- a/python/sglang/srt/models/gemma4_mm.py +++ b/python/sglang/srt/models/gemma4_mm.py @@ -625,7 +625,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ) # Try stacked (fused) params first + orig_name = name for param_name, weight_name, shard_id in self.stacked_params_mapping: + name = orig_name if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -639,6 +641,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): break else: for param_name, weight_name, shard_id in expert_params_mapping: + name = orig_name if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -650,6 +653,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader(param, loaded_weight[i].T, name, shard_id, i) break else: + name = orig_name if name.endswith(".bias") and name not in params_dict: continue name = maybe_remap_kv_scale_name(name, params_dict) @@ -689,7 +693,7 @@ def get_hidden_dim(self, module_name, layer_idx): ) elif module_name == "o_proj": return ( - self.head_dim * self.num_attention_heads, + self.config.head_dim * self.config.num_attention_heads, self.config.hidden_size, ) elif module_name == "gate_up_proj": diff --git a/python/sglang/srt/utils/hf_transformers_utils.py b/python/sglang/srt/utils/hf_transformers_utils.py index 38c046665381..646f490df29c 100644 --- a/python/sglang/srt/utils/hf_transformers_utils.py +++ b/python/sglang/srt/utils/hf_transformers_utils.py @@ -418,10 +418,10 @@ def get_config( if global_kv_heads is not None: text_config.num_key_value_heads = global_kv_heads - if not hasattr(config.text_config, "v_head_dim"): - config.text_config.v_head_dim = config.text_config.head_dim - if not hasattr(config.text_config, "swa_v_head_dim"): - config.text_config.swa_v_head_dim = config.text_config.swa_head_dim + if not hasattr(text_config, "v_head_dim"): + text_config.v_head_dim = text_config.head_dim + if not hasattr(text_config, "swa_v_head_dim"): + text_config.swa_v_head_dim = text_config.swa_head_dim if config.model_type == "longcat_flash": config.update({"architectures": ["LongcatFlashForCausalLM"]}) From 05057d3f3ae8622486cb39cf4dafaa76e353d912 Mon Sep 17 00:00:00 2001 From: Pengyu Chen Date: Thu, 19 Mar 2026 01:04:21 +0000 Subject: [PATCH 036/112] clean up + add chat template for sgl front lang --- python/sglang/lang/chat_template.py | 24 +++++++++++++------ .../srt/entrypoints/openai/serving_chat.py | 2 +- python/sglang/srt/layers/layernorm.py | 4 +--- .../srt/managers/detokenizer_manager.py | 1 - python/sglang/srt/mem_cache/memory_pool.py | 1 - .../sglang/srt/model_executor/model_runner.py | 2 -- scripts/playground/reference_hf.py | 11 ++++----- 7 files changed, 24 insertions(+), 21 deletions(-) diff --git a/python/sglang/lang/chat_template.py b/python/sglang/lang/chat_template.py index 212d07e0bebd..98cf501f159a 100644 --- a/python/sglang/lang/chat_template.py +++ b/python/sglang/lang/chat_template.py @@ -402,6 +402,19 @@ def get_chat_template_by_model_path(model_path): ) ) +register_chat_template( + ChatTemplate( + name="gemma-4-it", + default_system_prompt=None, + role_prefix_and_suffix={ + "system": ("", ""), + "user": ("<|turn>user\n", "\n"), + "assistant": ("<|turn>assistant\n", "\n"), + }, + style=ChatTemplateStyle.PLAIN, + ) +) + register_chat_template( ChatTemplate( name="dbrx-instruct", @@ -609,8 +622,10 @@ def match_chat_yi(model_path: str): @register_chat_template_matching_function -def match_gemma_it(model_path: str): - if re.search(r"gemma.*it", model_path, re.IGNORECASE): +def match_gemma(model_path: str): + if re.search(r"gemma-4.*it", model_path, re.IGNORECASE): + return "gemma-4-it" + if re.search(r"(gemma.*it)|(gemma-3)", model_path, re.IGNORECASE): return "gemma-it" @@ -634,11 +649,6 @@ def match_granite_instruct(model_path: str): return "granite-3-instruct" -@register_chat_template_matching_function -def match_gemma3_instruct(model_path: str): - if re.search(r"gemma-3", model_path, re.IGNORECASE): - return "gemma-it" - @register_chat_template_matching_function def match_internvl_chat(model_path: str): diff --git a/python/sglang/srt/entrypoints/openai/serving_chat.py b/python/sglang/srt/entrypoints/openai/serving_chat.py index ee8eae9cd0cf..405c9075a993 100644 --- a/python/sglang/srt/entrypoints/openai/serving_chat.py +++ b/python/sglang/srt/entrypoints/openai/serving_chat.py @@ -1233,7 +1233,7 @@ def _get_reasoning_from_request(self, request: ChatCompletionRequest) -> bool: """Judge whether the request needs reasoning""" if not self.reasoning_parser: return False - # Do we want to think by default? + if self.reasoning_parser in ["deepseek-v3", "gemma4"]: # Models that require explicit enable thinking (thinking=True) return ( diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index b7ca8f7b334f..e6d71bcc6c62 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -121,8 +121,6 @@ def forward_cuda( residual: Optional[torch.Tensor] = None, post_residual_addition: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - # # TODO: fix cuda: having some shape issue with sgl kernel - return self.forward_native(x, residual, post_residual_addition) if x.numel() == 0: return x if self.variance_size_override is not None: @@ -478,7 +476,7 @@ def forward_cuda( residual: Optional[torch.Tensor] = None, post_residual_addition: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - return self.forward_native(x, residual, post_residual_addition) + return self._forward_impl(x, residual, post_residual_addition) def forward_cpu( self, diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index 357f18e37b53..05cd787e6428 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -339,7 +339,6 @@ def _decode_batch_token_id_output(self, recv_obj: BatchTokenIDOutput): incremental_output = output_str[s.sent_offset :] s.sent_offset = len(output_str) output_strs.append(incremental_output) - # print(output_strs) return output_strs def _extract_routed_experts( diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index a7587e8b42ae..0b2383ed1cd9 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -834,7 +834,6 @@ def _create_buffers(self): ): # [size, head_num, head_dim] for each layer # The padded slot 0 is used for writing dummy outputs from padded tokens. - # adjust for global self.k_buffer = [ torch.zeros( (self.size + self.page_size, self.head_num, self.head_dim), diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index e35af50a83fd..e42dbc556230 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -1038,8 +1038,6 @@ def load_model(self): f"Setting sliding_window_size to be attention_chunk_size: {self.sliding_window_size}" ) - # Note(pyc): gemma4 has different swa def - self.dtype = self.model_config.dtype after_avail_memory = get_available_gpu_memory(self.device, self.gpu_id) diff --git a/scripts/playground/reference_hf.py b/scripts/playground/reference_hf.py index ab6b31677b18..48b5762106c2 100644 --- a/scripts/playground/reference_hf.py +++ b/scripts/playground/reference_hf.py @@ -124,8 +124,8 @@ def normal_text(args): prompts = [ "The capital of France is", - # "The capital of the United Kindom is", - # "Today is a sunny day and I like", + "The capital of the United Kindom is", + "Today is a sunny day and I like", ] max_new_tokens = args.max_new_tokens @@ -164,10 +164,9 @@ def synthetic_tokens(args): for p in prompts: input_ids = p for i in range(output_len + 1): - output = m.forward( - torch.tensor([input_ids], device="cuda"), output_hidden_states=True - ).logits[0][-1] - prefill_logits = output + prefill_logits = m.forward(torch.tensor([input_ids], device="cuda")).logits[ + 0 + ][-1] if i == 0: print("prefill logits", prefill_logits) else: From 66520efd97613a8b794ddfa2890580fa00c323e3 Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Thu, 19 Mar 2026 22:34:45 +0000 Subject: [PATCH 037/112] fix: gemma4 layer_scalar, num_experts guard, and RMSNorm 2D reshape - Register and apply layer_scalar unconditionally for all layers (not just full_attention), fixing garbage output on 26B-A4B MoE model - Guard num_experts access with getattr for non-MoE models (e2b crash) - Reshape higher-rank tensors to 2D before RMSNorm/GemmaRMSNorm kernel calls, following Gemma3nRMSNorm pattern (PLE 4D tensor compatibility) --- python/sglang/srt/layers/layernorm.py | 13 +++++++++++++ python/sglang/srt/models/gemma4_causal.py | 7 ++----- python/sglang/srt/models/gemma4_mm.py | 15 +++++++++------ 3 files changed, 24 insertions(+), 11 deletions(-) diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index e6d71bcc6c62..74fc26eed5a1 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -123,6 +123,11 @@ def forward_cuda( ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: if x.numel() == 0: return x + # sgl_kernel rmsnorm requires 2D input; reshape higher-rank tensors + needs_reshape = x.dim() != 2 and residual is None + if needs_reshape: + original_shape = x.shape + x = x.contiguous().reshape(-1, original_shape[-1]) if self.variance_size_override is not None: return self.forward_native(x, residual, post_residual_addition) if is_batch_invariant_mode_enabled(): @@ -146,6 +151,8 @@ def forward_cuda( fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon) return x, residual out = rmsnorm(x, self.weight.data, self.variance_epsilon) + if needs_reshape: + out = out.reshape(original_shape) return out def forward_npu( @@ -440,6 +447,10 @@ def _forward_impl( residual: Optional[torch.Tensor] = None, post_residual_addition: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + needs_reshape = x.dim() != 2 and residual is None + if needs_reshape: + original_shape = x.shape + x = x.contiguous().reshape(-1, original_shape[-1]) if residual is not None: if post_residual_addition is not None: residual = residual + post_residual_addition @@ -448,6 +459,8 @@ def _forward_impl( ) return x, residual out = gemma_rmsnorm(x, self.weight.data, self.variance_epsilon) + if needs_reshape: + out = out.reshape(original_shape) return out def forward_native( diff --git a/python/sglang/srt/models/gemma4_causal.py b/python/sglang/srt/models/gemma4_causal.py index 229a9b9f3a7c..cdb715bb90ef 100644 --- a/python/sglang/srt/models/gemma4_causal.py +++ b/python/sglang/srt/models/gemma4_causal.py @@ -41,8 +41,6 @@ from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention -<<<<<<< HEAD -======= from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import ( @@ -53,7 +51,6 @@ from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.server_args import get_global_server_args ->>>>>>> 8aaf187f9 (gemma4 moe) from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, @@ -554,7 +551,6 @@ def __init__( self.per_layer_projection = None self.post_per_layer_input_norm = None - self.register_buffer("layer_scalar", torch.ones(1), persistent=True) # Parallel MoE self.enable_moe_block = getattr(config, "enable_moe_block", False) or getattr( config, "use_second_mlp_block", False @@ -589,6 +585,7 @@ def __init__( self.post_feedforward_layernorm_2 = None self.pre_feedforward_layernorm_2 = None + self.register_buffer("layer_scalar", torch.ones(1), persistent=True) self.prefix = prefix def forward( @@ -657,7 +654,7 @@ def forward( per_layer_contribution ) hidden_states = hidden_states + per_layer_contribution - + hidden_states = hidden_states * self.layer_scalar return hidden_states, None diff --git a/python/sglang/srt/models/gemma4_mm.py b/python/sglang/srt/models/gemma4_mm.py index dc1b77173255..b5eeaa1d5d9e 100644 --- a/python/sglang/srt/models/gemma4_mm.py +++ b/python/sglang/srt/models/gemma4_mm.py @@ -579,12 +579,15 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): k_eq_v_layers = self._get_k_eq_v_layers() # TODO(pyc96): revisit and simplify. - expert_params_mapping = FusedMoE.make_expert_params_mapping_gemma4( - ckpt_gate_proj_name="gate_proj", - ckpt_down_proj_name="down_proj", - ckpt_up_proj_name="up_proj", - ) - num_experts = self.config.num_experts + num_experts = getattr(self.config.text_config, "num_experts", 0) or 0 + if num_experts > 0: + expert_params_mapping = FusedMoE.make_expert_params_mapping_gemma4( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + ) + else: + expert_params_mapping = [] params_dict = dict(self.named_parameters()) params_dict.update(dict(self.named_buffers())) From 977226f879d4ad61478d0beeacfc4bb2d0bf63d4 Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Thu, 19 Mar 2026 22:40:01 +0000 Subject: [PATCH 038/112] lint Signed-off-by: Xinyuan Tong --- benchmark/mmlu/bench_hf.py | 4 +- python/sglang/lang/chat_template.py | 1 - python/sglang/srt/layers/layernorm.py | 2 +- .../srt/layers/moe/fused_moe_triton/layer.py | 2 +- python/sglang/srt/models/gemma4_causal.py | 70 +++++++++---------- python/sglang/srt/models/gemma4_mm.py | 8 +-- 6 files changed, 42 insertions(+), 45 deletions(-) diff --git a/benchmark/mmlu/bench_hf.py b/benchmark/mmlu/bench_hf.py index ce7bc6075706..c76a18db685b 100644 --- a/benchmark/mmlu/bench_hf.py +++ b/benchmark/mmlu/bench_hf.py @@ -89,7 +89,7 @@ def main(args): preds = [] labels = [] tic = time.perf_counter() - + for i in range(test_df.shape[0]): prompt_end = format_example(test_df, i, include_answer=False) prompt = few_shot_examples + prompt_end @@ -148,4 +148,4 @@ def main(args): parser.add_argument("--nsub", type=int, default=60) parser.add_argument("--output", type=str, help="Output file path") args = parser.parse_args() - main(args) \ No newline at end of file + main(args) diff --git a/python/sglang/lang/chat_template.py b/python/sglang/lang/chat_template.py index 98cf501f159a..ac4d8e655c76 100644 --- a/python/sglang/lang/chat_template.py +++ b/python/sglang/lang/chat_template.py @@ -649,7 +649,6 @@ def match_granite_instruct(model_path: str): return "granite-3-instruct" - @register_chat_template_matching_function def match_internvl_chat(model_path: str): if re.search(r"internvl2_5", model_path, re.IGNORECASE): diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index 74fc26eed5a1..2164bccc3e68 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -597,7 +597,7 @@ def __repr__(self): def _norm(self, x): mean_squared = x.pow(2).mean(-1, keepdim=True) + self.eps - # Use torch.pow() (over torch.sqrt() or torch.rsqrt()) to addess compiler differences between Torch and JAX + # Use torch.pow() (over torch.sqrt() or torch.rsqrt()) to address compiler differences between Torch and JAX return x * torch.pow(mean_squared, -0.5) def forward(self, x: torch.Tensor) -> torch.Tensor: diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 071621872e97..b5e9b5ebb746 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -1059,7 +1059,7 @@ def make_expert_params_mapping_fused( ("experts.w2_weight", f"experts.{ckpt_down_proj_name}", "w2"), ("experts.w2_weight_bias", f"experts.{ckpt_down_proj_bias_name}", "w2"), ] - + @classmethod def make_expert_params_mapping_gemma4( cls, diff --git a/python/sglang/srt/models/gemma4_causal.py b/python/sglang/srt/models/gemma4_causal.py index cdb715bb90ef..dc464e00e233 100644 --- a/python/sglang/srt/models/gemma4_causal.py +++ b/python/sglang/srt/models/gemma4_causal.py @@ -19,7 +19,6 @@ import torch from torch import nn from transformers import ( - AutoModel, Gemma4TextConfig, PretrainedConfig, PreTrainedModel, @@ -27,9 +26,7 @@ from sglang.srt.distributed import ( get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce, ) -from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.layers.layernorm import Gemma4RMSNorm, GemmaRMSNorm, RMSNorm from sglang.srt.layers.linear import ( QKVParallelLinear, @@ -38,19 +35,10 @@ ) from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class +from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead -from sglang.srt.model_executor.forward_batch_info import ForwardBatch -from sglang.srt.model_loader.weight_utils import ( - default_weight_loader, - maybe_remap_kv_scale_name, -) -from sglang.srt.utils import add_prefix, make_layers -from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class -from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE -from sglang.srt.server_args import get_global_server_args from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, @@ -177,7 +165,9 @@ def __init__( self.hidden_size = config.hidden_size # RMSNorm without learned weight — pure normalization only - self.norm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps, with_scale=False) + self.norm = Gemma4RMSNorm( + self.hidden_size, eps=config.rms_norm_eps, with_scale=False + ) # Per-dimension learned scale, applied after norm + root_size self.scale = nn.Parameter(torch.ones(self.hidden_size)) # Constant 1/sqrt(hidden_size) scaling factor @@ -269,7 +259,8 @@ def routing_function( experts_type = get_moe_impl_class(quant_config) self.experts = experts_type( - num_experts=config.num_experts + get_global_server_args().ep_num_redundant_experts, + num_experts=config.num_experts + + get_global_server_args().ep_num_redundant_experts, hidden_size=config.hidden_size, intermediate_size=config.expert_intermediate_size, layer_id=layer_id, @@ -280,12 +271,15 @@ def routing_function( reduce_results=True, ) - def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor) -> torch.Tensor: + def forward( + self, hidden_states: torch.Tensor, router_logits: torch.Tensor + ) -> torch.Tensor: num_tokens, hidden_dim = hidden_states.shape topk_output = self.topk(hidden_states, router_logits) hidden_states = self.experts(hidden_states, topk_output) return hidden_states.view(num_tokens, hidden_dim) + class Gemma4Attention(nn.Module): def __init__( self, @@ -303,7 +297,9 @@ def __init__( tp_size = get_tensor_model_parallel_world_size() layer_type = config.layer_types[layer_id] - self.sliding_window = config.sliding_window if layer_type == "sliding_attention" else None + self.sliding_window = ( + config.sliding_window if layer_type == "sliding_attention" else None + ) self.total_num_heads = config.num_attention_heads assert self.total_num_heads % tp_size == 0 @@ -371,17 +367,17 @@ def __init__( # KV sharing logic num_kv_shared_layers = getattr(config, "num_kv_shared_layers", 0) - first_kv_shared_layer_idx = ( - config.num_hidden_layers - num_kv_shared_layers + first_kv_shared_layer_idx = config.num_hidden_layers - num_kv_shared_layers + self.is_kv_shared_layer = ( + layer_id >= first_kv_shared_layer_idx and num_kv_shared_layers > 0 ) - self.is_kv_shared_layer = layer_id >= first_kv_shared_layer_idx and num_kv_shared_layers > 0 self.kv_shared_layer_index = None if num_kv_shared_layers > 0 and self.layer_id >= first_kv_shared_layer_idx: prev_layers = config.layer_types[:first_kv_shared_layer_idx] current_layer_type = config.layer_types[self.layer_id] - self.kv_shared_layer_index = len(prev_layers) - 1 - prev_layers[::-1].index( - current_layer_type + self.kv_shared_layer_index = ( + len(prev_layers) - 1 - prev_layers[::-1].index(current_layer_type) ) self.rotary_emb = get_rope( @@ -472,9 +468,9 @@ def __init__( ) -> None: super().__init__() self.hidden_size = config.hidden_size - self.hidden_size_per_layer_input = getattr( - config, "hidden_size_per_layer_input", None - ) or 0 + self.hidden_size_per_layer_input = ( + getattr(config, "hidden_size_per_layer_input", None) or 0 + ) self.layer_id = layer_id @@ -654,7 +650,7 @@ def forward( per_layer_contribution ) hidden_states = hidden_states + per_layer_contribution - + hidden_states = hidden_states * self.layer_scalar return hidden_states, None @@ -676,17 +672,17 @@ def __init__( config.vocab_size, config.hidden_size, self.padding_idx, - embed_scale=self.config.hidden_size**0.5, # embeded normalizer + embed_scale=self.config.hidden_size**0.5, # embedded normalizer ) # Per-layer input embeddings self.hidden_size = config.hidden_size - self.hidden_size_per_layer_input = getattr( - config, "hidden_size_per_layer_input", None - ) or 0 - self.vocab_size_per_layer_input = getattr( - config, "vocab_size_per_layer_input", None - ) or config.vocab_size + self.hidden_size_per_layer_input = ( + getattr(config, "hidden_size_per_layer_input", None) or 0 + ) + self.vocab_size_per_layer_input = ( + getattr(config, "vocab_size_per_layer_input", None) or config.vocab_size + ) if self.hidden_size_per_layer_input and self.hidden_size_per_layer_input > 0: self.embed_tokens_per_layer = Gemma4TextScaledWordEmbedding( @@ -695,7 +691,7 @@ def __init__( self.padding_idx, embed_scale=self.hidden_size_per_layer_input**0.5, ) - + self.per_layer_model_projection = ReplicatedLinear( self.hidden_size, config.num_hidden_layers * self.hidden_size_per_layer_input, @@ -757,7 +753,7 @@ def get_per_layer_inputs(self, input_ids: torch.LongTensor) -> torch.Tensor: per_layer_embeds = self.embed_tokens_per_layer(per_layer_inputs_tokens) # Apply embed_scale (sqrt of per-layer hidden dim) - # Alreayd done in embedding layer + # Already done in embedding layer # per_layer_embeds = per_layer_embeds * self.embed_scale_per_layer # Reshape to (num_tokens, num_layers, hidden_size_per_layer_input) @@ -1039,7 +1035,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if name not in params_dict: continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) loaded_params.add(name) unloaded_params = params_dict.keys() - loaded_params diff --git a/python/sglang/srt/models/gemma4_mm.py b/python/sglang/srt/models/gemma4_mm.py index b5eeaa1d5d9e..8482e7d4f3ad 100644 --- a/python/sglang/srt/models/gemma4_mm.py +++ b/python/sglang/srt/models/gemma4_mm.py @@ -32,6 +32,7 @@ from sglang.srt.layers.layernorm import Gemma4RMSNorm from sglang.srt.layers.linear import ReplicatedLinear from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.managers.mm_utils import ( MultiModalityDataPaddingPatternMultimodalTokens, @@ -55,9 +56,6 @@ from sglang.srt.utils import add_prefix from sglang.srt.utils.hf_transformers_utils import get_processor -from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE -from sglang.srt.server_args import get_global_server_args - logger = logging.getLogger(__name__) cached_get_processor = lru_cache(get_processor) @@ -665,7 +663,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if name not in params_dict: continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) loaded_params.add(name) unloaded_params = params_dict.keys() - loaded_params From 959c42bb0bebd248561e9ebbb12cb8090e1de431 Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Fri, 20 Mar 2026 20:07:09 +0000 Subject: [PATCH 039/112] nit: modify weight loader warning msg --- python/sglang/srt/models/gemma4_causal.py | 32 ++++++++++--- python/sglang/srt/models/gemma4_mm.py | 56 ++++++++++++++++------- 2 files changed, 65 insertions(+), 23 deletions(-) diff --git a/python/sglang/srt/models/gemma4_causal.py b/python/sglang/srt/models/gemma4_causal.py index dc464e00e233..79a43b4acfc1 100644 --- a/python/sglang/srt/models/gemma4_causal.py +++ b/python/sglang/srt/models/gemma4_causal.py @@ -949,9 +949,7 @@ def _get_k_eq_v_layers(self) -> set: if not getattr(self.config, "attention_k_eq_v", False): return set() return { - i - for i, lt in enumerate(self.config.layer_types) - if lt == "full_attention" + i for i, lt in enumerate(self.config.layer_types) if lt == "full_attention" } def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): @@ -975,6 +973,12 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters()) params_dict.update(dict(self.named_buffers())) + non_persistent_buffers: Set[str] = set() + for mod_name, mod in self.named_modules(): + for buf_name in getattr(mod, "_non_persistent_buffers_set", set()): + full = f"{mod_name}.{buf_name}" if mod_name else buf_name + non_persistent_buffers.add(full) + loaded_params: Set[str] = set() for name, loaded_weight in weights: name = name.replace("model.language_model.", "model.") @@ -1042,9 +1046,25 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): loaded_params.add(name) unloaded_params = params_dict.keys() - loaded_params if unloaded_params: - logger.warning( - "Some weights are not initialized from checkpoints: %s", unloaded_params - ) + param_names = set(dict(self.named_parameters()).keys()) + buckets = { + logging.WARNING: ( + "Some weights are not initialized from checkpoints", + lambda p: p in param_names, + ), + logging.INFO: ( + "Persistent buffers not in checkpoint (using default init)", + lambda p: p not in param_names and p not in non_persistent_buffers, + ), + logging.DEBUG: ( + "Non-persistent buffers not in checkpoint (expected)", + lambda p: p in non_persistent_buffers, + ), + } + for level, (msg, pred) in buckets.items(): + names = sorted(p for p in unloaded_params if pred(p)) + if names: + logger.log(level, "%s: %s", msg, names) return loaded_params diff --git a/python/sglang/srt/models/gemma4_mm.py b/python/sglang/srt/models/gemma4_mm.py index 8482e7d4f3ad..e0ce7e4d1629 100644 --- a/python/sglang/srt/models/gemma4_mm.py +++ b/python/sglang/srt/models/gemma4_mm.py @@ -45,7 +45,6 @@ flatten_nested_list, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode -from sglang.srt.server_args import get_global_server_args from sglang.srt.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name, @@ -247,7 +246,7 @@ def prepare_attn_masks( mask_dtype: torch.dtype, ): """Prepare bidirectional attention masks for image tokens. - + Gemma 4 uses bidirectional attention for image soft tokens during prefill. Following the HF implementation, bidirectional attention is only enabled within each individual image group (same-image tokens), not across images. @@ -281,9 +280,7 @@ def prepare_attn_masks( ) # Start with causal mask bidirectional_attn_mask.fill_(1) - bidirectional_attn_mask = bidirectional_attn_mask.tril( - diagonal=prefix_len - ) + bidirectional_attn_mask = bidirectional_attn_mask.tril(diagonal=prefix_len) # Enable bidirectional attention within each image group mm_inputs = forward_batch.mm_inputs[i] @@ -291,18 +288,24 @@ def prepare_attn_masks( for mm_item in mm_inputs.mm_items: if mm_item.is_image(): for im_begin, im_end in mm_item.offsets: - # Note(kpham-sgl): We only apply bidirectional attention when the image token span - # is fully contained in the extend window. Otherwise, we silently fall back to + # Note(kpham-sgl): We only apply bidirectional attention when the image token span + # is fully contained in the extend window. Otherwise, we silently fall back to # causal attention. # FIXME(kpham-sgl): This is a hack to work around the fact that the image token span - # might not be fully contained in the extend window during chunked prefill. + # might not be fully contained in the extend window during chunked prefill. # We should fix this by properly making chunked prefill mask aware. - if im_begin >= prefix_len and im_end < prefix_len + extend_seq_len: + if ( + im_begin >= prefix_len + and im_end < prefix_len + extend_seq_len + ): bidirectional_attn_mask[ im_begin - prefix_len : im_end + 1 - prefix_len, im_begin : im_end + 1, ] = 1 - elif im_end >= prefix_len and im_begin < prefix_len + extend_seq_len: + elif ( + im_end >= prefix_len + and im_begin < prefix_len + extend_seq_len + ): split_images.append((i, im_begin, im_end)) bidirectional_attn_masks_list.append(bidirectional_attn_mask.flatten()) @@ -568,9 +571,7 @@ def _get_k_eq_v_layers(self) -> set: if not getattr(text_config, "attention_k_eq_v", False): return set() return { - i - for i, lt in enumerate(text_config.layer_types) - if lt == "full_attention" + i for i, lt in enumerate(text_config.layer_types) if lt == "full_attention" } def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): @@ -589,6 +590,12 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters()) params_dict.update(dict(self.named_buffers())) + non_persistent_buffers: Set[str] = set() + for mod_name, mod in self.named_modules(): + for buf_name in getattr(mod, "_non_persistent_buffers_set", set()): + full = f"{mod_name}.{buf_name}" if mod_name else buf_name + non_persistent_buffers.add(full) + loaded_params: Set[str] = set() for name, loaded_weight in weights: @@ -670,10 +677,25 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): loaded_params.add(name) unloaded_params = params_dict.keys() - loaded_params if unloaded_params: - logger.warning( - "Some weights are not initialized from checkpoints: %s", - unloaded_params, - ) + param_names = set(dict(self.named_parameters()).keys()) + buckets = { + logging.WARNING: ( + "Some weights are not initialized from checkpoints", + lambda p: p in param_names, + ), + logging.INFO: ( + "Persistent buffers not in checkpoint (using default init)", + lambda p: p not in param_names and p not in non_persistent_buffers, + ), + logging.DEBUG: ( + "Non-persistent buffers not in checkpoint (expected)", + lambda p: p in non_persistent_buffers, + ), + } + for level, (msg, pred) in buckets.items(): + names = sorted(p for p in unloaded_params if pred(p)) + if names: + logger.log(level, "%s: %s", msg, names) return loaded_params lora_pattern = re.compile( From ee43c619e0915899232b2a8e6f5ed74e33185008 Mon Sep 17 00:00:00 2001 From: Pengyu Chen Date: Sun, 22 Mar 2026 04:19:17 +0000 Subject: [PATCH 040/112] torch compile and tuning. --- .../kernels/fused_moe_triton/common_utils.py | 8 +- .../E=128,N=704,device_name=NVIDIA_B200.json | 146 ++++++++++++++++++ .../sglang/srt/mem_cache/swa_memory_pool.py | 4 +- 3 files changed, 156 insertions(+), 2 deletions(-) create mode 100644 python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=128,N=704,device_name=NVIDIA_B200.json diff --git a/benchmark/kernels/fused_moe_triton/common_utils.py b/benchmark/kernels/fused_moe_triton/common_utils.py index adac313a11b1..c61d7bd54500 100644 --- a/benchmark/kernels/fused_moe_triton/common_utils.py +++ b/benchmark/kernels/fused_moe_triton/common_utils.py @@ -40,7 +40,10 @@ def get_model_config( # Replace config with text_config for encoder-decoder models after getting block_shape and architecture if hasattr(config, "text_config"): + architecture = config.architectures[0] config = config.get_text_config() + else: + architecture = config.architectures[0] block_shape = None if ( @@ -62,7 +65,6 @@ def get_model_config( block_shape = [0, group_size] assert len(block_shape) == 2 - architecture = config.architectures[0] hidden_size = config.hidden_size if architecture == "DbrxForCausalLM": @@ -133,6 +135,10 @@ def get_model_config( topk = config.num_experts_per_tok intermediate_size = config.moe_intermediate_size hidden_size = getattr(config, "moe_latent_size", None) or hidden_size + elif architecture == "Gemma4ForConditionalGeneration": + E = config.num_experts // ep_size + topk = config.top_k_experts + intermediate_size = config.expert_intermediate_size else: # Default: Mixtral E = config.num_local_experts // ep_size diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=128,N=704,device_name=NVIDIA_B200.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=128,N=704,device_name=NVIDIA_B200.json new file mode 100644 index 000000000000..8ff7c371dab5 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=128,N=704,device_name=NVIDIA_B200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/mem_cache/swa_memory_pool.py b/python/sglang/srt/mem_cache/swa_memory_pool.py index d1ccb3ceace9..6c987f073677 100644 --- a/python/sglang/srt/mem_cache/swa_memory_pool.py +++ b/python/sglang/srt/mem_cache/swa_memory_pool.py @@ -301,7 +301,9 @@ def __init__( self.clear() self._kvcache = kvcache - self._kvcache.register_mapping(weakref.proxy(self.full_to_swa_index_mapping)) + # why do we need this? + # self._kvcache.register_mapping(weakref.proxy(self.full_to_swa_index_mapping)) + self._kvcache.register_mapping(self.full_to_swa_index_mapping) def available_size(self): return min( From b74a59645849376c8f99637cb358c2b5418c2643 Mon Sep 17 00:00:00 2001 From: Pengyu Chen Date: Mon, 23 Mar 2026 20:51:53 +0000 Subject: [PATCH 041/112] clean up grpc gen code --- .../sglang/srt/grpc/sglang_scheduler_pb2.py | 134 ---- .../sglang/srt/grpc/sglang_scheduler_pb2.pyi | 632 ------------------ .../srt/grpc/sglang_scheduler_pb2_grpc.py | 368 ---------- 3 files changed, 1134 deletions(-) delete mode 100644 python/sglang/srt/grpc/sglang_scheduler_pb2.py delete mode 100644 python/sglang/srt/grpc/sglang_scheduler_pb2.pyi delete mode 100644 python/sglang/srt/grpc/sglang_scheduler_pb2_grpc.py diff --git a/python/sglang/srt/grpc/sglang_scheduler_pb2.py b/python/sglang/srt/grpc/sglang_scheduler_pb2.py deleted file mode 100644 index e99981e3702b..000000000000 --- a/python/sglang/srt/grpc/sglang_scheduler_pb2.py +++ /dev/null @@ -1,134 +0,0 @@ -# -*- coding: utf-8 -*- -# Generated by the protocol buffer compiler. DO NOT EDIT! -# NO CHECKED-IN PROTOBUF GENCODE -# source: sglang_scheduler.proto -# Protobuf Python Version: 6.31.1 -"""Generated protocol buffer code.""" -from google.protobuf import descriptor as _descriptor -from google.protobuf import descriptor_pool as _descriptor_pool -from google.protobuf import runtime_version as _runtime_version -from google.protobuf import symbol_database as _symbol_database -from google.protobuf.internal import builder as _builder -_runtime_version.ValidateProtobufRuntimeVersion( - _runtime_version.Domain.PUBLIC, - 6, - 31, - 1, - '', - 'sglang_scheduler.proto' -) -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - - -from google.protobuf import timestamp_pb2 as google_dot_protobuf_dot_timestamp__pb2 -from google.protobuf import struct_pb2 as google_dot_protobuf_dot_struct__pb2 - - -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16sglang_scheduler.proto\x12\x15sglang.grpc.scheduler\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1cgoogle/protobuf/struct.proto\"\xd0\x05\n\x0eSamplingParams\x12\x13\n\x0btemperature\x18\x01 \x01(\x02\x12\r\n\x05top_p\x18\x02 \x01(\x02\x12\r\n\x05top_k\x18\x03 \x01(\x05\x12\r\n\x05min_p\x18\x04 \x01(\x02\x12\x19\n\x11\x66requency_penalty\x18\x05 \x01(\x02\x12\x18\n\x10presence_penalty\x18\x06 \x01(\x02\x12\x1a\n\x12repetition_penalty\x18\x07 \x01(\x02\x12\x1b\n\x0emax_new_tokens\x18\x08 \x01(\x05H\x01\x88\x01\x01\x12\x0c\n\x04stop\x18\t \x03(\t\x12\x16\n\x0estop_token_ids\x18\n \x03(\r\x12\x1b\n\x13skip_special_tokens\x18\x0b \x01(\x08\x12%\n\x1dspaces_between_special_tokens\x18\x0c \x01(\x08\x12\x0f\n\x05regex\x18\r \x01(\tH\x00\x12\x15\n\x0bjson_schema\x18\x0e \x01(\tH\x00\x12\x16\n\x0c\x65\x62nf_grammar\x18\x0f \x01(\tH\x00\x12\x18\n\x0estructural_tag\x18\x10 \x01(\tH\x00\x12\t\n\x01n\x18\x11 \x01(\x05\x12\x16\n\x0emin_new_tokens\x18\x12 \x01(\x05\x12\x12\n\nignore_eos\x18\x13 \x01(\x08\x12\x14\n\x0cno_stop_trim\x18\x14 \x01(\x08\x12\x1c\n\x0fstream_interval\x18\x15 \x01(\x05H\x02\x88\x01\x01\x12H\n\nlogit_bias\x18\x16 \x03(\x0b\x32\x34.sglang.grpc.scheduler.SamplingParams.LogitBiasEntry\x12.\n\rcustom_params\x18\x17 \x01(\x0b\x32\x17.google.protobuf.Struct\x1a\x30\n\x0eLogitBiasEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\x42\x0c\n\nconstraintB\x11\n\x0f_max_new_tokensB\x12\n\x10_stream_interval\"]\n\x13\x44isaggregatedParams\x12\x16\n\x0e\x62ootstrap_host\x18\x01 \x01(\t\x12\x16\n\x0e\x62ootstrap_port\x18\x02 \x01(\x05\x12\x16\n\x0e\x62ootstrap_room\x18\x03 \x01(\x05\"\xe2\x04\n\x0fGenerateRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\ttokenized\x18\x02 \x01(\x0b\x32%.sglang.grpc.scheduler.TokenizedInput\x12:\n\tmm_inputs\x18\x03 \x01(\x0b\x32\'.sglang.grpc.scheduler.MultimodalInputs\x12>\n\x0fsampling_params\x18\x04 \x01(\x0b\x32%.sglang.grpc.scheduler.SamplingParams\x12\x16\n\x0ereturn_logprob\x18\x05 \x01(\x08\x12\x19\n\x11logprob_start_len\x18\x06 \x01(\x05\x12\x18\n\x10top_logprobs_num\x18\x07 \x01(\x05\x12\x19\n\x11token_ids_logprob\x18\x08 \x03(\r\x12\x1c\n\x14return_hidden_states\x18\t \x01(\x08\x12H\n\x14\x64isaggregated_params\x18\n \x01(\x0b\x32*.sglang.grpc.scheduler.DisaggregatedParams\x12\x1e\n\x16\x63ustom_logit_processor\x18\x0b \x01(\t\x12-\n\ttimestamp\x18\x0c \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x13\n\x0blog_metrics\x18\r \x01(\x08\x12\x14\n\x0cinput_embeds\x18\x0e \x03(\x02\x12\x0f\n\x07lora_id\x18\x0f \x01(\t\x12\x1a\n\x12\x64\x61ta_parallel_rank\x18\x10 \x01(\x05\x12\x0e\n\x06stream\x18\x11 \x01(\x08\":\n\x0eTokenizedInput\x12\x15\n\roriginal_text\x18\x01 \x01(\t\x12\x11\n\tinput_ids\x18\x02 \x03(\r\"\xd3\x01\n\x10MultimodalInputs\x12\x12\n\nimage_urls\x18\x01 \x03(\t\x12\x12\n\nvideo_urls\x18\x02 \x03(\t\x12\x12\n\naudio_urls\x18\x03 \x03(\t\x12\x33\n\x12processed_features\x18\x04 \x01(\x0b\x32\x17.google.protobuf.Struct\x12\x12\n\nimage_data\x18\x05 \x03(\x0c\x12\x12\n\nvideo_data\x18\x06 \x03(\x0c\x12\x12\n\naudio_data\x18\x07 \x03(\x0c\x12\x12\n\nmodalities\x18\x08 \x03(\t\"\xe3\x01\n\x10GenerateResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12;\n\x05\x63hunk\x18\x02 \x01(\x0b\x32*.sglang.grpc.scheduler.GenerateStreamChunkH\x00\x12;\n\x08\x63omplete\x18\x03 \x01(\x0b\x32\'.sglang.grpc.scheduler.GenerateCompleteH\x00\x12\x35\n\x05\x65rror\x18\x04 \x01(\x0b\x32$.sglang.grpc.scheduler.GenerateErrorH\x00\x42\n\n\x08response\"\x95\x02\n\x13GenerateStreamChunk\x12\x11\n\ttoken_ids\x18\x01 \x03(\r\x12\x15\n\rprompt_tokens\x18\x02 \x01(\x05\x12\x19\n\x11\x63ompletion_tokens\x18\x03 \x01(\x05\x12\x15\n\rcached_tokens\x18\x04 \x01(\x05\x12>\n\x0foutput_logprobs\x18\x05 \x01(\x0b\x32%.sglang.grpc.scheduler.OutputLogProbs\x12\x15\n\rhidden_states\x18\x06 \x03(\x02\x12<\n\x0einput_logprobs\x18\x07 \x01(\x0b\x32$.sglang.grpc.scheduler.InputLogProbs\x12\r\n\x05index\x18\x08 \x01(\r\"\x9b\x03\n\x10GenerateComplete\x12\x12\n\noutput_ids\x18\x01 \x03(\r\x12\x15\n\rfinish_reason\x18\x02 \x01(\t\x12\x15\n\rprompt_tokens\x18\x03 \x01(\x05\x12\x19\n\x11\x63ompletion_tokens\x18\x04 \x01(\x05\x12\x15\n\rcached_tokens\x18\x05 \x01(\x05\x12>\n\x0foutput_logprobs\x18\x06 \x01(\x0b\x32%.sglang.grpc.scheduler.OutputLogProbs\x12>\n\x11\x61ll_hidden_states\x18\x07 \x03(\x0b\x32#.sglang.grpc.scheduler.HiddenStates\x12\x1a\n\x10matched_token_id\x18\x08 \x01(\rH\x00\x12\x1a\n\x10matched_stop_str\x18\t \x01(\tH\x00\x12<\n\x0einput_logprobs\x18\n \x01(\x0b\x32$.sglang.grpc.scheduler.InputLogProbs\x12\r\n\x05index\x18\x0b \x01(\rB\x0e\n\x0cmatched_stop\"K\n\rGenerateError\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x18\n\x10http_status_code\x18\x02 \x01(\t\x12\x0f\n\x07\x64\x65tails\x18\x03 \x01(\t\"u\n\x0eOutputLogProbs\x12\x16\n\x0etoken_logprobs\x18\x01 \x03(\x02\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x12\x38\n\x0ctop_logprobs\x18\x03 \x03(\x0b\x32\".sglang.grpc.scheduler.TopLogProbs\"\x9e\x01\n\rInputLogProbs\x12@\n\x0etoken_logprobs\x18\x01 \x03(\x0b\x32(.sglang.grpc.scheduler.InputTokenLogProb\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x12\x38\n\x0ctop_logprobs\x18\x03 \x03(\x0b\x32\".sglang.grpc.scheduler.TopLogProbs\"1\n\x11InputTokenLogProb\x12\x12\n\x05value\x18\x01 \x01(\x02H\x00\x88\x01\x01\x42\x08\n\x06_value\"0\n\x0bTopLogProbs\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\"?\n\x0cHiddenStates\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\r\n\x05layer\x18\x02 \x01(\x05\x12\x10\n\x08position\x18\x03 \x01(\x05\"\xca\x02\n\x0c\x45mbedRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\ttokenized\x18\x02 \x01(\x0b\x32%.sglang.grpc.scheduler.TokenizedInput\x12:\n\tmm_inputs\x18\x04 \x01(\x0b\x32\'.sglang.grpc.scheduler.MultimodalInputs\x12>\n\x0fsampling_params\x18\x05 \x01(\x0b\x32%.sglang.grpc.scheduler.SamplingParams\x12\x13\n\x0blog_metrics\x18\x06 \x01(\x08\x12\x16\n\x0etoken_type_ids\x18\x07 \x03(\x05\x12\x1a\n\x12\x64\x61ta_parallel_rank\x18\x08 \x01(\x05\x12\x18\n\x10is_cross_encoder\x18\t \x01(\x08\x12\r\n\x05texts\x18\n \x03(\t\"\x9d\x01\n\rEmbedResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\x08\x63omplete\x18\x02 \x01(\x0b\x32$.sglang.grpc.scheduler.EmbedCompleteH\x00\x12\x32\n\x05\x65rror\x18\x03 \x01(\x0b\x32!.sglang.grpc.scheduler.EmbedErrorH\x00\x42\n\n\x08response\"\xa3\x01\n\rEmbedComplete\x12\x11\n\tembedding\x18\x01 \x03(\x02\x12\x15\n\rprompt_tokens\x18\x02 \x01(\x05\x12\x15\n\rcached_tokens\x18\x03 \x01(\x05\x12\x15\n\rembedding_dim\x18\x04 \x01(\x05\x12:\n\x10\x62\x61tch_embeddings\x18\x05 \x03(\x0b\x32 .sglang.grpc.scheduler.Embedding\"*\n\tEmbedding\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\r\n\x05index\x18\x02 \x01(\x05\"<\n\nEmbedError\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x0c\n\x04\x63ode\x18\x02 \x01(\t\x12\x0f\n\x07\x64\x65tails\x18\x03 \x01(\t\"\x14\n\x12HealthCheckRequest\"7\n\x13HealthCheckResponse\x12\x0f\n\x07healthy\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"2\n\x0c\x41\x62ortRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06reason\x18\x02 \x01(\t\"1\n\rAbortResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"I\n\x0fLoadLoRARequest\x12\x12\n\nadapter_id\x18\x01 \x01(\t\x12\x14\n\x0c\x61\x64\x61pter_path\x18\x02 \x01(\t\x12\x0c\n\x04rank\x18\x03 \x01(\x05\"H\n\x10LoadLoRAResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x12\n\nadapter_id\x18\x02 \x01(\t\x12\x0f\n\x07message\x18\x03 \x01(\t\"\'\n\x11UnloadLoRARequest\x12\x12\n\nadapter_id\x18\x01 \x01(\t\"6\n\x12UnloadLoRAResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"w\n\x14UpdateWeightsRequest\x12\x13\n\tdisk_path\x18\x01 \x01(\tH\x00\x12\x15\n\x0btensor_data\x18\x02 \x01(\x0cH\x00\x12\x14\n\nremote_url\x18\x03 \x01(\tH\x00\x12\x13\n\x0bweight_name\x18\x04 \x01(\tB\x08\n\x06source\"9\n\x15UpdateWeightsResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"-\n\x17GetInternalStateRequest\x12\x12\n\nstate_keys\x18\x01 \x03(\t\"B\n\x18GetInternalStateResponse\x12&\n\x05state\x18\x01 \x01(\x0b\x32\x17.google.protobuf.Struct\"A\n\x17SetInternalStateRequest\x12&\n\x05state\x18\x01 \x01(\x0b\x32\x17.google.protobuf.Struct\"<\n\x18SetInternalStateResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"\x15\n\x13GetModelInfoRequest\"\xac\x03\n\x14GetModelInfoResponse\x12\x12\n\nmodel_path\x18\x01 \x01(\t\x12\x16\n\x0etokenizer_path\x18\x02 \x01(\t\x12\x15\n\ris_generation\x18\x03 \x01(\x08\x12!\n\x19preferred_sampling_params\x18\x04 \x01(\t\x12\x16\n\x0eweight_version\x18\x05 \x01(\t\x12\x19\n\x11served_model_name\x18\x06 \x01(\t\x12\x1a\n\x12max_context_length\x18\x07 \x01(\x05\x12\x12\n\nvocab_size\x18\x08 \x01(\x05\x12\x17\n\x0fsupports_vision\x18\t \x01(\x08\x12\x12\n\nmodel_type\x18\n \x01(\t\x12\x15\n\reos_token_ids\x18\x0b \x03(\x05\x12\x14\n\x0cpad_token_id\x18\x0c \x01(\x05\x12\x14\n\x0c\x62os_token_id\x18\r \x01(\x05\x12\x19\n\x11max_req_input_len\x18\x0e \x01(\x05\x12\x15\n\rarchitectures\x18\x0f \x03(\t\x12\x15\n\rid2label_json\x18\x10 \x01(\t\x12\x12\n\nnum_labels\x18\x11 \x01(\x05\"\x16\n\x14GetServerInfoRequest\"\xb7\x02\n\x15GetServerInfoResponse\x12,\n\x0bserver_args\x18\x01 \x01(\x0b\x32\x17.google.protobuf.Struct\x12/\n\x0escheduler_info\x18\x02 \x01(\x0b\x32\x17.google.protobuf.Struct\x12\x17\n\x0f\x61\x63tive_requests\x18\x03 \x01(\x05\x12\x11\n\tis_paused\x18\x04 \x01(\x08\x12\x1e\n\x16last_receive_timestamp\x18\x05 \x01(\x01\x12\x16\n\x0euptime_seconds\x18\x06 \x01(\x01\x12\x16\n\x0esglang_version\x18\x07 \x01(\t\x12\x13\n\x0bserver_type\x18\x08 \x01(\t\x12.\n\nstart_time\x18\t \x01(\x0b\x32\x1a.google.protobuf.Timestamp\"D\n\x0fGetLoadsRequest\x12\x14\n\x07\x64p_rank\x18\x01 \x01(\x05H\x00\x88\x01\x01\x12\x0f\n\x07include\x18\x02 \x03(\tB\n\n\x08_dp_rank\"\xbe\x01\n\x10GetLoadsResponse\x12\x11\n\ttimestamp\x18\x01 \x01(\t\x12\x0f\n\x07version\x18\x02 \x01(\t\x12\x15\n\rdp_rank_count\x18\x03 \x01(\x05\x12\x33\n\x05loads\x18\x04 \x03(\x0b\x32$.sglang.grpc.scheduler.SchedulerLoad\x12:\n\taggregate\x18\x05 \x01(\x0b\x32\'.sglang.grpc.scheduler.AggregateMetrics\"\x99\x05\n\rSchedulerLoad\x12\x0f\n\x07\x64p_rank\x18\x01 \x01(\x05\x12\x18\n\x10num_running_reqs\x18\x02 \x01(\x05\x12\x18\n\x10num_waiting_reqs\x18\x03 \x01(\x05\x12\x16\n\x0enum_total_reqs\x18\x04 \x01(\x05\x12\x17\n\x0fnum_used_tokens\x18\x05 \x01(\x05\x12\x1c\n\x14max_total_num_tokens\x18\x06 \x01(\x05\x12\x13\n\x0btoken_usage\x18\x07 \x01(\x01\x12\x16\n\x0egen_throughput\x18\x08 \x01(\x01\x12\x16\n\x0e\x63\x61\x63he_hit_rate\x18\t \x01(\x01\x12\x13\n\x0butilization\x18\n \x01(\x01\x12\x1c\n\x14max_running_requests\x18\x0b \x01(\x05\x12\x39\n\x06memory\x18\x0c \x01(\x0b\x32$.sglang.grpc.scheduler.MemoryMetricsH\x00\x88\x01\x01\x12\x43\n\x0bspeculative\x18\r \x01(\x0b\x32).sglang.grpc.scheduler.SpeculativeMetricsH\x01\x88\x01\x01\x12\x35\n\x04lora\x18\x0e \x01(\x0b\x32\".sglang.grpc.scheduler.LoRAMetricsH\x02\x88\x01\x01\x12I\n\x0e\x64isaggregation\x18\x0f \x01(\x0b\x32,.sglang.grpc.scheduler.DisaggregationMetricsH\x03\x88\x01\x01\x12\x38\n\x06queues\x18\x10 \x01(\x0b\x32#.sglang.grpc.scheduler.QueueMetricsH\x04\x88\x01\x01\x42\t\n\x07_memoryB\x0e\n\x0c_speculativeB\x07\n\x05_loraB\x11\n\x0f_disaggregationB\t\n\x07_queues\"a\n\rMemoryMetrics\x12\x11\n\tweight_gb\x18\x01 \x01(\x01\x12\x13\n\x0bkv_cache_gb\x18\x02 \x01(\x01\x12\x10\n\x08graph_gb\x18\x03 \x01(\x01\x12\x16\n\x0etoken_capacity\x18\x04 \x01(\x05\"@\n\x12SpeculativeMetrics\x12\x15\n\raccept_length\x18\x01 \x01(\x01\x12\x13\n\x0b\x61\x63\x63\x65pt_rate\x18\x02 \x01(\x01\"K\n\x0bLoRAMetrics\x12\x12\n\nslots_used\x18\x01 \x01(\x05\x12\x13\n\x0bslots_total\x18\x02 \x01(\x05\x12\x13\n\x0butilization\x18\x03 \x01(\x01\"\x9c\x02\n\x15\x44isaggregationMetrics\x12\x0c\n\x04mode\x18\x01 \x01(\t\x12#\n\x1bprefill_prealloc_queue_reqs\x18\x02 \x01(\x05\x12#\n\x1bprefill_inflight_queue_reqs\x18\x03 \x01(\x05\x12\"\n\x1a\x64\x65\x63ode_prealloc_queue_reqs\x18\x04 \x01(\x05\x12\"\n\x1a\x64\x65\x63ode_transfer_queue_reqs\x18\x05 \x01(\x05\x12#\n\x1b\x64\x65\x63ode_retracted_queue_reqs\x18\x06 \x01(\x05\x12\x1e\n\x16kv_transfer_speed_gb_s\x18\x07 \x01(\x01\x12\x1e\n\x16kv_transfer_latency_ms\x18\x08 \x01(\x01\"S\n\x0cQueueMetrics\x12\x0f\n\x07waiting\x18\x01 \x01(\x05\x12\x0f\n\x07grammar\x18\x02 \x01(\x05\x12\x0e\n\x06paused\x18\x03 \x01(\x05\x12\x11\n\tretracted\x18\x04 \x01(\x05\"\xa8\x01\n\x10\x41ggregateMetrics\x12\x1a\n\x12total_running_reqs\x18\x01 \x01(\x05\x12\x1a\n\x12total_waiting_reqs\x18\x02 \x01(\x05\x12\x12\n\ntotal_reqs\x18\x03 \x01(\x05\x12\x17\n\x0f\x61vg_token_usage\x18\x04 \x01(\x01\x12\x16\n\x0e\x61vg_throughput\x18\x05 \x01(\x01\x12\x17\n\x0f\x61vg_utilization\x18\x06 \x01(\x01\x32\xb0\x05\n\x0fSglangScheduler\x12]\n\x08Generate\x12&.sglang.grpc.scheduler.GenerateRequest\x1a\'.sglang.grpc.scheduler.GenerateResponse0\x01\x12R\n\x05\x45mbed\x12#.sglang.grpc.scheduler.EmbedRequest\x1a$.sglang.grpc.scheduler.EmbedResponse\x12\x64\n\x0bHealthCheck\x12).sglang.grpc.scheduler.HealthCheckRequest\x1a*.sglang.grpc.scheduler.HealthCheckResponse\x12R\n\x05\x41\x62ort\x12#.sglang.grpc.scheduler.AbortRequest\x1a$.sglang.grpc.scheduler.AbortResponse\x12g\n\x0cGetModelInfo\x12*.sglang.grpc.scheduler.GetModelInfoRequest\x1a+.sglang.grpc.scheduler.GetModelInfoResponse\x12j\n\rGetServerInfo\x12+.sglang.grpc.scheduler.GetServerInfoRequest\x1a,.sglang.grpc.scheduler.GetServerInfoResponse\x12[\n\x08GetLoads\x12&.sglang.grpc.scheduler.GetLoadsRequest\x1a\'.sglang.grpc.scheduler.GetLoadsResponseb\x06proto3') - -_globals = globals() -_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'sglang_scheduler_pb2', _globals) -if not _descriptor._USE_C_DESCRIPTORS: - DESCRIPTOR._loaded_options = None - _globals['_SAMPLINGPARAMS_LOGITBIASENTRY']._loaded_options = None - _globals['_SAMPLINGPARAMS_LOGITBIASENTRY']._serialized_options = b'8\001' - _globals['_SAMPLINGPARAMS']._serialized_start=113 - _globals['_SAMPLINGPARAMS']._serialized_end=833 - _globals['_SAMPLINGPARAMS_LOGITBIASENTRY']._serialized_start=732 - _globals['_SAMPLINGPARAMS_LOGITBIASENTRY']._serialized_end=780 - _globals['_DISAGGREGATEDPARAMS']._serialized_start=835 - _globals['_DISAGGREGATEDPARAMS']._serialized_end=928 - _globals['_GENERATEREQUEST']._serialized_start=931 - _globals['_GENERATEREQUEST']._serialized_end=1541 - _globals['_TOKENIZEDINPUT']._serialized_start=1543 - _globals['_TOKENIZEDINPUT']._serialized_end=1601 - _globals['_MULTIMODALINPUTS']._serialized_start=1604 - _globals['_MULTIMODALINPUTS']._serialized_end=1815 - _globals['_GENERATERESPONSE']._serialized_start=1818 - _globals['_GENERATERESPONSE']._serialized_end=2045 - _globals['_GENERATESTREAMCHUNK']._serialized_start=2048 - _globals['_GENERATESTREAMCHUNK']._serialized_end=2325 - _globals['_GENERATECOMPLETE']._serialized_start=2328 - _globals['_GENERATECOMPLETE']._serialized_end=2739 - _globals['_GENERATEERROR']._serialized_start=2741 - _globals['_GENERATEERROR']._serialized_end=2816 - _globals['_OUTPUTLOGPROBS']._serialized_start=2818 - _globals['_OUTPUTLOGPROBS']._serialized_end=2935 - _globals['_INPUTLOGPROBS']._serialized_start=2938 - _globals['_INPUTLOGPROBS']._serialized_end=3096 - _globals['_INPUTTOKENLOGPROB']._serialized_start=3098 - _globals['_INPUTTOKENLOGPROB']._serialized_end=3147 - _globals['_TOPLOGPROBS']._serialized_start=3149 - _globals['_TOPLOGPROBS']._serialized_end=3197 - _globals['_HIDDENSTATES']._serialized_start=3199 - _globals['_HIDDENSTATES']._serialized_end=3262 - _globals['_EMBEDREQUEST']._serialized_start=3265 - _globals['_EMBEDREQUEST']._serialized_end=3595 - _globals['_EMBEDRESPONSE']._serialized_start=3598 - _globals['_EMBEDRESPONSE']._serialized_end=3755 - _globals['_EMBEDCOMPLETE']._serialized_start=3758 - _globals['_EMBEDCOMPLETE']._serialized_end=3921 - _globals['_EMBEDDING']._serialized_start=3923 - _globals['_EMBEDDING']._serialized_end=3965 - _globals['_EMBEDERROR']._serialized_start=3967 - _globals['_EMBEDERROR']._serialized_end=4027 - _globals['_HEALTHCHECKREQUEST']._serialized_start=4029 - _globals['_HEALTHCHECKREQUEST']._serialized_end=4049 - _globals['_HEALTHCHECKRESPONSE']._serialized_start=4051 - _globals['_HEALTHCHECKRESPONSE']._serialized_end=4106 - _globals['_ABORTREQUEST']._serialized_start=4108 - _globals['_ABORTREQUEST']._serialized_end=4158 - _globals['_ABORTRESPONSE']._serialized_start=4160 - _globals['_ABORTRESPONSE']._serialized_end=4209 - _globals['_LOADLORAREQUEST']._serialized_start=4211 - _globals['_LOADLORAREQUEST']._serialized_end=4284 - _globals['_LOADLORARESPONSE']._serialized_start=4286 - _globals['_LOADLORARESPONSE']._serialized_end=4358 - _globals['_UNLOADLORAREQUEST']._serialized_start=4360 - _globals['_UNLOADLORAREQUEST']._serialized_end=4399 - _globals['_UNLOADLORARESPONSE']._serialized_start=4401 - _globals['_UNLOADLORARESPONSE']._serialized_end=4455 - _globals['_UPDATEWEIGHTSREQUEST']._serialized_start=4457 - _globals['_UPDATEWEIGHTSREQUEST']._serialized_end=4576 - _globals['_UPDATEWEIGHTSRESPONSE']._serialized_start=4578 - _globals['_UPDATEWEIGHTSRESPONSE']._serialized_end=4635 - _globals['_GETINTERNALSTATEREQUEST']._serialized_start=4637 - _globals['_GETINTERNALSTATEREQUEST']._serialized_end=4682 - _globals['_GETINTERNALSTATERESPONSE']._serialized_start=4684 - _globals['_GETINTERNALSTATERESPONSE']._serialized_end=4750 - _globals['_SETINTERNALSTATEREQUEST']._serialized_start=4752 - _globals['_SETINTERNALSTATEREQUEST']._serialized_end=4817 - _globals['_SETINTERNALSTATERESPONSE']._serialized_start=4819 - _globals['_SETINTERNALSTATERESPONSE']._serialized_end=4879 - _globals['_GETMODELINFOREQUEST']._serialized_start=4881 - _globals['_GETMODELINFOREQUEST']._serialized_end=4902 - _globals['_GETMODELINFORESPONSE']._serialized_start=4905 - _globals['_GETMODELINFORESPONSE']._serialized_end=5333 - _globals['_GETSERVERINFOREQUEST']._serialized_start=5335 - _globals['_GETSERVERINFOREQUEST']._serialized_end=5357 - _globals['_GETSERVERINFORESPONSE']._serialized_start=5360 - _globals['_GETSERVERINFORESPONSE']._serialized_end=5671 - _globals['_GETLOADSREQUEST']._serialized_start=5673 - _globals['_GETLOADSREQUEST']._serialized_end=5741 - _globals['_GETLOADSRESPONSE']._serialized_start=5744 - _globals['_GETLOADSRESPONSE']._serialized_end=5934 - _globals['_SCHEDULERLOAD']._serialized_start=5937 - _globals['_SCHEDULERLOAD']._serialized_end=6602 - _globals['_MEMORYMETRICS']._serialized_start=6604 - _globals['_MEMORYMETRICS']._serialized_end=6701 - _globals['_SPECULATIVEMETRICS']._serialized_start=6703 - _globals['_SPECULATIVEMETRICS']._serialized_end=6767 - _globals['_LORAMETRICS']._serialized_start=6769 - _globals['_LORAMETRICS']._serialized_end=6844 - _globals['_DISAGGREGATIONMETRICS']._serialized_start=6847 - _globals['_DISAGGREGATIONMETRICS']._serialized_end=7131 - _globals['_QUEUEMETRICS']._serialized_start=7133 - _globals['_QUEUEMETRICS']._serialized_end=7216 - _globals['_AGGREGATEMETRICS']._serialized_start=7219 - _globals['_AGGREGATEMETRICS']._serialized_end=7387 - _globals['_SGLANGSCHEDULER']._serialized_start=7390 - _globals['_SGLANGSCHEDULER']._serialized_end=8078 -# @@protoc_insertion_point(module_scope) diff --git a/python/sglang/srt/grpc/sglang_scheduler_pb2.pyi b/python/sglang/srt/grpc/sglang_scheduler_pb2.pyi deleted file mode 100644 index 8d3e979aa4ad..000000000000 --- a/python/sglang/srt/grpc/sglang_scheduler_pb2.pyi +++ /dev/null @@ -1,632 +0,0 @@ -import datetime - -from google.protobuf import timestamp_pb2 as _timestamp_pb2 -from google.protobuf import struct_pb2 as _struct_pb2 -from google.protobuf.internal import containers as _containers -from google.protobuf import descriptor as _descriptor -from google.protobuf import message as _message -from collections.abc import Iterable as _Iterable, Mapping as _Mapping -from typing import ClassVar as _ClassVar, Optional as _Optional, Union as _Union - -DESCRIPTOR: _descriptor.FileDescriptor - -class SamplingParams(_message.Message): - __slots__ = ("temperature", "top_p", "top_k", "min_p", "frequency_penalty", "presence_penalty", "repetition_penalty", "max_new_tokens", "stop", "stop_token_ids", "skip_special_tokens", "spaces_between_special_tokens", "regex", "json_schema", "ebnf_grammar", "structural_tag", "n", "min_new_tokens", "ignore_eos", "no_stop_trim", "stream_interval", "logit_bias", "custom_params") - class LogitBiasEntry(_message.Message): - __slots__ = ("key", "value") - KEY_FIELD_NUMBER: _ClassVar[int] - VALUE_FIELD_NUMBER: _ClassVar[int] - key: str - value: float - def __init__(self, key: _Optional[str] = ..., value: _Optional[float] = ...) -> None: ... - TEMPERATURE_FIELD_NUMBER: _ClassVar[int] - TOP_P_FIELD_NUMBER: _ClassVar[int] - TOP_K_FIELD_NUMBER: _ClassVar[int] - MIN_P_FIELD_NUMBER: _ClassVar[int] - FREQUENCY_PENALTY_FIELD_NUMBER: _ClassVar[int] - PRESENCE_PENALTY_FIELD_NUMBER: _ClassVar[int] - REPETITION_PENALTY_FIELD_NUMBER: _ClassVar[int] - MAX_NEW_TOKENS_FIELD_NUMBER: _ClassVar[int] - STOP_FIELD_NUMBER: _ClassVar[int] - STOP_TOKEN_IDS_FIELD_NUMBER: _ClassVar[int] - SKIP_SPECIAL_TOKENS_FIELD_NUMBER: _ClassVar[int] - SPACES_BETWEEN_SPECIAL_TOKENS_FIELD_NUMBER: _ClassVar[int] - REGEX_FIELD_NUMBER: _ClassVar[int] - JSON_SCHEMA_FIELD_NUMBER: _ClassVar[int] - EBNF_GRAMMAR_FIELD_NUMBER: _ClassVar[int] - STRUCTURAL_TAG_FIELD_NUMBER: _ClassVar[int] - N_FIELD_NUMBER: _ClassVar[int] - MIN_NEW_TOKENS_FIELD_NUMBER: _ClassVar[int] - IGNORE_EOS_FIELD_NUMBER: _ClassVar[int] - NO_STOP_TRIM_FIELD_NUMBER: _ClassVar[int] - STREAM_INTERVAL_FIELD_NUMBER: _ClassVar[int] - LOGIT_BIAS_FIELD_NUMBER: _ClassVar[int] - CUSTOM_PARAMS_FIELD_NUMBER: _ClassVar[int] - temperature: float - top_p: float - top_k: int - min_p: float - frequency_penalty: float - presence_penalty: float - repetition_penalty: float - max_new_tokens: int - stop: _containers.RepeatedScalarFieldContainer[str] - stop_token_ids: _containers.RepeatedScalarFieldContainer[int] - skip_special_tokens: bool - spaces_between_special_tokens: bool - regex: str - json_schema: str - ebnf_grammar: str - structural_tag: str - n: int - min_new_tokens: int - ignore_eos: bool - no_stop_trim: bool - stream_interval: int - logit_bias: _containers.ScalarMap[str, float] - custom_params: _struct_pb2.Struct - def __init__(self, temperature: _Optional[float] = ..., top_p: _Optional[float] = ..., top_k: _Optional[int] = ..., min_p: _Optional[float] = ..., frequency_penalty: _Optional[float] = ..., presence_penalty: _Optional[float] = ..., repetition_penalty: _Optional[float] = ..., max_new_tokens: _Optional[int] = ..., stop: _Optional[_Iterable[str]] = ..., stop_token_ids: _Optional[_Iterable[int]] = ..., skip_special_tokens: bool = ..., spaces_between_special_tokens: bool = ..., regex: _Optional[str] = ..., json_schema: _Optional[str] = ..., ebnf_grammar: _Optional[str] = ..., structural_tag: _Optional[str] = ..., n: _Optional[int] = ..., min_new_tokens: _Optional[int] = ..., ignore_eos: bool = ..., no_stop_trim: bool = ..., stream_interval: _Optional[int] = ..., logit_bias: _Optional[_Mapping[str, float]] = ..., custom_params: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ... - -class DisaggregatedParams(_message.Message): - __slots__ = ("bootstrap_host", "bootstrap_port", "bootstrap_room") - BOOTSTRAP_HOST_FIELD_NUMBER: _ClassVar[int] - BOOTSTRAP_PORT_FIELD_NUMBER: _ClassVar[int] - BOOTSTRAP_ROOM_FIELD_NUMBER: _ClassVar[int] - bootstrap_host: str - bootstrap_port: int - bootstrap_room: int - def __init__(self, bootstrap_host: _Optional[str] = ..., bootstrap_port: _Optional[int] = ..., bootstrap_room: _Optional[int] = ...) -> None: ... - -class GenerateRequest(_message.Message): - __slots__ = ("request_id", "tokenized", "mm_inputs", "sampling_params", "return_logprob", "logprob_start_len", "top_logprobs_num", "token_ids_logprob", "return_hidden_states", "disaggregated_params", "custom_logit_processor", "timestamp", "log_metrics", "input_embeds", "lora_id", "data_parallel_rank", "stream") - REQUEST_ID_FIELD_NUMBER: _ClassVar[int] - TOKENIZED_FIELD_NUMBER: _ClassVar[int] - MM_INPUTS_FIELD_NUMBER: _ClassVar[int] - SAMPLING_PARAMS_FIELD_NUMBER: _ClassVar[int] - RETURN_LOGPROB_FIELD_NUMBER: _ClassVar[int] - LOGPROB_START_LEN_FIELD_NUMBER: _ClassVar[int] - TOP_LOGPROBS_NUM_FIELD_NUMBER: _ClassVar[int] - TOKEN_IDS_LOGPROB_FIELD_NUMBER: _ClassVar[int] - RETURN_HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int] - DISAGGREGATED_PARAMS_FIELD_NUMBER: _ClassVar[int] - CUSTOM_LOGIT_PROCESSOR_FIELD_NUMBER: _ClassVar[int] - TIMESTAMP_FIELD_NUMBER: _ClassVar[int] - LOG_METRICS_FIELD_NUMBER: _ClassVar[int] - INPUT_EMBEDS_FIELD_NUMBER: _ClassVar[int] - LORA_ID_FIELD_NUMBER: _ClassVar[int] - DATA_PARALLEL_RANK_FIELD_NUMBER: _ClassVar[int] - STREAM_FIELD_NUMBER: _ClassVar[int] - request_id: str - tokenized: TokenizedInput - mm_inputs: MultimodalInputs - sampling_params: SamplingParams - return_logprob: bool - logprob_start_len: int - top_logprobs_num: int - token_ids_logprob: _containers.RepeatedScalarFieldContainer[int] - return_hidden_states: bool - disaggregated_params: DisaggregatedParams - custom_logit_processor: str - timestamp: _timestamp_pb2.Timestamp - log_metrics: bool - input_embeds: _containers.RepeatedScalarFieldContainer[float] - lora_id: str - data_parallel_rank: int - stream: bool - def __init__(self, request_id: _Optional[str] = ..., tokenized: _Optional[_Union[TokenizedInput, _Mapping]] = ..., mm_inputs: _Optional[_Union[MultimodalInputs, _Mapping]] = ..., sampling_params: _Optional[_Union[SamplingParams, _Mapping]] = ..., return_logprob: bool = ..., logprob_start_len: _Optional[int] = ..., top_logprobs_num: _Optional[int] = ..., token_ids_logprob: _Optional[_Iterable[int]] = ..., return_hidden_states: bool = ..., disaggregated_params: _Optional[_Union[DisaggregatedParams, _Mapping]] = ..., custom_logit_processor: _Optional[str] = ..., timestamp: _Optional[_Union[datetime.datetime, _timestamp_pb2.Timestamp, _Mapping]] = ..., log_metrics: bool = ..., input_embeds: _Optional[_Iterable[float]] = ..., lora_id: _Optional[str] = ..., data_parallel_rank: _Optional[int] = ..., stream: bool = ...) -> None: ... - -class TokenizedInput(_message.Message): - __slots__ = ("original_text", "input_ids") - ORIGINAL_TEXT_FIELD_NUMBER: _ClassVar[int] - INPUT_IDS_FIELD_NUMBER: _ClassVar[int] - original_text: str - input_ids: _containers.RepeatedScalarFieldContainer[int] - def __init__(self, original_text: _Optional[str] = ..., input_ids: _Optional[_Iterable[int]] = ...) -> None: ... - -class MultimodalInputs(_message.Message): - __slots__ = ("image_urls", "video_urls", "audio_urls", "processed_features", "image_data", "video_data", "audio_data", "modalities") - IMAGE_URLS_FIELD_NUMBER: _ClassVar[int] - VIDEO_URLS_FIELD_NUMBER: _ClassVar[int] - AUDIO_URLS_FIELD_NUMBER: _ClassVar[int] - PROCESSED_FEATURES_FIELD_NUMBER: _ClassVar[int] - IMAGE_DATA_FIELD_NUMBER: _ClassVar[int] - VIDEO_DATA_FIELD_NUMBER: _ClassVar[int] - AUDIO_DATA_FIELD_NUMBER: _ClassVar[int] - MODALITIES_FIELD_NUMBER: _ClassVar[int] - image_urls: _containers.RepeatedScalarFieldContainer[str] - video_urls: _containers.RepeatedScalarFieldContainer[str] - audio_urls: _containers.RepeatedScalarFieldContainer[str] - processed_features: _struct_pb2.Struct - image_data: _containers.RepeatedScalarFieldContainer[bytes] - video_data: _containers.RepeatedScalarFieldContainer[bytes] - audio_data: _containers.RepeatedScalarFieldContainer[bytes] - modalities: _containers.RepeatedScalarFieldContainer[str] - def __init__(self, image_urls: _Optional[_Iterable[str]] = ..., video_urls: _Optional[_Iterable[str]] = ..., audio_urls: _Optional[_Iterable[str]] = ..., processed_features: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ..., image_data: _Optional[_Iterable[bytes]] = ..., video_data: _Optional[_Iterable[bytes]] = ..., audio_data: _Optional[_Iterable[bytes]] = ..., modalities: _Optional[_Iterable[str]] = ...) -> None: ... - -class GenerateResponse(_message.Message): - __slots__ = ("request_id", "chunk", "complete", "error") - REQUEST_ID_FIELD_NUMBER: _ClassVar[int] - CHUNK_FIELD_NUMBER: _ClassVar[int] - COMPLETE_FIELD_NUMBER: _ClassVar[int] - ERROR_FIELD_NUMBER: _ClassVar[int] - request_id: str - chunk: GenerateStreamChunk - complete: GenerateComplete - error: GenerateError - def __init__(self, request_id: _Optional[str] = ..., chunk: _Optional[_Union[GenerateStreamChunk, _Mapping]] = ..., complete: _Optional[_Union[GenerateComplete, _Mapping]] = ..., error: _Optional[_Union[GenerateError, _Mapping]] = ...) -> None: ... - -class GenerateStreamChunk(_message.Message): - __slots__ = ("token_ids", "prompt_tokens", "completion_tokens", "cached_tokens", "output_logprobs", "hidden_states", "input_logprobs", "index") - TOKEN_IDS_FIELD_NUMBER: _ClassVar[int] - PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int] - COMPLETION_TOKENS_FIELD_NUMBER: _ClassVar[int] - CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int] - OUTPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int] - HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int] - INPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int] - INDEX_FIELD_NUMBER: _ClassVar[int] - token_ids: _containers.RepeatedScalarFieldContainer[int] - prompt_tokens: int - completion_tokens: int - cached_tokens: int - output_logprobs: OutputLogProbs - hidden_states: _containers.RepeatedScalarFieldContainer[float] - input_logprobs: InputLogProbs - index: int - def __init__(self, token_ids: _Optional[_Iterable[int]] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., output_logprobs: _Optional[_Union[OutputLogProbs, _Mapping]] = ..., hidden_states: _Optional[_Iterable[float]] = ..., input_logprobs: _Optional[_Union[InputLogProbs, _Mapping]] = ..., index: _Optional[int] = ...) -> None: ... - -class GenerateComplete(_message.Message): - __slots__ = ("output_ids", "finish_reason", "prompt_tokens", "completion_tokens", "cached_tokens", "output_logprobs", "all_hidden_states", "matched_token_id", "matched_stop_str", "input_logprobs", "index") - OUTPUT_IDS_FIELD_NUMBER: _ClassVar[int] - FINISH_REASON_FIELD_NUMBER: _ClassVar[int] - PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int] - COMPLETION_TOKENS_FIELD_NUMBER: _ClassVar[int] - CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int] - OUTPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int] - ALL_HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int] - MATCHED_TOKEN_ID_FIELD_NUMBER: _ClassVar[int] - MATCHED_STOP_STR_FIELD_NUMBER: _ClassVar[int] - INPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int] - INDEX_FIELD_NUMBER: _ClassVar[int] - output_ids: _containers.RepeatedScalarFieldContainer[int] - finish_reason: str - prompt_tokens: int - completion_tokens: int - cached_tokens: int - output_logprobs: OutputLogProbs - all_hidden_states: _containers.RepeatedCompositeFieldContainer[HiddenStates] - matched_token_id: int - matched_stop_str: str - input_logprobs: InputLogProbs - index: int - def __init__(self, output_ids: _Optional[_Iterable[int]] = ..., finish_reason: _Optional[str] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., output_logprobs: _Optional[_Union[OutputLogProbs, _Mapping]] = ..., all_hidden_states: _Optional[_Iterable[_Union[HiddenStates, _Mapping]]] = ..., matched_token_id: _Optional[int] = ..., matched_stop_str: _Optional[str] = ..., input_logprobs: _Optional[_Union[InputLogProbs, _Mapping]] = ..., index: _Optional[int] = ...) -> None: ... - -class GenerateError(_message.Message): - __slots__ = ("message", "http_status_code", "details") - MESSAGE_FIELD_NUMBER: _ClassVar[int] - HTTP_STATUS_CODE_FIELD_NUMBER: _ClassVar[int] - DETAILS_FIELD_NUMBER: _ClassVar[int] - message: str - http_status_code: str - details: str - def __init__(self, message: _Optional[str] = ..., http_status_code: _Optional[str] = ..., details: _Optional[str] = ...) -> None: ... - -class OutputLogProbs(_message.Message): - __slots__ = ("token_logprobs", "token_ids", "top_logprobs") - TOKEN_LOGPROBS_FIELD_NUMBER: _ClassVar[int] - TOKEN_IDS_FIELD_NUMBER: _ClassVar[int] - TOP_LOGPROBS_FIELD_NUMBER: _ClassVar[int] - token_logprobs: _containers.RepeatedScalarFieldContainer[float] - token_ids: _containers.RepeatedScalarFieldContainer[int] - top_logprobs: _containers.RepeatedCompositeFieldContainer[TopLogProbs] - def __init__(self, token_logprobs: _Optional[_Iterable[float]] = ..., token_ids: _Optional[_Iterable[int]] = ..., top_logprobs: _Optional[_Iterable[_Union[TopLogProbs, _Mapping]]] = ...) -> None: ... - -class InputLogProbs(_message.Message): - __slots__ = ("token_logprobs", "token_ids", "top_logprobs") - TOKEN_LOGPROBS_FIELD_NUMBER: _ClassVar[int] - TOKEN_IDS_FIELD_NUMBER: _ClassVar[int] - TOP_LOGPROBS_FIELD_NUMBER: _ClassVar[int] - token_logprobs: _containers.RepeatedCompositeFieldContainer[InputTokenLogProb] - token_ids: _containers.RepeatedScalarFieldContainer[int] - top_logprobs: _containers.RepeatedCompositeFieldContainer[TopLogProbs] - def __init__(self, token_logprobs: _Optional[_Iterable[_Union[InputTokenLogProb, _Mapping]]] = ..., token_ids: _Optional[_Iterable[int]] = ..., top_logprobs: _Optional[_Iterable[_Union[TopLogProbs, _Mapping]]] = ...) -> None: ... - -class InputTokenLogProb(_message.Message): - __slots__ = ("value",) - VALUE_FIELD_NUMBER: _ClassVar[int] - value: float - def __init__(self, value: _Optional[float] = ...) -> None: ... - -class TopLogProbs(_message.Message): - __slots__ = ("values", "token_ids") - VALUES_FIELD_NUMBER: _ClassVar[int] - TOKEN_IDS_FIELD_NUMBER: _ClassVar[int] - values: _containers.RepeatedScalarFieldContainer[float] - token_ids: _containers.RepeatedScalarFieldContainer[int] - def __init__(self, values: _Optional[_Iterable[float]] = ..., token_ids: _Optional[_Iterable[int]] = ...) -> None: ... - -class HiddenStates(_message.Message): - __slots__ = ("values", "layer", "position") - VALUES_FIELD_NUMBER: _ClassVar[int] - LAYER_FIELD_NUMBER: _ClassVar[int] - POSITION_FIELD_NUMBER: _ClassVar[int] - values: _containers.RepeatedScalarFieldContainer[float] - layer: int - position: int - def __init__(self, values: _Optional[_Iterable[float]] = ..., layer: _Optional[int] = ..., position: _Optional[int] = ...) -> None: ... - -class EmbedRequest(_message.Message): - __slots__ = ("request_id", "tokenized", "mm_inputs", "sampling_params", "log_metrics", "token_type_ids", "data_parallel_rank", "is_cross_encoder", "texts") - REQUEST_ID_FIELD_NUMBER: _ClassVar[int] - TOKENIZED_FIELD_NUMBER: _ClassVar[int] - MM_INPUTS_FIELD_NUMBER: _ClassVar[int] - SAMPLING_PARAMS_FIELD_NUMBER: _ClassVar[int] - LOG_METRICS_FIELD_NUMBER: _ClassVar[int] - TOKEN_TYPE_IDS_FIELD_NUMBER: _ClassVar[int] - DATA_PARALLEL_RANK_FIELD_NUMBER: _ClassVar[int] - IS_CROSS_ENCODER_FIELD_NUMBER: _ClassVar[int] - TEXTS_FIELD_NUMBER: _ClassVar[int] - request_id: str - tokenized: TokenizedInput - mm_inputs: MultimodalInputs - sampling_params: SamplingParams - log_metrics: bool - token_type_ids: _containers.RepeatedScalarFieldContainer[int] - data_parallel_rank: int - is_cross_encoder: bool - texts: _containers.RepeatedScalarFieldContainer[str] - def __init__(self, request_id: _Optional[str] = ..., tokenized: _Optional[_Union[TokenizedInput, _Mapping]] = ..., mm_inputs: _Optional[_Union[MultimodalInputs, _Mapping]] = ..., sampling_params: _Optional[_Union[SamplingParams, _Mapping]] = ..., log_metrics: bool = ..., token_type_ids: _Optional[_Iterable[int]] = ..., data_parallel_rank: _Optional[int] = ..., is_cross_encoder: bool = ..., texts: _Optional[_Iterable[str]] = ...) -> None: ... - -class EmbedResponse(_message.Message): - __slots__ = ("request_id", "complete", "error") - REQUEST_ID_FIELD_NUMBER: _ClassVar[int] - COMPLETE_FIELD_NUMBER: _ClassVar[int] - ERROR_FIELD_NUMBER: _ClassVar[int] - request_id: str - complete: EmbedComplete - error: EmbedError - def __init__(self, request_id: _Optional[str] = ..., complete: _Optional[_Union[EmbedComplete, _Mapping]] = ..., error: _Optional[_Union[EmbedError, _Mapping]] = ...) -> None: ... - -class EmbedComplete(_message.Message): - __slots__ = ("embedding", "prompt_tokens", "cached_tokens", "embedding_dim", "batch_embeddings") - EMBEDDING_FIELD_NUMBER: _ClassVar[int] - PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int] - CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int] - EMBEDDING_DIM_FIELD_NUMBER: _ClassVar[int] - BATCH_EMBEDDINGS_FIELD_NUMBER: _ClassVar[int] - embedding: _containers.RepeatedScalarFieldContainer[float] - prompt_tokens: int - cached_tokens: int - embedding_dim: int - batch_embeddings: _containers.RepeatedCompositeFieldContainer[Embedding] - def __init__(self, embedding: _Optional[_Iterable[float]] = ..., prompt_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., embedding_dim: _Optional[int] = ..., batch_embeddings: _Optional[_Iterable[_Union[Embedding, _Mapping]]] = ...) -> None: ... - -class Embedding(_message.Message): - __slots__ = ("values", "index") - VALUES_FIELD_NUMBER: _ClassVar[int] - INDEX_FIELD_NUMBER: _ClassVar[int] - values: _containers.RepeatedScalarFieldContainer[float] - index: int - def __init__(self, values: _Optional[_Iterable[float]] = ..., index: _Optional[int] = ...) -> None: ... - -class EmbedError(_message.Message): - __slots__ = ("message", "code", "details") - MESSAGE_FIELD_NUMBER: _ClassVar[int] - CODE_FIELD_NUMBER: _ClassVar[int] - DETAILS_FIELD_NUMBER: _ClassVar[int] - message: str - code: str - details: str - def __init__(self, message: _Optional[str] = ..., code: _Optional[str] = ..., details: _Optional[str] = ...) -> None: ... - -class HealthCheckRequest(_message.Message): - __slots__ = () - def __init__(self) -> None: ... - -class HealthCheckResponse(_message.Message): - __slots__ = ("healthy", "message") - HEALTHY_FIELD_NUMBER: _ClassVar[int] - MESSAGE_FIELD_NUMBER: _ClassVar[int] - healthy: bool - message: str - def __init__(self, healthy: bool = ..., message: _Optional[str] = ...) -> None: ... - -class AbortRequest(_message.Message): - __slots__ = ("request_id", "reason") - REQUEST_ID_FIELD_NUMBER: _ClassVar[int] - REASON_FIELD_NUMBER: _ClassVar[int] - request_id: str - reason: str - def __init__(self, request_id: _Optional[str] = ..., reason: _Optional[str] = ...) -> None: ... - -class AbortResponse(_message.Message): - __slots__ = ("success", "message") - SUCCESS_FIELD_NUMBER: _ClassVar[int] - MESSAGE_FIELD_NUMBER: _ClassVar[int] - success: bool - message: str - def __init__(self, success: bool = ..., message: _Optional[str] = ...) -> None: ... - -class LoadLoRARequest(_message.Message): - __slots__ = ("adapter_id", "adapter_path", "rank") - ADAPTER_ID_FIELD_NUMBER: _ClassVar[int] - ADAPTER_PATH_FIELD_NUMBER: _ClassVar[int] - RANK_FIELD_NUMBER: _ClassVar[int] - adapter_id: str - adapter_path: str - rank: int - def __init__(self, adapter_id: _Optional[str] = ..., adapter_path: _Optional[str] = ..., rank: _Optional[int] = ...) -> None: ... - -class LoadLoRAResponse(_message.Message): - __slots__ = ("success", "adapter_id", "message") - SUCCESS_FIELD_NUMBER: _ClassVar[int] - ADAPTER_ID_FIELD_NUMBER: _ClassVar[int] - MESSAGE_FIELD_NUMBER: _ClassVar[int] - success: bool - adapter_id: str - message: str - def __init__(self, success: bool = ..., adapter_id: _Optional[str] = ..., message: _Optional[str] = ...) -> None: ... - -class UnloadLoRARequest(_message.Message): - __slots__ = ("adapter_id",) - ADAPTER_ID_FIELD_NUMBER: _ClassVar[int] - adapter_id: str - def __init__(self, adapter_id: _Optional[str] = ...) -> None: ... - -class UnloadLoRAResponse(_message.Message): - __slots__ = ("success", "message") - SUCCESS_FIELD_NUMBER: _ClassVar[int] - MESSAGE_FIELD_NUMBER: _ClassVar[int] - success: bool - message: str - def __init__(self, success: bool = ..., message: _Optional[str] = ...) -> None: ... - -class UpdateWeightsRequest(_message.Message): - __slots__ = ("disk_path", "tensor_data", "remote_url", "weight_name") - DISK_PATH_FIELD_NUMBER: _ClassVar[int] - TENSOR_DATA_FIELD_NUMBER: _ClassVar[int] - REMOTE_URL_FIELD_NUMBER: _ClassVar[int] - WEIGHT_NAME_FIELD_NUMBER: _ClassVar[int] - disk_path: str - tensor_data: bytes - remote_url: str - weight_name: str - def __init__(self, disk_path: _Optional[str] = ..., tensor_data: _Optional[bytes] = ..., remote_url: _Optional[str] = ..., weight_name: _Optional[str] = ...) -> None: ... - -class UpdateWeightsResponse(_message.Message): - __slots__ = ("success", "message") - SUCCESS_FIELD_NUMBER: _ClassVar[int] - MESSAGE_FIELD_NUMBER: _ClassVar[int] - success: bool - message: str - def __init__(self, success: bool = ..., message: _Optional[str] = ...) -> None: ... - -class GetInternalStateRequest(_message.Message): - __slots__ = ("state_keys",) - STATE_KEYS_FIELD_NUMBER: _ClassVar[int] - state_keys: _containers.RepeatedScalarFieldContainer[str] - def __init__(self, state_keys: _Optional[_Iterable[str]] = ...) -> None: ... - -class GetInternalStateResponse(_message.Message): - __slots__ = ("state",) - STATE_FIELD_NUMBER: _ClassVar[int] - state: _struct_pb2.Struct - def __init__(self, state: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ... - -class SetInternalStateRequest(_message.Message): - __slots__ = ("state",) - STATE_FIELD_NUMBER: _ClassVar[int] - state: _struct_pb2.Struct - def __init__(self, state: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ... - -class SetInternalStateResponse(_message.Message): - __slots__ = ("success", "message") - SUCCESS_FIELD_NUMBER: _ClassVar[int] - MESSAGE_FIELD_NUMBER: _ClassVar[int] - success: bool - message: str - def __init__(self, success: bool = ..., message: _Optional[str] = ...) -> None: ... - -class GetModelInfoRequest(_message.Message): - __slots__ = () - def __init__(self) -> None: ... - -class GetModelInfoResponse(_message.Message): - __slots__ = ("model_path", "tokenizer_path", "is_generation", "preferred_sampling_params", "weight_version", "served_model_name", "max_context_length", "vocab_size", "supports_vision", "model_type", "eos_token_ids", "pad_token_id", "bos_token_id", "max_req_input_len", "architectures", "id2label_json", "num_labels") - MODEL_PATH_FIELD_NUMBER: _ClassVar[int] - TOKENIZER_PATH_FIELD_NUMBER: _ClassVar[int] - IS_GENERATION_FIELD_NUMBER: _ClassVar[int] - PREFERRED_SAMPLING_PARAMS_FIELD_NUMBER: _ClassVar[int] - WEIGHT_VERSION_FIELD_NUMBER: _ClassVar[int] - SERVED_MODEL_NAME_FIELD_NUMBER: _ClassVar[int] - MAX_CONTEXT_LENGTH_FIELD_NUMBER: _ClassVar[int] - VOCAB_SIZE_FIELD_NUMBER: _ClassVar[int] - SUPPORTS_VISION_FIELD_NUMBER: _ClassVar[int] - MODEL_TYPE_FIELD_NUMBER: _ClassVar[int] - EOS_TOKEN_IDS_FIELD_NUMBER: _ClassVar[int] - PAD_TOKEN_ID_FIELD_NUMBER: _ClassVar[int] - BOS_TOKEN_ID_FIELD_NUMBER: _ClassVar[int] - MAX_REQ_INPUT_LEN_FIELD_NUMBER: _ClassVar[int] - ARCHITECTURES_FIELD_NUMBER: _ClassVar[int] - ID2LABEL_JSON_FIELD_NUMBER: _ClassVar[int] - NUM_LABELS_FIELD_NUMBER: _ClassVar[int] - model_path: str - tokenizer_path: str - is_generation: bool - preferred_sampling_params: str - weight_version: str - served_model_name: str - max_context_length: int - vocab_size: int - supports_vision: bool - model_type: str - eos_token_ids: _containers.RepeatedScalarFieldContainer[int] - pad_token_id: int - bos_token_id: int - max_req_input_len: int - architectures: _containers.RepeatedScalarFieldContainer[str] - id2label_json: str - num_labels: int - def __init__(self, model_path: _Optional[str] = ..., tokenizer_path: _Optional[str] = ..., is_generation: bool = ..., preferred_sampling_params: _Optional[str] = ..., weight_version: _Optional[str] = ..., served_model_name: _Optional[str] = ..., max_context_length: _Optional[int] = ..., vocab_size: _Optional[int] = ..., supports_vision: bool = ..., model_type: _Optional[str] = ..., eos_token_ids: _Optional[_Iterable[int]] = ..., pad_token_id: _Optional[int] = ..., bos_token_id: _Optional[int] = ..., max_req_input_len: _Optional[int] = ..., architectures: _Optional[_Iterable[str]] = ..., id2label_json: _Optional[str] = ..., num_labels: _Optional[int] = ...) -> None: ... - -class GetServerInfoRequest(_message.Message): - __slots__ = () - def __init__(self) -> None: ... - -class GetServerInfoResponse(_message.Message): - __slots__ = ("server_args", "scheduler_info", "active_requests", "is_paused", "last_receive_timestamp", "uptime_seconds", "sglang_version", "server_type", "start_time") - SERVER_ARGS_FIELD_NUMBER: _ClassVar[int] - SCHEDULER_INFO_FIELD_NUMBER: _ClassVar[int] - ACTIVE_REQUESTS_FIELD_NUMBER: _ClassVar[int] - IS_PAUSED_FIELD_NUMBER: _ClassVar[int] - LAST_RECEIVE_TIMESTAMP_FIELD_NUMBER: _ClassVar[int] - UPTIME_SECONDS_FIELD_NUMBER: _ClassVar[int] - SGLANG_VERSION_FIELD_NUMBER: _ClassVar[int] - SERVER_TYPE_FIELD_NUMBER: _ClassVar[int] - START_TIME_FIELD_NUMBER: _ClassVar[int] - server_args: _struct_pb2.Struct - scheduler_info: _struct_pb2.Struct - active_requests: int - is_paused: bool - last_receive_timestamp: float - uptime_seconds: float - sglang_version: str - server_type: str - start_time: _timestamp_pb2.Timestamp - def __init__(self, server_args: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ..., scheduler_info: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ..., active_requests: _Optional[int] = ..., is_paused: bool = ..., last_receive_timestamp: _Optional[float] = ..., uptime_seconds: _Optional[float] = ..., sglang_version: _Optional[str] = ..., server_type: _Optional[str] = ..., start_time: _Optional[_Union[datetime.datetime, _timestamp_pb2.Timestamp, _Mapping]] = ...) -> None: ... - -class GetLoadsRequest(_message.Message): - __slots__ = ("dp_rank", "include") - DP_RANK_FIELD_NUMBER: _ClassVar[int] - INCLUDE_FIELD_NUMBER: _ClassVar[int] - dp_rank: int - include: _containers.RepeatedScalarFieldContainer[str] - def __init__(self, dp_rank: _Optional[int] = ..., include: _Optional[_Iterable[str]] = ...) -> None: ... - -class GetLoadsResponse(_message.Message): - __slots__ = ("timestamp", "version", "dp_rank_count", "loads", "aggregate") - TIMESTAMP_FIELD_NUMBER: _ClassVar[int] - VERSION_FIELD_NUMBER: _ClassVar[int] - DP_RANK_COUNT_FIELD_NUMBER: _ClassVar[int] - LOADS_FIELD_NUMBER: _ClassVar[int] - AGGREGATE_FIELD_NUMBER: _ClassVar[int] - timestamp: str - version: str - dp_rank_count: int - loads: _containers.RepeatedCompositeFieldContainer[SchedulerLoad] - aggregate: AggregateMetrics - def __init__(self, timestamp: _Optional[str] = ..., version: _Optional[str] = ..., dp_rank_count: _Optional[int] = ..., loads: _Optional[_Iterable[_Union[SchedulerLoad, _Mapping]]] = ..., aggregate: _Optional[_Union[AggregateMetrics, _Mapping]] = ...) -> None: ... - -class SchedulerLoad(_message.Message): - __slots__ = ("dp_rank", "num_running_reqs", "num_waiting_reqs", "num_total_reqs", "num_used_tokens", "max_total_num_tokens", "token_usage", "gen_throughput", "cache_hit_rate", "utilization", "max_running_requests", "memory", "speculative", "lora", "disaggregation", "queues") - DP_RANK_FIELD_NUMBER: _ClassVar[int] - NUM_RUNNING_REQS_FIELD_NUMBER: _ClassVar[int] - NUM_WAITING_REQS_FIELD_NUMBER: _ClassVar[int] - NUM_TOTAL_REQS_FIELD_NUMBER: _ClassVar[int] - NUM_USED_TOKENS_FIELD_NUMBER: _ClassVar[int] - MAX_TOTAL_NUM_TOKENS_FIELD_NUMBER: _ClassVar[int] - TOKEN_USAGE_FIELD_NUMBER: _ClassVar[int] - GEN_THROUGHPUT_FIELD_NUMBER: _ClassVar[int] - CACHE_HIT_RATE_FIELD_NUMBER: _ClassVar[int] - UTILIZATION_FIELD_NUMBER: _ClassVar[int] - MAX_RUNNING_REQUESTS_FIELD_NUMBER: _ClassVar[int] - MEMORY_FIELD_NUMBER: _ClassVar[int] - SPECULATIVE_FIELD_NUMBER: _ClassVar[int] - LORA_FIELD_NUMBER: _ClassVar[int] - DISAGGREGATION_FIELD_NUMBER: _ClassVar[int] - QUEUES_FIELD_NUMBER: _ClassVar[int] - dp_rank: int - num_running_reqs: int - num_waiting_reqs: int - num_total_reqs: int - num_used_tokens: int - max_total_num_tokens: int - token_usage: float - gen_throughput: float - cache_hit_rate: float - utilization: float - max_running_requests: int - memory: MemoryMetrics - speculative: SpeculativeMetrics - lora: LoRAMetrics - disaggregation: DisaggregationMetrics - queues: QueueMetrics - def __init__(self, dp_rank: _Optional[int] = ..., num_running_reqs: _Optional[int] = ..., num_waiting_reqs: _Optional[int] = ..., num_total_reqs: _Optional[int] = ..., num_used_tokens: _Optional[int] = ..., max_total_num_tokens: _Optional[int] = ..., token_usage: _Optional[float] = ..., gen_throughput: _Optional[float] = ..., cache_hit_rate: _Optional[float] = ..., utilization: _Optional[float] = ..., max_running_requests: _Optional[int] = ..., memory: _Optional[_Union[MemoryMetrics, _Mapping]] = ..., speculative: _Optional[_Union[SpeculativeMetrics, _Mapping]] = ..., lora: _Optional[_Union[LoRAMetrics, _Mapping]] = ..., disaggregation: _Optional[_Union[DisaggregationMetrics, _Mapping]] = ..., queues: _Optional[_Union[QueueMetrics, _Mapping]] = ...) -> None: ... - -class MemoryMetrics(_message.Message): - __slots__ = ("weight_gb", "kv_cache_gb", "graph_gb", "token_capacity") - WEIGHT_GB_FIELD_NUMBER: _ClassVar[int] - KV_CACHE_GB_FIELD_NUMBER: _ClassVar[int] - GRAPH_GB_FIELD_NUMBER: _ClassVar[int] - TOKEN_CAPACITY_FIELD_NUMBER: _ClassVar[int] - weight_gb: float - kv_cache_gb: float - graph_gb: float - token_capacity: int - def __init__(self, weight_gb: _Optional[float] = ..., kv_cache_gb: _Optional[float] = ..., graph_gb: _Optional[float] = ..., token_capacity: _Optional[int] = ...) -> None: ... - -class SpeculativeMetrics(_message.Message): - __slots__ = ("accept_length", "accept_rate") - ACCEPT_LENGTH_FIELD_NUMBER: _ClassVar[int] - ACCEPT_RATE_FIELD_NUMBER: _ClassVar[int] - accept_length: float - accept_rate: float - def __init__(self, accept_length: _Optional[float] = ..., accept_rate: _Optional[float] = ...) -> None: ... - -class LoRAMetrics(_message.Message): - __slots__ = ("slots_used", "slots_total", "utilization") - SLOTS_USED_FIELD_NUMBER: _ClassVar[int] - SLOTS_TOTAL_FIELD_NUMBER: _ClassVar[int] - UTILIZATION_FIELD_NUMBER: _ClassVar[int] - slots_used: int - slots_total: int - utilization: float - def __init__(self, slots_used: _Optional[int] = ..., slots_total: _Optional[int] = ..., utilization: _Optional[float] = ...) -> None: ... - -class DisaggregationMetrics(_message.Message): - __slots__ = ("mode", "prefill_prealloc_queue_reqs", "prefill_inflight_queue_reqs", "decode_prealloc_queue_reqs", "decode_transfer_queue_reqs", "decode_retracted_queue_reqs", "kv_transfer_speed_gb_s", "kv_transfer_latency_ms") - MODE_FIELD_NUMBER: _ClassVar[int] - PREFILL_PREALLOC_QUEUE_REQS_FIELD_NUMBER: _ClassVar[int] - PREFILL_INFLIGHT_QUEUE_REQS_FIELD_NUMBER: _ClassVar[int] - DECODE_PREALLOC_QUEUE_REQS_FIELD_NUMBER: _ClassVar[int] - DECODE_TRANSFER_QUEUE_REQS_FIELD_NUMBER: _ClassVar[int] - DECODE_RETRACTED_QUEUE_REQS_FIELD_NUMBER: _ClassVar[int] - KV_TRANSFER_SPEED_GB_S_FIELD_NUMBER: _ClassVar[int] - KV_TRANSFER_LATENCY_MS_FIELD_NUMBER: _ClassVar[int] - mode: str - prefill_prealloc_queue_reqs: int - prefill_inflight_queue_reqs: int - decode_prealloc_queue_reqs: int - decode_transfer_queue_reqs: int - decode_retracted_queue_reqs: int - kv_transfer_speed_gb_s: float - kv_transfer_latency_ms: float - def __init__(self, mode: _Optional[str] = ..., prefill_prealloc_queue_reqs: _Optional[int] = ..., prefill_inflight_queue_reqs: _Optional[int] = ..., decode_prealloc_queue_reqs: _Optional[int] = ..., decode_transfer_queue_reqs: _Optional[int] = ..., decode_retracted_queue_reqs: _Optional[int] = ..., kv_transfer_speed_gb_s: _Optional[float] = ..., kv_transfer_latency_ms: _Optional[float] = ...) -> None: ... - -class QueueMetrics(_message.Message): - __slots__ = ("waiting", "grammar", "paused", "retracted") - WAITING_FIELD_NUMBER: _ClassVar[int] - GRAMMAR_FIELD_NUMBER: _ClassVar[int] - PAUSED_FIELD_NUMBER: _ClassVar[int] - RETRACTED_FIELD_NUMBER: _ClassVar[int] - waiting: int - grammar: int - paused: int - retracted: int - def __init__(self, waiting: _Optional[int] = ..., grammar: _Optional[int] = ..., paused: _Optional[int] = ..., retracted: _Optional[int] = ...) -> None: ... - -class AggregateMetrics(_message.Message): - __slots__ = ("total_running_reqs", "total_waiting_reqs", "total_reqs", "avg_token_usage", "avg_throughput", "avg_utilization") - TOTAL_RUNNING_REQS_FIELD_NUMBER: _ClassVar[int] - TOTAL_WAITING_REQS_FIELD_NUMBER: _ClassVar[int] - TOTAL_REQS_FIELD_NUMBER: _ClassVar[int] - AVG_TOKEN_USAGE_FIELD_NUMBER: _ClassVar[int] - AVG_THROUGHPUT_FIELD_NUMBER: _ClassVar[int] - AVG_UTILIZATION_FIELD_NUMBER: _ClassVar[int] - total_running_reqs: int - total_waiting_reqs: int - total_reqs: int - avg_token_usage: float - avg_throughput: float - avg_utilization: float - def __init__(self, total_running_reqs: _Optional[int] = ..., total_waiting_reqs: _Optional[int] = ..., total_reqs: _Optional[int] = ..., avg_token_usage: _Optional[float] = ..., avg_throughput: _Optional[float] = ..., avg_utilization: _Optional[float] = ...) -> None: ... diff --git a/python/sglang/srt/grpc/sglang_scheduler_pb2_grpc.py b/python/sglang/srt/grpc/sglang_scheduler_pb2_grpc.py deleted file mode 100644 index 99bf78bb4864..000000000000 --- a/python/sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +++ /dev/null @@ -1,368 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -"""Client and server classes corresponding to protobuf-defined services.""" -import grpc -import warnings - -from . import sglang_scheduler_pb2 as sglang__scheduler__pb2 - -GRPC_GENERATED_VERSION = '1.75.1' -GRPC_VERSION = grpc.__version__ -_version_not_supported = False - -try: - from grpc._utilities import first_version_is_lower - _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION) -except ImportError: - _version_not_supported = True - -if _version_not_supported: - raise RuntimeError( - f'The grpc package installed is at version {GRPC_VERSION},' - + f' but the generated code in sglang_scheduler_pb2_grpc.py depends on' - + f' grpcio>={GRPC_GENERATED_VERSION}.' - + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}' - + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.' - ) - - -class SglangSchedulerStub(object): - """Service definition for SGLang scheduler communication - This protocol bridges the Rust router and Python scheduler - """ - - def __init__(self, channel): - """Constructor. - - Args: - channel: A grpc.Channel. - """ - self.Generate = channel.unary_stream( - '/sglang.grpc.scheduler.SglangScheduler/Generate', - request_serializer=sglang__scheduler__pb2.GenerateRequest.SerializeToString, - response_deserializer=sglang__scheduler__pb2.GenerateResponse.FromString, - _registered_method=True) - self.Embed = channel.unary_unary( - '/sglang.grpc.scheduler.SglangScheduler/Embed', - request_serializer=sglang__scheduler__pb2.EmbedRequest.SerializeToString, - response_deserializer=sglang__scheduler__pb2.EmbedResponse.FromString, - _registered_method=True) - self.HealthCheck = channel.unary_unary( - '/sglang.grpc.scheduler.SglangScheduler/HealthCheck', - request_serializer=sglang__scheduler__pb2.HealthCheckRequest.SerializeToString, - response_deserializer=sglang__scheduler__pb2.HealthCheckResponse.FromString, - _registered_method=True) - self.Abort = channel.unary_unary( - '/sglang.grpc.scheduler.SglangScheduler/Abort', - request_serializer=sglang__scheduler__pb2.AbortRequest.SerializeToString, - response_deserializer=sglang__scheduler__pb2.AbortResponse.FromString, - _registered_method=True) - self.GetModelInfo = channel.unary_unary( - '/sglang.grpc.scheduler.SglangScheduler/GetModelInfo', - request_serializer=sglang__scheduler__pb2.GetModelInfoRequest.SerializeToString, - response_deserializer=sglang__scheduler__pb2.GetModelInfoResponse.FromString, - _registered_method=True) - self.GetServerInfo = channel.unary_unary( - '/sglang.grpc.scheduler.SglangScheduler/GetServerInfo', - request_serializer=sglang__scheduler__pb2.GetServerInfoRequest.SerializeToString, - response_deserializer=sglang__scheduler__pb2.GetServerInfoResponse.FromString, - _registered_method=True) - self.GetLoads = channel.unary_unary( - '/sglang.grpc.scheduler.SglangScheduler/GetLoads', - request_serializer=sglang__scheduler__pb2.GetLoadsRequest.SerializeToString, - response_deserializer=sglang__scheduler__pb2.GetLoadsResponse.FromString, - _registered_method=True) - - -class SglangSchedulerServicer(object): - """Service definition for SGLang scheduler communication - This protocol bridges the Rust router and Python scheduler - """ - - def Generate(self, request, context): - """Submit a generation request (supports streaming) - """ - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def Embed(self, request, context): - """Submit an embedding request - """ - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def HealthCheck(self, request, context): - """Health check and metrics - """ - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def Abort(self, request, context): - """Abort a running request - """ - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def GetModelInfo(self, request, context): - """Get model information - """ - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def GetServerInfo(self, request, context): - """Get server information - """ - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def GetLoads(self, request, context): - """Get comprehensive load metrics - """ - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - -def add_SglangSchedulerServicer_to_server(servicer, server): - rpc_method_handlers = { - 'Generate': grpc.unary_stream_rpc_method_handler( - servicer.Generate, - request_deserializer=sglang__scheduler__pb2.GenerateRequest.FromString, - response_serializer=sglang__scheduler__pb2.GenerateResponse.SerializeToString, - ), - 'Embed': grpc.unary_unary_rpc_method_handler( - servicer.Embed, - request_deserializer=sglang__scheduler__pb2.EmbedRequest.FromString, - response_serializer=sglang__scheduler__pb2.EmbedResponse.SerializeToString, - ), - 'HealthCheck': grpc.unary_unary_rpc_method_handler( - servicer.HealthCheck, - request_deserializer=sglang__scheduler__pb2.HealthCheckRequest.FromString, - response_serializer=sglang__scheduler__pb2.HealthCheckResponse.SerializeToString, - ), - 'Abort': grpc.unary_unary_rpc_method_handler( - servicer.Abort, - request_deserializer=sglang__scheduler__pb2.AbortRequest.FromString, - response_serializer=sglang__scheduler__pb2.AbortResponse.SerializeToString, - ), - 'GetModelInfo': grpc.unary_unary_rpc_method_handler( - servicer.GetModelInfo, - request_deserializer=sglang__scheduler__pb2.GetModelInfoRequest.FromString, - response_serializer=sglang__scheduler__pb2.GetModelInfoResponse.SerializeToString, - ), - 'GetServerInfo': grpc.unary_unary_rpc_method_handler( - servicer.GetServerInfo, - request_deserializer=sglang__scheduler__pb2.GetServerInfoRequest.FromString, - response_serializer=sglang__scheduler__pb2.GetServerInfoResponse.SerializeToString, - ), - 'GetLoads': grpc.unary_unary_rpc_method_handler( - servicer.GetLoads, - request_deserializer=sglang__scheduler__pb2.GetLoadsRequest.FromString, - response_serializer=sglang__scheduler__pb2.GetLoadsResponse.SerializeToString, - ), - } - generic_handler = grpc.method_handlers_generic_handler( - 'sglang.grpc.scheduler.SglangScheduler', rpc_method_handlers) - server.add_generic_rpc_handlers((generic_handler,)) - server.add_registered_method_handlers('sglang.grpc.scheduler.SglangScheduler', rpc_method_handlers) - - - # This class is part of an EXPERIMENTAL API. -class SglangScheduler(object): - """Service definition for SGLang scheduler communication - This protocol bridges the Rust router and Python scheduler - """ - - @staticmethod - def Generate(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_stream( - request, - target, - '/sglang.grpc.scheduler.SglangScheduler/Generate', - sglang__scheduler__pb2.GenerateRequest.SerializeToString, - sglang__scheduler__pb2.GenerateResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True) - - @staticmethod - def Embed(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/sglang.grpc.scheduler.SglangScheduler/Embed', - sglang__scheduler__pb2.EmbedRequest.SerializeToString, - sglang__scheduler__pb2.EmbedResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True) - - @staticmethod - def HealthCheck(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/sglang.grpc.scheduler.SglangScheduler/HealthCheck', - sglang__scheduler__pb2.HealthCheckRequest.SerializeToString, - sglang__scheduler__pb2.HealthCheckResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True) - - @staticmethod - def Abort(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/sglang.grpc.scheduler.SglangScheduler/Abort', - sglang__scheduler__pb2.AbortRequest.SerializeToString, - sglang__scheduler__pb2.AbortResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True) - - @staticmethod - def GetModelInfo(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/sglang.grpc.scheduler.SglangScheduler/GetModelInfo', - sglang__scheduler__pb2.GetModelInfoRequest.SerializeToString, - sglang__scheduler__pb2.GetModelInfoResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True) - - @staticmethod - def GetServerInfo(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/sglang.grpc.scheduler.SglangScheduler/GetServerInfo', - sglang__scheduler__pb2.GetServerInfoRequest.SerializeToString, - sglang__scheduler__pb2.GetServerInfoResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True) - - @staticmethod - def GetLoads(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/sglang.grpc.scheduler.SglangScheduler/GetLoads', - sglang__scheduler__pb2.GetLoadsRequest.SerializeToString, - sglang__scheduler__pb2.GetLoadsResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True) From 57fa43f392d501cd05202c0fdf2bc17e28d1f6da Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Fri, 20 Mar 2026 20:07:09 +0000 Subject: [PATCH 042/112] nit: modify weight loader warning msg --- python/sglang/srt/models/gemma4_causal.py | 32 ++++++++++--- python/sglang/srt/models/gemma4_mm.py | 56 ++++++++++++++++------- 2 files changed, 65 insertions(+), 23 deletions(-) diff --git a/python/sglang/srt/models/gemma4_causal.py b/python/sglang/srt/models/gemma4_causal.py index dc464e00e233..79a43b4acfc1 100644 --- a/python/sglang/srt/models/gemma4_causal.py +++ b/python/sglang/srt/models/gemma4_causal.py @@ -949,9 +949,7 @@ def _get_k_eq_v_layers(self) -> set: if not getattr(self.config, "attention_k_eq_v", False): return set() return { - i - for i, lt in enumerate(self.config.layer_types) - if lt == "full_attention" + i for i, lt in enumerate(self.config.layer_types) if lt == "full_attention" } def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): @@ -975,6 +973,12 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters()) params_dict.update(dict(self.named_buffers())) + non_persistent_buffers: Set[str] = set() + for mod_name, mod in self.named_modules(): + for buf_name in getattr(mod, "_non_persistent_buffers_set", set()): + full = f"{mod_name}.{buf_name}" if mod_name else buf_name + non_persistent_buffers.add(full) + loaded_params: Set[str] = set() for name, loaded_weight in weights: name = name.replace("model.language_model.", "model.") @@ -1042,9 +1046,25 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): loaded_params.add(name) unloaded_params = params_dict.keys() - loaded_params if unloaded_params: - logger.warning( - "Some weights are not initialized from checkpoints: %s", unloaded_params - ) + param_names = set(dict(self.named_parameters()).keys()) + buckets = { + logging.WARNING: ( + "Some weights are not initialized from checkpoints", + lambda p: p in param_names, + ), + logging.INFO: ( + "Persistent buffers not in checkpoint (using default init)", + lambda p: p not in param_names and p not in non_persistent_buffers, + ), + logging.DEBUG: ( + "Non-persistent buffers not in checkpoint (expected)", + lambda p: p in non_persistent_buffers, + ), + } + for level, (msg, pred) in buckets.items(): + names = sorted(p for p in unloaded_params if pred(p)) + if names: + logger.log(level, "%s: %s", msg, names) return loaded_params diff --git a/python/sglang/srt/models/gemma4_mm.py b/python/sglang/srt/models/gemma4_mm.py index 8482e7d4f3ad..e0ce7e4d1629 100644 --- a/python/sglang/srt/models/gemma4_mm.py +++ b/python/sglang/srt/models/gemma4_mm.py @@ -45,7 +45,6 @@ flatten_nested_list, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode -from sglang.srt.server_args import get_global_server_args from sglang.srt.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name, @@ -247,7 +246,7 @@ def prepare_attn_masks( mask_dtype: torch.dtype, ): """Prepare bidirectional attention masks for image tokens. - + Gemma 4 uses bidirectional attention for image soft tokens during prefill. Following the HF implementation, bidirectional attention is only enabled within each individual image group (same-image tokens), not across images. @@ -281,9 +280,7 @@ def prepare_attn_masks( ) # Start with causal mask bidirectional_attn_mask.fill_(1) - bidirectional_attn_mask = bidirectional_attn_mask.tril( - diagonal=prefix_len - ) + bidirectional_attn_mask = bidirectional_attn_mask.tril(diagonal=prefix_len) # Enable bidirectional attention within each image group mm_inputs = forward_batch.mm_inputs[i] @@ -291,18 +288,24 @@ def prepare_attn_masks( for mm_item in mm_inputs.mm_items: if mm_item.is_image(): for im_begin, im_end in mm_item.offsets: - # Note(kpham-sgl): We only apply bidirectional attention when the image token span - # is fully contained in the extend window. Otherwise, we silently fall back to + # Note(kpham-sgl): We only apply bidirectional attention when the image token span + # is fully contained in the extend window. Otherwise, we silently fall back to # causal attention. # FIXME(kpham-sgl): This is a hack to work around the fact that the image token span - # might not be fully contained in the extend window during chunked prefill. + # might not be fully contained in the extend window during chunked prefill. # We should fix this by properly making chunked prefill mask aware. - if im_begin >= prefix_len and im_end < prefix_len + extend_seq_len: + if ( + im_begin >= prefix_len + and im_end < prefix_len + extend_seq_len + ): bidirectional_attn_mask[ im_begin - prefix_len : im_end + 1 - prefix_len, im_begin : im_end + 1, ] = 1 - elif im_end >= prefix_len and im_begin < prefix_len + extend_seq_len: + elif ( + im_end >= prefix_len + and im_begin < prefix_len + extend_seq_len + ): split_images.append((i, im_begin, im_end)) bidirectional_attn_masks_list.append(bidirectional_attn_mask.flatten()) @@ -568,9 +571,7 @@ def _get_k_eq_v_layers(self) -> set: if not getattr(text_config, "attention_k_eq_v", False): return set() return { - i - for i, lt in enumerate(text_config.layer_types) - if lt == "full_attention" + i for i, lt in enumerate(text_config.layer_types) if lt == "full_attention" } def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): @@ -589,6 +590,12 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters()) params_dict.update(dict(self.named_buffers())) + non_persistent_buffers: Set[str] = set() + for mod_name, mod in self.named_modules(): + for buf_name in getattr(mod, "_non_persistent_buffers_set", set()): + full = f"{mod_name}.{buf_name}" if mod_name else buf_name + non_persistent_buffers.add(full) + loaded_params: Set[str] = set() for name, loaded_weight in weights: @@ -670,10 +677,25 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): loaded_params.add(name) unloaded_params = params_dict.keys() - loaded_params if unloaded_params: - logger.warning( - "Some weights are not initialized from checkpoints: %s", - unloaded_params, - ) + param_names = set(dict(self.named_parameters()).keys()) + buckets = { + logging.WARNING: ( + "Some weights are not initialized from checkpoints", + lambda p: p in param_names, + ), + logging.INFO: ( + "Persistent buffers not in checkpoint (using default init)", + lambda p: p not in param_names and p not in non_persistent_buffers, + ), + logging.DEBUG: ( + "Non-persistent buffers not in checkpoint (expected)", + lambda p: p in non_persistent_buffers, + ), + } + for level, (msg, pred) in buckets.items(): + names = sorted(p for p in unloaded_params if pred(p)) + if names: + logger.log(level, "%s: %s", msg, names) return loaded_params lora_pattern = re.compile( From f50a9fd3266a62740e38b5008f292863c91b3c6c Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Sun, 22 Mar 2026 05:15:29 +0000 Subject: [PATCH 043/112] perf: accelerate Gemma4RMSNorm with sgl_kernel CUDA kernels Replace pure Python/PyTorch forward in Gemma4RMSNorm with fused CUDA kernels from sgl_kernel, eliminating ~5 native kernels per norm call (float, pow, mean, rsqrt, mul, type_as). - Change Gemma4RMSNorm from nn.Module to MultiPlatformOp - scale_shift=1.0: use gemma_rmsnorm (same as GemmaRMSNorm) - scale_shift=0.0: use rmsnorm (standard RMSNorm) - with_scale=False: use gemma_rmsnorm with zero weight buffer - Handle >2D tensors via reshape (e.g. q/k/v norms on 3D input) --- python/sglang/srt/layers/layernorm.py | 34 +++++++++++++++++++++++---- 1 file changed, 29 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index 2164bccc3e68..487662d12137 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -569,7 +569,7 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.eps}" -class Gemma4RMSNorm(nn.Module): +class Gemma4RMSNorm(MultiPlatformOp): def __init__( self, dim: int, @@ -583,13 +583,14 @@ def __init__( if self.with_scale: self.weight = nn.Parameter(torch.zeros(dim)) else: - self.register_buffer("weight", torch.tensor(1.0), persistent=False) + # Zero buffer: gemma_rmsnorm(x, zeros) = norm(x) * (1+0) = norm(x) + self.register_buffer("weight", torch.zeros(dim), persistent=False) self.eps = eps self.scale_shift = scale_shift def __repr__(self): - dim = self.weight.shape[-1] if self.weight.shape else None + dim = self.weight.shape[0] return ( f"{self.__class__.__name__}(dim={dim}, eps={self.eps}, " f"with_scale={self.with_scale}, scale_shift={self.scale_shift})" @@ -597,15 +598,38 @@ def __repr__(self): def _norm(self, x): mean_squared = x.pow(2).mean(-1, keepdim=True) + self.eps - # Use torch.pow() (over torch.sqrt() or torch.rsqrt()) to address compiler differences between Torch and JAX return x * torch.pow(mean_squared, -0.5) - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward_native(self, x: torch.Tensor) -> torch.Tensor: normed_output = self._norm(x.float()) if self.with_scale: normed_output = normed_output * (self.weight.float() + self.scale_shift) return normed_output.type_as(x) + def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + if x.numel() == 0: + return x + needs_reshape = x.dim() != 2 + if needs_reshape: + original_shape = x.shape + x = x.contiguous().reshape(-1, original_shape[-1]) + if self.scale_shift == 1.0: + # gemma_rmsnorm: norm(x) * (1 + weight) + # When with_scale=False, weight is zeros → norm(x) * 1 = norm(x) + # When with_scale=True, weight is learned → norm(x) * (1 + w) + out = gemma_rmsnorm(x, self.weight.data, self.eps) + elif self.scale_shift == 0.0 and self.with_scale: + # rmsnorm: norm(x) * weight (standard RMSNorm without +1 shift) + out = rmsnorm(x, self.weight.data, self.eps) + else: + out = self.forward_native( + x.reshape(original_shape) if needs_reshape else x + ) + return out + if needs_reshape: + out = out.reshape(original_shape) + return out + class RMSNormWithoutScale(MultiPlatformOp): def __init__(self, hidden_size: int, eps=1e-6): From dd40b9d46fcb751e8de099fcf3f3c52204362aec Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Sun, 22 Mar 2026 05:15:35 +0000 Subject: [PATCH 044/112] perf: add fused triton kernel for gemma4 norm+residual+scalar Add gemma4_fused_ops.py with a single triton kernel that fuses: out = (gemma_rmsnorm(x, w) + residual) * scalar This eliminates 3 separate kernel launches (norm, add, mul) at the end of each decoder layer. Uses HAS_SCALAR constexpr to support both with-scalar and without-scalar variants from one kernel. --- python/sglang/srt/layers/gemma4_fused_ops.py | 111 +++++++++++++++++++ 1 file changed, 111 insertions(+) create mode 100644 python/sglang/srt/layers/gemma4_fused_ops.py diff --git a/python/sglang/srt/layers/gemma4_fused_ops.py b/python/sglang/srt/layers/gemma4_fused_ops.py new file mode 100644 index 000000000000..ffb5ed48a1be --- /dev/null +++ b/python/sglang/srt/layers/gemma4_fused_ops.py @@ -0,0 +1,111 @@ +"""Fused triton kernels for Gemma4 decoder layer operations. + +Fuses post-norm + residual-add (+ optional scalar multiply) into a single +kernel pass to reduce kernel launch overhead. +""" + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _gemma_rmsnorm_residual_kernel( + X_ptr, + W_ptr, + Residual_ptr, + Scalar_ptr, + Out_ptr, + stride_x, + stride_r, + stride_o, + N, + eps, + HAS_SCALAR: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + """Fused kernel: out = gemma_rmsnorm(x, w) + residual [* scalar] + + When HAS_SCALAR is True, also multiplies by a scalar loaded from Scalar_ptr. + """ + row = tl.program_id(0) + cols = tl.arange(0, BLOCK_SIZE) + mask = cols < N + + x = tl.load(X_ptr + row * stride_x + cols, mask=mask, other=0.0).to(tl.float32) + w = tl.load(W_ptr + cols, mask=mask, other=0.0).to(tl.float32) + r = tl.load(Residual_ptr + row * stride_r + cols, mask=mask, other=0.0).to( + tl.float32 + ) + + var = tl.sum(x * x, axis=0) / N + rrms = tl.rsqrt(var + eps) + out = x * rrms * (1.0 + w) + r + + if HAS_SCALAR: + scalar = tl.load(Scalar_ptr).to(tl.float32) + out = out * scalar + + tl.store(Out_ptr + row * stride_o + cols, out.to(x.dtype), mask=mask) + + +def gemma_rmsnorm_residual( + x: torch.Tensor, + weight: torch.Tensor, + residual: torch.Tensor, + eps: float = 1e-6, +) -> tuple[torch.Tensor, torch.Tensor]: + """Fused gemma_rmsnorm(x) + residual. + + Returns (output, new_residual) where new_residual = output. + """ + assert x.dim() == 2 and x.stride(-1) == 1, "Expected contiguous 2D input" + M, N = x.shape + BLOCK_SIZE = triton.next_power_of_2(N) + out = torch.empty_like(x) + + _gemma_rmsnorm_residual_kernel[(M,)]( + x, + weight, + residual, + None, # no scalar + out, + x.stride(0), + residual.stride(0), + out.stride(0), + N, + eps, + HAS_SCALAR=False, + BLOCK_SIZE=BLOCK_SIZE, + ) + return out, out.clone() + + +def gemma_rmsnorm_residual_scalar( + x: torch.Tensor, + weight: torch.Tensor, + residual: torch.Tensor, + scalar: torch.Tensor, + eps: float = 1e-6, +) -> torch.Tensor: + """Fused (gemma_rmsnorm(x) + residual) * scalar.""" + assert x.dim() == 2 and x.stride(-1) == 1, "Expected contiguous 2D input" + M, N = x.shape + BLOCK_SIZE = triton.next_power_of_2(N) + out = torch.empty_like(x) + + _gemma_rmsnorm_residual_kernel[(M,)]( + x, + weight, + residual, + scalar, + out, + x.stride(0), + residual.stride(0), + out.stride(0), + N, + eps, + HAS_SCALAR=True, + BLOCK_SIZE=BLOCK_SIZE, + ) + return out From 4bafc584f52f8d8c4284428e462302c790065fbb Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Sun, 22 Mar 2026 05:15:47 +0000 Subject: [PATCH 045/112] perf: fuse residual add and layer_scalar in Gemma4 decoder layer Two fusion optimizations in Gemma4DecoderLayer.forward(): 1. Fuse post_attn_norm + residual_add + pre_ff_norm: use existing gemma_fused_add_rmsnorm via GemmaRMSNorm's (x, residual) path, reducing 3 kernels to 2 per layer. 2. Fuse post_ff_norm + residual_add + layer_scalar: use the new gemma_rmsnorm_residual_scalar triton kernel when PLE is not active, reducing 3 kernels to 1 per layer. Also cache self.has_ple in __init__ to avoid repeated attribute lookups on the hot path. TP1 benchmark on Gemma4 26B-A4B MoE (B200): - Per-step kernel time: 4.96ms -> 2.00ms (-60%) - vs vLLM: 0.47x kernel time (was 1.17x slower) - Throughput (bs=1): 55.2 tok/s --- python/sglang/srt/models/gemma4_causal.py | 82 ++++++++++++++--------- 1 file changed, 51 insertions(+), 31 deletions(-) diff --git a/python/sglang/srt/models/gemma4_causal.py b/python/sglang/srt/models/gemma4_causal.py index 79a43b4acfc1..0b5e6ec2043b 100644 --- a/python/sglang/srt/models/gemma4_causal.py +++ b/python/sglang/srt/models/gemma4_causal.py @@ -49,6 +49,7 @@ default_weight_loader, maybe_remap_kv_scale_name, ) +from sglang.srt.layers.gemma4_fused_ops import gemma_rmsnorm_residual_scalar from sglang.srt.models.gemma3_causal import Gemma3MLP, Gemma3TextScaledWordEmbedding from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import add_prefix, make_layers @@ -582,6 +583,7 @@ def __init__( self.pre_feedforward_layernorm_2 = None self.register_buffer("layer_scalar", torch.ones(1), persistent=True) + self.has_ple = self.hidden_size_per_layer_input > 0 self.prefix = prefix def forward( @@ -597,6 +599,12 @@ def forward( # Gemma4 residual pattern following JAX implementation: # 1. input_norm(x) -> attn -> post_attn_norm -> ADD residual # 2. pre_ff_norm -> mlp -> post_ff_norm -> ADD residual + # + # Optimization: fuse "post_attn_norm(h) + residual; pre_ff_norm(...)" + # into "post_attn_norm(h); pre_ff_norm(h, residual)" using + # gemma_fused_add_rmsnorm which computes: + # residual = h + residual (in-place) + # h = gemma_norm(residual) residual = hidden_states # Apply input layernorm @@ -607,51 +615,63 @@ def forward( forward_batch=forward_batch, ) hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = hidden_states + residual - - residual = hidden_states if self.enable_moe_block: + # Fuse: hidden_states + residual -> residual; pre_ff_norm(residual) -> hidden_states + # Also need raw (unfused) residual for router and pre_ff_norm_2 + hidden_states, residual = self.pre_feedforward_layernorm( + hidden_states, residual + ) + # For MoE: router and pre_ff_norm_2 need the unfused residual + # (which is now updated to post_attn_out + old_residual) + moe_input = residual + # Dense MLP branch - hidden_states_1 = self.pre_feedforward_layernorm(hidden_states) - hidden_states_1 = self.mlp(hidden_states_1) + hidden_states_1 = self.mlp(hidden_states) hidden_states_1 = self.post_feedforward_layernorm_1(hidden_states_1) - # MoE branch: router sees raw hidden_states (applies its own - # norm + scale internally); experts see separately normed input - router_logits = self.router(hidden_states) - hidden_states_2 = self.pre_feedforward_layernorm_2(hidden_states) + # MoE branch: router sees residual (= post_attn_out + old_residual) + router_logits = self.router(moe_input) + hidden_states_2 = self.pre_feedforward_layernorm_2(moe_input) hidden_states_2 = self.moe(hidden_states_2, router_logits) hidden_states_2 = self.post_feedforward_layernorm_2(hidden_states_2) # Combine branches hidden_states = hidden_states_1 + hidden_states_2 else: - hidden_states = self.pre_feedforward_layernorm(hidden_states) + # Fuse: hidden_states + residual -> residual; pre_ff_norm(residual) -> hidden_states + hidden_states, residual = self.pre_feedforward_layernorm( + hidden_states, residual + ) hidden_states = self.mlp(hidden_states) - hidden_states = self.post_feedforward_layernorm(hidden_states) - hidden_states = hidden_states + residual - - if ( - per_layer_input is not None - and self.per_layer_input_gate is not None - and self.per_layer_projection is not None - and self.post_per_layer_input_norm is not None - ): - gate, _ = self.per_layer_input_gate(hidden_states) - # PLE uses gelu activation for the gate - # Note: GeluAndMul expects concatenated [gate, up] but here we - # only have a single projection. Use F.gelu directly. - gate = torch.nn.functional.gelu(gate, approximate="tanh") - gated_per_layer = gate * per_layer_input - per_layer_contribution, _ = self.per_layer_projection(gated_per_layer) - per_layer_contribution = self.post_per_layer_input_norm( - per_layer_contribution + if not self.has_ple and hidden_states.is_cuda and hidden_states.dim() == 2: + # Fused: (post_ff_norm(h) + residual) * layer_scalar in one kernel + norm = self.post_feedforward_layernorm + hidden_states = gemma_rmsnorm_residual_scalar( + hidden_states, + norm.weight.data, + residual, + self.layer_scalar, + norm.variance_epsilon, ) - hidden_states = hidden_states + per_layer_contribution - - hidden_states = hidden_states * self.layer_scalar + else: + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = hidden_states + residual + + if self.has_ple and per_layer_input is not None: + gate, _ = self.per_layer_input_gate(hidden_states) + gate = torch.nn.functional.gelu(gate, approximate="tanh") + gated_per_layer = gate * per_layer_input + per_layer_contribution, _ = self.per_layer_projection( + gated_per_layer + ) + per_layer_contribution = self.post_per_layer_input_norm( + per_layer_contribution + ) + hidden_states = hidden_states + per_layer_contribution + + hidden_states = hidden_states * self.layer_scalar return hidden_states, None From 31fc2e78f758446175dddc8f54be8f7d36754551 Mon Sep 17 00:00:00 2001 From: Pengyu Chen Date: Tue, 24 Mar 2026 01:19:30 +0000 Subject: [PATCH 046/112] attempt to fix rms norm accuracy --- python/sglang/srt/layers/layernorm.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index 487662d12137..0d5e32f5ff9e 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -613,19 +613,18 @@ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: if needs_reshape: original_shape = x.shape x = x.contiguous().reshape(-1, original_shape[-1]) - if self.scale_shift == 1.0: + if self.with_scale and self.scale_shift == 1.0: # gemma_rmsnorm: norm(x) * (1 + weight) # When with_scale=False, weight is zeros → norm(x) * 1 = norm(x) # When with_scale=True, weight is learned → norm(x) * (1 + w) out = gemma_rmsnorm(x, self.weight.data, self.eps) - elif self.scale_shift == 0.0 and self.with_scale: + elif not self.with_scale or self.scale_shift == 0.0: + # scale_shift == 0.0 # rmsnorm: norm(x) * weight (standard RMSNorm without +1 shift) out = rmsnorm(x, self.weight.data, self.eps) else: - out = self.forward_native( - x.reshape(original_shape) if needs_reshape else x - ) - return out + return self.forward_native(x.reshape(original_shape)) + if needs_reshape: out = out.reshape(original_shape) return out From b616a9e7833686d3a1dbeb13e963a79cc8b94397 Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Tue, 24 Mar 2026 03:02:00 +0000 Subject: [PATCH 047/112] fix: Gemma4RMSNorm use rmsnorm(ones) for with_scale=False MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When with_scale=False (v_norm, router.norm), use rmsnorm(x, ones) instead of gemma_rmsnorm(x, zeros). The zeros buffer passed to gemma_rmsnorm causes quality regression under CUDA graph capture (MMLU 63→53). Using rmsnorm with ones buffer computes norm(x) * 1 = norm(x), correct and fully on CUDA. forward_cuda branches: - with_scale=True, scale_shift=1.0 → gemma_rmsnorm (norm * (1+w)) - else → rmsnorm (norm * w) Verified: MMLU 0.631 (baseline 0.632), no forward_native fallback. --- python/sglang/srt/layers/layernorm.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index 0d5e32f5ff9e..2cc40dab1ad2 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -583,8 +583,8 @@ def __init__( if self.with_scale: self.weight = nn.Parameter(torch.zeros(dim)) else: - # Zero buffer: gemma_rmsnorm(x, zeros) = norm(x) * (1+0) = norm(x) - self.register_buffer("weight", torch.zeros(dim), persistent=False) + # Ones buffer: rmsnorm(x, ones) = norm(x) * 1 = norm(x) + self.register_buffer("weight", torch.ones(dim), persistent=False) self.eps = eps self.scale_shift = scale_shift @@ -615,15 +615,12 @@ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: x = x.contiguous().reshape(-1, original_shape[-1]) if self.with_scale and self.scale_shift == 1.0: # gemma_rmsnorm: norm(x) * (1 + weight) - # When with_scale=False, weight is zeros → norm(x) * 1 = norm(x) - # When with_scale=True, weight is learned → norm(x) * (1 + w) out = gemma_rmsnorm(x, self.weight.data, self.eps) - elif not self.with_scale or self.scale_shift == 0.0: - # scale_shift == 0.0 - # rmsnorm: norm(x) * weight (standard RMSNorm without +1 shift) - out = rmsnorm(x, self.weight.data, self.eps) else: - return self.forward_native(x.reshape(original_shape)) + # rmsnorm: norm(x) * weight + # with_scale=False → weight is ones → norm(x) * 1 = norm(x) + # scale_shift=0.0 → standard RMSNorm without +1 shift + out = rmsnorm(x, self.weight.data, self.eps) if needs_reshape: out = out.reshape(original_shape) From b9d8667bb270433221892d8b7e490bbfc7cf6e86 Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Tue, 24 Mar 2026 03:38:19 +0000 Subject: [PATCH 048/112] perf: simplify MoE routing, fuse router mul - Simplify MoE routing: replace one_hot+indicator+sum+where+div+gather with softmax(topk_logits). Mathematically equivalent when renormalizing. - Fuse router mul: cache scale * root_size on first forward to avoid recomputing per call. One elementwise mul instead of two per MoE layer. MMLU: 0.631 (baseline 0.632). Latency: 52.3s (baseline 55.3s, -5.4%). --- python/sglang/srt/models/gemma4_causal.py | 33 +++++++++++------------ 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/python/sglang/srt/models/gemma4_causal.py b/python/sglang/srt/models/gemma4_causal.py index 0b5e6ec2043b..a0e998f244e7 100644 --- a/python/sglang/srt/models/gemma4_causal.py +++ b/python/sglang/srt/models/gemma4_causal.py @@ -186,11 +186,16 @@ def __init__( prefix=add_prefix("proj", prefix), ) + def fuse_scale(self): + """Pre-compute scale * root_size. Call after weights are loaded.""" + self._fused_scale = (self.scale * self.root_size).to(self.scale.dtype) + def forward(self, x: torch.Tensor) -> torch.Tensor: """Returns raw router logits [T, E].""" x = self.norm(x) - x = x * self.root_size.to(x.dtype) - x = x * self.scale.to(x.dtype) + if not hasattr(self, "_fused_scale"): + self.fuse_scale() + x = x * self._fused_scale.to(x.dtype) router_logits, _ = self.proj(x) return router_logits @@ -224,30 +229,24 @@ def __init__( # MoE's fused kernel computes: Σ_e (expert_e * w_e * scale_e) self.per_expert_scale = nn.Parameter(torch.ones(config.num_experts)) - # Gemma4 routing: softmax over ALL experts → top-k → renormalize. + # Capture param directly to avoid closing over self in the routing closure. per_expert_scale = self.per_expert_scale def routing_function( hidden_states: torch.Tensor, gating_output: torch.Tensor, topk: int, - renormalize: bool, + renormalize: bool, # always True for Gemma4; softmax identity only holds when renormalizing ) -> tuple[torch.Tensor, torch.Tensor]: - _, topk_ids = torch.topk(gating_output, k=topk, dim=-1) - router_probabilities = torch.nn.functional.softmax(gating_output, dim=-1) - indicator = torch.nn.functional.one_hot( - topk_ids, num_classes=gating_output.size(-1) - ).sum(dim=-2) - gate_weights = indicator * router_probabilities - renorm_factor = torch.sum(gate_weights, dim=-1, keepdim=True) - renorm_factor = torch.where(renorm_factor > 0.0, renorm_factor, 1.0) - dispatch_weights = gate_weights / renorm_factor - - topk_weights = dispatch_weights.gather(1, topk_ids) + # softmax(all)[topk] / sum(softmax(all)[topk]) = softmax(topk_logits), + # so we softmax only the top-k logits (fewer kernel launches). + topk_logits, topk_ids = torch.topk(gating_output, k=topk, dim=-1) + topk_weights = torch.nn.functional.softmax(topk_logits, dim=-1) # Fold per_expert_scale into routing weights - expert_scales = per_expert_scale[topk_ids].to(topk_weights.dtype) - topk_weights = topk_weights * expert_scales + topk_weights = topk_weights * per_expert_scale[topk_ids].to( + topk_weights.dtype + ) return topk_weights.to(torch.float32), topk_ids.to(torch.int32) From a2ea29c653e4ed4e8c9b9fe9e916dc2050d153d1 Mon Sep 17 00:00:00 2001 From: Pengyu Chen Date: Tue, 24 Mar 2026 03:42:15 +0000 Subject: [PATCH 049/112] cleanup --- python/sglang/srt/mem_cache/swa_memory_pool.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/sglang/srt/mem_cache/swa_memory_pool.py b/python/sglang/srt/mem_cache/swa_memory_pool.py index 6c987f073677..615c48f31c38 100644 --- a/python/sglang/srt/mem_cache/swa_memory_pool.py +++ b/python/sglang/srt/mem_cache/swa_memory_pool.py @@ -301,8 +301,6 @@ def __init__( self.clear() self._kvcache = kvcache - # why do we need this? - # self._kvcache.register_mapping(weakref.proxy(self.full_to_swa_index_mapping)) self._kvcache.register_mapping(self.full_to_swa_index_mapping) def available_size(self): From 38332fb1d5f03cbd43b8ce414d9ea379229576c7 Mon Sep 17 00:00:00 2001 From: Pengyu Chen Date: Tue, 24 Mar 2026 03:44:26 +0000 Subject: [PATCH 050/112] lint --- benchmark/kernels/fused_moe_triton/common_utils.py | 1 - python/sglang/srt/function_call/gemma4_detector.py | 4 ++-- python/sglang/srt/mem_cache/swa_memory_pool.py | 1 - python/sglang/srt/models/gemma4_causal.py | 6 ++---- test/registered/function_call/test_function_call_parser.py | 2 +- 5 files changed, 5 insertions(+), 9 deletions(-) diff --git a/benchmark/kernels/fused_moe_triton/common_utils.py b/benchmark/kernels/fused_moe_triton/common_utils.py index c61d7bd54500..1af59f91945d 100644 --- a/benchmark/kernels/fused_moe_triton/common_utils.py +++ b/benchmark/kernels/fused_moe_triton/common_utils.py @@ -65,7 +65,6 @@ def get_model_config( block_shape = [0, group_size] assert len(block_shape) == 2 - hidden_size = config.hidden_size if architecture == "DbrxForCausalLM": E = config.ffn_config.moe_num_experts // ep_size diff --git a/python/sglang/srt/function_call/gemma4_detector.py b/python/sglang/srt/function_call/gemma4_detector.py index abf890e7987e..e4d3a1916dec 100644 --- a/python/sglang/srt/function_call/gemma4_detector.py +++ b/python/sglang/srt/function_call/gemma4_detector.py @@ -73,8 +73,8 @@ def _parse_gemma4_array(arr_str: str) -> list: while i < n and depth > 0: if arr_str[i:].startswith(STRING_DELIM): i += len(STRING_DELIM) - nd = arr_str.find(STRING_DELIM, i) - i = nd + len(STRING_DELIM) if nd != -1 else n + next_delim = arr_str.find(STRING_DELIM, i) + i = next_delim + len(STRING_DELIM) if next_delim != -1 else n continue if arr_str[i] == "{": depth += 1 diff --git a/python/sglang/srt/mem_cache/swa_memory_pool.py b/python/sglang/srt/mem_cache/swa_memory_pool.py index 615c48f31c38..9d47f12482d1 100644 --- a/python/sglang/srt/mem_cache/swa_memory_pool.py +++ b/python/sglang/srt/mem_cache/swa_memory_pool.py @@ -1,5 +1,4 @@ import logging -import weakref from typing import Dict, List, Optional, Tuple import torch diff --git a/python/sglang/srt/models/gemma4_causal.py b/python/sglang/srt/models/gemma4_causal.py index a0e998f244e7..658d4a05d65d 100644 --- a/python/sglang/srt/models/gemma4_causal.py +++ b/python/sglang/srt/models/gemma4_causal.py @@ -27,6 +27,7 @@ from sglang.srt.distributed import ( get_tensor_model_parallel_world_size, ) +from sglang.srt.layers.gemma4_fused_ops import gemma_rmsnorm_residual_scalar from sglang.srt.layers.layernorm import Gemma4RMSNorm, GemmaRMSNorm, RMSNorm from sglang.srt.layers.linear import ( QKVParallelLinear, @@ -49,7 +50,6 @@ default_weight_loader, maybe_remap_kv_scale_name, ) -from sglang.srt.layers.gemma4_fused_ops import gemma_rmsnorm_residual_scalar from sglang.srt.models.gemma3_causal import Gemma3MLP, Gemma3TextScaledWordEmbedding from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import add_prefix, make_layers @@ -662,9 +662,7 @@ def forward( gate, _ = self.per_layer_input_gate(hidden_states) gate = torch.nn.functional.gelu(gate, approximate="tanh") gated_per_layer = gate * per_layer_input - per_layer_contribution, _ = self.per_layer_projection( - gated_per_layer - ) + per_layer_contribution, _ = self.per_layer_projection(gated_per_layer) per_layer_contribution = self.post_per_layer_input_norm( per_layer_contribution ) diff --git a/test/registered/function_call/test_function_call_parser.py b/test/registered/function_call/test_function_call_parser.py index 2c2be7e1ca46..0c57f07e42a2 100644 --- a/test/registered/function_call/test_function_call_parser.py +++ b/test/registered/function_call/test_function_call_parser.py @@ -3765,7 +3765,7 @@ def test_parse_streaming_increment(self): "Some text ", "before <|tool", "_call>call:get_we", - "ather{location:<|", + "ather{location:<|", # codespell:ignore '"|>Tokyo<|"|>} after", ] From e5ee7aaad32654683bf832d972e2b8c1314b23f9 Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Tue, 24 Mar 2026 17:17:03 +0000 Subject: [PATCH 051/112] cleanup: deduplicate Gemma4 hybrid layer config and tidy comments Merge identical elif branches for Gemma4ForCausalLM and Gemma4ForConditionalGeneration in get_hybrid_layer_ids. Clean up redundant comment on dummy key in attention. Co-Authored-By: Claude Opus 4.6 (1M context) --- python/sglang/srt/configs/model_config.py | 13 ++++--------- python/sglang/srt/models/gemma4_causal.py | 6 ++---- 2 files changed, 6 insertions(+), 13 deletions(-) diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index f520cb3d44bd..f17a2fe404d4 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -1487,15 +1487,10 @@ def get_hybrid_layer_ids( elif "Step3p5MTP" in model_architectures: swa_attention_layer_ids = [0] full_attention_layer_ids = [] - elif "Gemma4ForCausalLM" in model_architectures: - layer_types = getattr(hf_text_config, "layer_types", None) - swa_attention_layer_ids = [ - i for i, x in enumerate(layer_types) if x == "sliding_attention" - ] - full_attention_layer_ids = [ - i for i, x in enumerate(layer_types) if x == "full_attention" - ] - elif "Gemma4ForConditionalGeneration" in model_architectures: + elif ( + "Gemma4ForCausalLM" in model_architectures + or "Gemma4ForConditionalGeneration" in model_architectures + ): layer_types = getattr(hf_text_config, "layer_types", None) swa_attention_layer_ids = [ i for i, x in enumerate(layer_types) if x == "sliding_attention" diff --git a/python/sglang/srt/models/gemma4_causal.py b/python/sglang/srt/models/gemma4_causal.py index 658d4a05d65d..f668c37aed61 100644 --- a/python/sglang/srt/models/gemma4_causal.py +++ b/python/sglang/srt/models/gemma4_causal.py @@ -437,10 +437,8 @@ def forward( q, k = self.rotary_emb(positions, q, k) k = k.unflatten(-1, (self.num_kv_heads, self.head_dim)) else: - # For shared KV layers, create a dummy key for rotary embedding and discard it - dummy_k = torch.zeros_like( - q[:, : self.kv_size] - ) # Create dummy key with same shape as needed + # Rotary embedding requires a key input; use zeros since KV is shared from another layer + dummy_k = torch.zeros_like(q[:, : self.kv_size]) q, _ = self.rotary_emb(positions, q, dummy_k) q = q.unflatten(-1, (self.num_heads, self.head_dim)) From b5048e7e43ed4f98496599019ba5ee5e2cef0d90 Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Tue, 24 Mar 2026 18:57:57 +0000 Subject: [PATCH 052/112] remove comments --- python/sglang/srt/utils/hf_transformers_utils.py | 6 +----- python/sglang/test/runners.py | 5 ----- test/manual/test_vlm_accuracy.py | 2 -- 3 files changed, 1 insertion(+), 12 deletions(-) diff --git a/python/sglang/srt/utils/hf_transformers_utils.py b/python/sglang/srt/utils/hf_transformers_utils.py index 511913ecd8e6..1d3e96e9540c 100644 --- a/python/sglang/srt/utils/hf_transformers_utils.py +++ b/python/sglang/srt/utils/hf_transformers_utils.py @@ -1182,11 +1182,7 @@ def get_processor( kwargs["size"] = {"shortest_edge": 3136, "longest_edge": 1003520} if config.model_type not in {"llava", "clip"}: - if config.model_type == "gemma4": - # TODO(kpham-sgl): revert this once we have a fast tokenizer for gemma4 - kwargs["use_fast"] = False - else: - kwargs["use_fast"] = use_fast + kwargs["use_fast"] = use_fast try: if "InternVL3_5" in tokenizer_name: processor = AutoTokenizer.from_pretrained( diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index 490a12e15407..a9481ef7f54f 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -30,11 +30,6 @@ GenerationConfig, ) -try: - # TODO(kpham-sgl): For whatever reason the provided transformers package does not have this module - from transformers import AutoModelForVision2Seq -except ImportError: - AutoModelForVision2Seq = None from sglang.srt.entrypoints.engine import Engine from sglang.srt.model_loader.ci_weight_validation import ci_validate_and_clean_hf_cache diff --git a/test/manual/test_vlm_accuracy.py b/test/manual/test_vlm_accuracy.py index fde2579a1493..64bfe0f261ef 100644 --- a/test/manual/test_vlm_accuracy.py +++ b/test/manual/test_vlm_accuracy.py @@ -1,7 +1,5 @@ """Multimodal encoder accuracy tests: compare HF vs SGLang encoder outputs. -# TODO(kpham-sgl): Rename this file to test_mm_accuracy.py — it now covers both -# vision and audio encoder comparisons, not just VLM embeddings. """ import os From 18686395e399b70e212590d0333da572e5fb8e91 Mon Sep 17 00:00:00 2001 From: Pengyu Chen Date: Tue, 24 Mar 2026 18:56:48 +0000 Subject: [PATCH 053/112] Tuning fused moe for b200 TP2 --- .../E=128,N=352,device_name=NVIDIA_B200.json | 146 ++++++++++++++++++ 1 file changed, 146 insertions(+) create mode 100644 python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=128,N=352,device_name=NVIDIA_B200.json diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=128,N=352,device_name=NVIDIA_B200.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=128,N=352,device_name=NVIDIA_B200.json new file mode 100644 index 000000000000..f0eb57ab8dc0 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=128,N=352,device_name=NVIDIA_B200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + } +} From a5e08a3f534522eb56d7349ffed1b31ff4c2e9e0 Mon Sep 17 00:00:00 2001 From: Pengyu Chen Date: Tue, 24 Mar 2026 19:34:38 +0000 Subject: [PATCH 054/112] Tuning moe on h100. --- ...352,device_name=NVIDIA_H100_80GB_HBM3.json | 146 ++++++++++++++++++ ...704,device_name=NVIDIA_H100_80GB_HBM3.json | 146 ++++++++++++++++++ 2 files changed, 292 insertions(+) create mode 100644 python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=128,N=352,device_name=NVIDIA_H100_80GB_HBM3.json create mode 100644 python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=128,N=704,device_name=NVIDIA_H100_80GB_HBM3.json diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=128,N=352,device_name=NVIDIA_H100_80GB_HBM3.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=128,N=352,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 000000000000..60adcf03cea9 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=128,N=352,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=128,N=704,device_name=NVIDIA_H100_80GB_HBM3.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=128,N=704,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 000000000000..48b07c17d5b7 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=128,N=704,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + } +} From a43ef6af45ec34247ae4d1e56af75567a4b7f0f2 Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Thu, 26 Mar 2026 19:10:05 +0000 Subject: [PATCH 055/112] misc image processor change --- python/sglang/srt/models/gemma4_audio.py | 2 +- python/sglang/srt/models/gemma4_mm.py | 57 +++++--- python/sglang/srt/models/gemma4_vision.py | 129 ++++++------------ .../srt/multimodal/processors/gemma4.py | 6 + 4 files changed, 83 insertions(+), 111 deletions(-) diff --git a/python/sglang/srt/models/gemma4_audio.py b/python/sglang/srt/models/gemma4_audio.py index 3926a99f2cef..59fd08c9ab9e 100644 --- a/python/sglang/srt/models/gemma4_audio.py +++ b/python/sglang/srt/models/gemma4_audio.py @@ -439,7 +439,7 @@ def __init__( self.norm = nn.LayerNorm( [out_channels], - eps=config.sscp_conv_group_norm_eps, + eps=getattr(config, "sscp_conv_eps", getattr(config, "sscp_conv_group_norm_eps", 1e-3)), elementwise_affine=True, bias=False, ) diff --git a/python/sglang/srt/models/gemma4_mm.py b/python/sglang/srt/models/gemma4_mm.py index e0ce7e4d1629..9e26acce3b99 100644 --- a/python/sglang/srt/models/gemma4_mm.py +++ b/python/sglang/srt/models/gemma4_mm.py @@ -335,34 +335,49 @@ def prepare_attn_masks( ) def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: - all_pixel_values = flatten_nested_list([item.feature for item in items]) vt = self.vision_tower all_embeds = [] - for pv in all_pixel_values: - if ( - pv.dim() in (2, 3) - and pv.shape[-1] == self.config.text_config.hidden_size - ): - all_embeds.append(pv.to(self.language_model.device)) - continue + for item in items: + all_pixel_values = flatten_nested_list([item.feature]) + all_position_ids = flatten_nested_list( + [getattr(item, "pixel_position_ids", None)] + ) + vol = getattr(item, "vision_output_length", None) + if isinstance(vol, torch.Tensor): + vol = vol.item() + + for pv_idx, pv in enumerate(all_pixel_values): + if ( + pv.dim() in (2, 3) + and pv.shape[-1] == self.config.text_config.hidden_size + ): + all_embeds.append(pv.to(self.language_model.device)) + continue - if pv.dim() == 5: - pv = pv.squeeze(0) - if pv.dim() == 3: - pv = pv.unsqueeze(0) - elif pv.dim() != 4: - raise ValueError(f"Unexpected pixel_values shape: {pv.shape}") + pp = ( + all_position_ids[pv_idx] + if pv_idx < len(all_position_ids) and all_position_ids[pv_idx] is not None + else None + ) - pv = pv.to(device=vt.device, dtype=self.language_model.dtype()) + # Pre-patchified pixel_values: (num_images, num_patches, patch_pixels) + if pv.dim() == 2: + pv = pv.unsqueeze(0) + if pp is not None and pp.dim() == 2: + pp = pp.unsqueeze(0) - pooled, pooler_mask = vt(pv) + pv = pv.to(device=vt.device, dtype=self.language_model.dtype()) + if pp is not None: + pp = pp.to(device=vt.device) - for hs, mask in zip(pooled, pooler_mask): - real_tokens = hs[mask] - all_embeds.append( - self.embed_vision(inputs_embeds=real_tokens.unsqueeze(0)).squeeze(0) - ) + pooled, pooler_mask = vt(pv, pp, output_length=vol) + + for hs, mask in zip(pooled, pooler_mask): + real_tokens = hs[mask] + all_embeds.append( + self.embed_vision(inputs_embeds=real_tokens.unsqueeze(0)).squeeze(0) + ) if all_embeds: return torch.cat(all_embeds, dim=0) diff --git a/python/sglang/srt/models/gemma4_vision.py b/python/sglang/srt/models/gemma4_vision.py index 24988ca16dc3..4e6246db8526 100644 --- a/python/sglang/srt/models/gemma4_vision.py +++ b/python/sglang/srt/models/gemma4_vision.py @@ -382,7 +382,7 @@ def forward( Returns: last_hidden_state: [batch, seq, hidden_size] """ - cos, sin = self.rotary_emb(inputs_embeds, patch_positions) + cos, sin = self.rotary_emb(inputs_embeds, patch_positions.clamp(min=0)) hidden_states = inputs_embeds for layer in self.layers: hidden_states = layer(hidden_states, cos, sin, attention_mask) @@ -420,40 +420,34 @@ def _position_embeddings( ) return position_embeddings - def _patchify(self, pixel_values: torch.Tensor) -> torch.Tensor: - batch_size, num_channels, height, width = pixel_values.shape - patch_height = height // self.patch_size - patch_width = width // self.patch_size - patchified_shape = ( - batch_size, - num_channels, - patch_height, - self.patch_size, - patch_width, - self.patch_size, - ) - consolidated_shape = ( - batch_size, - patch_height * patch_width, - num_channels * self.patch_size**2, - ) - patches = ( - pixel_values.reshape(patchified_shape) - .permute(0, 2, 4, 3, 5, 1) - .reshape(consolidated_shape) - ) - patches = 2 * (patches - 0.5) + def _patch_projection(self, pixel_values: torch.Tensor) -> torch.Tensor: + """Project pre-patchified pixels into model space. + + Args: + pixel_values: [batch, num_patches, patch_pixels] — already patchified + by the image processor, values in [0, 1]. + """ + patches = 2 * (pixel_values - 0.5) return self.input_proj(patches.to(self.input_proj.weight.dtype)) def forward( self, pixel_values: torch.Tensor, - patch_positions: torch.Tensor, + pixel_position_ids: torch.Tensor, padding_positions: torch.Tensor, ) -> torch.Tensor: - hidden_states = self._patchify(pixel_values) + """Compute patch embeddings with positional information. + + Args: + pixel_values: [batch, num_patches, patch_pixels] — pre-patchified. + pixel_position_ids: [batch, num_patches, 2] — (x, y) positions, + -1 for padding patches. + padding_positions: [batch, num_patches] — True for padding patches. + """ + hidden_states = self._patch_projection(pixel_values) + clamped_positions = pixel_position_ids.clamp(min=0) position_embeddings = self._position_embeddings( - patch_positions, padding_positions + clamped_positions, padding_positions ) return hidden_states + position_embeddings @@ -529,9 +523,7 @@ def __init__( super().__init__() self.config = config self.patch_size = config.patch_size - self.pooling_kernel_size = config.pooling_kernel_size self.default_output_length = config.default_output_length - self.max_patches = self.default_output_length * self.pooling_kernel_size**2 self.patch_embedder = Gemma4VisionPatchEmbedder(config) self.encoder = Gemma4VisionTransformer( @@ -545,83 +537,42 @@ def __init__( def device(self) -> torch.device: return self.patch_embedder.input_proj.weight.device - def _num_real_patches(self, pixel_values: torch.Tensor) -> int: - _, _, height, width = pixel_values.shape - return (height // self.patch_size) * (width // self.patch_size) - - def _patch_positions( - self, pixel_values: torch.Tensor + def forward( + self, + pixel_values: torch.Tensor, + pixel_position_ids: torch.Tensor, + output_length: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: - batch_size, _, height, width = pixel_values.shape - device = pixel_values.device - patch_height = height // self.patch_size - patch_width = width // self.patch_size - num_patches = patch_height * patch_width - num_padding = self.max_patches - num_patches - - patch_grid = torch.meshgrid( - torch.arange(patch_width, device=device), - torch.arange(patch_height, device=device), - indexing="xy", - ) - stacked_grid = torch.stack(patch_grid, dim=-1) - real_positions = ( - stacked_grid.reshape(num_patches, 2).unsqueeze(0).repeat(batch_size, 1, 1) - ) - - if num_padding > 0: - pad_positions = torch.full( - (batch_size, num_padding, 2), -1, device=device, dtype=torch.long - ) - patch_positions = torch.cat([real_positions, pad_positions], dim=1) - else: - patch_positions = real_positions - - padding_positions = torch.zeros( - batch_size, self.max_patches, device=device, dtype=torch.bool - ) - if num_padding > 0: - padding_positions[:, num_patches:] = True - - return patch_positions.long(), padding_positions - - def forward(self, pixel_values: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - """Encode pixel_values into soft tokens. + """Encode pre-patchified pixel_values into soft tokens. Args: - pixel_values: [batch, channels, height, width] + pixel_values: [batch, num_patches, patch_pixels] — pre-patchified + by the image processor. + pixel_position_ids: [batch, num_patches, 2] — (x, y) positions, + -1 for padding patches. + output_length: target number of output soft tokens (optional, + defaults to config.default_output_length). Returns: (hidden_states, pooler_mask) — hidden_states [batch, output_len, hidden], pooler_mask [batch, output_len] True = valid. """ - patch_positions, padding_positions = self._patch_positions(pixel_values) + padding_positions = (pixel_position_ids == -1).all(dim=-1) inputs_embeds = self.patch_embedder( - pixel_values, - patch_positions[:, : self._num_real_patches(pixel_values)], - padding_positions[:, : self._num_real_patches(pixel_values)], + pixel_values, pixel_position_ids, padding_positions ) - num_real = inputs_embeds.shape[1] - num_padding = self.max_patches - num_real - if num_padding > 0: - pad_embeds = torch.zeros( - inputs_embeds.shape[0], - num_padding, - inputs_embeds.shape[2], - device=inputs_embeds.device, - dtype=inputs_embeds.dtype, - ) - inputs_embeds = torch.cat([inputs_embeds, pad_embeds], dim=1) - last_hidden = self.encoder( inputs_embeds=inputs_embeds, attention_mask=~padding_positions, - patch_positions=patch_positions, + patch_positions=pixel_position_ids, ) + if output_length is None: + output_length = self.default_output_length + pooled, pooler_mask = self.pooler( - last_hidden, patch_positions, padding_positions + last_hidden, pixel_position_ids, padding_positions ) return pooled, pooler_mask diff --git a/python/sglang/srt/multimodal/processors/gemma4.py b/python/sglang/srt/multimodal/processors/gemma4.py index efea548680ba..d510611b4bee 100644 --- a/python/sglang/srt/multimodal/processors/gemma4.py +++ b/python/sglang/srt/multimodal/processors/gemma4.py @@ -19,6 +19,7 @@ from sglang.srt.managers.multimodal_processor import ( BaseMultimodalProcessor as SGLangBaseProcessor, ) +from sglang.srt.managers.schedule_batch import Modality from sglang.srt.models.gemma4_mm import Gemma4ForConditionalGeneration from sglang.srt.multimodal.processors.base_processor import MultimodalSpecialTokens @@ -41,6 +42,11 @@ def __init__(self, hf_config, server_args, _processor, *args, **kwargs): audio_token_id=hf_config.audio_token_id, ).build(_processor) + # Register new image-processor outputs so they are stored on + # MultimodalDataItem via collect_mm_items_from_processor_output. + self.ATTR_NAME_TO_MODALITY["pixel_position_ids"] = Modality.IMAGE + self.ATTR_NAME_TO_MODALITY["vision_output_length"] = Modality.IMAGE + def _get_audio_pad_multiple(self) -> int: """Derive the waveform padding alignment from processor config. From bbac4bfd3f7289aa69534d9d1ef101593e8d768c Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Thu, 26 Mar 2026 19:49:37 +0000 Subject: [PATCH 056/112] minor fix --- python/sglang/srt/models/gemma4_audio.py | 2 +- python/sglang/srt/models/gemma4_vision.py | 11 ++++++----- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/models/gemma4_audio.py b/python/sglang/srt/models/gemma4_audio.py index 59fd08c9ab9e..6a0ae51cbc2e 100644 --- a/python/sglang/srt/models/gemma4_audio.py +++ b/python/sglang/srt/models/gemma4_audio.py @@ -439,7 +439,7 @@ def __init__( self.norm = nn.LayerNorm( [out_channels], - eps=getattr(config, "sscp_conv_eps", getattr(config, "sscp_conv_group_norm_eps", 1e-3)), + eps=config.sscp_conv_eps, elementwise_affine=True, bias=False, ) diff --git a/python/sglang/srt/models/gemma4_vision.py b/python/sglang/srt/models/gemma4_vision.py index 4e6246db8526..a8ebaf664f95 100644 --- a/python/sglang/srt/models/gemma4_vision.py +++ b/python/sglang/srt/models/gemma4_vision.py @@ -488,12 +488,13 @@ def forward( hidden_states: torch.Tensor, patch_positions: torch.Tensor, padding_positions: torch.Tensor, + output_length: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Returns: (pooled_hidden_states, mask) where mask is True for valid tokens. """ - length = self.default_output_length + length = self.default_output_length if output_length is None else output_length if isinstance(length, (list, tuple)): length = length[0] if hidden_states.shape[1] == length: @@ -569,10 +570,10 @@ def forward( patch_positions=pixel_position_ids, ) - if output_length is None: - output_length = self.default_output_length - pooled, pooler_mask = self.pooler( - last_hidden, pixel_position_ids, padding_positions + last_hidden, + pixel_position_ids, + padding_positions, + output_length=output_length, ) return pooled, pooler_mask From 24328259459f3b9eb9d9d4b23e6982dbc58163d8 Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Thu, 26 Mar 2026 20:27:06 +0000 Subject: [PATCH 057/112] qkv weight name change for audio/vision tower --- python/sglang/srt/models/gemma4_mm.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/models/gemma4_mm.py b/python/sglang/srt/models/gemma4_mm.py index 9e26acce3b99..99fc08ec3805 100644 --- a/python/sglang/srt/models/gemma4_mm.py +++ b/python/sglang/srt/models/gemma4_mm.py @@ -552,8 +552,9 @@ def _remap_tower_name(name: str, params_dict: dict) -> str: m = Gemma4ForConditionalGeneration._RE_TOWER_QKV.match(name) if m: pfx, proj, attr = m.groups() - if attr in ("weight", "bias"): - return f"{pfx}.qkv.{proj}.{attr}" + if attr in ("weight", "bias", "linear.weight", "linear.bias"): + bare_attr = attr.rsplit(".", 1)[-1] + return f"{pfx}.qkv.{proj}.{bare_attr}" if attr.startswith("output_"): return f"{pfx}.qkv.{proj[0]}_{attr}" if attr.startswith("input_"): @@ -564,8 +565,9 @@ def _remap_tower_name(name: str, params_dict: dict) -> str: if m: pfx, proj, attr = m.groups() short = proj.split("_")[0] # "gate" or "up" - if attr in ("weight", "bias"): - return f"{pfx}.gate_up.{proj}.{attr}" + if attr in ("weight", "bias", "linear.weight", "linear.bias"): + bare_attr = attr.rsplit(".", 1)[-1] + return f"{pfx}.gate_up.{proj}.{bare_attr}" if attr.startswith("output_"): return f"{pfx}.gate_up.{short}_{attr}" if attr.startswith("input_"): From 456e98a2be36b08eb914b11e9d0b8d266fad6070 Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Thu, 26 Mar 2026 20:28:34 +0000 Subject: [PATCH 058/112] lint --- python/sglang/srt/models/gemma4_mm.py | 7 +++++-- python/sglang/test/runners.py | 1 - test/manual/test_vlm_accuracy.py | 4 +--- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/models/gemma4_mm.py b/python/sglang/srt/models/gemma4_mm.py index 99fc08ec3805..481ea9ff4d20 100644 --- a/python/sglang/srt/models/gemma4_mm.py +++ b/python/sglang/srt/models/gemma4_mm.py @@ -357,7 +357,8 @@ def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: pp = ( all_position_ids[pv_idx] - if pv_idx < len(all_position_ids) and all_position_ids[pv_idx] is not None + if pv_idx < len(all_position_ids) + and all_position_ids[pv_idx] is not None else None ) @@ -376,7 +377,9 @@ def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: for hs, mask in zip(pooled, pooler_mask): real_tokens = hs[mask] all_embeds.append( - self.embed_vision(inputs_embeds=real_tokens.unsqueeze(0)).squeeze(0) + self.embed_vision( + inputs_embeds=real_tokens.unsqueeze(0) + ).squeeze(0) ) if all_embeds: diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index a9481ef7f54f..61781fea21de 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -30,7 +30,6 @@ GenerationConfig, ) - from sglang.srt.entrypoints.engine import Engine from sglang.srt.model_loader.ci_weight_validation import ci_validate_and_clean_hf_cache from sglang.srt.utils import get_device, is_npu, load_image diff --git a/test/manual/test_vlm_accuracy.py b/test/manual/test_vlm_accuracy.py index 64bfe0f261ef..010703da7f7c 100644 --- a/test/manual/test_vlm_accuracy.py +++ b/test/manual/test_vlm_accuracy.py @@ -1,6 +1,4 @@ -"""Multimodal encoder accuracy tests: compare HF vs SGLang encoder outputs. - -""" +1"""Multimodal encoder accuracy tests: compare HF vs SGLang encoder outputs.""" import os import socket From 7ed4e44ad92ce77b465aaa9b9eb33322be287f44 Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Thu, 26 Mar 2026 20:35:22 +0000 Subject: [PATCH 059/112] nit --- python/sglang/srt/models/gemma4_vision.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/models/gemma4_vision.py b/python/sglang/srt/models/gemma4_vision.py index a8ebaf664f95..26eeadcdeb8b 100644 --- a/python/sglang/srt/models/gemma4_vision.py +++ b/python/sglang/srt/models/gemma4_vision.py @@ -382,7 +382,7 @@ def forward( Returns: last_hidden_state: [batch, seq, hidden_size] """ - cos, sin = self.rotary_emb(inputs_embeds, patch_positions.clamp(min=0)) + cos, sin = self.rotary_emb(inputs_embeds, patch_positions) hidden_states = inputs_embeds for layer in self.layers: hidden_states = layer(hidden_states, cos, sin, attention_mask) @@ -411,7 +411,8 @@ def __init__(self, config: Gemma4VisionConfig): def _position_embeddings( self, patch_positions: torch.Tensor, padding_positions: torch.Tensor ) -> torch.Tensor: - one_hot = F.one_hot(patch_positions, num_classes=self.position_embedding_size) + clamped_positions = patch_positions.clamp(min=0) + one_hot = F.one_hot(clamped_positions, num_classes=self.position_embedding_size) one_hot = one_hot.permute(0, 2, 1, 3).to(self.position_embedding_table) position_embeddings = one_hot @ self.position_embedding_table position_embeddings = position_embeddings.sum(dim=1) @@ -445,9 +446,8 @@ def forward( padding_positions: [batch, num_patches] — True for padding patches. """ hidden_states = self._patch_projection(pixel_values) - clamped_positions = pixel_position_ids.clamp(min=0) position_embeddings = self._position_embeddings( - clamped_positions, padding_positions + pixel_position_ids, padding_positions ) return hidden_states + position_embeddings From c646d087e422c0cb82c0a51c586e492acb2496d9 Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Thu, 26 Mar 2026 21:58:19 +0000 Subject: [PATCH 060/112] new moe weight remapping --- .../srt/layers/moe/fused_moe_triton/layer.py | 25 --------------- python/sglang/srt/models/gemma4_causal.py | 31 ++++++++++-------- python/sglang/srt/models/gemma4_mm.py | 32 ++++++++++--------- 3 files changed, 34 insertions(+), 54 deletions(-) diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index ef275e2a337c..80735f3da2c7 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -1093,31 +1093,6 @@ def make_expert_params_mapping_fused( ("experts.w2_weight_bias", f"experts.{ckpt_down_proj_bias_name}", "w2"), ] - @classmethod - def make_expert_params_mapping_gemma4( - cls, - ckpt_gate_proj_name: str, - ckpt_down_proj_name: str, - ckpt_up_proj_name: str, - ): - return [ - # (param_name, weight_name, shard_id) - ( - ( - "experts.w13_weight" - if weight_name in [ckpt_gate_proj_name, ckpt_up_proj_name] - else "experts.w2_weight" - ), - f"experts.{weight_name}", - shard_id, - ) - for shard_id, weight_name in [ - ("w1", ckpt_gate_proj_name), - ("w2", ckpt_down_proj_name), - ("w3", ckpt_up_proj_name), - ] - ] - @classmethod def make_expert_params_mapping_fused_mxfp4( cls, diff --git a/python/sglang/srt/models/gemma4_causal.py b/python/sglang/srt/models/gemma4_causal.py index f668c37aed61..591740ae6876 100644 --- a/python/sglang/srt/models/gemma4_causal.py +++ b/python/sglang/srt/models/gemma4_causal.py @@ -36,7 +36,6 @@ ) from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class -from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention @@ -977,11 +976,12 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("gate_up_proj", "up_proj", 1), ] - expert_params_mapping = FusedMoE.make_expert_params_mapping_gemma4( - ckpt_gate_proj_name="gate_proj", - ckpt_down_proj_name="down_proj", - ckpt_up_proj_name="up_proj", - ) + expert_params_mapping = [ + # (param_name, ckpt_weight_name, shard_ids) + # gate_up_proj is fused [E, 2*I, H] — chunk into w1 (gate) + w3 (up) + ("experts.w13_weight", "experts.gate_up_proj", ("w1", "w3")), + ("experts.w2_weight", "experts.down_proj", ("w2",)), + ] num_experts = self.config.num_experts k_eq_v_layers = self._get_k_eq_v_layers() @@ -1016,9 +1016,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): and int(m.group(1)) in k_eq_v_layers ) - # Try stacked (fused) params first + # MoE expert weights checked first (gate_up_proj contains "up_proj" + # which would false-match the stacked dense MLP mapping). orig_name = name - for param_name, weight_name, shard_id in stacked_params_mapping: + for param_name, weight_name, shard_ids in expert_params_mapping: name = orig_name if weight_name not in name: continue @@ -1027,12 +1028,13 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): continue param = params_dict[name] weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - if should_dup_k_to_v: - weight_loader(param, loaded_weight, "v") + for i in range(num_experts): + chunks = loaded_weight[i].chunk(len(shard_ids), dim=0) + for chunk, sid in zip(chunks, shard_ids): + weight_loader(param, chunk, name, sid, i) break else: - for param_name, weight_name, shard_id in expert_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: name = orig_name if weight_name not in name: continue @@ -1041,8 +1043,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): continue param = params_dict[name] weight_loader = param.weight_loader - for i in range(num_experts): - weight_loader(param, loaded_weight[i].T, name, shard_id, i) + weight_loader(param, loaded_weight, shard_id) + if should_dup_k_to_v: + weight_loader(param, loaded_weight, "v") break else: name = orig_name diff --git a/python/sglang/srt/models/gemma4_mm.py b/python/sglang/srt/models/gemma4_mm.py index 481ea9ff4d20..0ff546b3c1dd 100644 --- a/python/sglang/srt/models/gemma4_mm.py +++ b/python/sglang/srt/models/gemma4_mm.py @@ -32,7 +32,6 @@ from sglang.srt.layers.layernorm import Gemma4RMSNorm from sglang.srt.layers.linear import ReplicatedLinear from sglang.srt.layers.logits_processor import LogitsProcessor -from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.managers.mm_utils import ( MultiModalityDataPaddingPatternMultimodalTokens, @@ -597,14 +596,14 @@ def _get_k_eq_v_layers(self) -> set: def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): k_eq_v_layers = self._get_k_eq_v_layers() - # TODO(pyc96): revisit and simplify. num_experts = getattr(self.config.text_config, "num_experts", 0) or 0 if num_experts > 0: - expert_params_mapping = FusedMoE.make_expert_params_mapping_gemma4( - ckpt_gate_proj_name="gate_proj", - ckpt_down_proj_name="down_proj", - ckpt_up_proj_name="up_proj", - ) + expert_params_mapping = [ + # (param_name, ckpt_weight_name, shard_ids) + # gate_up_proj is fused [E, 2*I, H] — chunk into w1 (gate) + w3 (up) + ("experts.w13_weight", "experts.gate_up_proj", ("w1", "w3")), + ("experts.w2_weight", "experts.down_proj", ("w2",)), + ] else: expert_params_mapping = [] @@ -652,9 +651,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): and int(m.group(1)) in k_eq_v_layers ) - # Try stacked (fused) params first + # MoE expert weights checked first (gate_up_proj contains "up_proj" + # which would false-match the stacked dense MLP mapping). orig_name = name - for param_name, weight_name, shard_id in self.stacked_params_mapping: + for param_name, weight_name, shard_ids in expert_params_mapping: name = orig_name if weight_name not in name: continue @@ -663,12 +663,13 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): continue param = params_dict[name] weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - if should_dup_k_to_v: - weight_loader(param, loaded_weight, "v") + for i in range(num_experts): + chunks = loaded_weight[i].chunk(len(shard_ids), dim=0) + for chunk, sid in zip(chunks, shard_ids): + weight_loader(param, chunk, name, sid, i) break else: - for param_name, weight_name, shard_id in expert_params_mapping: + for param_name, weight_name, shard_id in self.stacked_params_mapping: name = orig_name if weight_name not in name: continue @@ -677,8 +678,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): continue param = params_dict[name] weight_loader = param.weight_loader - for i in range(num_experts): - weight_loader(param, loaded_weight[i].T, name, shard_id, i) + weight_loader(param, loaded_weight, shard_id) + if should_dup_k_to_v: + weight_loader(param, loaded_weight, "v") break else: name = orig_name From db7c50de7f600b042e5c64083ca053284682a1b0 Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Thu, 26 Mar 2026 22:46:08 +0000 Subject: [PATCH 061/112] nit --- python/sglang/srt/models/gemma4_mm.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/models/gemma4_mm.py b/python/sglang/srt/models/gemma4_mm.py index 0ff546b3c1dd..456976e9ea2f 100644 --- a/python/sglang/srt/models/gemma4_mm.py +++ b/python/sglang/srt/models/gemma4_mm.py @@ -597,15 +597,12 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): k_eq_v_layers = self._get_k_eq_v_layers() num_experts = getattr(self.config.text_config, "num_experts", 0) or 0 - if num_experts > 0: - expert_params_mapping = [ - # (param_name, ckpt_weight_name, shard_ids) - # gate_up_proj is fused [E, 2*I, H] — chunk into w1 (gate) + w3 (up) - ("experts.w13_weight", "experts.gate_up_proj", ("w1", "w3")), - ("experts.w2_weight", "experts.down_proj", ("w2",)), - ] - else: - expert_params_mapping = [] + expert_params_mapping = [ + # (param_name, ckpt_weight_name, shard_ids) + # gate_up_proj is fused [E, 2*I, H] — chunk into w1 (gate) + w3 (up) + ("experts.w13_weight", "experts.gate_up_proj", ("w1", "w3")), + ("experts.w2_weight", "experts.down_proj", ("w2",)), + ] params_dict = dict(self.named_parameters()) params_dict.update(dict(self.named_buffers())) From 0fd8fb20c6cc0329adcb26841d925f6cc928925d Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Fri, 27 Mar 2026 21:41:40 +0000 Subject: [PATCH 062/112] fix --- python/sglang/srt/models/gemma4_mm.py | 2 +- python/sglang/srt/models/gemma4_vision.py | 4 ++-- python/sglang/srt/multimodal/processors/gemma4.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/models/gemma4_mm.py b/python/sglang/srt/models/gemma4_mm.py index 456976e9ea2f..52b38cc5885d 100644 --- a/python/sglang/srt/models/gemma4_mm.py +++ b/python/sglang/srt/models/gemma4_mm.py @@ -340,7 +340,7 @@ def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: for item in items: all_pixel_values = flatten_nested_list([item.feature]) all_position_ids = flatten_nested_list( - [getattr(item, "pixel_position_ids", None)] + [getattr(item, "image_position_ids", None)] ) vol = getattr(item, "vision_output_length", None) if isinstance(vol, torch.Tensor): diff --git a/python/sglang/srt/models/gemma4_vision.py b/python/sglang/srt/models/gemma4_vision.py index 26eeadcdeb8b..d83f9b2538a6 100644 --- a/python/sglang/srt/models/gemma4_vision.py +++ b/python/sglang/srt/models/gemma4_vision.py @@ -461,7 +461,7 @@ class Gemma4VisionPooler(nn.Module): def __init__(self, config: Gemma4VisionConfig): super().__init__() self.hidden_size = config.hidden_size - self.default_output_length = config.default_output_length + self.default_output_length = config.image_seq_length self.root_hidden_size = self.hidden_size**0.5 def _avg_pool_by_positions( @@ -524,7 +524,7 @@ def __init__( super().__init__() self.config = config self.patch_size = config.patch_size - self.default_output_length = config.default_output_length + self.default_output_length = config.image_seq_length self.patch_embedder = Gemma4VisionPatchEmbedder(config) self.encoder = Gemma4VisionTransformer( diff --git a/python/sglang/srt/multimodal/processors/gemma4.py b/python/sglang/srt/multimodal/processors/gemma4.py index d510611b4bee..31318a84b83e 100644 --- a/python/sglang/srt/multimodal/processors/gemma4.py +++ b/python/sglang/srt/multimodal/processors/gemma4.py @@ -44,7 +44,7 @@ def __init__(self, hf_config, server_args, _processor, *args, **kwargs): # Register new image-processor outputs so they are stored on # MultimodalDataItem via collect_mm_items_from_processor_output. - self.ATTR_NAME_TO_MODALITY["pixel_position_ids"] = Modality.IMAGE + self.ATTR_NAME_TO_MODALITY["image_position_ids"] = Modality.IMAGE self.ATTR_NAME_TO_MODALITY["vision_output_length"] = Modality.IMAGE def _get_audio_pad_multiple(self) -> int: From 9555512c6beed62fb76013b5bf91830a26519ef6 Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Fri, 27 Mar 2026 21:55:55 +0000 Subject: [PATCH 063/112] better error msg --- python/sglang/srt/models/gemma4_mm.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/models/gemma4_mm.py b/python/sglang/srt/models/gemma4_mm.py index 52b38cc5885d..10f91f3ce613 100644 --- a/python/sglang/srt/models/gemma4_mm.py +++ b/python/sglang/srt/models/gemma4_mm.py @@ -354,22 +354,25 @@ def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: all_embeds.append(pv.to(self.language_model.device)) continue - pp = ( - all_position_ids[pv_idx] - if pv_idx < len(all_position_ids) - and all_position_ids[pv_idx] is not None - else None - ) + if ( + pv_idx >= len(all_position_ids) + or all_position_ids[pv_idx] is None + ): + raise ValueError( + f"pixel_values[{pv_idx}] has no matching image_position_ids. " + "The HF image processor likely renamed this output — " + "update ATTR_NAME_TO_MODALITY in the Gemma4 processor." + ) + pp = all_position_ids[pv_idx] # Pre-patchified pixel_values: (num_images, num_patches, patch_pixels) if pv.dim() == 2: pv = pv.unsqueeze(0) - if pp is not None and pp.dim() == 2: + if pp.dim() == 2: pp = pp.unsqueeze(0) pv = pv.to(device=vt.device, dtype=self.language_model.dtype()) - if pp is not None: - pp = pp.to(device=vt.device) + pp = pp.to(device=vt.device) pooled, pooler_mask = vt(pv, pp, output_length=vol) From 80cffa1a226c80d97458c273a1b371355785c212 Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Fri, 27 Mar 2026 22:13:28 +0000 Subject: [PATCH 064/112] lint --- python/sglang/srt/models/gemma4_mm.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/python/sglang/srt/models/gemma4_mm.py b/python/sglang/srt/models/gemma4_mm.py index 10f91f3ce613..817b3086cfab 100644 --- a/python/sglang/srt/models/gemma4_mm.py +++ b/python/sglang/srt/models/gemma4_mm.py @@ -354,10 +354,7 @@ def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: all_embeds.append(pv.to(self.language_model.device)) continue - if ( - pv_idx >= len(all_position_ids) - or all_position_ids[pv_idx] is None - ): + if pv_idx >= len(all_position_ids) or all_position_ids[pv_idx] is None: raise ValueError( f"pixel_values[{pv_idx}] has no matching image_position_ids. " "The HF image processor likely renamed this output — " From 363e7e710e1d53d6464e4fe4f0132107fdfa76ec Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Sat, 28 Mar 2026 03:29:54 +0000 Subject: [PATCH 065/112] fix: add post-pooling standardization to Gemma4 vision encoder MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The Gemma4 vision encoder was missing the `standardize` step that HF applies after pooling: `(hidden_states - std_bias) * std_scale`. This normalizes vision token magnitudes (std ~3000 → ~2) before they are projected into the text embedding space. Models with `standardize=True` in their vision config (e.g. gemma-4-25b-a4b-it) produced wildly wrong image descriptions (e.g. identifying a cow as "Pig" or "Goat") because un-normalized embeddings dominated the text signal. Models with `standardize=False` (e.g. gemma-4-e2b-it) were unaffected. --- python/sglang/srt/models/gemma4_vision.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/python/sglang/srt/models/gemma4_vision.py b/python/sglang/srt/models/gemma4_vision.py index d83f9b2538a6..67c09b6f9ee2 100644 --- a/python/sglang/srt/models/gemma4_vision.py +++ b/python/sglang/srt/models/gemma4_vision.py @@ -478,6 +478,7 @@ def _avg_pool_by_positions( max_x = clamped_positions[..., 0].max(dim=-1, keepdim=True)[0] + 1 kernel_idxs = torch.div(clamped_positions, k, rounding_mode="floor") kernel_idxs = kernel_idxs[..., 0] + (max_x // k) * kernel_idxs[..., 1] + weights = F.one_hot(kernel_idxs.long(), length).float() / k_squared output = weights.transpose(1, 2).to(x.dtype) @ x mask = torch.logical_not((weights == 0).all(dim=1)) @@ -534,6 +535,16 @@ def __init__( ) self.pooler = Gemma4VisionPooler(config) + # Post-pooling standardization (normalizes vision tokens before projection) + self.standardize = getattr(config, "standardize", False) + if self.standardize: + self.register_buffer( + "std_bias", torch.zeros(config.hidden_size) + ) + self.register_buffer( + "std_scale", torch.ones(config.hidden_size) + ) + @property def device(self) -> torch.device: return self.patch_embedder.input_proj.weight.device @@ -576,4 +587,8 @@ def forward( padding_positions, output_length=output_length, ) + + if self.standardize: + pooled = (pooled - self.std_bias) * self.std_scale + return pooled, pooler_mask From d15c2fa5e00ce66a4690e9e02e7a072392b78159 Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Mon, 30 Mar 2026 01:40:41 +0000 Subject: [PATCH 066/112] gemma 4 norm remove +1 shift --- python/sglang/srt/layers/gemma4_fused_ops.py | 12 ++++++------ python/sglang/srt/layers/layernorm.py | 5 ++--- python/sglang/srt/models/gemma4_causal.py | 18 +++++++++--------- python/sglang/srt/models/gemma4_mm.py | 9 +++++---- 4 files changed, 22 insertions(+), 22 deletions(-) diff --git a/python/sglang/srt/layers/gemma4_fused_ops.py b/python/sglang/srt/layers/gemma4_fused_ops.py index ffb5ed48a1be..14092aa16f3d 100644 --- a/python/sglang/srt/layers/gemma4_fused_ops.py +++ b/python/sglang/srt/layers/gemma4_fused_ops.py @@ -1,7 +1,7 @@ """Fused triton kernels for Gemma4 decoder layer operations. -Fuses post-norm + residual-add (+ optional scalar multiply) into a single -kernel pass to reduce kernel launch overhead. +Fuses standard RMSNorm + residual-add (+ optional scalar multiply) into +a single kernel pass to reduce kernel launch overhead. """ import torch @@ -24,7 +24,7 @@ def _gemma_rmsnorm_residual_kernel( HAS_SCALAR: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): - """Fused kernel: out = gemma_rmsnorm(x, w) + residual [* scalar] + """Fused kernel: out = rmsnorm(x, w) + residual [* scalar] When HAS_SCALAR is True, also multiplies by a scalar loaded from Scalar_ptr. """ @@ -40,7 +40,7 @@ def _gemma_rmsnorm_residual_kernel( var = tl.sum(x * x, axis=0) / N rrms = tl.rsqrt(var + eps) - out = x * rrms * (1.0 + w) + r + out = x * rrms * w + r if HAS_SCALAR: scalar = tl.load(Scalar_ptr).to(tl.float32) @@ -55,7 +55,7 @@ def gemma_rmsnorm_residual( residual: torch.Tensor, eps: float = 1e-6, ) -> tuple[torch.Tensor, torch.Tensor]: - """Fused gemma_rmsnorm(x) + residual. + """Fused rmsnorm(x) + residual. Returns (output, new_residual) where new_residual = output. """ @@ -88,7 +88,7 @@ def gemma_rmsnorm_residual_scalar( scalar: torch.Tensor, eps: float = 1e-6, ) -> torch.Tensor: - """Fused (gemma_rmsnorm(x) + residual) * scalar.""" + """Fused (rmsnorm(x) + residual) * scalar.""" assert x.dim() == 2 and x.stride(-1) == 1, "Expected contiguous 2D input" M, N = x.shape BLOCK_SIZE = triton.next_power_of_2(N) diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index 94b7740fab81..214af4fd2dcb 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -651,16 +651,15 @@ def __init__( self, dim: int, eps: float = 1e-6, - scale_shift: float = 1.0, + scale_shift: float = 0.0, with_scale: bool = True, ): super().__init__() self.with_scale = with_scale if self.with_scale: - self.weight = nn.Parameter(torch.zeros(dim)) + self.weight = nn.Parameter(torch.ones(dim)) else: - # Ones buffer: rmsnorm(x, ones) = norm(x) * 1 = norm(x) self.register_buffer("weight", torch.ones(dim), persistent=False) self.eps = eps diff --git a/python/sglang/srt/models/gemma4_causal.py b/python/sglang/srt/models/gemma4_causal.py index 591740ae6876..5705b38f5b53 100644 --- a/python/sglang/srt/models/gemma4_causal.py +++ b/python/sglang/srt/models/gemma4_causal.py @@ -28,7 +28,7 @@ get_tensor_model_parallel_world_size, ) from sglang.srt.layers.gemma4_fused_ops import gemma_rmsnorm_residual_scalar -from sglang.srt.layers.layernorm import Gemma4RMSNorm, GemmaRMSNorm, RMSNorm +from sglang.srt.layers.layernorm import Gemma4RMSNorm, RMSNorm from sglang.srt.layers.linear import ( QKVParallelLinear, ReplicatedLinear, @@ -507,14 +507,14 @@ def __init__( prefix=add_prefix("mlp", prefix), ) - self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = GemmaRMSNorm( + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( config.hidden_size, eps=config.rms_norm_eps ) - self.pre_feedforward_layernorm = GemmaRMSNorm( + self.pre_feedforward_layernorm = RMSNorm( config.hidden_size, eps=config.rms_norm_eps ) - self.post_feedforward_layernorm = GemmaRMSNorm( + self.post_feedforward_layernorm = RMSNorm( config.hidden_size, eps=config.rms_norm_eps ) @@ -562,13 +562,13 @@ def __init__( prefix=add_prefix("moe", prefix), ) - self.post_feedforward_layernorm_1 = GemmaRMSNorm( + self.post_feedforward_layernorm_1 = RMSNorm( config.hidden_size, eps=config.rms_norm_eps ) - self.post_feedforward_layernorm_2 = GemmaRMSNorm( + self.post_feedforward_layernorm_2 = RMSNorm( config.hidden_size, eps=config.rms_norm_eps ) - self.pre_feedforward_layernorm_2 = GemmaRMSNorm( + self.pre_feedforward_layernorm_2 = RMSNorm( config.hidden_size, eps=config.rms_norm_eps ) else: @@ -740,7 +740,7 @@ def __init__( prefix=add_prefix("layers", prefix), ) - self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_init() def get_input_embeddings(self) -> nn.Embedding: diff --git a/python/sglang/srt/models/gemma4_mm.py b/python/sglang/srt/models/gemma4_mm.py index 817b3086cfab..a4e83eaf551f 100644 --- a/python/sglang/srt/models/gemma4_mm.py +++ b/python/sglang/srt/models/gemma4_mm.py @@ -101,8 +101,8 @@ def __init__( prefix=add_prefix("embedding_projection", prefix), ) - self.embedding_post_projection_norm = Gemma4RMSNorm( - self.text_hidden_size, + self.embedding_pre_projection_norm = Gemma4RMSNorm( + embedding_dim, eps=self.eps, with_scale=False, ) @@ -112,8 +112,9 @@ def forward( inputs_embeds: torch.Tensor, ) -> torch.Tensor: """Project soft tokens from a multimodal tower into LM space.""" - embs_proj, _ = self.embedding_projection(inputs_embeds) - return self.embedding_post_projection_norm(embs_proj) + embs_normed = self.embedding_pre_projection_norm(inputs_embeds) + embs_proj, _ = self.embedding_projection(embs_normed) + return embs_proj class Gemma4ForConditionalGeneration(PreTrainedModel): From a3d86be9b02c06103494884b451505896ef10cb4 Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Mon, 30 Mar 2026 02:07:47 +0000 Subject: [PATCH 067/112] lint --- python/sglang/srt/models/gemma4_vision.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/models/gemma4_vision.py b/python/sglang/srt/models/gemma4_vision.py index 67c09b6f9ee2..f9935f06e03d 100644 --- a/python/sglang/srt/models/gemma4_vision.py +++ b/python/sglang/srt/models/gemma4_vision.py @@ -538,12 +538,8 @@ def __init__( # Post-pooling standardization (normalizes vision tokens before projection) self.standardize = getattr(config, "standardize", False) if self.standardize: - self.register_buffer( - "std_bias", torch.zeros(config.hidden_size) - ) - self.register_buffer( - "std_scale", torch.ones(config.hidden_size) - ) + self.register_buffer("std_bias", torch.zeros(config.hidden_size)) + self.register_buffer("std_scale", torch.ones(config.hidden_size)) @property def device(self) -> torch.device: From aa6722dc9b17d6f6c5e9dd1f8904646e61e76935 Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Mon, 30 Mar 2026 18:48:44 +0000 Subject: [PATCH 068/112] init video processor --- python/sglang/srt/models/gemma4_mm.py | 85 ++++++++++++++++--- .../srt/multimodal/processors/gemma4.py | 9 +- 2 files changed, 82 insertions(+), 12 deletions(-) diff --git a/python/sglang/srt/models/gemma4_mm.py b/python/sglang/srt/models/gemma4_mm.py index a4e83eaf551f..25d0cd9f1395 100644 --- a/python/sglang/srt/models/gemma4_mm.py +++ b/python/sglang/srt/models/gemma4_mm.py @@ -245,11 +245,12 @@ def prepare_attn_masks( input_ids: torch.Tensor, mask_dtype: torch.dtype, ): - """Prepare bidirectional attention masks for image tokens. + """Prepare bidirectional attention masks for image/video tokens. - Gemma 4 uses bidirectional attention for image soft tokens during prefill. - Following the HF implementation, bidirectional attention is only enabled - within each individual image group (same-image tokens), not across images. + Gemma 4 uses bidirectional attention for image and video soft tokens + during prefill. Following the HF implementation, bidirectional attention + is only enabled within each individual image/video group (same-item + tokens), not across items. Currently only the TritonAttnBackend supports this. TODO(kpham-sgl): Guard appropriately for gemma3_mm.py:prepare_attn_masks() @@ -286,7 +287,7 @@ def prepare_attn_masks( mm_inputs = forward_batch.mm_inputs[i] if mm_inputs is not None: for mm_item in mm_inputs.mm_items: - if mm_item.is_image(): + if mm_item.is_image() or mm_item.is_video(): for im_begin, im_end in mm_item.offsets: # Note(kpham-sgl): We only apply bidirectional attention when the image token span # is fully contained in the extend window. Otherwise, we silently fall back to @@ -392,6 +393,68 @@ def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: dtype=self.language_model.dtype(), ) + def get_video_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: + """Encode video frames through the vision tower with video-specific pooling. + + Each video is (num_frames, num_patches, patch_pixels) with matching + position_ids (num_frames, num_patches, 2). Frames are flattened into + the batch dimension so each frame is encoded independently, then pooled + to video_seq_length tokens per frame (vs image_seq_length for images). + """ + vt = self.vision_tower + video_seq_length = self.config.vision_config.video_seq_length + + all_embeds = [] + for item in items: + all_pixel_values = flatten_nested_list([item.feature]) + all_position_ids = flatten_nested_list( + [getattr(item, "video_position_ids", None)] + ) + + for pv_idx, pv in enumerate(all_pixel_values): + if ( + pv.dim() in (2, 3) + and pv.shape[-1] == self.config.text_config.hidden_size + ): + all_embeds.append(pv.to(self.language_model.device)) + continue + + if pv_idx >= len(all_position_ids) or all_position_ids[pv_idx] is None: + raise ValueError( + f"pixel_values_videos[{pv_idx}] has no matching video_position_ids." + ) + pp = all_position_ids[pv_idx] + + # pv: (num_frames, num_patches, patch_pixels) + # pp: (num_frames, num_patches, 2) + if pv.dim() == 2: + pv = pv.unsqueeze(0) + if pp.dim() == 2: + pp = pp.unsqueeze(0) + + pv = pv.to(device=vt.device, dtype=self.language_model.dtype()) + pp = pp.to(device=vt.device) + + pooled, pooler_mask = vt(pv, pp, output_length=video_seq_length) + + for hs, mask in zip(pooled, pooler_mask): + real_tokens = hs[mask] + all_embeds.append( + self.embed_vision( + inputs_embeds=real_tokens.unsqueeze(0) + ).squeeze(0) + ) + + if all_embeds: + return torch.cat(all_embeds, dim=0) + else: + return torch.empty( + 0, + self.language_model.config.hidden_size, + device=next(self.parameters()).device, + dtype=self.language_model.dtype(), + ) + def get_audio_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: if self.audio_tower is None: raise ValueError( @@ -472,16 +535,17 @@ def forward( if input_ids is not None: ple_ids = input_ids.clone() ple_ids[input_ids == self.config.image_token_id] = 0 + ple_ids[input_ids == self.config.video_token_id] = 0 ple_ids[input_ids == self.config.audio_token_id] = 0 per_layer_inputs = self.get_per_layer_inputs(ple_ids) - # Prepare bidirectional attention masks for image tokens during prefill. - # Gemma 4 uses bidirectional attention for image soft tokens. + # Prepare bidirectional attention masks for image/video tokens during prefill. + # Gemma 4 uses bidirectional attention for image/video soft tokens. # Only TritonAttnBackend supports this; incompatible with CUDA Graph and # chunked prefill. - if ( - forward_batch.forward_mode == ForwardMode.EXTEND - and forward_batch.contains_image_inputs() + if forward_batch.forward_mode == ForwardMode.EXTEND and ( + forward_batch.contains_image_inputs() + or forward_batch.contains_video_inputs() ): self.prepare_attn_masks( forward_batch, @@ -496,6 +560,7 @@ def forward( language_model=self.language_model, data_embedding_funcs={ Modality.IMAGE: self.get_image_feature, + Modality.VIDEO: self.get_video_feature, Modality.AUDIO: self.get_audio_feature, }, positions=positions, diff --git a/python/sglang/srt/multimodal/processors/gemma4.py b/python/sglang/srt/multimodal/processors/gemma4.py index 31318a84b83e..ce5f387c0d0f 100644 --- a/python/sglang/srt/multimodal/processors/gemma4.py +++ b/python/sglang/srt/multimodal/processors/gemma4.py @@ -25,7 +25,7 @@ class Gemma4SGLangProcessor(SGLangBaseProcessor): - """Multimodal processor for Gemma4 supporting image and audio inputs.""" + """Multimodal processor for Gemma4 supporting image, video, and audio inputs.""" models = [Gemma4ForConditionalGeneration] @@ -39,14 +39,19 @@ def __init__(self, hf_config, server_args, _processor, *args, **kwargs): self.AUDIO_END_TOKEN_ID = hf_config.eoa_token_id self.mm_tokens = MultimodalSpecialTokens( image_token_id=hf_config.image_token_id, + video_token_id=hf_config.video_token_id, audio_token_id=hf_config.audio_token_id, ).build(_processor) - # Register new image-processor outputs so they are stored on + # Register image-processor outputs so they are stored on # MultimodalDataItem via collect_mm_items_from_processor_output. self.ATTR_NAME_TO_MODALITY["image_position_ids"] = Modality.IMAGE self.ATTR_NAME_TO_MODALITY["vision_output_length"] = Modality.IMAGE + # Register video-processor outputs so they are stored on + # MultimodalDataItem via collect_mm_items_from_processor_output. + self.ATTR_NAME_TO_MODALITY["video_position_ids"] = Modality.VIDEO + def _get_audio_pad_multiple(self) -> int: """Derive the waveform padding alignment from processor config. From 66c01ecdf222d965edc594a8940845f96b7d655d Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Mon, 30 Mar 2026 20:34:39 +0000 Subject: [PATCH 069/112] fix --- python/sglang/srt/multimodal/processors/gemma4.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/multimodal/processors/gemma4.py b/python/sglang/srt/multimodal/processors/gemma4.py index ce5f387c0d0f..5f39edefe6b4 100644 --- a/python/sglang/srt/multimodal/processors/gemma4.py +++ b/python/sglang/srt/multimodal/processors/gemma4.py @@ -93,10 +93,11 @@ async def process_mm_data_async( *args, **kwargs, ): - """Process multimodal data including images and audio.""" + """Process multimodal data including images, video, and audio.""" base_output = self.load_mm_data( prompt=input_text, image_data=image_data, + video_data=request_obj.video_data if request_obj else None, audio_data=audio_data, multimodal_tokens=self.mm_tokens, ) @@ -109,5 +110,6 @@ async def process_mm_data_async( "input_ids": input_ids.tolist(), "mm_items": mm_items, "im_token_id": self.mm_tokens.image_token_id, + "video_token_id": self.mm_tokens.video_token_id, "audio_token_id": self.mm_tokens.audio_token_id, } From f2e15b020ef8b95f6e516ca7b1bf73ea05933fb9 Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Mon, 30 Mar 2026 21:09:12 +0000 Subject: [PATCH 070/112] more vision changes --- python/sglang/srt/models/gemma4_mm.py | 5 ++--- python/sglang/srt/models/gemma4_vision.py | 15 ++++++++++----- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/python/sglang/srt/models/gemma4_mm.py b/python/sglang/srt/models/gemma4_mm.py index 25d0cd9f1395..16d561b04347 100644 --- a/python/sglang/srt/models/gemma4_mm.py +++ b/python/sglang/srt/models/gemma4_mm.py @@ -399,10 +399,9 @@ def get_video_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: Each video is (num_frames, num_patches, patch_pixels) with matching position_ids (num_frames, num_patches, 2). Frames are flattened into the batch dimension so each frame is encoded independently, then pooled - to video_seq_length tokens per frame (vs image_seq_length for images). + dynamically based on the input patch count and pooling_kernel_size. """ vt = self.vision_tower - video_seq_length = self.config.vision_config.video_seq_length all_embeds = [] for item in items: @@ -435,7 +434,7 @@ def get_video_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: pv = pv.to(device=vt.device, dtype=self.language_model.dtype()) pp = pp.to(device=vt.device) - pooled, pooler_mask = vt(pv, pp, output_length=video_seq_length) + pooled, pooler_mask = vt(pv, pp) for hs, mask in zip(pooled, pooler_mask): real_tokens = hs[mask] diff --git a/python/sglang/srt/models/gemma4_vision.py b/python/sglang/srt/models/gemma4_vision.py index f9935f06e03d..58c11d699b95 100644 --- a/python/sglang/srt/models/gemma4_vision.py +++ b/python/sglang/srt/models/gemma4_vision.py @@ -461,7 +461,6 @@ class Gemma4VisionPooler(nn.Module): def __init__(self, config: Gemma4VisionConfig): super().__init__() self.hidden_size = config.hidden_size - self.default_output_length = config.image_seq_length self.root_hidden_size = self.hidden_size**0.5 def _avg_pool_by_positions( @@ -495,7 +494,9 @@ def forward( Returns: (pooled_hidden_states, mask) where mask is True for valid tokens. """ - length = self.default_output_length if output_length is None else output_length + if output_length is None: + raise ValueError("output_length is required for Gemma4VisionPooler") + length = output_length if isinstance(length, (list, tuple)): length = length[0] if hidden_states.shape[1] == length: @@ -525,7 +526,7 @@ def __init__( super().__init__() self.config = config self.patch_size = config.patch_size - self.default_output_length = config.image_seq_length + self.pooling_kernel_size = config.pooling_kernel_size self.patch_embedder = Gemma4VisionPatchEmbedder(config) self.encoder = Gemma4VisionTransformer( @@ -558,13 +559,17 @@ def forward( by the image processor. pixel_position_ids: [batch, num_patches, 2] — (x, y) positions, -1 for padding patches. - output_length: target number of output soft tokens (optional, - defaults to config.default_output_length). + output_length: target number of output soft tokens. If None, + computed as num_patches // pooling_kernel_size^2. Returns: (hidden_states, pooler_mask) — hidden_states [batch, output_len, hidden], pooler_mask [batch, output_len] True = valid. """ + if output_length is None: + k2 = self.pooling_kernel_size * self.pooling_kernel_size + output_length = pixel_values.shape[-2] // k2 + padding_positions = (pixel_position_ids == -1).all(dim=-1) inputs_embeds = self.patch_embedder( From a7b20d714b034351c5f69d724ce48a5ce424b888 Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Mon, 30 Mar 2026 23:20:15 +0000 Subject: [PATCH 071/112] misc bug fix for video pipeline --- python/sglang/srt/models/gemma4_mm.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/models/gemma4_mm.py b/python/sglang/srt/models/gemma4_mm.py index 16d561b04347..3f8a753be466 100644 --- a/python/sglang/srt/models/gemma4_mm.py +++ b/python/sglang/srt/models/gemma4_mm.py @@ -364,7 +364,8 @@ def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: ) pp = all_position_ids[pv_idx] - # Pre-patchified pixel_values: (num_images, num_patches, patch_pixels) + # Vision tower expects 3-D (batch, num_patches, ...). + # A single image may arrive as 2-D; add the batch dim if needed. if pv.dim() == 2: pv = pv.unsqueeze(0) if pp.dim() == 2: @@ -424,12 +425,14 @@ def get_video_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: ) pp = all_position_ids[pv_idx] - # pv: (num_frames, num_patches, patch_pixels) - # pp: (num_frames, num_patches, 2) - if pv.dim() == 2: - pv = pv.unsqueeze(0) - if pp.dim() == 2: - pp = pp.unsqueeze(0) + # HF processor returns 4-D tensors + # (num_videos, num_frames, num_patches, ...) — collapse to + # 3-D (num_frames, num_patches, ...) so each frame is a + # batch element for the vision tower. + if pv.dim() == 4: + pv = pv.reshape(-1, pv.shape[-2], pv.shape[-1]) + if pp.dim() == 4: + pp = pp.reshape(-1, pp.shape[-2], pp.shape[-1]) pv = pv.to(device=vt.device, dtype=self.language_model.dtype()) pp = pp.to(device=vt.device) From cbe0a585190ad9b24dd338b9f4781ca4502db356 Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Mon, 30 Mar 2026 23:44:20 +0000 Subject: [PATCH 072/112] remove output_length --- python/sglang/srt/models/gemma4_mm.py | 5 +---- python/sglang/srt/models/gemma4_vision.py | 8 ++------ python/sglang/srt/multimodal/processors/gemma4.py | 1 - 3 files changed, 3 insertions(+), 11 deletions(-) diff --git a/python/sglang/srt/models/gemma4_mm.py b/python/sglang/srt/models/gemma4_mm.py index 3f8a753be466..5aaef55bd79f 100644 --- a/python/sglang/srt/models/gemma4_mm.py +++ b/python/sglang/srt/models/gemma4_mm.py @@ -344,9 +344,6 @@ def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: all_position_ids = flatten_nested_list( [getattr(item, "image_position_ids", None)] ) - vol = getattr(item, "vision_output_length", None) - if isinstance(vol, torch.Tensor): - vol = vol.item() for pv_idx, pv in enumerate(all_pixel_values): if ( @@ -374,7 +371,7 @@ def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: pv = pv.to(device=vt.device, dtype=self.language_model.dtype()) pp = pp.to(device=vt.device) - pooled, pooler_mask = vt(pv, pp, output_length=vol) + pooled, pooler_mask = vt(pv, pp) for hs, mask in zip(pooled, pooler_mask): real_tokens = hs[mask] diff --git a/python/sglang/srt/models/gemma4_vision.py b/python/sglang/srt/models/gemma4_vision.py index 58c11d699b95..99d55db5c6c2 100644 --- a/python/sglang/srt/models/gemma4_vision.py +++ b/python/sglang/srt/models/gemma4_vision.py @@ -550,7 +550,6 @@ def forward( self, pixel_values: torch.Tensor, pixel_position_ids: torch.Tensor, - output_length: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """Encode pre-patchified pixel_values into soft tokens. @@ -559,16 +558,13 @@ def forward( by the image processor. pixel_position_ids: [batch, num_patches, 2] — (x, y) positions, -1 for padding patches. - output_length: target number of output soft tokens. If None, - computed as num_patches // pooling_kernel_size^2. Returns: (hidden_states, pooler_mask) — hidden_states [batch, output_len, hidden], pooler_mask [batch, output_len] True = valid. """ - if output_length is None: - k2 = self.pooling_kernel_size * self.pooling_kernel_size - output_length = pixel_values.shape[-2] // k2 + k2 = self.pooling_kernel_size * self.pooling_kernel_size + output_length = pixel_values.shape[-2] // k2 padding_positions = (pixel_position_ids == -1).all(dim=-1) diff --git a/python/sglang/srt/multimodal/processors/gemma4.py b/python/sglang/srt/multimodal/processors/gemma4.py index 5f39edefe6b4..6b5e2016f3ef 100644 --- a/python/sglang/srt/multimodal/processors/gemma4.py +++ b/python/sglang/srt/multimodal/processors/gemma4.py @@ -46,7 +46,6 @@ def __init__(self, hf_config, server_args, _processor, *args, **kwargs): # Register image-processor outputs so they are stored on # MultimodalDataItem via collect_mm_items_from_processor_output. self.ATTR_NAME_TO_MODALITY["image_position_ids"] = Modality.IMAGE - self.ATTR_NAME_TO_MODALITY["vision_output_length"] = Modality.IMAGE # Register video-processor outputs so they are stored on # MultimodalDataItem via collect_mm_items_from_processor_output. From f533a290440d9ca5e889857acd9f6c27ee05486d Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Mon, 30 Mar 2026 23:50:41 +0000 Subject: [PATCH 073/112] bidirectional attention only applies to image not video --- python/sglang/srt/models/gemma4_mm.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/models/gemma4_mm.py b/python/sglang/srt/models/gemma4_mm.py index 5aaef55bd79f..c4229fbc45ae 100644 --- a/python/sglang/srt/models/gemma4_mm.py +++ b/python/sglang/srt/models/gemma4_mm.py @@ -245,11 +245,11 @@ def prepare_attn_masks( input_ids: torch.Tensor, mask_dtype: torch.dtype, ): - """Prepare bidirectional attention masks for image/video tokens. + """Prepare bidirectional attention masks for image (NO video) tokens. - Gemma 4 uses bidirectional attention for image and video soft tokens + Gemma 4 uses bidirectional attention for image (NO video) soft tokens during prefill. Following the HF implementation, bidirectional attention - is only enabled within each individual image/video group (same-item + is only enabled within each individual image (NO video) group (same-item tokens), not across items. Currently only the TritonAttnBackend supports this. @@ -283,11 +283,12 @@ def prepare_attn_masks( bidirectional_attn_mask.fill_(1) bidirectional_attn_mask = bidirectional_attn_mask.tril(diagonal=prefix_len) - # Enable bidirectional attention within each image group + # HF only enables bidirectional attention for image tokens, + # not video or audio (see create_causal_mask_mapping). mm_inputs = forward_batch.mm_inputs[i] if mm_inputs is not None: for mm_item in mm_inputs.mm_items: - if mm_item.is_image() or mm_item.is_video(): + if mm_item.is_image(): for im_begin, im_end in mm_item.offsets: # Note(kpham-sgl): We only apply bidirectional attention when the image token span # is fully contained in the extend window. Otherwise, we silently fall back to @@ -542,10 +543,7 @@ def forward( # Gemma 4 uses bidirectional attention for image/video soft tokens. # Only TritonAttnBackend supports this; incompatible with CUDA Graph and # chunked prefill. - if forward_batch.forward_mode == ForwardMode.EXTEND and ( - forward_batch.contains_image_inputs() - or forward_batch.contains_video_inputs() - ): + if forward_batch.forward_mode == ForwardMode.EXTEND and forward_batch.contains_image_inputs(): self.prepare_attn_masks( forward_batch, input_ids, From 1bac1df900e7bebd9d81a589c0d9112b64bfe680 Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Mon, 30 Mar 2026 23:54:55 +0000 Subject: [PATCH 074/112] change ple pad id --- python/sglang/srt/models/gemma4_mm.py | 7 ++++--- python/sglang/srt/models/gemma4_vision.py | 5 +++++ 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/models/gemma4_mm.py b/python/sglang/srt/models/gemma4_mm.py index c4229fbc45ae..cfdc9dea7458 100644 --- a/python/sglang/srt/models/gemma4_mm.py +++ b/python/sglang/srt/models/gemma4_mm.py @@ -534,9 +534,10 @@ def forward( per_layer_inputs = None if input_ids is not None: ple_ids = input_ids.clone() - ple_ids[input_ids == self.config.image_token_id] = 0 - ple_ids[input_ids == self.config.video_token_id] = 0 - ple_ids[input_ids == self.config.audio_token_id] = 0 + pad_id = self.config.text_config.pad_token_id + ple_ids[input_ids == self.config.image_token_id] = pad_id + ple_ids[input_ids == self.config.video_token_id] = pad_id + ple_ids[input_ids == self.config.audio_token_id] = pad_id per_layer_inputs = self.get_per_layer_inputs(ple_ids) # Prepare bidirectional attention masks for image/video tokens during prefill. diff --git a/python/sglang/srt/models/gemma4_vision.py b/python/sglang/srt/models/gemma4_vision.py index 99d55db5c6c2..6c1cfa61b59e 100644 --- a/python/sglang/srt/models/gemma4_vision.py +++ b/python/sglang/srt/models/gemma4_vision.py @@ -496,6 +496,11 @@ def forward( """ if output_length is None: raise ValueError("output_length is required for Gemma4VisionPooler") + if output_length > hidden_states.shape[1]: + raise ValueError( + f"Cannot output more soft tokens (requested {output_length}) than there are patches" + f" ({hidden_states.shape[1]}). Change the value of `num_soft_tokens` when processing." + ) length = output_length if isinstance(length, (list, tuple)): length = length[0] From 59284714ee39696228b1ea68be284cff4620b939 Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Mon, 30 Mar 2026 23:03:10 +0000 Subject: [PATCH 075/112] fix vision token id, add VDW to tensor step --- .../srt/multimodal/processors/gemma4.py | 32 +++++++++++++++++++ .../sglang/srt/utils/hf_transformers_utils.py | 7 ++++ 2 files changed, 39 insertions(+) diff --git a/python/sglang/srt/multimodal/processors/gemma4.py b/python/sglang/srt/multimodal/processors/gemma4.py index 6b5e2016f3ef..2550a48cca87 100644 --- a/python/sglang/srt/multimodal/processors/gemma4.py +++ b/python/sglang/srt/multimodal/processors/gemma4.py @@ -15,6 +15,7 @@ from typing import Dict, List, Optional, Union import numpy as np +import torch from sglang.srt.managers.multimodal_processor import ( BaseMultimodalProcessor as SGLangBaseProcessor, @@ -22,6 +23,7 @@ from sglang.srt.managers.schedule_batch import Modality from sglang.srt.models.gemma4_mm import Gemma4ForConditionalGeneration from sglang.srt.multimodal.processors.base_processor import MultimodalSpecialTokens +from sglang.srt.utils.video_decoder import VideoDecoderWrapper class Gemma4SGLangProcessor(SGLangBaseProcessor): @@ -66,6 +68,28 @@ def _get_audio_pad_multiple(self) -> int: first_stride = ac.sscp_conv_stride_size[0][0] if ac is not None else 2 return hop * first_stride + def _video_decoder_to_tensor(self, vdw: VideoDecoderWrapper) -> torch.Tensor: + """Convert a VideoDecoderWrapper to a (sampled_frames, C, H, W) uint8 tensor. + + SGLang's load_video returns VideoDecoderWrapper which the HF + Gemma4VideoProcessor does not recognise (expects torch.Tensor or + np.ndarray). We replicate HF's uniform frame sampling here to + avoid materialising the entire video in memory, then delegate the + rest (resize, patchify, position IDs) to the HF video processor. + """ + total = len(vdw) + num_frames = getattr( + getattr(self._processor, "video_processor", None), + "num_frames", + 32, + ) + if total <= num_frames: + indices = list(range(total)) + else: + indices = torch.arange(0, total, total / num_frames).int().tolist() + frames_np = vdw.get_frames_at(indices) # (N, H, W, C) + return torch.from_numpy(frames_np).permute(0, 3, 1, 2).contiguous() + def process_mm_data( self, input_text, images=None, videos=None, audios=None, **kwargs ): @@ -79,6 +103,14 @@ def process_mm_data( a = np.pad(a, (0, pad_multiple - remainder), mode="constant") padded.append(a) audios = padded + if videos: + videos = [ + self._video_decoder_to_tensor(v) + if isinstance(v, VideoDecoderWrapper) + else v + for v in videos + ] + kwargs.setdefault("do_sample_frames", False) return super().process_mm_data( input_text, images=images, videos=videos, audios=audios, **kwargs ) diff --git a/python/sglang/srt/utils/hf_transformers_utils.py b/python/sglang/srt/utils/hf_transformers_utils.py index 1d3e96e9540c..88027d9444cb 100644 --- a/python/sglang/srt/utils/hf_transformers_utils.py +++ b/python/sglang/srt/utils/hf_transformers_utils.py @@ -644,6 +644,13 @@ def get_config( if not hasattr(text_config, "swa_v_head_dim"): text_config.swa_v_head_dim = text_config.swa_head_dim + # TODO(kpham-sgl): config.video_token_id is 262144 (== vocab_size, + # out of range) while tokenizer["<|video|>"] is 258884. Hardcode + # the correct value until the upstream HF config is fixed, then + # remove this override. + if getattr(config, "video_token_id", None) == 262144: + config.video_token_id = 258884 + if config.model_type == "longcat_flash": config.update({"architectures": ["LongcatFlashForCausalLM"]}) From d1318caf88cd001c4a8810472cbf747eb16f7423 Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Tue, 31 Mar 2026 00:00:18 +0000 Subject: [PATCH 076/112] change how rope theta parameter is read --- python/sglang/srt/models/gemma4_vision.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/sglang/srt/models/gemma4_vision.py b/python/sglang/srt/models/gemma4_vision.py index 6c1cfa61b59e..816dc1538886 100644 --- a/python/sglang/srt/models/gemma4_vision.py +++ b/python/sglang/srt/models/gemma4_vision.py @@ -55,8 +55,7 @@ class Gemma4VisionRotaryEmbedding(nn.Module): def __init__(self, config: Gemma4VisionConfig): super().__init__() self.head_dim = config.head_dim - rope_params = config.rope_parameters.get("full_attention", {}) - self.rope_theta: float = rope_params.get("rope_theta", 100.0) + self.rope_theta: float = config.rope_parameters["rope_theta"] @torch.no_grad() def forward( From b9e3aca87b5323ddcbb570cd7e88fc0ef48251f5 Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Tue, 31 Mar 2026 00:01:35 +0000 Subject: [PATCH 077/112] lint --- python/sglang/srt/models/gemma4_mm.py | 5 ++++- python/sglang/srt/multimodal/processors/gemma4.py | 8 +++++--- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/models/gemma4_mm.py b/python/sglang/srt/models/gemma4_mm.py index cfdc9dea7458..9cd822210b86 100644 --- a/python/sglang/srt/models/gemma4_mm.py +++ b/python/sglang/srt/models/gemma4_mm.py @@ -544,7 +544,10 @@ def forward( # Gemma 4 uses bidirectional attention for image/video soft tokens. # Only TritonAttnBackend supports this; incompatible with CUDA Graph and # chunked prefill. - if forward_batch.forward_mode == ForwardMode.EXTEND and forward_batch.contains_image_inputs(): + if ( + forward_batch.forward_mode == ForwardMode.EXTEND + and forward_batch.contains_image_inputs() + ): self.prepare_attn_masks( forward_batch, input_ids, diff --git a/python/sglang/srt/multimodal/processors/gemma4.py b/python/sglang/srt/multimodal/processors/gemma4.py index 2550a48cca87..0cfd2dcdcb4e 100644 --- a/python/sglang/srt/multimodal/processors/gemma4.py +++ b/python/sglang/srt/multimodal/processors/gemma4.py @@ -105,9 +105,11 @@ def process_mm_data( audios = padded if videos: videos = [ - self._video_decoder_to_tensor(v) - if isinstance(v, VideoDecoderWrapper) - else v + ( + self._video_decoder_to_tensor(v) + if isinstance(v, VideoDecoderWrapper) + else v + ) for v in videos ] kwargs.setdefault("do_sample_frames", False) From eaf827a9a514b5d8041e257fe33e0910770616fa Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Tue, 31 Mar 2026 00:31:29 +0000 Subject: [PATCH 078/112] misc comment fix --- python/sglang/srt/models/gemma4_mm.py | 10 +++++----- python/sglang/srt/multimodal/processors/gemma4.py | 5 +---- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/models/gemma4_mm.py b/python/sglang/srt/models/gemma4_mm.py index 9cd822210b86..7402ab34eeb8 100644 --- a/python/sglang/srt/models/gemma4_mm.py +++ b/python/sglang/srt/models/gemma4_mm.py @@ -245,11 +245,11 @@ def prepare_attn_masks( input_ids: torch.Tensor, mask_dtype: torch.dtype, ): - """Prepare bidirectional attention masks for image (NO video) tokens. + """Prepare bidirectional attention masks for image tokens. - Gemma 4 uses bidirectional attention for image (NO video) soft tokens + Gemma 4 uses bidirectional attention for image soft tokens during prefill. Following the HF implementation, bidirectional attention - is only enabled within each individual image (NO video) group (same-item + is only enabled within each individual image group (same-item tokens), not across items. Currently only the TritonAttnBackend supports this. @@ -540,8 +540,8 @@ def forward( ple_ids[input_ids == self.config.audio_token_id] = pad_id per_layer_inputs = self.get_per_layer_inputs(ple_ids) - # Prepare bidirectional attention masks for image/video tokens during prefill. - # Gemma 4 uses bidirectional attention for image/video soft tokens. + # Prepare bidirectional attention masks for image tokens during prefill. + # Gemma 4 uses bidirectional attention for image soft tokens. # Only TritonAttnBackend supports this; incompatible with CUDA Graph and # chunked prefill. if ( diff --git a/python/sglang/srt/multimodal/processors/gemma4.py b/python/sglang/srt/multimodal/processors/gemma4.py index 0cfd2dcdcb4e..d8b795e646ff 100644 --- a/python/sglang/srt/multimodal/processors/gemma4.py +++ b/python/sglang/srt/multimodal/processors/gemma4.py @@ -45,12 +45,9 @@ def __init__(self, hf_config, server_args, _processor, *args, **kwargs): audio_token_id=hf_config.audio_token_id, ).build(_processor) - # Register image-processor outputs so they are stored on + # Register image-processor and video-processor outputs so they are stored on # MultimodalDataItem via collect_mm_items_from_processor_output. self.ATTR_NAME_TO_MODALITY["image_position_ids"] = Modality.IMAGE - - # Register video-processor outputs so they are stored on - # MultimodalDataItem via collect_mm_items_from_processor_output. self.ATTR_NAME_TO_MODALITY["video_position_ids"] = Modality.VIDEO def _get_audio_pad_multiple(self) -> int: From 7bf599b38e3bf4f9c29595e52bf6cc1ee46a584b Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Tue, 31 Mar 2026 02:36:36 +0000 Subject: [PATCH 079/112] config.json is fixed --- python/sglang/srt/utils/hf_transformers_utils.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/python/sglang/srt/utils/hf_transformers_utils.py b/python/sglang/srt/utils/hf_transformers_utils.py index 88027d9444cb..77cc7282f96b 100644 --- a/python/sglang/srt/utils/hf_transformers_utils.py +++ b/python/sglang/srt/utils/hf_transformers_utils.py @@ -644,12 +644,6 @@ def get_config( if not hasattr(text_config, "swa_v_head_dim"): text_config.swa_v_head_dim = text_config.swa_head_dim - # TODO(kpham-sgl): config.video_token_id is 262144 (== vocab_size, - # out of range) while tokenizer["<|video|>"] is 258884. Hardcode - # the correct value until the upstream HF config is fixed, then - # remove this override. - if getattr(config, "video_token_id", None) == 262144: - config.video_token_id = 258884 if config.model_type == "longcat_flash": config.update({"architectures": ["LongcatFlashForCausalLM"]}) From 4e1385a361301ec9c94af981f474956bac3dbaaf Mon Sep 17 00:00:00 2001 From: Andy Luo Date: Tue, 31 Mar 2026 21:38:19 +0000 Subject: [PATCH 080/112] [ROCm] Fix Gemma4 MoE: AITER CK fallback + Gemma4RMSNorm forward_hip Gemma4 MoE models (e.g. 26B-A4B) crash on ROCm due to two issues: 1. AITER CK fused_moe GEMM has no tuned configs for Gemma4's unusual MoE dimensions (128 experts x 704 intermediate size). The CK GEMM crashes with a device_gemm error at runtime. Fix: Wrap the AITER CK fused_moe call in try/except RuntimeError so it gracefully falls back to the Triton MoE runner when CK doesn't support the dimensions. A warning is logged on the first fallback. 2. Gemma4RMSNorm only defines forward_cuda (which calls sgl_kernel's gemma_rmsnorm/rmsnorm), but sgl_kernel ops are not available on ROCm. MultiPlatformOp falls back to forward_cuda on HIP when no forward_hip is defined, causing a crash. Fix: Add forward_hip that delegates to forward_native (pure PyTorch). Tested on MI300X with Gemma4 26B-A4B (MoE) and 31B (dense): - AITER ON + CUDAGraph ON: correct output for both models - AITER OFF: correct output for both models - Text generation verified: math, knowledge, generation tasks --- python/sglang/srt/layers/layernorm.py | 5 ++ .../sglang/srt/layers/quantization/unquant.py | 70 +++++++++++-------- 2 files changed, 46 insertions(+), 29 deletions(-) diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index 214af4fd2dcb..0db6675e648f 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -702,6 +702,11 @@ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: out = out.reshape(original_shape) return out + def forward_hip(self, x: torch.Tensor) -> torch.Tensor: + # sgl_kernel's gemma_rmsnorm is not available on ROCm; + # delegate to the pure-PyTorch implementation. + return self.forward_native(x) + class RMSNormWithoutScale(MultiPlatformOp): def __init__(self, hidden_size: int, eps=1e-6): diff --git a/python/sglang/srt/layers/quantization/unquant.py b/python/sglang/srt/layers/quantization/unquant.py index 9fcb88e541f7..f5686ff7b500 100644 --- a/python/sglang/srt/layers/quantization/unquant.py +++ b/python/sglang/srt/layers/quantization/unquant.py @@ -458,36 +458,48 @@ def forward_cuda( # Skip aiter fused_moe when using non-auto MoE backend (e.g., triton, triton_kernels) # because aiter CK kernels don't support all GEMM dimensions _should_use_aiter_moe = _use_aiter and get_moe_runner_backend().is_auto() + _aiter_moe_succeeded = False if _should_use_aiter_moe: - assert not moe_runner_config.no_combine, "unsupported" - topk_weights, topk_ids, _ = topk_output - if moe_runner_config.apply_router_weight_on_input: - assert ( - topk_weights.dim() == 2 - ), "`topk_weights` should be in shape (num_tokens, topk)" - _, topk = topk_weights.shape - assert ( - topk == 1 - ), "Only support topk=1 when `apply_router_weight_on_input` is True" - x = x * topk_weights.to(x.dtype) - topk_weights = torch.ones_like( - topk_weights, dtype=torch.float32 - ) # topk_weights must be FP32 (float32) - output = fused_moe( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights, - topk_ids, - activation=( - ActivationType.Silu - if moe_runner_config.activation == "silu" - else ActivationType.Gelu - ), - expert_mask=layer.expert_mask_gpu, - ) - return StandardCombineInput(hidden_states=output) - else: + try: + assert not moe_runner_config.no_combine, "unsupported" + topk_weights, topk_ids, _ = topk_output + if moe_runner_config.apply_router_weight_on_input: + assert ( + topk_weights.dim() == 2 + ), "`topk_weights` should be in shape (num_tokens, topk)" + _, topk = topk_weights.shape + assert ( + topk == 1 + ), "Only support topk=1 when `apply_router_weight_on_input` is True" + x = x * topk_weights.to(x.dtype) + topk_weights = torch.ones_like( + topk_weights, dtype=torch.float32 + ) # topk_weights must be FP32 (float32) + output = fused_moe( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + activation=( + ActivationType.Silu + if moe_runner_config.activation == "silu" + else ActivationType.Gelu + ), + expert_mask=layer.expert_mask_gpu, + ) + _aiter_moe_succeeded = True + return StandardCombineInput(hidden_states=output) + except RuntimeError as e: + # AITER CK fused_moe may not support all GEMM dimensions + # (e.g. Gemma4 MoE with 128 experts × 704 intermediate size). + # Fall through to Triton MoE runner below. + import logging + logging.getLogger(__name__).warning_once( + f"AITER CK fused_moe failed ({e}), " + "falling back to Triton MoE runner." + ) + if not _aiter_moe_succeeded: quant_info = TritonMoeQuantInfo( w13_weight=layer.w13_weight, w2_weight=layer.w2_weight, From 8c69336aaef5da90c10566d87e52b7fb59698d58 Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Tue, 31 Mar 2026 21:45:50 +0000 Subject: [PATCH 081/112] init config and param rename --- python/sglang/srt/models/gemma4_audio.py | 80 +++++++++---------- python/sglang/srt/models/gemma4_mm.py | 54 +++++++++++++ .../srt/multimodal/processors/gemma4.py | 3 +- test/manual/test_vlm_accuracy.py | 4 +- 4 files changed, 95 insertions(+), 46 deletions(-) diff --git a/python/sglang/srt/models/gemma4_audio.py b/python/sglang/srt/models/gemma4_audio.py index 6a0ae51cbc2e..1cca4a37db04 100644 --- a/python/sglang/srt/models/gemma4_audio.py +++ b/python/sglang/srt/models/gemma4_audio.py @@ -49,6 +49,12 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.utils import add_prefix, make_layers, set_weight_attrs +# SSCP convolution constants (no longer in config.json, never varied across models) +_SSCP_INPUT_FEAT_SIZE = 128 +_SSCP_CONV_KERNEL_SIZES = ((3, 3), (3, 3)) +_SSCP_CONV_STRIDE_SIZES = ((2, 2), (2, 2)) +_SSCP_CONV_EPS = 0.001 + # --------------------------------------------------------------------------- # Relative Position Embedding # --------------------------------------------------------------------------- @@ -65,12 +71,12 @@ def __init__( self.config = config tp_size = get_attention_tp_size() - total_num_heads = config.conf_num_attention_heads + total_num_heads = config.num_attention_heads self.channels = config.hidden_size self.head_dim = self.channels // total_num_heads self.num_heads = total_num_heads // tp_size - self.max_backward = max(0, config.conf_attention_context_left - 1) - self.max_forward = config.conf_attention_context_right + self.max_backward = max(0, config.attention_context_left - 1) + self.max_forward = config.attention_context_right self.pos_proj = ColumnParallelLinear( self.channels, @@ -215,15 +221,15 @@ def __init__( self.config = config tp_size = get_attention_tp_size() - total_num_heads = config.conf_num_attention_heads + total_num_heads = config.num_attention_heads self.hidden_size = config.hidden_size self.head_dim = self.hidden_size // total_num_heads self.num_heads = total_num_heads // tp_size - self.chunk_size = config.conf_attention_chunk_size - self.max_future_horizon = config.conf_attention_context_right - self.max_past_horizon = max(0, config.conf_attention_context_left - 1) - self.attention_logits_soft_cap = config.conf_attention_logit_cap + self.chunk_size = config.attention_chunk_size + self.max_future_horizon = config.attention_context_right + self.max_past_horizon = max(0, config.attention_context_left - 1) + self.attention_logits_soft_cap = config.attention_logit_cap self.context_size = ( self.chunk_size + self.max_past_horizon + self.max_future_horizon ) @@ -358,7 +364,7 @@ def forward( logits = torch.where( final_condition_for_where, logits, - self.config.conf_attention_invalid_logits_value, + self.config.attention_invalid_logits_value, ) probabilities = F.softmax(logits, dim=-1, dtype=torch.float32).to( @@ -401,24 +407,16 @@ def __init__( super().__init__() self.config = config - in_channels = 1 if idx == 0 else config.sscp_conv_channel_size[idx - 1] - out_channels = config.sscp_conv_channel_size[idx] - kernel_t, kernel_f = config.sscp_conv_kernel_size[idx] - stride_t, stride_f = config.sscp_conv_stride_size[idx] + conv_channels = config.subsampling_conv_channels + in_channels = 1 if idx == 0 else conv_channels[idx - 1] + out_channels = conv_channels[idx] + kernel_t, kernel_f = _SSCP_CONV_KERNEL_SIZES[idx] + stride_t, stride_f = _SSCP_CONV_STRIDE_SIZES[idx] self.time_stride = stride_t - if ( - config.sscp_conv_time_pad_top is not None - and config.sscp_conv_time_pad_bottom is not None - ): - pad_t_top = config.sscp_conv_time_pad_top - pad_t_bottom = config.sscp_conv_time_pad_bottom - elif config.sscp_conv_padding_type == "semicausal": - pad_t_top = kernel_t // 2 - pad_t_bottom = 0 if config.streaming else kernel_t // 2 - else: - pad_t_top = 0 - pad_t_bottom = 0 if config.streaming else kernel_t - 1 + # Semicausal padding (hardcoded — streaming is not supported) + pad_t_top = kernel_t // 2 + pad_t_bottom = kernel_t // 2 pad_f_left = 1 pad_f_right = 1 @@ -439,7 +437,7 @@ def __init__( self.norm = nn.LayerNorm( [out_channels], - eps=config.sscp_conv_eps, + eps=_SSCP_CONV_EPS, elementwise_affine=True, bias=False, ) @@ -476,12 +474,14 @@ def __init__( super().__init__() self.config = config - current_f = config.input_feat_size + conv_channels = config.subsampling_conv_channels + + current_f = _SSCP_INPUT_FEAT_SIZE calculated_f_out_dims = [] for i in range(2): - kernel_h, kernel_w = config.sscp_conv_kernel_size[i] - stride_h, stride_w = config.sscp_conv_stride_size[i] + kernel_h, kernel_w = _SSCP_CONV_KERNEL_SIZES[i] + stride_h, stride_w = _SSCP_CONV_STRIDE_SIZES[i] pad_f_left = 1 pad_f_right = 1 @@ -492,7 +492,7 @@ def __init__( self.conv_0 = Gemma4AudioSSCPConvBlock( idx=0, - input_freq_dim=config.input_feat_size, + input_freq_dim=_SSCP_INPUT_FEAT_SIZE, config=config, ) self.conv_1 = Gemma4AudioSSCPConvBlock( @@ -501,7 +501,7 @@ def __init__( config=config, ) - final_c_out = config.sscp_conv_channel_size[-1] + final_c_out = conv_channels[-1] final_f_out = calculated_f_out_dims[-1] self.input_proj_in_features = final_c_out * final_f_out @@ -621,7 +621,7 @@ def __init__( prefix=add_prefix("ffw_layer_2", prefix), ) self.post_layer_norm = Gemma4RMSNorm(config.hidden_size, scale_shift=0.0) - self.post_layer_scale = config.conf_residual_weight + self.post_layer_scale = config.residual_weight def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor: residual = audio_encodings @@ -648,7 +648,7 @@ def __init__( ): super().__init__() self.config = config - self.causal_padding = config.conf_conv_kernel_size - 1 + self.causal_padding = config.conv_kernel_size - 1 tp_size = get_attention_tp_size() hidden_per_tp = config.hidden_size // tp_size @@ -671,7 +671,7 @@ def __init__( self.depthwise_conv1d = nn.Conv1d( in_channels=hidden_per_tp, out_channels=hidden_per_tp, - kernel_size=config.conf_conv_kernel_size, + kernel_size=config.conv_kernel_size, stride=1, padding=0, groups=hidden_per_tp, @@ -794,7 +794,7 @@ def __init__( config, quant_config, prefix=add_prefix("subsample_conv_projection", prefix) ) self.conformer = make_layers( - config.conf_num_hidden_layers, + config.num_hidden_layers, lambda idx, prefix: Gemma4AudioConformerBlock( config=config, quant_config=quant_config, @@ -837,9 +837,9 @@ def forward( ) with torch.no_grad(): - chunk_size = self.config.conf_attention_chunk_size - max_future_horizon = self.config.conf_attention_context_right - max_past_horizon = max(0, self.config.conf_attention_context_left - 1) + chunk_size = self.config.attention_chunk_size + max_future_horizon = self.config.attention_context_right + max_past_horizon = max(0, self.config.attention_context_left - 1) upper_diagonal = max_past_horizon + max_future_horizon context_size = chunk_size + max_past_horizon + max_future_horizon @@ -861,10 +861,6 @@ def forward( for block in self.conformer: audio_encodings = block(audio_encodings, current_mask, causal_valid_mask) - if self.config.conf_reduction_factor > 1: - audio_encodings = audio_encodings[:, :: self.config.conf_reduction_factor] - current_mask = current_mask[:, :: self.config.conf_reduction_factor] - if self.output_proj is not None: audio_encodings, _ = self.output_proj(audio_encodings) diff --git a/python/sglang/srt/models/gemma4_mm.py b/python/sglang/srt/models/gemma4_mm.py index 7402ab34eeb8..85a801069fb8 100644 --- a/python/sglang/srt/models/gemma4_mm.py +++ b/python/sglang/srt/models/gemma4_mm.py @@ -596,6 +596,56 @@ def tie_weights(self, recompute_mapping=False): # Regex for fused GateUp in the vision tower MLP. _RE_TOWER_GATE_UP = re.compile(r"(.+\.mlp)\.(gate_proj|up_proj)\.(.*)") + _RE_AUDIO_LAYER = re.compile(r"(audio_tower)\.layers\.(\d+)\.(.*)") + + @staticmethod + def _remap_audio_tower_name(name: str) -> str: + """Remap audio tower checkpoint names to our module tree. + + Checkpoint naming (``layers``, ``self_attn``, ``feed_forward1/2``, etc.) + differs from our module tree (``conformer``, ``attention.attn``, + ``ffw_layer_start/end``, etc.). Applied before ``_remap_tower_name``. + """ + if "audio_tower." not in name: + return name + + # SSCP conv block: layer0/layer1 → conv_0/conv_1 + name = name.replace( + "subsample_conv_projection.layer0.", + "subsample_conv_projection.conv_0.", + ) + name = name.replace( + "subsample_conv_projection.layer1.", + "subsample_conv_projection.conv_1.", + ) + + # Conformer layers: audio_tower.layers.{i} → audio_tower.conformer.{i} + m = Gemma4ForConditionalGeneration._RE_AUDIO_LAYER.match(name) + if m: + tower, layer_idx, suffix = m.groups() + + # Order matters: more specific patterns first. + # relative_k_proj → relative_position_embedding.pos_proj + suffix = suffix.replace( + "self_attn.relative_k_proj.", + "attention.attn.relative_position_embedding.pos_proj.", + ) + # self_attn.post → attention.post (the output projection) + suffix = suffix.replace("self_attn.post.", "attention.post.") + # general self_attn → attention.attn + suffix = suffix.replace("self_attn.", "attention.attn.") + # norms + suffix = suffix.replace("norm_pre_attn.", "attention.pre_attn_norm.") + suffix = suffix.replace("norm_post_attn.", "attention.post_norm.") + suffix = suffix.replace("norm_out.", "norm.") + # feed-forward blocks + suffix = suffix.replace("feed_forward1.", "ffw_layer_start.") + suffix = suffix.replace("feed_forward2.", "ffw_layer_end.") + + name = f"{tower}.conformer.{layer_idx}.{suffix}" + + return name + @staticmethod def _remap_tower_name(name: str, params_dict: dict) -> str: """Remap a vision/audio tower checkpoint name to our module tree. @@ -699,6 +749,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ): name = name.replace(".moe.", ".moe.experts.") + # Remap audio tower checkpoint names to our module tree + if "audio_tower." in name: + name = self._remap_audio_tower_name(name) + # Remap vision / audio tower names (fused QKV/GateUp, clippable wrappers) if "vision_tower." in name or "audio_tower." in name: name = self._remap_tower_name(name, params_dict) diff --git a/python/sglang/srt/multimodal/processors/gemma4.py b/python/sglang/srt/multimodal/processors/gemma4.py index d8b795e646ff..2b6607208056 100644 --- a/python/sglang/srt/multimodal/processors/gemma4.py +++ b/python/sglang/srt/multimodal/processors/gemma4.py @@ -61,8 +61,7 @@ def _get_audio_pad_multiple(self) -> int: """ fe = getattr(self._processor, "feature_extractor", None) hop = getattr(fe, "hop_length", 160) - ac = getattr(self.hf_config, "audio_config", None) - first_stride = ac.sscp_conv_stride_size[0][0] if ac is not None else 2 + first_stride = 2 # SSCP first conv stride (constant, not in config) return hop * first_stride def _video_decoder_to_tensor(self, vdw: VideoDecoderWrapper) -> torch.Tensor: diff --git a/test/manual/test_vlm_accuracy.py b/test/manual/test_vlm_accuracy.py index bb2cae545751..3f0ba42b1be5 100644 --- a/test/manual/test_vlm_accuracy.py +++ b/test/manual/test_vlm_accuracy.py @@ -366,7 +366,7 @@ def setUpClass(cls): cls.hf_audio_tower = hf_full.model.audio_tower.eval().to(cls.device) cls.hf_embed_audio = hf_full.model.embed_audio.eval().to(cls.device) config = AutoConfig.from_pretrained(cls.MODEL_PATH) - cls.mel_bins = config.audio_config.input_feat_size + cls.mel_bins = 128 del hf_full torch.cuda.empty_cache() @@ -570,7 +570,7 @@ def setUpClass(cls): cls.device = torch.device("cuda:0") config = AutoConfig.from_pretrained(cls.MODEL_PATH) - cls.mel_bins = config.audio_config.input_feat_size + cls.mel_bins = 128 # -- HF reference (run on GPU 0, then free) ---------------------------- from transformers import ( From 96b5c156aa380ae41c612560973e59d972cae246 Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Tue, 31 Mar 2026 21:58:02 +0000 Subject: [PATCH 082/112] more config rename and per key dim scale fix --- python/sglang/srt/models/gemma4_audio.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/python/sglang/srt/models/gemma4_audio.py b/python/sglang/srt/models/gemma4_audio.py index 1cca4a37db04..ee11312bc1c0 100644 --- a/python/sglang/srt/models/gemma4_audio.py +++ b/python/sglang/srt/models/gemma4_audio.py @@ -53,7 +53,6 @@ _SSCP_INPUT_FEAT_SIZE = 128 _SSCP_CONV_KERNEL_SIZES = ((3, 3), (3, 3)) _SSCP_CONV_STRIDE_SIZES = ((2, 2), (2, 2)) -_SSCP_CONV_EPS = 0.001 # --------------------------------------------------------------------------- # Relative Position Embedding @@ -240,7 +239,6 @@ def __init__( prefix=add_prefix("relative_position_embedding", prefix), ) self.per_dim_scale = nn.Parameter(torch.zeros((self.head_dim,))) - self.per_dim_key_scale = nn.Parameter(torch.ones((self.head_dim,))) self.qkv = ClippableQKVParallelLinear( hidden_size=self.hidden_size, @@ -252,10 +250,8 @@ def __init__( prefix=prefix, ) - # softplus(0) = log(2); pre-fold into scale factors - r_softplus_0 = 1.0 / math.log(2) - self.q_scale = (self.head_dim**-0.5) * r_softplus_0 - self.k_scale = r_softplus_0 + self.q_scale = (self.head_dim**-0.5) / math.log(2) + self.k_scale = math.log(1 + math.e) / math.log(2) self.register_buffer( "softcap", @@ -315,10 +311,7 @@ def forward( query_states * self.q_scale * per_dim_scale_sp.view(broadcast_shape) ) - per_dim_key_scale_sp = F.softplus(self.per_dim_key_scale) - key_states = ( - key_states * self.k_scale * per_dim_key_scale_sp.view(broadcast_shape) - ) + key_states = key_states * self.k_scale batch_size, q_time = query_states.shape[:2] @@ -437,7 +430,7 @@ def __init__( self.norm = nn.LayerNorm( [out_channels], - eps=_SSCP_CONV_EPS, + eps=config.rms_norm_eps, elementwise_affine=True, bias=False, ) From 248a945f06dc861cd026f490be3066be38a2b473 Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Tue, 31 Mar 2026 21:58:27 +0000 Subject: [PATCH 083/112] update vlm test --- test/manual/test_vlm_accuracy.py | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/test/manual/test_vlm_accuracy.py b/test/manual/test_vlm_accuracy.py index 3f0ba42b1be5..f94225383335 100644 --- a/test/manual/test_vlm_accuracy.py +++ b/test/manual/test_vlm_accuracy.py @@ -341,7 +341,7 @@ class TestGemma4EncoderAccuracy(unittest.TestCase): """ MODEL_PATH = "gg-hf-gg/gemma-4-e4b-it" - COSINE_THRESHOLD = 0.99 + COSINE_THRESHOLD = 0.98 @classmethod def setUpClass(cls): @@ -447,15 +447,18 @@ def test_audio_encoder(self): ) with torch.no_grad(): - # HF: returns (encodings, mask) — does NOT zero-fill padding - hf_enc, hf_mask = self.hf_audio_tower(audio_mel, audio_mel_mask) - hf_valid_mask = ~hf_mask - hf_valid = hf_enc[hf_valid_mask.unsqueeze(-1).expand_as(hf_enc)].reshape( - -1, hf_enc.shape[-1] - ) + # HF: attention_mask convention is True=valid. + # SGLang: audio_mel_mask convention is True=padding. + hf_attention_mask = ~audio_mel_mask + hf_out = self.hf_audio_tower(audio_mel, hf_attention_mask) + hf_enc = hf_out.last_hidden_state + hf_output_mask = hf_out.attention_mask # True=valid + hf_valid = hf_enc[ + hf_output_mask.unsqueeze(-1).expand_as(hf_enc) + ].reshape(-1, hf_enc.shape[-1]) hf_projected = self.hf_embed_audio(hf_valid.unsqueeze(0)).squeeze(0) - # SGLang: returns (encodings, mask) — zero-fills padding positions + # SGLang: returns (encodings, mask) where mask True=padding sg_enc, sg_mask = self.sg_model.audio_tower(audio_mel, audio_mel_mask) sg_valid_mask = ~sg_mask sg_valid = sg_enc[sg_valid_mask.unsqueeze(-1).expand_as(sg_enc)].reshape( @@ -599,10 +602,13 @@ def setUpClass(cls): ) with torch.no_grad(): - hf_enc, hf_mask = hf_audio_tower(audio_mel, audio_mel_mask) - hf_valid_mask = ~hf_mask + # HF attention_mask: True=valid; SGLang audio_mel_mask: True=padding + hf_attention_mask = ~audio_mel_mask + hf_out = hf_audio_tower(audio_mel, hf_attention_mask) + hf_enc = hf_out.last_hidden_state + hf_output_mask = hf_out.attention_mask # True=valid cls.hf_audio_valid = ( - hf_enc[hf_valid_mask.unsqueeze(-1).expand_as(hf_enc)] + hf_enc[hf_output_mask.unsqueeze(-1).expand_as(hf_enc)] .reshape(-1, hf_enc.shape[-1]) .cpu() ) From 4f4a3053a00af79215c9479c38922abc2b45f562 Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Tue, 31 Mar 2026 21:59:56 +0000 Subject: [PATCH 084/112] lint --- python/sglang/srt/utils/hf_transformers_utils.py | 1 - test/manual/test_vlm_accuracy.py | 6 +++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/utils/hf_transformers_utils.py b/python/sglang/srt/utils/hf_transformers_utils.py index 77cc7282f96b..1d3e96e9540c 100644 --- a/python/sglang/srt/utils/hf_transformers_utils.py +++ b/python/sglang/srt/utils/hf_transformers_utils.py @@ -644,7 +644,6 @@ def get_config( if not hasattr(text_config, "swa_v_head_dim"): text_config.swa_v_head_dim = text_config.swa_head_dim - if config.model_type == "longcat_flash": config.update({"architectures": ["LongcatFlashForCausalLM"]}) diff --git a/test/manual/test_vlm_accuracy.py b/test/manual/test_vlm_accuracy.py index f94225383335..e533e9f891c0 100644 --- a/test/manual/test_vlm_accuracy.py +++ b/test/manual/test_vlm_accuracy.py @@ -453,9 +453,9 @@ def test_audio_encoder(self): hf_out = self.hf_audio_tower(audio_mel, hf_attention_mask) hf_enc = hf_out.last_hidden_state hf_output_mask = hf_out.attention_mask # True=valid - hf_valid = hf_enc[ - hf_output_mask.unsqueeze(-1).expand_as(hf_enc) - ].reshape(-1, hf_enc.shape[-1]) + hf_valid = hf_enc[hf_output_mask.unsqueeze(-1).expand_as(hf_enc)].reshape( + -1, hf_enc.shape[-1] + ) hf_projected = self.hf_embed_audio(hf_valid.unsqueeze(0)).squeeze(0) # SGLang: returns (encodings, mask) where mask True=padding From 11ca0dcc924f52fc8e67b9e4f3e8ba3f00186c48 Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Tue, 31 Mar 2026 22:03:42 +0000 Subject: [PATCH 085/112] nit constant --- python/sglang/srt/multimodal/processors/gemma4.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/multimodal/processors/gemma4.py b/python/sglang/srt/multimodal/processors/gemma4.py index 2b6607208056..e97885002342 100644 --- a/python/sglang/srt/multimodal/processors/gemma4.py +++ b/python/sglang/srt/multimodal/processors/gemma4.py @@ -21,6 +21,7 @@ BaseMultimodalProcessor as SGLangBaseProcessor, ) from sglang.srt.managers.schedule_batch import Modality +from sglang.srt.models.gemma4_audio import _SSCP_CONV_STRIDE_SIZES from sglang.srt.models.gemma4_mm import Gemma4ForConditionalGeneration from sglang.srt.multimodal.processors.base_processor import MultimodalSpecialTokens from sglang.srt.utils.video_decoder import VideoDecoderWrapper @@ -61,7 +62,7 @@ def _get_audio_pad_multiple(self) -> int: """ fe = getattr(self._processor, "feature_extractor", None) hop = getattr(fe, "hop_length", 160) - first_stride = 2 # SSCP first conv stride (constant, not in config) + first_stride = _SSCP_CONV_STRIDE_SIZES[0][0] return hop * first_stride def _video_decoder_to_tensor(self, vdw: VideoDecoderWrapper) -> torch.Tensor: From e3930574a845f93ef63dd1596a67a1b17a9910d5 Mon Sep 17 00:00:00 2001 From: Andy Luo Date: Tue, 31 Mar 2026 22:09:36 +0000 Subject: [PATCH 086/112] [ROCm] Fix vision attention backend selection for Gemma4 select_backend() falls through to 'sdpa' on ROCm because is_cuda() returns False on HIP. SDPA with flatten_batch=True breaks multi-image and video inputs. Fix: Add is_hip() check to return 'triton_attn' on ROCm, which correctly handles batched vision inputs regardless of flatten_batch. This addresses the root cause identified in PR #13. --- python/sglang/srt/models/gemma4_vision.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/models/gemma4_vision.py b/python/sglang/srt/models/gemma4_vision.py index 816dc1538886..f0c49cbc68b8 100644 --- a/python/sglang/srt/models/gemma4_vision.py +++ b/python/sglang/srt/models/gemma4_vision.py @@ -30,7 +30,7 @@ from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.layers.layernorm import Gemma4RMSNorm from sglang.srt.layers.quantization.base_config import QuantizationConfig -from sglang.srt.utils import add_prefix, get_device_capability, is_cuda +from sglang.srt.utils import add_prefix, get_device_capability, is_cuda, is_hip # --------------------------------------------------------------------------- # 2-D Multidimensional RoPE (matches HF Gemma4RotaryEmbedding for vision) @@ -195,6 +195,10 @@ def _select_backend() -> str: return "triton_attn" return "fa3" return "triton_attn" + if is_hip(): + # ROCm: use triton_attn to avoid SDPA flatten_batch issues + # with multi-image/video inputs + return "triton_attn" return "sdpa" def forward( From 89cf215ab84e5b9d9d83f49f0f2269a68a844058 Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Tue, 31 Mar 2026 23:40:22 +0000 Subject: [PATCH 087/112] fix vlm accuracy test for vision --- test/manual/test_vlm_accuracy.py | 54 +++++++++++++++++++++++--------- 1 file changed, 40 insertions(+), 14 deletions(-) diff --git a/test/manual/test_vlm_accuracy.py b/test/manual/test_vlm_accuracy.py index e533e9f891c0..94edaa4c535a 100644 --- a/test/manual/test_vlm_accuracy.py +++ b/test/manual/test_vlm_accuracy.py @@ -329,6 +329,30 @@ async def test_vlm_embedding_output(self): # --------------------------------------------------------------------------- +def _make_patchified_vision_inputs( + device: torch.device, + dtype: torch.dtype = torch.bfloat16, + side_patches: int = 48, + patch_size: int = 16, +) -> tuple: + """Create synthetic patchified vision inputs matching the HF image-processor format. + + Returns (pixel_values, pixel_position_ids) with no padding. + """ + num_patches = side_patches * side_patches + patch_pixels = 3 * patch_size**2 + pixel_values = torch.randn( + 1, num_patches, patch_pixels, device=device, dtype=dtype + ) + ys, xs = torch.meshgrid( + torch.arange(side_patches), torch.arange(side_patches), indexing="ij" + ) + pixel_position_ids = ( + torch.stack([xs.flatten(), ys.flatten()], dim=-1).unsqueeze(0).to(device) + ) + return pixel_values, pixel_position_ids + + class TestGemma4EncoderAccuracy(unittest.TestCase): """Compare Gemma 4 vision and audio encoder outputs between HF and SGLang. @@ -410,19 +434,21 @@ def _assert_cosine_close(self, hf: torch.Tensor, sg: torch.Tensor, label: str): # -- vision --------------------------------------------------------------- def test_vision_encoder(self): - """Vision tower + embed_vision should match HF on random pixels.""" - pixel_values = torch.randn( - 1, 3, 768, 768, device=self.device, dtype=torch.bfloat16 + """Vision tower + embed_vision should match HF on patchified pixels.""" + pixel_values, pixel_position_ids = _make_patchified_vision_inputs( + self.device ) with torch.no_grad(): - # HF: last_hidden_state is [1, num_real_tokens, hidden] (padding stripped) - hf_out = self.hf_vision_tower(pixel_values) - hf_tokens = hf_out.last_hidden_state.squeeze(0) + # HF: last_hidden_state contains only valid (non-padding) tokens + hf_out = self.hf_vision_tower(pixel_values, pixel_position_ids) + hf_tokens = hf_out.last_hidden_state hf_projected = self.hf_embed_vision(hf_tokens.unsqueeze(0)).squeeze(0) # SGLang: returns (pooled, pooler_mask) with mask True = valid - sg_pooled, sg_mask = self.sg_model.vision_tower(pixel_values) + sg_pooled, sg_mask = self.sg_model.vision_tower( + pixel_values, pixel_position_ids + ) sg_tokens = torch.cat([hs[m] for hs, m in zip(sg_pooled, sg_mask)]) sg_projected = self.sg_model.embed_vision(sg_tokens.unsqueeze(0)).squeeze(0) @@ -524,7 +550,7 @@ def _tp2_encoder_worker( 1, num_frames, mel_bins, device=device, dtype=torch.bfloat16 ) audio_mel_mask = torch.zeros(1, num_frames, device=device, dtype=torch.bool) - pixel_values = torch.randn(1, 3, 768, 768, device=device, dtype=torch.bfloat16) + pixel_values, pixel_position_ids = _make_patchified_vision_inputs(device) with torch.no_grad(): # Audio @@ -536,7 +562,9 @@ def _tp2_encoder_worker( sg_audio_proj = sg_model.embed_audio(sg_audio_valid.unsqueeze(0)).squeeze(0) # Vision - sg_vis_pooled, sg_vis_mask = sg_model.vision_tower(pixel_values) + sg_vis_pooled, sg_vis_mask = sg_model.vision_tower( + pixel_values, pixel_position_ids + ) sg_vis_tokens = torch.cat([hs[m] for hs, m in zip(sg_vis_pooled, sg_vis_mask)]) sg_vis_proj = sg_model.embed_vision(sg_vis_tokens.unsqueeze(0)).squeeze(0) @@ -597,9 +625,7 @@ def setUpClass(cls): audio_mel_mask = torch.zeros( 1, cls.NUM_FRAMES, device=cls.device, dtype=torch.bool ) - pixel_values = torch.randn( - 1, 3, 768, 768, device=cls.device, dtype=torch.bfloat16 - ) + pixel_values, pixel_position_ids = _make_patchified_vision_inputs(cls.device) with torch.no_grad(): # HF attention_mask: True=valid; SGLang audio_mel_mask: True=padding @@ -618,8 +644,8 @@ def setUpClass(cls): .cpu() ) - hf_vis_out = hf_vision_tower(pixel_values) - cls.hf_vis_tokens = hf_vis_out.last_hidden_state.squeeze(0).cpu() + hf_vis_out = hf_vision_tower(pixel_values, pixel_position_ids) + cls.hf_vis_tokens = hf_vis_out.last_hidden_state.cpu() cls.hf_vis_proj = ( hf_embed_vision(cls.hf_vis_tokens.unsqueeze(0).to(cls.device)) .squeeze(0) From 7ab05a91f5a3ba2358ddabca742ba761ec2087f6 Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Tue, 31 Mar 2026 23:41:46 +0000 Subject: [PATCH 088/112] lint --- test/manual/test_vlm_accuracy.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/test/manual/test_vlm_accuracy.py b/test/manual/test_vlm_accuracy.py index 94edaa4c535a..e35cb17233b6 100644 --- a/test/manual/test_vlm_accuracy.py +++ b/test/manual/test_vlm_accuracy.py @@ -341,9 +341,7 @@ def _make_patchified_vision_inputs( """ num_patches = side_patches * side_patches patch_pixels = 3 * patch_size**2 - pixel_values = torch.randn( - 1, num_patches, patch_pixels, device=device, dtype=dtype - ) + pixel_values = torch.randn(1, num_patches, patch_pixels, device=device, dtype=dtype) ys, xs = torch.meshgrid( torch.arange(side_patches), torch.arange(side_patches), indexing="ij" ) @@ -435,9 +433,7 @@ def _assert_cosine_close(self, hf: torch.Tensor, sg: torch.Tensor, label: str): def test_vision_encoder(self): """Vision tower + embed_vision should match HF on patchified pixels.""" - pixel_values, pixel_position_ids = _make_patchified_vision_inputs( - self.device - ) + pixel_values, pixel_position_ids = _make_patchified_vision_inputs(self.device) with torch.no_grad(): # HF: last_hidden_state contains only valid (non-padding) tokens From eaf75b2893d96c9ed93815a15ba2a50f8e4e9f9c Mon Sep 17 00:00:00 2001 From: Andy Luo Date: Wed, 1 Apr 2026 01:54:39 +0000 Subject: [PATCH 089/112] Address review: narrow try/except to fused_moe() call only Per JustinTong0323's review: remove _aiter_moe_succeeded flag, narrow the try block to just the fused_moe() call, and let the Triton fallback naturally fall through. Cleaner control flow. --- .../sglang/srt/layers/quantization/unquant.py | 46 +++++++++---------- 1 file changed, 22 insertions(+), 24 deletions(-) diff --git a/python/sglang/srt/layers/quantization/unquant.py b/python/sglang/srt/layers/quantization/unquant.py index f5686ff7b500..c7f6c1166045 100644 --- a/python/sglang/srt/layers/quantization/unquant.py +++ b/python/sglang/srt/layers/quantization/unquant.py @@ -458,23 +458,22 @@ def forward_cuda( # Skip aiter fused_moe when using non-auto MoE backend (e.g., triton, triton_kernels) # because aiter CK kernels don't support all GEMM dimensions _should_use_aiter_moe = _use_aiter and get_moe_runner_backend().is_auto() - _aiter_moe_succeeded = False if _should_use_aiter_moe: + assert not moe_runner_config.no_combine, "unsupported" + topk_weights, topk_ids, _ = topk_output + if moe_runner_config.apply_router_weight_on_input: + assert ( + topk_weights.dim() == 2 + ), "`topk_weights` should be in shape (num_tokens, topk)" + _, topk = topk_weights.shape + assert ( + topk == 1 + ), "Only support topk=1 when `apply_router_weight_on_input` is True" + x = x * topk_weights.to(x.dtype) + topk_weights = torch.ones_like( + topk_weights, dtype=torch.float32 + ) # topk_weights must be FP32 (float32) try: - assert not moe_runner_config.no_combine, "unsupported" - topk_weights, topk_ids, _ = topk_output - if moe_runner_config.apply_router_weight_on_input: - assert ( - topk_weights.dim() == 2 - ), "`topk_weights` should be in shape (num_tokens, topk)" - _, topk = topk_weights.shape - assert ( - topk == 1 - ), "Only support topk=1 when `apply_router_weight_on_input` is True" - x = x * topk_weights.to(x.dtype) - topk_weights = torch.ones_like( - topk_weights, dtype=torch.float32 - ) # topk_weights must be FP32 (float32) output = fused_moe( x, layer.w13_weight, @@ -488,7 +487,6 @@ def forward_cuda( ), expert_mask=layer.expert_mask_gpu, ) - _aiter_moe_succeeded = True return StandardCombineInput(hidden_states=output) except RuntimeError as e: # AITER CK fused_moe may not support all GEMM dimensions @@ -499,14 +497,14 @@ def forward_cuda( f"AITER CK fused_moe failed ({e}), " "falling back to Triton MoE runner." ) - if not _aiter_moe_succeeded: - quant_info = TritonMoeQuantInfo( - w13_weight=layer.w13_weight, - w2_weight=layer.w2_weight, - b13=getattr(layer, "w13_weight_bias", None), - b2=getattr(layer, "w2_weight_bias", None), - ) - return self.runner.run(dispatch_output, quant_info) + + quant_info = TritonMoeQuantInfo( + w13_weight=layer.w13_weight, + w2_weight=layer.w2_weight, + b13=getattr(layer, "w13_weight_bias", None), + b2=getattr(layer, "w2_weight_bias", None), + ) + return self.runner.run(dispatch_output, quant_info) def forward_cpu( self, From 37fe6020a3809ec4027f9dfa3f42a0d7aa17733b Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Wed, 1 Apr 2026 02:39:35 +0000 Subject: [PATCH 090/112] nit: use module-level logger instead of inline import in except block --- python/sglang/srt/layers/quantization/unquant.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/quantization/unquant.py b/python/sglang/srt/layers/quantization/unquant.py index c7f6c1166045..d8b4d4819278 100644 --- a/python/sglang/srt/layers/quantization/unquant.py +++ b/python/sglang/srt/layers/quantization/unquant.py @@ -1,7 +1,10 @@ from __future__ import annotations +import logging from typing import TYPE_CHECKING, List, Optional +logger = logging.getLogger(__name__) + import torch import torch.nn.functional as F from torch.nn.parameter import Parameter @@ -492,8 +495,7 @@ def forward_cuda( # AITER CK fused_moe may not support all GEMM dimensions # (e.g. Gemma4 MoE with 128 experts × 704 intermediate size). # Fall through to Triton MoE runner below. - import logging - logging.getLogger(__name__).warning_once( + logger.warning_once( f"AITER CK fused_moe failed ({e}), " "falling back to Triton MoE runner." ) From 5bff462f3f91e5da927a595acd6ae274bea37936 Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Wed, 1 Apr 2026 19:42:57 +0000 Subject: [PATCH 091/112] fix: Gemma4 fused kernel correctness, detector robustness, and dead code cleanup MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Fix critical bug in Triton fused RMSNorm kernel: add missing Gemma +1 weight shift (w -> w+1.0) that caused incorrect normalization on the fused path - Replace regex-based tool call parsing with brace-balanced _extract_tool_calls and shared _find_matching_brace helper to handle nested args correctly - Narrow exception handling in gemma4_detector from bare Exception to (ValueError, IndexError, TypeError, KeyError) with full state reset on error - Cache _tool_indices with lazy init in streaming path (hasattr guard) - Fix O(n²) string slicing in parser (text[i:].startswith -> fixed-length slice) - Add .index() validation with clear error message for KV sharing layer config - Add getattr fallback for expert_intermediate_size/moe_intermediate_size compat - Remove dead code: Gemma4PerLayerEmbedding class, gemma_rmsnorm_residual function, unused manual_padding parameter, stale VocabParallelEmbedding import --- .../srt/function_call/gemma4_detector.py | 278 +- python/sglang/srt/layers/gemma4_fused_ops.py | 34 +- python/sglang/srt/models/gemma4_audio.py | 1 - python/sglang/srt/models/gemma4_causal.py | 91 +- .../test_function_call_parser.py | 3937 +++++++++++++++++ 5 files changed, 4103 insertions(+), 238 deletions(-) create mode 100644 test/registered/function_call/test_function_call_parser.py diff --git a/python/sglang/srt/function_call/gemma4_detector.py b/python/sglang/srt/function_call/gemma4_detector.py index e4d3a1916dec..2b4b9e05a16b 100644 --- a/python/sglang/srt/function_call/gemma4_detector.py +++ b/python/sglang/srt/function_call/gemma4_detector.py @@ -1,6 +1,5 @@ import json import logging -import re from typing import List, Optional from sglang.srt.entrypoints.openai.protocol import Tool @@ -56,7 +55,7 @@ def _parse_gemma4_array(arr_str: str) -> list: break # String element - if arr_str[i:].startswith(STRING_DELIM): + if arr_str[i : i + len(STRING_DELIM)] == STRING_DELIM: i += len(STRING_DELIM) end_pos = arr_str.find(STRING_DELIM, i) if end_pos == -1: @@ -71,7 +70,7 @@ def _parse_gemma4_array(arr_str: str) -> list: obj_start = i + 1 i += 1 while i < n and depth > 0: - if arr_str[i:].startswith(STRING_DELIM): + if arr_str[i : i + len(STRING_DELIM)] == STRING_DELIM: i += len(STRING_DELIM) next_delim = arr_str.find(STRING_DELIM, i) i = next_delim + len(STRING_DELIM) if next_delim != -1 else n @@ -144,7 +143,7 @@ def _parse_gemma4_args(args_str: str) -> dict: break # String value: <|"|>...<|"|> - if args_str[i:].startswith(STRING_DELIM): + if args_str[i : i + len(STRING_DELIM)] == STRING_DELIM: i += len(STRING_DELIM) val_start = i end_pos = args_str.find(STRING_DELIM, i) @@ -161,7 +160,7 @@ def _parse_gemma4_args(args_str: str) -> dict: obj_start = i + 1 i += 1 while i < n and depth > 0: - if args_str[i:].startswith(STRING_DELIM): + if args_str[i : i + len(STRING_DELIM)] == STRING_DELIM: # Skip over string contents i += len(STRING_DELIM) next_delim = args_str.find(STRING_DELIM, i) @@ -183,7 +182,7 @@ def _parse_gemma4_args(args_str: str) -> dict: arr_start = i + 1 i += 1 while i < n and depth > 0: - if args_str[i:].startswith(STRING_DELIM): + if args_str[i : i + len(STRING_DELIM)] == STRING_DELIM: i += len(STRING_DELIM) next_delim = args_str.find(STRING_DELIM, i) if next_delim == -1: @@ -209,21 +208,69 @@ def _parse_gemma4_args(args_str: str) -> dict: return result +def _find_matching_brace(text: str) -> int: + """Find index of matching '}' in text, respecting STRING_DELIM and nesting. + + Assumes text starts just after the opening '{'. + Returns index of closing brace, or -1 if not found (incomplete). + """ + depth = 1 + i = 0 + n = len(text) + delim_len = len(STRING_DELIM) + while i < n and depth > 0: + if text[i : i + delim_len] == STRING_DELIM: + i += delim_len + next_delim = text.find(STRING_DELIM, i) + if next_delim == -1: + return -1 + i = next_delim + delim_len + continue + if text[i] == "{": + depth += 1 + elif text[i] == "}": + depth -= 1 + i += 1 + return (i - 1) if depth == 0 else -1 + + class Gemma4Detector(BaseFormatDetector): def __init__(self): super().__init__() self.tool_call_start_token = TOOL_CALL_START self.tool_call_end_token = TOOL_CALL_END - self.tool_call_regex = re.compile( - r"<\|tool_call>call:(\w+)\{(.*?)\}", - re.DOTALL, - ) # Streaming state self.parsed_pos: int = 0 self.is_inside_tool_call: bool = False self.current_func_name: Optional[str] = None - self.json_started: bool = False + self._tool_indices: Optional[dict] = None + + @staticmethod + def _extract_tool_calls(text: str) -> list: + """Extract (func_name, args_str) pairs using brace-balanced parsing.""" + results = [] + search_from = 0 + while True: + start = text.find(TOOL_CALL_START, search_from) + if start == -1: + break + end = text.find(TOOL_CALL_END, start) + if end == -1: + break + inner = text[start + len(TOOL_CALL_START) : end] + if inner.startswith("call:"): + brace = inner.find("{") + if brace != -1: + func_name = inner[5:brace] + args_content = inner[brace + 1 :] + match_idx = _find_matching_brace(args_content) + args_str = ( + args_content[:match_idx] if match_idx != -1 else args_content + ) + results.append((func_name, args_str)) + search_from = end + len(TOOL_CALL_END) + return results def has_tool_call(self, text: str) -> bool: return self.tool_call_start_token in text @@ -234,7 +281,7 @@ def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult calls = [] try: - matches = self.tool_call_regex.findall(text) + matches = self._extract_tool_calls(text) if not matches: return StreamingParseResult(normal_text=text) @@ -255,8 +302,8 @@ def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult return StreamingParseResult(normal_text=normal_text, calls=calls) - except Exception as e: - logger.error(f"Error in detect_and_parse: {e}") + except (ValueError, IndexError, TypeError, KeyError) as e: + logger.error(f"Error in detect_and_parse: {e}", exc_info=True) return StreamingParseResult(normal_text=text) def parse_streaming_increment( @@ -269,131 +316,120 @@ def parse_streaming_increment( calls = [] normal_text_chunks = [] + if self._tool_indices is None: + self._tool_indices = self._get_tool_indices(tools) - while True: - current_slice = self._buffer[self.parsed_pos :] - if not current_slice: - break - - if not self.is_inside_tool_call: - # Step 4: Outside tool call block - next_start = current_slice.find(self.tool_call_start_token) - if next_start == -1: - # Check for partial match at the end - partial_len = self._ends_with_partial_token( - current_slice, self.tool_call_start_token - ) - if partial_len > 0: - text_to_append = current_slice[:-partial_len] - if text_to_append: - normal_text_chunks.append(text_to_append) - self.parsed_pos += len(text_to_append) - break + try: + while True: + current_slice = self._buffer[self.parsed_pos :] + if not current_slice: + break + + if not self.is_inside_tool_call: + # Outside tool call block + next_start = current_slice.find(self.tool_call_start_token) + if next_start == -1: + # Check for partial match at the end + partial_len = self._ends_with_partial_token( + current_slice, self.tool_call_start_token + ) + if partial_len > 0: + text_to_append = current_slice[:-partial_len] + if text_to_append: + normal_text_chunks.append(text_to_append) + self.parsed_pos += len(text_to_append) + break + else: + normal_text_chunks.append(current_slice) + self.parsed_pos += len(current_slice) + continue + elif next_start == 0: + self.parsed_pos += len(self.tool_call_start_token) + self.is_inside_tool_call = True + continue else: - normal_text_chunks.append(current_slice) - self.parsed_pos += len(current_slice) + normal_text_chunks.append(current_slice[:next_start]) + self.parsed_pos += next_start continue - elif next_start == 0: - self.parsed_pos += len(self.tool_call_start_token) - self.is_inside_tool_call = True - continue else: - normal_text_chunks.append(current_slice[:next_start]) - self.parsed_pos += next_start - continue - else: - # Inside tool call block - - # Check for TOOL_CALL_END first - if current_slice.startswith(self.tool_call_end_token): - self.parsed_pos += len(self.tool_call_end_token) - self.is_inside_tool_call = False - self.current_func_name = None - continue + # Inside tool call block - if not self.current_func_name: - # Skip leading whitespace - if current_slice[0] in (" ", "\n", "\t"): - self.parsed_pos += 1 + # Check for TOOL_CALL_END first + if current_slice.startswith(self.tool_call_end_token): + self.parsed_pos += len(self.tool_call_end_token) + self.is_inside_tool_call = False + self.current_func_name = None continue - if current_slice.startswith("call:"): - brace_pos = current_slice.find("{") - if brace_pos != -1: - func_name = current_slice[5:brace_pos] - self.current_tool_id += 1 - self.current_func_name = func_name - self.current_tool_name_sent = True - - tool_indices = self._get_tool_indices(tools) - calls.append( - ToolCallItem( - tool_index=tool_indices.get(func_name, -1), - name=func_name, - parameters="", - ) - ) - self.parsed_pos += brace_pos + 1 - continue - else: - # Incomplete call:name{ - break - else: - # Check for partial matches - if "call:".startswith( - current_slice - ) or self.tool_call_end_token.startswith(current_slice): - break - - # Unexpected content, skip - self.parsed_pos += 1 - continue - else: - # Parsing arguments (looking for balancing }) - depth = 1 - i = 0 - n = len(current_slice) - found = False - while i < n: - if current_slice[i : i + len(STRING_DELIM)] == STRING_DELIM: - i += len(STRING_DELIM) - next_delim = current_slice.find(STRING_DELIM, i) - if next_delim == -1: - i = n # Force wait - break - i = next_delim + len(STRING_DELIM) + if not self.current_func_name: + # Skip leading whitespace + if current_slice[0] in (" ", "\n", "\t"): + self.parsed_pos += 1 continue - if current_slice[i] == "{": - depth += 1 - elif current_slice[i] == "}": - depth -= 1 - if depth == 0: - args_str = current_slice[:i] - arguments = _parse_gemma4_args(args_str) + if current_slice.startswith("call:"): + brace_pos = current_slice.find("{") + if brace_pos != -1: + func_name = current_slice[5:brace_pos] + self.current_tool_id += 1 + self.current_func_name = func_name + self.current_tool_name_sent = True - tool_indices = self._get_tool_indices(tools) calls.append( ToolCallItem( - tool_index=tool_indices.get( - self.current_func_name, -1 - ), - parameters=json.dumps( - arguments, ensure_ascii=False + tool_index=self._tool_indices.get( + func_name, -1 ), + name=func_name, + parameters="", ) ) - self.parsed_pos += i + 1 - self.current_func_name = None # Reset for next call: - found = True + self.parsed_pos += brace_pos + 1 + continue + else: + # Incomplete call:name{ + break + else: + # Check for partial matches + if "call:".startswith( + current_slice + ) or self.tool_call_end_token.startswith(current_slice): break - i += 1 - if found: - continue + # Unexpected content, skip + self.parsed_pos += 1 + continue else: - # Incomplete arguments block - break + # Parsing arguments (looking for balancing }) + match_idx = _find_matching_brace(current_slice) + if match_idx != -1: + args_str = current_slice[:match_idx] + arguments = _parse_gemma4_args(args_str) + + calls.append( + ToolCallItem( + tool_index=self._tool_indices.get( + self.current_func_name, -1 + ), + parameters=json.dumps( + arguments, ensure_ascii=False + ), + ) + ) + self.parsed_pos += match_idx + 1 + self.current_func_name = None + continue + else: + # Incomplete arguments block + break + + except (ValueError, IndexError, TypeError, KeyError) as e: + logger.error(f"Error in parse_streaming_increment: {e}", exc_info=True) + # Reset parser state to prevent corruption + self.is_inside_tool_call = False + self.current_func_name = None + self._buffer = "" + self.parsed_pos = 0 if self.parsed_pos > 0: self._buffer = self._buffer[self.parsed_pos :] diff --git a/python/sglang/srt/layers/gemma4_fused_ops.py b/python/sglang/srt/layers/gemma4_fused_ops.py index 14092aa16f3d..3e4bd28ae314 100644 --- a/python/sglang/srt/layers/gemma4_fused_ops.py +++ b/python/sglang/srt/layers/gemma4_fused_ops.py @@ -40,7 +40,7 @@ def _gemma_rmsnorm_residual_kernel( var = tl.sum(x * x, axis=0) / N rrms = tl.rsqrt(var + eps) - out = x * rrms * w + r + out = x * rrms * (w + 1.0) + r if HAS_SCALAR: scalar = tl.load(Scalar_ptr).to(tl.float32) @@ -49,38 +49,6 @@ def _gemma_rmsnorm_residual_kernel( tl.store(Out_ptr + row * stride_o + cols, out.to(x.dtype), mask=mask) -def gemma_rmsnorm_residual( - x: torch.Tensor, - weight: torch.Tensor, - residual: torch.Tensor, - eps: float = 1e-6, -) -> tuple[torch.Tensor, torch.Tensor]: - """Fused rmsnorm(x) + residual. - - Returns (output, new_residual) where new_residual = output. - """ - assert x.dim() == 2 and x.stride(-1) == 1, "Expected contiguous 2D input" - M, N = x.shape - BLOCK_SIZE = triton.next_power_of_2(N) - out = torch.empty_like(x) - - _gemma_rmsnorm_residual_kernel[(M,)]( - x, - weight, - residual, - None, # no scalar - out, - x.stride(0), - residual.stride(0), - out.stride(0), - N, - eps, - HAS_SCALAR=False, - BLOCK_SIZE=BLOCK_SIZE, - ) - return out, out.clone() - - def gemma_rmsnorm_residual_scalar( x: torch.Tensor, weight: torch.Tensor, diff --git a/python/sglang/srt/models/gemma4_audio.py b/python/sglang/srt/models/gemma4_audio.py index ee11312bc1c0..91dc52ae7f75 100644 --- a/python/sglang/srt/models/gemma4_audio.py +++ b/python/sglang/srt/models/gemma4_audio.py @@ -395,7 +395,6 @@ def __init__( config: Gemma4AudioConfig, idx: int, input_freq_dim: int, - manual_padding: Tuple[int, int, int, int] = (0, 0, 0, 0), ): super().__init__() self.config = config diff --git a/python/sglang/srt/models/gemma4_causal.py b/python/sglang/srt/models/gemma4_causal.py index 5705b38f5b53..495e6642fd51 100644 --- a/python/sglang/srt/models/gemma4_causal.py +++ b/python/sglang/srt/models/gemma4_causal.py @@ -40,10 +40,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.rotary_embedding import get_rope -from sglang.srt.layers.vocab_parallel_embedding import ( - ParallelLMHead, - VocabParallelEmbedding, -) +from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import ( default_weight_loader, @@ -66,84 +63,6 @@ def get_attention_sliding_window_size(config): Gemma4TextScaledWordEmbedding = Gemma3TextScaledWordEmbedding -class Gemma4PerLayerEmbedding(nn.Module): - """Per-Layer Embedding (PLE) system for Gemma 4. - - Gemma 4 uses a secondary embedding stream that provides layer-specific - token embeddings. These are combined with the main hidden states via - a gating mechanism in each decoder layer. - - The PLE embedding stores embeddings for all layers packed together: - (vocab_size, hidden_size_per_layer_input * num_hidden_layers) - """ - - def __init__( - self, - vocab_size_per_layer_input: int, - hidden_size_per_layer_input: int, - hidden_size: int, - num_hidden_layers: int, - rms_norm_eps: float, - quant_config: QuantizationConfig | None = None, - prefix: str = "", - ) -> None: - super().__init__() - self.vocab_size = vocab_size_per_layer_input - self.hidden_size_per_layer = hidden_size_per_layer_input - self.hidden_size = hidden_size - self.num_layers = num_hidden_layers - - # Packed embedding: (vocab_size, hidden_size_per_layer * num_layers) - # We store embeddings for ALL layers together - total_embed_dim = hidden_size_per_layer_input * num_hidden_layers - self.embed_tokens_per_layer = VocabParallelEmbedding( - vocab_size_per_layer_input, - total_embed_dim, - quant_config=quant_config, - prefix=f"{prefix}.embed_tokens_per_layer", - ) - - # Projection from PLE space to hidden space - # (hidden_size_per_layer * num_layers, hidden_size) - self.per_layer_model_projection = nn.Linear( - total_embed_dim, - hidden_size, - bias=False, - ) - - # Normalization for PLE output - # JAX uses scale_plus_one=False for this norm (x * scale, not x * (1+scale)) - self.per_layer_projection_norm = RMSNorm( - self.hidden_size_per_layer, - eps=rms_norm_eps, - ) - - def forward(self, input_ids: torch.Tensor) -> torch.Tensor: - """Compute per-layer embeddings and project to hidden size. - - Args: - input_ids: Token IDs (batch_size, seq_len) - - Returns: - Per-layer input tensor (batch_size, seq_len, hidden_size) - """ - # Get packed per-layer embeddings - per_layer_embeds = self.embed_tokens_per_layer(input_ids) - - # Apply normalization (reshape to apply per-layer, then reshape back) - # Original shape: (batch, seq, hidden_size_per_layer * num_layers) - batch_size, seq_len, _ = per_layer_embeds.shape - per_layer_embeds = per_layer_embeds.view( - batch_size, seq_len, self.num_layers, self.hidden_size_per_layer - ) - per_layer_embeds = self.per_layer_projection_norm(per_layer_embeds) - per_layer_embeds = per_layer_embeds.view(batch_size, seq_len, -1) - - # Project to hidden size - per_layer_input = self.per_layer_model_projection(per_layer_embeds) - return per_layer_input - - class Gemma4Router(nn.Module): """Router for Gemma4 MoE that preprocesses input before projection. @@ -261,7 +180,7 @@ def routing_function( num_experts=config.num_experts + get_global_server_args().ep_num_redundant_experts, hidden_size=config.hidden_size, - intermediate_size=config.expert_intermediate_size, + intermediate_size=getattr(config, "expert_intermediate_size", config.moe_intermediate_size), layer_id=layer_id, top_k=config.top_k_experts, quant_config=quant_config, @@ -375,6 +294,12 @@ def __init__( if num_kv_shared_layers > 0 and self.layer_id >= first_kv_shared_layer_idx: prev_layers = config.layer_types[:first_kv_shared_layer_idx] current_layer_type = config.layer_types[self.layer_id] + if current_layer_type not in prev_layers: + raise ValueError( + f"KV sharing layer {self.layer_id} has type '{current_layer_type}' " + f"but no matching type found in layers 0..{first_kv_shared_layer_idx - 1}. " + f"Available types: {set(prev_layers)}" + ) self.kv_shared_layer_index = ( len(prev_layers) - 1 - prev_layers[::-1].index(current_layer_type) ) diff --git a/test/registered/function_call/test_function_call_parser.py b/test/registered/function_call/test_function_call_parser.py new file mode 100644 index 000000000000..172e71610bae --- /dev/null +++ b/test/registered/function_call/test_function_call_parser.py @@ -0,0 +1,3937 @@ +import json +import unittest + +from sglang.srt.entrypoints.openai.protocol import Function, Tool +from sglang.srt.function_call.base_format_detector import BaseFormatDetector +from sglang.srt.function_call.core_types import StreamingParseResult +from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector +from sglang.srt.function_call.deepseekv32_detector import DeepSeekV32Detector +from sglang.srt.function_call.gemma4_detector import ( + Gemma4Detector, + _parse_gemma4_args, + _parse_gemma4_array, + _parse_gemma4_value, +) +from sglang.srt.function_call.gigachat3_detector import GigaChat3Detector +from sglang.srt.function_call.glm4_moe_detector import Glm4MoeDetector +from sglang.srt.function_call.glm47_moe_detector import Glm47MoeDetector +from sglang.srt.function_call.json_array_parser import JsonArrayParser +from sglang.srt.function_call.kimik2_detector import KimiK2Detector +from sglang.srt.function_call.lfm2_detector import Lfm2Detector +from sglang.srt.function_call.llama32_detector import Llama32Detector +from sglang.srt.function_call.mistral_detector import MistralDetector +from sglang.srt.function_call.pythonic_detector import PythonicDetector +from sglang.srt.function_call.qwen3_coder_detector import Qwen3CoderDetector +from sglang.test.ci.ci_register import register_cpu_ci + +register_cpu_ci(1.0, "default") + + +class TestPythonicDetector(unittest.TestCase): + def setUp(self): + # Create sample tools for testing + self.tools = [ + Tool( + type="function", + function=Function( + name="get_weather", + description="Get weather information", + parameters={ + "properties": { + "location": { + "type": "string", + "description": "Location to get weather for", + }, + "unit": { + "type": "string", + "description": "Temperature unit", + "enum": ["celsius", "fahrenheit"], + }, + }, + "required": ["location"], + }, + ), + ), + Tool( + type="function", + function=Function( + name="search", + description="Search for information", + parameters={ + "properties": { + "query": { + "type": "string", + "description": "Search query", + }, + }, + "required": ["query"], + }, + ), + ), + ] + self.detector = PythonicDetector() + + def test_parse_streaming_no_brackets(self): + """Test parsing text with no brackets (no tool calls).""" + text = "This is just normal text without any tool calls." + result = self.detector.parse_streaming_increment(text, self.tools) + + self.assertEqual(result.normal_text, text) + self.assertEqual(result.calls, []) + self.assertEqual(self.detector._buffer, "") # Buffer should be cleared + + def test_parse_streaming_complete_tool_call(self): + """Test parsing a complete tool call.""" + text = "Here's a tool call: [get_weather(location='New York', unit='celsius')]" + result = self.detector.parse_streaming_increment(text, self.tools) + + self.assertEqual(result.normal_text, "Here's a tool call: ") + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].name, "get_weather") + self.assertEqual( + self.detector._buffer, "" + ) # Buffer should be cleared after processing + + # Check the parameters + params = json.loads(result.calls[0].parameters) + self.assertEqual(params["location"], "New York") + self.assertEqual(params["unit"], "celsius") + + def test_parse_streaming_text_before_tool_call(self): + """Test parsing text that appears before a tool call.""" + text = "This is some text before [get_weather(location='London')]" + result = self.detector.parse_streaming_increment(text, self.tools) + + self.assertEqual(result.normal_text, "This is some text before ") + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].name, "get_weather") + + # Check the parameters + params = json.loads(result.calls[0].parameters) + self.assertEqual(params["location"], "London") + + def test_parse_streaming_partial_tool_call(self): + """Test parsing a partial tool call that spans multiple chunks.""" + # First chunk with opening bracket but no closing bracket + text1 = "Let me check the weather: [get_weather(location=" + result1 = self.detector.parse_streaming_increment(text1, self.tools) + + self.assertEqual(result1.normal_text, "Let me check the weather: ") + self.assertEqual(result1.calls, []) + self.assertEqual( + self.detector._buffer, "[get_weather(location=" + ) # Partial tool call remains in buffer + + # Second chunk completing the tool call + text2 = "'Paris')]" + result2 = self.detector.parse_streaming_increment(text2, self.tools) + + self.assertEqual(result2.normal_text, "") + self.assertEqual(len(result2.calls), 1) + self.assertEqual(result2.calls[0].name, "get_weather") + + # Check the parameters + params = json.loads(result2.calls[0].parameters) + self.assertEqual(params["location"], "Paris") + self.assertEqual( + self.detector._buffer, "" + ) # Buffer should be cleared after processing + + def test_parse_streaming_bracket_without_text_before(self): + """Test parsing a tool call that starts at the beginning of the text.""" + text = "[search(query='python programming')]" + result = self.detector.parse_streaming_increment(text, self.tools) + + self.assertEqual(result.normal_text, "") + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].name, "search") + + # Check the parameters + params = json.loads(result.calls[0].parameters) + self.assertEqual(params["query"], "python programming") + + def test_parse_streaming_text_after_tool_call(self): + """Test parsing text that appears after a tool call.""" + # First chunk with complete tool call and some text after + text = "[get_weather(location='Tokyo')] Here's the forecast:" + result = self.detector.parse_streaming_increment(text, self.tools) + + self.assertEqual(result.normal_text, "") + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].name, "get_weather") + self.assertEqual( + self.detector._buffer, " Here's the forecast:" + ) # Text after tool call remains in buffer + + # Process the remaining text in buffer + result2 = self.detector.parse_streaming_increment("", self.tools) + self.assertEqual(result2.normal_text, " Here's the forecast:") + self.assertEqual(result2.calls, []) + self.assertEqual(self.detector._buffer, "") # Buffer should be cleared + + def test_parse_streaming_multiple_tool_calls(self): + """Test parsing multiple tool calls in sequence.""" + text = "[get_weather(location='Berlin')] and [search(query='restaurants')]" + + # First tool call + result1 = self.detector.parse_streaming_increment(text, self.tools) + self.assertEqual(len(result1.calls), 1) + self.assertEqual(result1.calls[0].name, "get_weather") + self.assertEqual(self.detector._buffer, " and [search(query='restaurants')]") + + # Second tool call + result2 = self.detector.parse_streaming_increment("", self.tools) + self.assertEqual(result2.normal_text, " and ") + self.assertEqual(len(result2.calls), 1) + self.assertEqual(result2.calls[0].name, "search") + self.assertEqual(self.detector._buffer, "") + + def test_parse_streaming_opening_bracket_only(self): + """Test parsing text with only an opening bracket but no closing bracket.""" + text = "Let's try this: [" + result = self.detector.parse_streaming_increment(text, self.tools) + + self.assertEqual(result.normal_text, "Let's try this: ") + self.assertEqual(result.calls, []) + self.assertEqual( + self.detector._buffer, "[" + ) # Opening bracket remains in buffer + + def test_parse_streaming_nested_brackets(self): + """Test parsing tool calls with nested brackets in arguments.""" + # Test with list argument containing nested brackets + text = "[get_weather(location='New York', unit='celsius', data=[1, 2, 3])]" + result = self.detector.parse_streaming_increment(text, self.tools) + + self.assertEqual(result.normal_text, "") + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].name, "get_weather") + self.assertEqual(self.detector._buffer, "") + + # Check the parameters + params = json.loads(result.calls[0].parameters) + self.assertEqual(params["location"], "New York") + self.assertEqual(params["unit"], "celsius") + self.assertEqual(params["data"], [1, 2, 3]) + + def test_parse_streaming_nested_brackets_dict(self): + """Test parsing tool calls with nested dictionaries and lists.""" + # Test with nested dict and list arguments + text = "[search(query='test', config={'options': [1, 2], 'nested': {'key': 'value'}})]" + result = self.detector.parse_streaming_increment(text, self.tools) + + self.assertEqual(result.normal_text, "") + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].name, "search") + self.assertEqual(self.detector._buffer, "") + + # Check the parameters + params = json.loads(result.calls[0].parameters) + self.assertEqual(params["query"], "test") + self.assertEqual(params["config"]["options"], [1, 2]) + self.assertEqual(params["config"]["nested"]["key"], "value") + + def test_parse_streaming_multiple_tools_with_nested_brackets(self): + """Test parsing multiple tool calls with nested brackets.""" + text = "[get_weather(location='Paris', data=[10, 20]), search(query='test', filters=['a', 'b'])]" + result = self.detector.parse_streaming_increment(text, self.tools) + + self.assertEqual(result.normal_text, "") + self.assertEqual(len(result.calls), 2) + self.assertEqual(self.detector._buffer, "") + + # Check first tool call + params1 = json.loads(result.calls[0].parameters) + self.assertEqual(result.calls[0].name, "get_weather") + self.assertEqual(params1["location"], "Paris") + self.assertEqual(params1["data"], [10, 20]) + + # Check second tool call + params2 = json.loads(result.calls[1].parameters) + self.assertEqual(result.calls[1].name, "search") + self.assertEqual(params2["query"], "test") + self.assertEqual(params2["filters"], ["a", "b"]) + + def test_parse_streaming_partial_nested_brackets(self): + """Test parsing partial tool calls with nested brackets across chunks.""" + # First chunk with nested brackets but incomplete + text1 = "Here's a call: [get_weather(location='Tokyo', data=[1, 2" + result1 = self.detector.parse_streaming_increment(text1, self.tools) + + self.assertEqual(result1.normal_text, "Here's a call: ") + self.assertEqual(result1.calls, []) + self.assertEqual( + self.detector._buffer, "[get_weather(location='Tokyo', data=[1, 2" + ) + + # Second chunk completing the nested brackets + text2 = ", 3])]" + result2 = self.detector.parse_streaming_increment(text2, self.tools) + + self.assertEqual(result2.normal_text, "") + self.assertEqual(len(result2.calls), 1) + self.assertEqual(result2.calls[0].name, "get_weather") + self.assertEqual(self.detector._buffer, "") + + # Check the parameters + params = json.loads(result2.calls[0].parameters) + self.assertEqual(params["location"], "Tokyo") + self.assertEqual(params["data"], [1, 2, 3]) + + def test_parse_streaming_with_python_start_and_end_token(self): + """Test parsing a message that starts with <|python_start|> and <|python_end|> across chunks.""" + chunks = [ + "Here's a call: ", + "<|python_", + "start|>[get_weather(location=", + "'Tokyo', data=[1, 2", + ", 3])]<|python_end|>", + ] + + normal_text = "" + call_name = "" + parameters = "" + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, self.tools) + if result.normal_text: + normal_text += result.normal_text + if result.calls: + call_name += result.calls[0].name + parameters += result.calls[0].parameters + + self.assertEqual(normal_text, "Here's a call: ") + self.assertEqual(call_name, "get_weather") + self.assertEqual(self.detector._buffer, "") + self.assertEqual( + result.normal_text, "", "Final result should have no normal text" + ) + + # Check the parameters + params = json.loads(parameters) + self.assertEqual(params["location"], "Tokyo") + self.assertEqual(params["data"], [1, 2, 3]) + + chunks = [ + "Here's a call: <|python_start|>[get_weather(location='Tokyo', data=[1, 2, 3])]<|python_end|>" + ] + + normal_text = "" + call_name = "" + parameters = "" + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, self.tools) + if result.normal_text: + normal_text += result.normal_text + if result.calls: + call_name += result.calls[0].name + parameters += result.calls[0].parameters + + self.assertEqual(normal_text, "Here's a call: ") + self.assertEqual(call_name, "get_weather") + self.assertEqual(self.detector._buffer, "") + + # Check the parameters + params = json.loads(parameters) + self.assertEqual(params["location"], "Tokyo") + self.assertEqual(params["data"], [1, 2, 3]) + + def test_detect_and_parse_with_python_start_and_end_token(self): + """Test parsing a message that starts with <|python_start|> and contains a valid tool call.""" + text = "User wants to get the weather in Mars. <|python_start|>[get_weather(location='Mars', unit='celsius')]<|python_end|> In this way we will get the weather in Mars." + result = self.detector.detect_and_parse(text, self.tools) + + self.assertEqual( + result.normal_text, + "User wants to get the weather in Mars. In this way we will get the weather in Mars.", + ) + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].name, "get_weather") + self.assertEqual(self.detector._buffer, "") + + # Check the parameters + params = json.loads(result.calls[0].parameters) + self.assertEqual(params["location"], "Mars") + self.assertEqual(params["unit"], "celsius") + + +class TestMistralDetector(unittest.TestCase): + def setUp(self): + """Set up test tools and detector for Mistral format testing.""" + self.tools = [ + Tool( + type="function", + function=Function( + name="make_next_step_decision", + description="Test function for decision making", + parameters={ + "type": "object", + "properties": { + "decision": { + "type": "string", + "description": "The next step to take", + }, + "content": { + "type": "string", + "description": "The content of the next step", + }, + }, + "required": ["decision", "content"], + }, + ), + ), + ] + self.detector = MistralDetector() + + def test_detect_and_parse_with_nested_brackets_in_content(self): + """Test parsing Mistral format with nested brackets in JSON content. + + This test case specifically addresses the issue where the regex pattern + was incorrectly truncating JSON when it contained nested brackets like [City Name]. + """ + # This is the exact problematic text from the original test failure + test_text = '[TOOL_CALLS] [{"name":"make_next_step_decision", "arguments":{"decision":"","content":"```\\nTOOL: Access a weather API or service\\nOBSERVATION: Retrieve the current weather data for the top 5 populated cities in the US\\nANSWER: The weather in the top 5 populated cities in the US is as follows: [City Name] - [Weather Conditions] - [Temperature]\\n```"}}]' + + result = self.detector.detect_and_parse(test_text, self.tools) + + # Verify that the parsing was successful + self.assertEqual(len(result.calls), 1, "Should detect exactly one tool call") + + call = result.calls[0] + self.assertEqual( + call.name, + "make_next_step_decision", + "Should detect the correct function name", + ) + + # Verify that the parameters are valid JSON and contain the expected content + params = json.loads(call.parameters) + self.assertEqual( + params["decision"], "", "Decision parameter should be empty string" + ) + + # The content should contain the full text including the nested brackets [City Name] + expected_content = "```\nTOOL: Access a weather API or service\nOBSERVATION: Retrieve the current weather data for the top 5 populated cities in the US\nANSWER: The weather in the top 5 populated cities in the US is as follows: [City Name] - [Weather Conditions] - [Temperature]\n```" + self.assertEqual( + params["content"], + expected_content, + "Content should include nested brackets without truncation", + ) + + # Verify that normal text is empty (since the entire input is a tool call) + self.assertEqual( + result.normal_text, "", "Normal text should be empty for pure tool call" + ) + + def test_detect_and_parse_simple_case(self): + """Test parsing a simple Mistral format tool call without nested brackets.""" + test_text = '[TOOL_CALLS] [{"name":"make_next_step_decision", "arguments":{"decision":"TOOL", "content":"Use weather API"}}]' + + result = self.detector.detect_and_parse(test_text, self.tools) + + self.assertEqual(len(result.calls), 1) + call = result.calls[0] + self.assertEqual(call.name, "make_next_step_decision") + + params = json.loads(call.parameters) + self.assertEqual(params["decision"], "TOOL") + self.assertEqual(params["content"], "Use weather API") + + def test_detect_and_parse_no_tool_calls(self): + """Test parsing text without any tool calls.""" + test_text = "This is just normal text without any tool calls." + + result = self.detector.detect_and_parse(test_text, self.tools) + + self.assertEqual(len(result.calls), 0, "Should detect no tool calls") + self.assertEqual( + result.normal_text, + test_text, + "Should return the original text as normal text", + ) + + def test_detect_and_parse_with_text_before_tool_call(self): + """Test parsing text that has content before the tool call.""" + test_text = 'Here is some text before the tool call: [TOOL_CALLS] [{"name":"make_next_step_decision", "arguments":{"decision":"ANSWER", "content":"The answer is 42"}}]' + + result = self.detector.detect_and_parse(test_text, self.tools) + + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.normal_text, "Here is some text before the tool call:") + + call = result.calls[0] + self.assertEqual(call.name, "make_next_step_decision") + + params = json.loads(call.parameters) + self.assertEqual(params["decision"], "ANSWER") + self.assertEqual(params["content"], "The answer is 42") + + def test_detect_and_parse_compact_args_format(self): + """Test parsing compact format: [TOOL_CALLS]name[ARGS]{...}.""" + test_text = '[TOOL_CALLS]make_next_step_decision[ARGS]{"decision":"TOOL", "content":"Use weather API"}' + + result = self.detector.detect_and_parse(test_text, self.tools) + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].name, "make_next_step_decision") + params = json.loads(result.calls[0].parameters) + self.assertEqual(params["decision"], "TOOL") + self.assertEqual(params["content"], "Use weather API") + + def test_streaming_compact_args_format_emits_tool_calls(self): + """Test streaming chunks for compact format produce tool_calls items.""" + chunks = [ + "[TOOL_CALLS]make_next_step_decision[ARGS]", + '{"decision":"TOOL", ', + '"content":"Use weather API"}', + ] + + emitted = [] + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, self.tools) + if result.calls: + emitted.extend(result.calls) + + # Expect two items: name chunk + full args chunk + self.assertEqual(len(emitted), 2) + self.assertEqual(emitted[0].name, "make_next_step_decision") + self.assertEqual(emitted[0].parameters, "") + self.assertIsNone(emitted[1].name) + params = json.loads(emitted[1].parameters) + self.assertEqual(params["decision"], "TOOL") + self.assertEqual(params["content"], "Use weather API") + + +class TestBaseFormatDetector(unittest.TestCase): + """Test buffer management and sequential tool index assignment in BaseFormatDetector.""" + + def setUp(self): + """Set up test detector and tools.""" + + # Create a concrete implementation of BaseFormatDetector for testing + class TestFormatDetector(BaseFormatDetector): + def __init__(self): + super().__init__() + self.bot_token = "" + self.eot_token = "" + + def detect_and_parse(self, text, tools): + # Not used in streaming tests + pass + + def has_tool_call(self, text): + return "" in text + + def structure_info(self): + # Not used in streaming tests + pass + + self.detector = TestFormatDetector() + self.tools = [ + Tool( + type="function", + function=Function( + name="get_weather", + description="Get weather information", + parameters={ + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "City name", + } + }, + "required": ["city"], + }, + ), + ), + Tool( + type="function", + function=Function( + name="get_tourist_attractions", + description="Get tourist attractions", + parameters={ + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "City name", + } + }, + "required": ["city"], + }, + ), + ), + ] + + def test_sequential_tool_index_assignment(self): + """Test that multiple tool calls get sequential tool_index values (0, 1, 2, ...).""" + # Simulate streaming chunks for two consecutive tool calls + chunks = [ + "", + '{"name": "get_weather", ', + '"arguments": {"city": "Paris"}}', + ", ", + '{"name": "get_tourist_attractions", ', + '"arguments": {"city": "London"}}', + "", + ] + + tool_indices_seen = [] + + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, self.tools) + + if result.calls: + for call in result.calls: + if call.tool_index is not None: + tool_indices_seen.append(call.tool_index) + + # Verify we got sequential tool indices + unique_indices = sorted(set(tool_indices_seen)) + self.assertEqual( + unique_indices, + [0, 1], + f"Expected sequential tool indices [0, 1], got {unique_indices}", + ) + + def test_buffer_content_preservation(self): + """Test that buffer correctly preserves unprocessed content when tool completes.""" + # Test simpler scenario: tool completion followed by new tool start + chunks = [ + "", + '{"name": "get_weather", ', + '"arguments": {"city": "Paris"}}', + ", ", + '{"name": "get_tourist_attractions", ', + '"arguments": {"city": "London"}} ', + ] + + tool_calls_seen = [] + + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, self.tools) + if result.calls: + for call in result.calls: + if ( + call.name + ): # Only count calls with names (not just parameter updates) + tool_calls_seen.append(call.name) + + # Should see both tool names + self.assertIn("get_weather", tool_calls_seen, "Should process first tool") + self.assertIn( + "get_tourist_attractions", tool_calls_seen, "Should process second tool" + ) + + def test_current_tool_id_increment_on_completion(self): + """Test that current_tool_id increments when a tool completes.""" + # Initial state + self.assertEqual( + self.detector.current_tool_id, -1, "Should start with current_tool_id=-1" + ) + + # Process first tool completely + chunks = [ + "", + '{"name": "get_weather", ', + ] + + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, self.tools) + + self.assertEqual( + self.detector.current_tool_id, 0, "current_tool_id should be 0" + ) + self.assertEqual( + result.calls[0].name, "get_weather", "The first tool should be get_weather" + ) + self.assertEqual( + result.calls[0].tool_index, 0, "The first tool index should be 0" + ) + + # Complete second tool name - this should show that current_tool_id is now 1 + result = self.detector.parse_streaming_increment( + '"arguments": {"city": "Paris"}}, {"name": "get_', self.tools + ) + self.assertEqual(result.calls[0].parameters, '{"city": "Paris"}') + + self.assertEqual( + self.detector.current_tool_id, + 1, + "current_tool_id should be 1 after first tool completes and second tool starts", + ) + + result = self.detector.parse_streaming_increment( + 'tourist_attractions", ', self.tools + ) + + # Second tool should have tool_index=1 + tourist_calls = [ + call for call in result.calls if call.name == "get_tourist_attractions" + ] + self.assertEqual( + tourist_calls[0].tool_index, 1, "Second tool should have tool_index=1" + ) + + def test_tool_name_streaming_with_correct_index(self): + """Test that tool names are streamed with correct tool_index values.""" + # Process first tool + self.detector.parse_streaming_increment("", self.tools) + result1 = self.detector.parse_streaming_increment( + '{"name": "get_weather", ', self.tools + ) + + # First tool name should have tool_index=0 + weather_calls = [call for call in result1.calls if call.name == "get_weather"] + self.assertEqual(len(weather_calls), 1, "Should have one weather call") + self.assertEqual( + weather_calls[0].tool_index, 0, "First tool should have tool_index=0" + ) + + # Complete first tool + self.detector.parse_streaming_increment( + '"arguments": {"city": "Paris"}}', self.tools + ) + + # Start second tool + self.detector.parse_streaming_increment(", ", self.tools) + result2 = self.detector.parse_streaming_increment( + '{"name": "get_tourist_attractions", ', self.tools + ) + + # Second tool name should have tool_index=1 + tourist_calls = [ + call for call in result2.calls if call.name == "get_tourist_attractions" + ] + self.assertEqual( + len(tourist_calls), 1, "Should have one tourist attractions call" + ) + self.assertEqual( + tourist_calls[0].tool_index, 1, "Second tool should have tool_index=1" + ) + + def test_buffer_reset_on_invalid_tool(self): + """Test that buffer and state are reset when an invalid tool name is encountered.""" + # Start fresh with an invalid tool name from the beginning + result = self.detector.parse_streaming_increment( + '{"name": "invalid_tool", ', self.tools + ) + + # Should return empty result and reset state + self.assertEqual(result.calls, [], "Should return no calls for invalid tool") + self.assertEqual( + self.detector.current_tool_id, + -1, + "current_tool_id should remain -1 for invalid tool", + ) + self.assertEqual( + self.detector._buffer, "", "Buffer should be cleared for invalid tool" + ) + + def test_chinese_characters_not_double_escaped(self): + """Test that Chinese characters in tool call parameters are not double-escaped.""" + # Test with Chinese city name "杭州" (Hangzhou) + chunks = [ + "", + '{"name": "get_weather", ', + '"arguments": {"city": "杭州"}}', + "", + ] + + accumulated_parameters = {} + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, self.tools) + if result.calls: + for call in result.calls: + if call.parameters: + tool_idx = call.tool_index if call.tool_index is not None else 0 + if tool_idx not in accumulated_parameters: + accumulated_parameters[tool_idx] = "" + accumulated_parameters[tool_idx] += call.parameters + + # Verify that Chinese characters are preserved (not escaped as \uXXXX) + self.assertGreater( + len(accumulated_parameters), 0, "Should have parsed parameters" + ) + final_params_str = accumulated_parameters[0] + + # The parameters string should contain the actual Chinese characters, not escaped Unicode + self.assertIn( + "杭州", final_params_str, "Should contain actual Chinese characters" + ) + self.assertNotIn( + "\\u676d", final_params_str, "Should not contain escaped Unicode sequences" + ) + self.assertNotIn( + "\\u5dde", final_params_str, "Should not contain escaped Unicode sequences" + ) + + # Verify the JSON can be parsed and contains the correct value + params = json.loads(final_params_str) + self.assertEqual( + params["city"], "杭州", "Should correctly parse Chinese city name" + ) + + def test_chinese_characters_incremental_streaming(self): + """Test that Chinese characters work correctly with incremental streaming.""" + # Test incremental streaming with Chinese characters + chunks = [ + "", + '{"name": "get_weather", ', + '"arguments": {"city": "', + "杭州", + '"}}', + "", + ] + + accumulated_parameters = {} + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, self.tools) + if result.calls: + for call in result.calls: + if call.parameters: + tool_idx = call.tool_index if call.tool_index is not None else 0 + if tool_idx not in accumulated_parameters: + accumulated_parameters[tool_idx] = "" + accumulated_parameters[tool_idx] += call.parameters + + # Verify Chinese characters are preserved throughout streaming + self.assertGreater( + len(accumulated_parameters), 0, "Should have parsed parameters" + ) + final_params_str = accumulated_parameters[0] + + # Should contain actual Chinese characters, not escaped + self.assertIn( + "杭州", final_params_str, "Should contain actual Chinese characters" + ) + + # Parse and verify + params = json.loads(final_params_str) + self.assertEqual( + params["city"], "杭州", "Should correctly parse Chinese city name" + ) + + def test_multiple_chinese_parameters(self): + """Test multiple tool calls with Chinese parameters.""" + # Test with multiple tool calls containing Chinese characters + chunks = [ + "", + '{"name": "get_weather", "arguments": {"city": "北京"}}, ', + '{"name": "get_tourist_attractions", "arguments": {"city": "上海"}}', + "", + ] + + accumulated_parameters = {} + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, self.tools) + if result.calls: + for call in result.calls: + if call.parameters: + tool_idx = call.tool_index if call.tool_index is not None else 0 + if tool_idx not in accumulated_parameters: + accumulated_parameters[tool_idx] = "" + accumulated_parameters[tool_idx] += call.parameters + + # Verify both tool calls have correct Chinese characters + self.assertGreaterEqual( + len(accumulated_parameters), 1, "Should have parsed parameters" + ) + + # Check first tool call (北京 - Beijing) + if 0 in accumulated_parameters: + params0 = json.loads(accumulated_parameters[0]) + self.assertIn( + "北京", + accumulated_parameters[0], + "Should contain actual Chinese characters", + ) + self.assertEqual( + params0["city"], "北京", "Should correctly parse first Chinese city" + ) + + # Check second tool call (上海 - Shanghai) if present + if 1 in accumulated_parameters: + params1 = json.loads(accumulated_parameters[1]) + self.assertIn( + "上海", + accumulated_parameters[1], + "Should contain actual Chinese characters", + ) + self.assertEqual( + params1["city"], "上海", "Should correctly parse second Chinese city" + ) + + +class TestLlama32Detector(unittest.TestCase): + def setUp(self): + """Set up test tools and detector for Mistral format testing.""" + self.tools = [ + Tool( + type="function", + function=Function( + name="get_weather", + description="Get weather information", + parameters={ + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "City name", + } + }, + "required": ["city"], + }, + ), + ), + Tool( + type="function", + function=Function( + name="get_tourist_attractions", + description="Get tourist attractions", + parameters={ + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "City name", + } + }, + "required": ["city"], + }, + ), + ), + ] + self.detector = Llama32Detector() + + def test_single_json(self): + text = '{"name": "get_weather", "parameters": {"city": "Paris"}}' + result = self.detector.detect_and_parse(text, self.tools) + assert len(result.calls) == 1 + assert result.calls[0].name == "get_weather" + assert result.normal_text == "" + + def test_multiple_json_with_separator(self): + text = ( + '<|python_tag|>{"name": "get_weather", "parameters": {"city": "Paris"}};' + '{"name": "get_tourist_attractions", "parameters": {"city": "Paris"}}' + ) + result = self.detector.detect_and_parse(text, self.tools) + self.assertEqual(len(result.calls), 2) + self.assertEqual(result.calls[1].name, "get_tourist_attractions") + self.assertEqual(result.normal_text, "") + + def test_multiple_json_with_separator_customized(self): + text = ( + '<|python_tag|>{"name": "get_weather", "parameters": {}}' + '<|python_tag|>{"name": "get_tourist_attractions", "parameters": {}}' + ) + result = self.detector.detect_and_parse(text, self.tools) + self.assertEqual(len(result.calls), 2) + self.assertEqual(result.calls[1].name, "get_tourist_attractions") + self.assertEqual(result.normal_text, "") + + def test_json_with_trailing_text(self): + text = '{"name": "get_weather", "parameters": {}} Some follow-up text' + result = self.detector.detect_and_parse(text, self.tools) + self.assertEqual(len(result.calls), 1) + self.assertIn("follow-up", result.normal_text) + + def test_invalid_then_valid_json(self): + text = ( + '{"name": "get_weather", "parameters": {' # malformed + '{"name": "get_weather", "parameters": {}}' + ) + result = self.detector.detect_and_parse(text, self.tools) + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].name, "get_weather") + + def test_plain_text_only(self): + text = "This is just plain explanation text." + result = self.detector.detect_and_parse(text, self.tools) + self.assertEqual(result.calls, []) + self.assertEqual(result.normal_text, text) + + def test_with_python_tag_prefix(self): + text = 'Some intro. <|python_tag|>{"name": "get_weather", "parameters": {}}' + result = self.detector.detect_and_parse(text, self.tools) + self.assertEqual(len(result.calls), 1) + self.assertTrue(result.normal_text.strip().startswith("Some intro.")) + + +class TestKimiK2Detector(unittest.TestCase): + + def setUp(self): + """Set up test tools and detector.""" + self.tools = [ + Tool( + type="function", + function=Function( + name="get_weather", + description="Get weather information", + parameters={ + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "City name", + } + }, + "required": ["city"], + }, + ), + ), + Tool( + type="function", + function=Function( + name="get_tourist_attractions", + description="Get tourist attractions", + parameters={ + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "City name", + } + }, + "required": ["city"], + }, + ), + ), + ] + self.detector = KimiK2Detector() + + def test_single_tool_call(self): + """Test parsing a single tool call in a complete text.""" + text = '<|tool_calls_section_begin|><|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{"city": "Paris"}<|tool_call_end|><|tool_calls_section_end|>' + result = self.detector.detect_and_parse(text, self.tools) + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].name, "get_weather") + self.assertEqual(result.calls[0].parameters, '{"city": "Paris"}') + self.assertEqual(result.normal_text, "") + + def test_multiple_tool_calls(self): + """Test parsing multiple tool calls in a complete text.""" + text = '<|tool_calls_section_begin|><|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{"city": "Paris"}<|tool_call_end|><|tool_call_begin|>functions.get_tourist_attractions:1<|tool_call_argument_begin|>{"city": "London"}<|tool_call_end|><|tool_calls_section_end|>' + result = self.detector.detect_and_parse(text, self.tools) + self.assertEqual(len(result.calls), 2) + self.assertEqual(result.calls[0].name, "get_weather") + self.assertEqual(result.calls[0].parameters, '{"city": "Paris"}') + self.assertEqual(result.calls[1].name, "get_tourist_attractions") + self.assertEqual(result.calls[1].parameters, '{"city": "London"}') + self.assertEqual(result.normal_text, "") + + def test_streaming_tool_call(self): + """Test streaming incremental parsing of a tool call.""" + chunks = [ + "<|tool_calls_section_begin|><|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{", + '"city": "Paris"', + "}", + "<|tool_call_end|><|tool_calls_section_end|>", + ] + + tool_calls = [] + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, self.tools) + for tool_call_chunk in result.calls: + if tool_call_chunk.tool_index is not None: + + while len(tool_calls) <= tool_call_chunk.tool_index: + tool_calls.append({"name": "", "parameters": ""}) + + tc = tool_calls[tool_call_chunk.tool_index] + + if tool_call_chunk.name: + tc["name"] += tool_call_chunk.name + if tool_call_chunk.parameters: + tc["parameters"] += tool_call_chunk.parameters + + self.assertEqual(len(tool_calls), 1) + self.assertEqual(tool_calls[0]["name"], "get_weather") + self.assertEqual(tool_calls[0]["parameters"], '{"city": "Paris"}') + + def test_streaming_multiple_tool_calls(self): + """Test streaming incremental parsing of multiple tool calls.""" + chunks = [ + "<|tool_calls_section_begin|><|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{", + '"city": "Paris"', + "}<|tool_call_end|>", + "<|tool_call_begin|>functions.get_tourist_attractions:1<|tool_call_argument_begin|>{", + '"city": "London"', + "}<|tool_call_end|>", + "<|tool_calls_section_end|>", + ] + + tool_calls = [] + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, self.tools) + for tool_call_chunk in result.calls: + if tool_call_chunk.tool_index is not None: + + while len(tool_calls) <= tool_call_chunk.tool_index: + tool_calls.append({"name": "", "parameters": ""}) + + tc = tool_calls[tool_call_chunk.tool_index] + + if tool_call_chunk.name: + tc["name"] += tool_call_chunk.name + if tool_call_chunk.parameters: + tc["parameters"] += tool_call_chunk.parameters + + self.assertEqual(len(tool_calls), 2) + self.assertEqual(tool_calls[0]["name"], "get_weather") + self.assertEqual(tool_calls[0]["parameters"], '{"city": "Paris"}') + self.assertEqual(tool_calls[1]["name"], "get_tourist_attractions") + self.assertEqual(tool_calls[1]["parameters"], '{"city": "London"}') + + def test_tool_call_completion(self): + """Test that the buffer and state are reset after a tool call is completed.""" + chunks = [ + "<|tool_calls_section_begin|><|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{", + '"city": "Paris"', + "}", + "<|tool_call_end|>", + "<|tool_calls_section_end|>", + ] + + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, self.tools) + + # After processing all chunks, the buffer should be empty and current_tool_id should be reset + self.assertEqual(self.detector._buffer, "") + self.assertEqual(self.detector.current_tool_id, 1) + + def test_tool_name_streaming(self): + """Test that tool names are streamed correctly with the right index.""" + chunks = [ + "<|tool_calls_section_begin|><|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{", + '"city": "Paris"', + "}", + "<|tool_call_end|>", + "<|tool_call_begin|>functions.get_tourist_attractions:1<|tool_call_argument_begin|>{", + ] + + tool_calls = [] + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, self.tools) + for tool_call_chunk in result.calls: + if tool_call_chunk.tool_index is not None: + + while len(tool_calls) <= tool_call_chunk.tool_index: + tool_calls.append({"name": "", "parameters": ""}) + + tc = tool_calls[tool_call_chunk.tool_index] + + if tool_call_chunk.name: + tc["name"] += tool_call_chunk.name + if tool_call_chunk.parameters: + tc["parameters"] += tool_call_chunk.parameters + + self.assertEqual(len(tool_calls), 2) + self.assertEqual(tool_calls[0]["name"], "get_weather") + self.assertEqual(tool_calls[0]["parameters"], '{"city": "Paris"}') + self.assertEqual(tool_calls[1]["name"], "get_tourist_attractions") + + def test_invalid_tool_call(self): + """Test that invalid tool calls are handled correctly.""" + text = 'invalid_tool:0<|tool_call_argument_begin|>{"city": "Paris"}<|tool_call_end|><|tool_calls_section_end|>' + result = self.detector.detect_and_parse(text, self.tools) + self.assertEqual(len(result.calls), 0) + self.assertEqual(result.normal_text, text) + + def test_partial_tool_call(self): + """Test that partial tool calls are handled correctly in streaming mode.""" + chunks = [ + "<|tool_calls_section_begin|><|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{", + '"city": "Paris"', + ] + + tool_calls = [] + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, self.tools) + for tool_call_chunk in result.calls: + if tool_call_chunk.tool_index is not None: + + while len(tool_calls) <= tool_call_chunk.tool_index: + tool_calls.append({"name": "", "parameters": ""}) + + tc = tool_calls[tool_call_chunk.tool_index] + + if tool_call_chunk.name: + tc["name"] += tool_call_chunk.name + if tool_call_chunk.parameters: + tc["parameters"] += tool_call_chunk.parameters + + self.assertEqual(len(tool_calls), 1) + self.assertEqual(tool_calls[0]["name"], "get_weather") + self.assertEqual(tool_calls[0]["parameters"], '{"city": "Paris"') + + +class TestDeepSeekV3Detector(unittest.TestCase): + def setUp(self): + """Set up test tools and detector for DeepSeekV3 format testing.""" + self.tools = [ + Tool( + type="function", + function=Function( + name="get_weather", + description="Get weather information", + parameters={ + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "City name", + } + }, + "required": ["city"], + }, + ), + ), + Tool( + type="function", + function=Function( + name="get_tourist_attractions", + description="Get tourist attractions", + parameters={ + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "City name", + } + }, + "required": ["city"], + }, + ), + ), + ] + self.detector = DeepSeekV3Detector() + + def test_parse_streaming_multiple_tool_calls_with_multi_token_chunk(self): + """Test parsing multiple tool calls when streaming chunks contains multi-tokens (e.g. DeepSeekV3 enable MTP)""" + # Simulate streaming chunks with multi-tokens for two consecutive tool calls + chunks = [ + "<|tool▁calls▁begin|>", + "<|tool▁call▁begin|>function", + "<|tool▁sep|>get", + "_weather\n", + "```json\n", + '{"city":', + '"Shanghai', + '"}\n```<|tool▁call▁end|>', + "\n<|tool▁call▁begin|>", + "function<|tool▁sep|>", + "get_tour", + "ist_att", + "ractions\n```" 'json\n{"', + 'city": "', + 'Beijing"}\n', + "```<|tool▁call▁end|>", + "<|tool▁calls▁end|>", + ] + + tool_calls_seen = [] + tool_calls_parameters = [] + + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, self.tools) + if result.calls: + for call in result.calls: + if call.name: + tool_calls_seen.append(call.name) + if call.parameters: + tool_calls_parameters.append(call.parameters) + + # Should see both tool names + self.assertIn("get_weather", tool_calls_seen, "Should process first tool") + self.assertIn( + "get_tourist_attractions", tool_calls_seen, "Should process second tool" + ) + + # Verify that the parameters are valid JSON and contain the expected content + params1 = json.loads(tool_calls_parameters[0]) + params2 = json.loads(tool_calls_parameters[1]) + self.assertEqual(params1["city"], "Shanghai") + self.assertEqual(params2["city"], "Beijing") + + +class TestDeepSeekV32Detector(unittest.TestCase): + def setUp(self): + """Set up test tools and detector for DeepSeekV32 format testing.""" + self.tools = [ + Tool( + type="function", + function=Function( + name="search", + description="Searches for information related to query and displays topn results.", + parameters={ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "The search query string", + }, + "topn": { + "type": "integer", + "description": "Number of top results to display", + "default": 10, + }, + "source": { + "type": "string", + "description": "Source to search within", + "enum": ["web", "news"], + "default": "web", + }, + }, + "required": ["query"], + }, + ), + ), + Tool( + type="function", + function=Function( + name="get_favorite_tourist_spot", + description="Return the favorite tourist spot for a given city.", + parameters={ + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + }, + ), + ), + ] + self.detector = DeepSeekV32Detector() + from transformers import AutoTokenizer + + self.tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-V3.2") + self.interval = 1 + + def test_detect_and_parse_xml_format(self): + """Test parsing standard XML format (DSML)""" + text = """I'll help you with information about San Francisco and get its favorite tourist spot for you.\n\n + <|DSML|function_calls>\n + <|DSML|invoke name="get_favorite_tourist_spot">\n + <|DSML|parameter name="city" string="true">San Francisco\n + \n + <|DSML|invoke name="search"> + <|DSML|parameter name="query" string="true">WebNav benchmark + <|DSML|parameter name="topn" string="false">10 + <|DSML|parameter name="source" string="true">web + + + """ + result = self.detector.detect_and_parse(text, self.tools) + + self.assertIn("I'll help you with information", result.normal_text) + self.assertEqual(len(result.calls), 2) + + # Check first call + call1 = result.calls[0] + self.assertEqual(call1.name, "get_favorite_tourist_spot") + params1 = json.loads(call1.parameters) + self.assertEqual(params1["city"], "San Francisco") + + # Check second call + call2 = result.calls[1] + self.assertEqual(call2.name, "search") + params2 = json.loads(call2.parameters) + self.assertEqual(params2["query"], "WebNav benchmark") + self.assertEqual(params2["topn"], 10) + self.assertEqual(params2["source"], "web") + + def test_detect_and_parse_json_format(self): + """Test parsing JSON format inside invoke tags""" + text = """I'll help you with information about San Francisco and get its favorite tourist spot for you. + + <|DSML|function_calls> + <|DSML|invoke name="get_favorite_tourist_spot"> + { + "city": "San Francisco" + } + + <|DSML|invoke name="search"> + { + "query": "WebNav benchmark", + "topn": 10, + "source": "web" + } + + + """ + result = self.detector.detect_and_parse(text, self.tools) + + self.assertIn("I'll help you with information", result.normal_text) + self.assertEqual(len(result.calls), 2) + + # Check first call + call1 = result.calls[0] + self.assertEqual(call1.name, "get_favorite_tourist_spot") + params1 = json.loads(call1.parameters) + self.assertEqual(params1["city"], "San Francisco") + + # Check second call + call2 = result.calls[1] + self.assertEqual(call2.name, "search") + params2 = json.loads(call2.parameters) + self.assertEqual(params2["query"], "WebNav benchmark") + self.assertEqual(params2["topn"], 10) + self.assertEqual(params2["source"], "web") + + def test_streaming_xml_format(self): + """Test streaming parsing of XML format""" + text = """<|DSML|function_calls> + <|DSML|invoke name="get_favorite_tourist_spot"> + <|DSML|parameter name="city" string="true">San Francisco + <|DSML|parameter name="another_city" string="true">London + <|DSML|parameter name="topn" string="false">10 + <|DSML|parameter name="obj" string="false">{"name": "John", "age": 30} + + """ + + input_ids = self.tokenizer.encode(text, add_special_tokens=False) + chunk_ids = [ + input_ids[i : i + self.interval] + for i in range(0, len(input_ids), self.interval) + ] + chunks = [self.tokenizer.decode(chunk_id) for chunk_id in chunk_ids] + + tool_calls_by_index = {} + + num_tool_call_chunks = 0 + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, self.tools) + for call in result.calls: + num_tool_call_chunks += 1 + if call.tool_index is not None: + if call.tool_index not in tool_calls_by_index: + tool_calls_by_index[call.tool_index] = { + "name": "", + "parameters": "", + } + + if call.name: + tool_calls_by_index[call.tool_index]["name"] = call.name + if call.parameters: + tool_calls_by_index[call.tool_index][ + "parameters" + ] += call.parameters + + self.assertGreater(num_tool_call_chunks, 8) + + self.assertEqual(len(tool_calls_by_index), 1) + self.assertEqual(tool_calls_by_index[0]["name"], "get_favorite_tourist_spot") + params = json.loads(tool_calls_by_index[0]["parameters"]) + self.assertEqual(params["city"], "San Francisco") + self.assertEqual(params["another_city"], "London") + self.assertEqual(params["topn"], 10) + self.assertEqual(params["obj"]["name"], "John") + self.assertEqual(params["obj"]["age"], 30) + + def test_streaming_json_format(self): + """Test streaming parsing of JSON format""" + text = """<|DSML|function_calls> + <|DSML|invoke name="get_favorite_tourist_spot"> + { + "city": "San Francisco", + "another_city": "London", + "topn": 10, + "obj": { + "name": "John", + "age": 30 + } + } + + """ + + input_ids = self.tokenizer.encode(text, add_special_tokens=False) + chunk_ids = [ + input_ids[i : i + self.interval] + for i in range(0, len(input_ids), self.interval) + ] + chunks = [self.tokenizer.decode(chunk_id) for chunk_id in chunk_ids] + + tool_calls_by_index = {} + + num_tool_call_chunks = 0 + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, self.tools) + for call in result.calls: + num_tool_call_chunks += 1 + if call.tool_index is not None: + if call.tool_index not in tool_calls_by_index: + tool_calls_by_index[call.tool_index] = { + "name": "", + "parameters": "", + } + + if call.name: + tool_calls_by_index[call.tool_index]["name"] = call.name + if call.parameters: + tool_calls_by_index[call.tool_index][ + "parameters" + ] += call.parameters + + self.assertGreater(num_tool_call_chunks, 8) + self.assertEqual(len(tool_calls_by_index), 1) + self.assertEqual(tool_calls_by_index[0]["name"], "get_favorite_tourist_spot") + + # Clean up parameters string if needed (trim whitespace) + params_str = tool_calls_by_index[0]["parameters"].strip() + params = json.loads(params_str) + self.assertEqual(params["city"], "San Francisco") + + def test_detect_and_parse_no_parameters(self): + """Test parsing function calls with no parameters (non-streaming)""" + # Add a no-parameter tool + tools_with_no_param = self.tools + [ + Tool( + type="function", + function=Function( + name="get_date", + description="Get the current date.", + parameters={"type": "object", "properties": {}}, + ), + ), + ] + + text = """Let me get the current date for you. + +<|DSML|function_calls> +<|DSML|invoke name="get_date"> + +""" + + result = self.detector.detect_and_parse(text, tools_with_no_param) + + self.assertIn("Let me get the current date", result.normal_text) + self.assertEqual(len(result.calls), 1) + + call = result.calls[0] + self.assertEqual(call.name, "get_date") + params = json.loads(call.parameters) + self.assertEqual(params, {}) + + def test_streaming_no_parameters(self): + """Test streaming parsing of function calls with no parameters. + + This test verifies the fix for the bug where functions with no parameters + were being silently skipped in streaming mode. + """ + # Add a no-parameter tool + tools_with_no_param = self.tools + [ + Tool( + type="function", + function=Function( + name="get_date", + description="Get the current date.", + parameters={"type": "object", "properties": {}}, + ), + ), + ] + + text = """<|DSML|function_calls> +<|DSML|invoke name="get_date"> + +""" + + # Reset detector state + self.detector = DeepSeekV32Detector() + + # Simulate streaming by splitting into small chunks + input_ids = self.tokenizer.encode(text, add_special_tokens=False) + chunk_ids = [ + input_ids[i : i + self.interval] + for i in range(0, len(input_ids), self.interval) + ] + chunks = [self.tokenizer.decode(chunk_id) for chunk_id in chunk_ids] + + tool_calls_by_index = {} + + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, tools_with_no_param) + for call in result.calls: + if call.tool_index is not None: + if call.tool_index not in tool_calls_by_index: + tool_calls_by_index[call.tool_index] = { + "name": "", + "parameters": "", + } + + if call.name: + tool_calls_by_index[call.tool_index]["name"] = call.name + if call.parameters: + tool_calls_by_index[call.tool_index][ + "parameters" + ] += call.parameters + + # Verify that the no-parameter function was correctly parsed + self.assertEqual( + len(tool_calls_by_index), 1, "Should have exactly one tool call" + ) + self.assertEqual(tool_calls_by_index[0]["name"], "get_date") + + # Parameters should be empty JSON object + params_str = tool_calls_by_index[0]["parameters"].strip() + params = json.loads(params_str) + self.assertEqual(params, {}) + + def test_streaming_no_parameters_with_whitespace(self): + """Test streaming parsing when invoke content has only whitespace (newlines).""" + tools_with_no_param = self.tools + [ + Tool( + type="function", + function=Function( + name="get_date", + description="Get the current date.", + parameters={"type": "object", "properties": {}}, + ), + ), + ] + + # This format has newlines inside the invoke tag (common model output) + text = """<|DSML|function_calls> +<|DSML|invoke name="get_date"> + + +""" + + # Reset detector state + self.detector = DeepSeekV32Detector() + + input_ids = self.tokenizer.encode(text, add_special_tokens=False) + chunk_ids = [ + input_ids[i : i + self.interval] + for i in range(0, len(input_ids), self.interval) + ] + chunks = [self.tokenizer.decode(chunk_id) for chunk_id in chunk_ids] + + tool_calls_by_index = {} + + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, tools_with_no_param) + for call in result.calls: + if call.tool_index is not None: + if call.tool_index not in tool_calls_by_index: + tool_calls_by_index[call.tool_index] = { + "name": "", + "parameters": "", + } + + if call.name: + tool_calls_by_index[call.tool_index]["name"] = call.name + if call.parameters: + tool_calls_by_index[call.tool_index][ + "parameters" + ] += call.parameters + + # Should still parse correctly even with whitespace-only content + self.assertEqual( + len(tool_calls_by_index), 1, "Should have exactly one tool call" + ) + self.assertEqual(tool_calls_by_index[0]["name"], "get_date") + params = json.loads(tool_calls_by_index[0]["parameters"]) + self.assertEqual(params, {}) + + +class TestQwen3CoderDetector(unittest.TestCase): + """Test suite for Qwen3CoderDetector.""" + + def setUp(self): + """Initialize test fixtures before each test method.""" + self.tools = [ + Tool( + type="function", + function=Function( + name="get_current_weather", + parameters={ + "type": "object", + "properties": { + "location": {"type": "string"}, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + }, + "days": {"type": "integer"}, + }, + "required": ["location"], + }, + ), + ), + Tool( + type="function", + function=Function( + name="sql_interpreter", + parameters={ + "type": "object", + "properties": { + "query": {"type": "string"}, + "dry_run": {"type": "boolean"}, + }, + }, + ), + ), + Tool( + type="function", + function=Function( + name="TodoWrite", + parameters={ + "type": "object", + "properties": { + "todos": { + "type": "array", + "items": { + "type": "object", + "properties": { + "content": {"type": "string"}, + "status": {"type": "string"}, + }, + "required": ["content", "status"], + }, + }, + }, + }, + ), + ), + ] + self.detector = Qwen3CoderDetector() + + # ==================== Basic Functionality Tests ==================== + + def test_plain_text_only(self): + """ + Test parsing of plain text without any tool calls. + + Scenario: Input contains only plain text, no tool call markers. + Purpose: Verify that plain text is correctly identified and no false tool calls are detected. + """ + text = "This is plain text without any tool calls." + result = self.detector.detect_and_parse(text, self.tools) + + self.assertEqual(result.normal_text, text) + self.assertEqual(len(result.calls), 0) + + def test_single_tool_call(self): + """ + Test parsing of a single tool call. + + Scenario: Input contains one complete tool call with parameters. + Purpose: Verify correct extraction of tool name and parameters. + """ + text = """ + +Boston +celsius +3 + +""" + result = self.detector.detect_and_parse(text, self.tools) + + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].name, "get_current_weather") + + params = json.loads(result.calls[0].parameters) + self.assertEqual(params["location"], "Boston") + self.assertEqual(params["unit"], "celsius") + self.assertEqual(params["days"], 3) + + def test_single_tool_call_with_text_prefix(self): + """ + Test parsing of tool call with preceding text. + + Scenario: Input has plain text followed by a tool call. + Purpose: Verify correct separation of text and tool call. + """ + text = """Let me check the weather for you. + + + +New York + +""" + result = self.detector.detect_and_parse(text, self.tools) + + self.assertTrue(result.normal_text.startswith("Let me check")) + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].name, "get_current_weather") + + def test_multiple_tool_calls(self): + """ + Test parsing of multiple consecutive tool calls. + + Scenario: Input contains two tool calls one after another. + Purpose: Verify that multiple tool calls are correctly identified and parsed. + """ + text = """ + +New York + + + + +SELECT * FROM users +True + +""" + result = self.detector.detect_and_parse(text, self.tools) + + self.assertEqual(len(result.calls), 2) + self.assertEqual(result.calls[0].name, "get_current_weather") + self.assertEqual(result.calls[1].name, "sql_interpreter") + + params1 = json.loads(result.calls[0].parameters) + self.assertEqual(params1["location"], "New York") + + params2 = json.loads(result.calls[1].parameters) + self.assertEqual(params2["query"], "SELECT * FROM users") + self.assertEqual(params2["dry_run"], True) + + # ==================== Streaming Tests ==================== + + def test_streaming_single_tool_call(self): + """ + Test streaming parsing of a single tool call. + + Scenario: Tool call is fed incrementally in chunks. + Purpose: Verify streaming parser correctly assembles tool call from chunks. + """ + chunks = [ + "", + "", + "", + "Boston", + "", + "celsius", + "", + "", + ] + + detector = Qwen3CoderDetector() + all_calls = [] + collected_params = "" + + for chunk in chunks: + result = detector.parse_streaming_increment(chunk, self.tools) + all_calls.extend(result.calls) + for call in result.calls: + if call.parameters: + collected_params += call.parameters + + # Verify we got the tool call + self.assertGreater(len(all_calls), 0) + + # Verify parameters were collected + if collected_params: + params = json.loads(collected_params) + self.assertEqual(params["location"], "Boston") + self.assertEqual(params["unit"], "celsius") + + def test_streaming_with_text_and_tool(self): + """ + Test streaming parsing with mixed text and tool call. + + Scenario: Stream contains plain text followed by a tool call. + Purpose: Verify correct separation in streaming mode. + """ + chunks = [ + "Let me ", + "help you.\n\n", + "", + "", + "Paris", + "", + "", + ] + + detector = Qwen3CoderDetector() + full_text = "" + all_calls = [] + + for chunk in chunks: + result = detector.parse_streaming_increment(chunk, self.tools) + if result.normal_text: + full_text += result.normal_text + all_calls.extend(result.calls) + + self.assertTrue(full_text.startswith("Let me")) + self.assertGreater(len(all_calls), 0) + + # ==================== Parameter Type Tests ==================== + + def test_integer_parameter_conversion(self): + """ + Test correct type conversion for integer parameters. + + Scenario: Tool call with integer parameter. + Purpose: Verify integer values are correctly parsed and typed. + """ + text = """ + +Tokyo +5 + +""" + result = self.detector.detect_and_parse(text, self.tools) + + params = json.loads(result.calls[0].parameters) + self.assertIsInstance(params["days"], int) + self.assertEqual(params["days"], 5) + + def test_boolean_parameter_conversion(self): + """ + Test correct type conversion for boolean parameters. + + Scenario: Tool call with boolean parameter. + Purpose: Verify boolean values are correctly parsed. + """ + text = """ + +SELECT 1 +True + +""" + result = self.detector.detect_and_parse(text, self.tools) + + params = json.loads(result.calls[0].parameters) + self.assertIsInstance(params["dry_run"], bool) + self.assertEqual(params["dry_run"], True) + + def test_complex_array_parameter(self): + """ + Test parsing of complex array parameters. + + Scenario: Tool call with array of objects as parameter. + Purpose: Verify complex nested structures are correctly parsed. + """ + text = """ + + +[ + {"content": "Buy groceries", "status": "pending"}, + {"content": "Finish report", "status": "completed"} +] + + +""" + result = self.detector.detect_and_parse(text, self.tools) + + params = json.loads(result.calls[0].parameters) + self.assertIsInstance(params["todos"], list) + self.assertEqual(len(params["todos"]), 2) + self.assertEqual(params["todos"][0]["content"], "Buy groceries") + self.assertEqual(params["todos"][1]["status"], "completed") + + # ==================== Edge Cases ==================== + + def test_empty_parameter_value(self): + """ + Test handling of empty parameter values. + + Scenario: Tool call with empty parameter value. + Purpose: Verify empty values are handled gracefully. + """ + text = """ + + + +""" + result = self.detector.detect_and_parse(text, self.tools) + + self.assertEqual(len(result.calls), 1) + params = json.loads(result.calls[0].parameters) + self.assertEqual(params["location"], "") + + def test_parameter_with_special_characters(self): + """ + Test handling of parameters with special characters. + + Scenario: Parameter value contains special characters like quotes, newlines. + Purpose: Verify special characters are correctly preserved. + """ + text = """ + +SELECT * FROM users WHERE name = 'John "Doe"' + +""" + result = self.detector.detect_and_parse(text, self.tools) + + params = json.loads(result.calls[0].parameters) + self.assertIn("John", params["query"]) + self.assertIn("Doe", params["query"]) + + def test_incomplete_tool_call(self): + """ + Test handling of incomplete tool call at end of stream. + + Scenario: Stream ends with an incomplete tool call (missing closing tag). + Purpose: Verify detector handles incomplete input gracefully without crashing. + """ + text = """ + +London""" + + # Should not crash + result = self.detector.detect_and_parse(text, self.tools) + self.assertIsInstance(result, StreamingParseResult) + + def test_has_tool_call_detection(self): + """ + Test the has_tool_call method for detecting tool call markers. + + Scenario: Various inputs with and without tool call markers. + Purpose: Verify correct detection of tool call presence. + """ + self.assertTrue(self.detector.has_tool_call("")) + self.assertTrue(self.detector.has_tool_call("text more")) + self.assertFalse(self.detector.has_tool_call("plain text only")) + self.assertFalse(self.detector.has_tool_call("")) + + +class TestGlm4MoeDetector(unittest.TestCase): + def setUp(self): + self.tools = [ + Tool( + type="function", + function=Function( + name="get_weather", + description="Get weather information", + parameters={ + "type": "object", + "properties": { + "city": {"type": "string", "description": "City name"}, + "date": {"type": "string", "description": "Date"}, + }, + "required": ["city", "date"], + }, + ), + ), + ] + self.detector = Glm4MoeDetector() + + def test_single_tool_call(self): + text = ( + "get_weather\n" + "city\nBeijing\n" + "date\n2024-06-27\n" + "" + ) + result = self.detector.detect_and_parse(text, self.tools) + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].name, "get_weather") + self.assertEqual( + result.calls[0].parameters, '{"city": "Beijing", "date": "2024-06-27"}' + ) + self.assertEqual(result.normal_text, "") + + def test_multiple_tool_calls(self): + text = ( + "get_weather\n" + "city\nBeijing\n" + "date\n2024-06-27\n" + "" + "get_weather\n" + "city\nShanghai\n" + "date\n2024-06-28\n" + "" + ) + result = self.detector.detect_and_parse(text, self.tools) + self.assertEqual(len(result.calls), 2) + self.assertEqual(result.calls[0].name, "get_weather") + self.assertEqual( + result.calls[0].parameters, '{"city": "Beijing", "date": "2024-06-27"}' + ) + self.assertEqual(result.calls[1].name, "get_weather") + self.assertEqual( + result.calls[1].parameters, '{"city": "Shanghai", "date": "2024-06-28"}' + ) + self.assertEqual(result.normal_text, "") + + def test_streaming_tool_call(self): + """Test streaming incremental parsing of a tool call.""" + chunks = [ + "get_weather\n", + "city\nBeijing\n", + "date\n2024-06-27\n", + "", + ] + tool_calls = [] + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, self.tools) + for tool_call_chunk in result.calls: + if ( + hasattr(tool_call_chunk, "tool_index") + and tool_call_chunk.tool_index is not None + ): + while len(tool_calls) <= tool_call_chunk.tool_index: + tool_calls.append({"name": "", "parameters": ""}) + tc = tool_calls[tool_call_chunk.tool_index] + if tool_call_chunk.name: + tc["name"] = tool_call_chunk.name + if tool_call_chunk.parameters: + tc["parameters"] += tool_call_chunk.parameters + self.assertEqual(len(tool_calls), 1) + self.assertEqual(tool_calls[0]["name"], "get_weather") + self.assertEqual( + tool_calls[0]["parameters"], '{"city": "Beijing", "date": "2024-06-27"}' + ) + + def test_streaming_multiple_tool_calls(self): + """Test streaming incremental parsing of multiple tool calls.""" + chunks = [ + "get_weather\n", + "city\nBeijing\n", + "date\n2024-06-27\n", + "get_weather\n", + "city\nShanghai\n", + "date\n2024-06-28\n", + "", + ] + tool_calls = [] + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, self.tools) + for tool_call_chunk in result.calls: + if ( + hasattr(tool_call_chunk, "tool_index") + and tool_call_chunk.tool_index is not None + ): + while len(tool_calls) <= tool_call_chunk.tool_index: + tool_calls.append({"name": "", "parameters": ""}) + tc = tool_calls[tool_call_chunk.tool_index] + if tool_call_chunk.name: + tc["name"] = tool_call_chunk.name + if tool_call_chunk.parameters: + tc["parameters"] += tool_call_chunk.parameters + self.assertEqual(len(tool_calls), 2) + self.assertEqual(tool_calls[0]["name"], "get_weather") + self.assertEqual( + tool_calls[0]["parameters"], '{"city": "Beijing", "date": "2024-06-27"}' + ) + self.assertEqual(tool_calls[1]["name"], "get_weather") + self.assertEqual( + tool_calls[1]["parameters"], '{"city": "Shanghai", "date": "2024-06-28"}' + ) + + def test_tool_call_id(self): + """Test that the buffer and state are reset after a tool call is completed.""" + chunks = [ + "get_weather\n", + "city\nBeijing\n", + "date\n2024-06-27\n", + "", + ] + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, self.tools) + self.assertEqual(self.detector.current_tool_id, 1) + + def test_invalid_tool_call(self): + """Test that invalid tool calls are handled correctly.""" + text = "invalid_func\ncity\nBeijing\n" + result = self.detector.detect_and_parse(text, self.tools) + self.assertEqual(len(result.calls), 0) + + def test_partial_tool_call(self): + """Test parsing a partial tool call that spans multiple chunks.""" + chunks = [ + "get_weather\n", + "city\nBeijing\n", + "date\n2024-06-27\n", + ] + + tool_calls = [] + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, self.tools) + for tool_call_chunk in result.calls: + if ( + hasattr(tool_call_chunk, "tool_index") + and tool_call_chunk.tool_index is not None + ): + while len(tool_calls) <= tool_call_chunk.tool_index: + tool_calls.append({"name": "", "parameters": ""}) + tc = tool_calls[tool_call_chunk.tool_index] + if tool_call_chunk.name: + tc["name"] = tool_call_chunk.name + if tool_call_chunk.parameters: + tc["parameters"] += tool_call_chunk.parameters + + self.assertEqual(len(tool_calls), 1) + self.assertEqual(tool_calls[0]["name"], "get_weather") + self.assertEqual( + tool_calls[0]["parameters"], '{"city": "Beijing", "date": "2024-06-27"}' + ) + + def test_array_argument_with_escaped_json(self): + """Test that array arguments with escaped JSON are properly handled without double-escaping.""" + # Add a tool with array parameter + tools_with_array = [ + Tool( + type="function", + function=Function( + name="todo_write", + description="Write todos", + parameters={ + "type": "object", + "properties": { + "todos": { + "type": "array", + "description": "The updated todo list", + } + }, + "required": ["todos"], + }, + ), + ), + ] + + def check_params(result): + self.assertEqual(1, len(result.calls)) + self.assertEqual("todo_write", result.calls[0].name) + params = json.loads(result.calls[0].parameters) + self.assertIsInstance(params["todos"], list) + self.assertEqual(4, len(params["todos"])) + self.assertEqual("1", params["todos"][0]["id"]) + self.assertEqual( + "Check for hard-coded issues in the backend code", + params["todos"][0]["task"], + ) + self.assertEqual("in_progress", params["todos"][0]["status"]) + self.assertEqual("2", params["todos"][1]["id"]) + self.assertEqual( + "Check for hard-coded issues in the frontend code", + params["todos"][1]["task"], + ) + self.assertEqual("pending", params["todos"][1]["status"]) + self.assertEqual("3", params["todos"][2]["id"]) + self.assertEqual( + "Check for code violating the Single Responsibility Principle", + params["todos"][2]["task"], + ) + self.assertEqual("pending", params["todos"][2]["status"]) + self.assertEqual("4", params["todos"][3]["id"]) + self.assertEqual( + "Generate a rectification proposal report", params["todos"][3]["task"] + ) + self.assertEqual("pending", params["todos"][3]["status"]) + + # Simulate the raw response from GLM-4.6 model with normal and escaped JSON in XML + result = self.detector.detect_and_parse( + """todo_write\ntodos\n[{\"id\": \"1\", \"task\": \"Check for hard-coded issues in the backend code\", \"status\": \"in_progress\"}, {\"id\": \"2\", \"task\": \"Check for hard-coded issues in the frontend code\", \"status\": \"pending\"}, {\"id\": \"3\", \"task\": \"Check for code violating the Single Responsibility Principle\", \"status\": \"pending\"}, {\"id\": \"4\", \"task\": \"Generate a rectification proposal report\", \"status\": \"pending\"}] +""", + tools_with_array, + ) + check_params(result) + result = self.detector.detect_and_parse( + r"""todo_write\ntodos\n[{\"id\": \"1\", \"task\": \"Check for hard-coded issues in the backend code\", \"status\": \"in_progress\"}, {\"id\": \"2\", \"task\": \"Check for hard-coded issues in the frontend code\", \"status\": \"pending\"}, {\"id\": \"3\", \"task\": \"Check for code violating the Single Responsibility Principle\", \"status\": \"pending\"}, {\"id\": \"4\", \"task\": \"Generate a rectification proposal report\", \"status\": \"pending\"}] +""", + tools_with_array, + ) + check_params(result) + + def check_single_todos(tool_result, expected): + self.assertEqual(1, len(tool_result.calls)) + self.assertEqual("todo_write", tool_result.calls[0].name) + params = json.loads(tool_result.calls[0].parameters) + self.assertIsInstance(params["todos"], list) + self.assertEqual(1, len(params["todos"])) + self.assertEqual("1", params["todos"][0]["id"]) + self.assertEqual(expected, params["todos"][0]["task"]) + self.assertEqual("pending", params["todos"][0]["status"]) + + # Test with escaped JSON containing backslashes in content (e.g., Windows paths) + expected_path = r"Check file at C:\Users\test.txt" + result = self.detector.detect_and_parse( + """todo_write\ntodos\n[{\"id\": \"1\", \"task\": \"Check file at C:\\\\Users\\\\test.txt\", \"status\": \"pending\"}]""", + tools_with_array, + ) + check_single_todos(result, expected_path) + result = self.detector.detect_and_parse( + r"""todo_write\ntodos\n[{\"id\": \"1\", \"task\": \"Check file at C:\\\\Users\\\\test.txt\", \"status\": \"pending\"}]""", + tools_with_array, + ) + check_single_todos(result, expected_path) + + # Should contain literal \n, not actual newline + expected_output = r"Print \n to see newline" + result = self.detector.detect_and_parse( + """todo_write\ntodos\n[{\"id\": \"1\", \"task\": \"Print \\\\n to see newline\",\"status\": \"pending\"}]""", + tools_with_array, + ) + check_single_todos(result, expected_output) + result = self.detector.detect_and_parse( + r"""todo_write\ntodos\n[{\"id\": \"1\", \"task\": \"Print \\\\n to see newline\",\"status\": \"pending\"}]""", + tools_with_array, + ) + check_single_todos(result, expected_output) + + def test_empty_function_name_handling(self): + """Test that empty function name is handled gracefully without assertion error.""" + # This test simulates the issue where the model outputs only the start token without a function name + chunks = [ + "", # Start token only, no function name yet + "\n", # More content without function name + ] + + for chunk in chunks: + # Should not raise AssertionError: func_name should not be empty + result = self.detector.parse_streaming_increment(chunk, self.tools) + # Should return empty calls without error + self.assertIsInstance(result, StreamingParseResult) + self.assertEqual(result.calls, []) + + +class TestGlm47MoeDetector(unittest.TestCase): + def setUp(self): + self.tools = [ + Tool( + type="function", + function=Function( + name="get_weather", + description="Get weather information", + parameters={ + "type": "object", + "properties": { + "city": {"type": "string", "description": "City name"}, + "date": {"type": "string", "description": "Date"}, + }, + "required": ["city", "date"], + }, + ), + ), + ] + self.detector = Glm47MoeDetector() + + def test_single_tool_call(self): + text = ( + "get_weather" + "cityBeijing" + "date2024-06-27" + "" + ) + result = self.detector.detect_and_parse(text, self.tools) + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].name, "get_weather") + self.assertEqual( + result.calls[0].parameters, '{"city": "Beijing", "date": "2024-06-27"}' + ) + self.assertEqual(result.normal_text, "") + + def test_multiple_tool_calls(self): + text = ( + "get_weather" + "cityBeijing" + "date2024-06-27" + "" + "get_weather" + "cityShanghai" + "date2024-06-28" + "" + ) + result = self.detector.detect_and_parse(text, self.tools) + self.assertEqual(len(result.calls), 2) + self.assertEqual(result.calls[0].name, "get_weather") + self.assertEqual( + result.calls[0].parameters, '{"city": "Beijing", "date": "2024-06-27"}' + ) + self.assertEqual(result.calls[1].name, "get_weather") + self.assertEqual( + result.calls[1].parameters, '{"city": "Shanghai", "date": "2024-06-28"}' + ) + self.assertEqual(result.normal_text, "") + + def test_streaming_tool_call(self): + """Test streaming incremental parsing of a tool call.""" + chunks = [ + "get_weather", + "cityBeijing", + "date2024-06-27", + "", + ] + tool_calls = [] + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, self.tools) + for tool_call_chunk in result.calls: + if ( + hasattr(tool_call_chunk, "tool_index") + and tool_call_chunk.tool_index is not None + ): + while len(tool_calls) <= tool_call_chunk.tool_index: + tool_calls.append({"name": "", "parameters": ""}) + tc = tool_calls[tool_call_chunk.tool_index] + if tool_call_chunk.name: + tc["name"] = tool_call_chunk.name + if tool_call_chunk.parameters: + tc["parameters"] += tool_call_chunk.parameters + self.assertEqual(len(tool_calls), 1) + self.assertEqual(tool_calls[0]["name"], "get_weather") + self.assertEqual( + tool_calls[0]["parameters"], '{"city": "Beijing", "date": "2024-06-27"}' + ) + + def test_streaming_multiple_tool_calls(self): + """Test streaming incremental parsing of multiple tool calls.""" + chunks = [ + "get_weather", + "cityBeijing", + "date2024-06-27", + "get_weather", + "cityShanghai", + "date2024-06-28", + "", + ] + tool_calls = [] + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, self.tools) + for tool_call_chunk in result.calls: + if ( + hasattr(tool_call_chunk, "tool_index") + and tool_call_chunk.tool_index is not None + ): + while len(tool_calls) <= tool_call_chunk.tool_index: + tool_calls.append({"name": "", "parameters": ""}) + tc = tool_calls[tool_call_chunk.tool_index] + if tool_call_chunk.name: + tc["name"] = tool_call_chunk.name + if tool_call_chunk.parameters: + tc["parameters"] += tool_call_chunk.parameters + self.assertEqual(len(tool_calls), 2) + self.assertEqual(tool_calls[0]["name"], "get_weather") + self.assertEqual( + tool_calls[0]["parameters"], '{"city": "Beijing", "date": "2024-06-27"}' + ) + self.assertEqual(tool_calls[1]["name"], "get_weather") + self.assertEqual( + tool_calls[1]["parameters"], '{"city": "Shanghai", "date": "2024-06-28"}' + ) + + def test_tool_call_id(self): + """Test that the buffer and state are reset after a tool call is completed.""" + chunks = [ + "get_weather", + "cityBeijing", + "date2024-06-27", + "", + ] + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, self.tools) + self.assertEqual(self.detector.current_tool_id, 1) + + def test_invalid_tool_call(self): + """Test that invalid tool calls are handled correctly.""" + text = "invalid_funccityBeijing" + result = self.detector.detect_and_parse(text, self.tools) + self.assertEqual(len(result.calls), 0) + + def test_partial_tool_call(self): + """Test parsing a partial tool call that spans multiple chunks.""" + chunks = [ + "get_weather", + "cityBeijing", + "date2024-06-27", + ] + + tool_calls = [] + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, self.tools) + for tool_call_chunk in result.calls: + if ( + hasattr(tool_call_chunk, "tool_index") + and tool_call_chunk.tool_index is not None + ): + while len(tool_calls) <= tool_call_chunk.tool_index: + tool_calls.append({"name": "", "parameters": ""}) + tc = tool_calls[tool_call_chunk.tool_index] + if tool_call_chunk.name: + tc["name"] = tool_call_chunk.name + if tool_call_chunk.parameters: + tc["parameters"] += tool_call_chunk.parameters + + self.assertEqual(len(tool_calls), 1) + self.assertEqual(tool_calls[0]["name"], "get_weather") + self.assertEqual( + tool_calls[0]["parameters"], '{"city": "Beijing", "date": "2024-06-27"}' + ) + + def test_array_argument_with_escaped_json(self): + """Test that array arguments with escaped JSON are properly handled without double-escaping.""" + # Add a tool with array parameter + tools_with_array = [ + Tool( + type="function", + function=Function( + name="todo_write", + description="Write todos", + parameters={ + "type": "object", + "properties": { + "todos": { + "type": "array", + "description": "The updated todo list", + } + }, + "required": ["todos"], + }, + ), + ), + ] + + def check_params(result): + self.assertEqual(1, len(result.calls)) + self.assertEqual("todo_write", result.calls[0].name) + params = json.loads(result.calls[0].parameters) + self.assertIsInstance(params["todos"], list) + self.assertEqual(4, len(params["todos"])) + self.assertEqual("1", params["todos"][0]["id"]) + self.assertEqual( + "Check for hard-coded issues in the backend code", + params["todos"][0]["task"], + ) + self.assertEqual("in_progress", params["todos"][0]["status"]) + self.assertEqual("2", params["todos"][1]["id"]) + self.assertEqual( + "Check for hard-coded issues in the frontend code", + params["todos"][1]["task"], + ) + self.assertEqual("pending", params["todos"][1]["status"]) + self.assertEqual("3", params["todos"][2]["id"]) + self.assertEqual( + "Check for code violating the Single Responsibility Principle", + params["todos"][2]["task"], + ) + self.assertEqual("pending", params["todos"][2]["status"]) + self.assertEqual("4", params["todos"][3]["id"]) + self.assertEqual( + "Generate a rectification proposal report", params["todos"][3]["task"] + ) + self.assertEqual("pending", params["todos"][3]["status"]) + + # Simulate the raw response from GLM-4.6 model with normal and escaped JSON in XML + result = self.detector.detect_and_parse( + """todo_writetodos[{\"id\": \"1\", \"task\": \"Check for hard-coded issues in the backend code\", \"status\": \"in_progress\"}, {\"id\": \"2\", \"task\": \"Check for hard-coded issues in the frontend code\", \"status\": \"pending\"}, {\"id\": \"3\", \"task\": \"Check for code violating the Single Responsibility Principle\", \"status\": \"pending\"}, {\"id\": \"4\", \"task\": \"Generate a rectification proposal report\", \"status\": \"pending\"}] +""", + tools_with_array, + ) + check_params(result) + result = self.detector.detect_and_parse( + r"""todo_writetodos[{\"id\": \"1\", \"task\": \"Check for hard-coded issues in the backend code\", \"status\": \"in_progress\"}, {\"id\": \"2\", \"task\": \"Check for hard-coded issues in the frontend code\", \"status\": \"pending\"}, {\"id\": \"3\", \"task\": \"Check for code violating the Single Responsibility Principle\", \"status\": \"pending\"}, {\"id\": \"4\", \"task\": \"Generate a rectification proposal report\", \"status\": \"pending\"}] +""", + tools_with_array, + ) + check_params(result) + + def check_single_todos(tool_result, expected): + self.assertEqual(1, len(tool_result.calls)) + self.assertEqual("todo_write", tool_result.calls[0].name) + params = json.loads(tool_result.calls[0].parameters) + self.assertIsInstance(params["todos"], list) + self.assertEqual(1, len(params["todos"])) + self.assertEqual("1", params["todos"][0]["id"]) + self.assertEqual(expected, params["todos"][0]["task"]) + self.assertEqual("pending", params["todos"][0]["status"]) + + # Test with escaped JSON containing backslashes in content (e.g., Windows paths) + expected_path = r"Check file at C:\Users\test.txt" + result = self.detector.detect_and_parse( + """todo_writetodos[{\"id\": \"1\", \"task\": \"Check file at C:\\\\Users\\\\test.txt\", \"status\": \"pending\"}]""", + tools_with_array, + ) + check_single_todos(result, expected_path) + result = self.detector.detect_and_parse( + r"""todo_writetodos[{\"id\": \"1\", \"task\": \"Check file at C:\\\\Users\\\\test.txt\", \"status\": \"pending\"}]""", + tools_with_array, + ) + check_single_todos(result, expected_path) + + # Should contain literal \n, not actual newline + expected_output = r"Print \n to see newline" + result = self.detector.detect_and_parse( + """todo_writetodos[{\"id\": \"1\", \"task\": \"Print \\\\n to see newline\",\"status\": \"pending\"}]""", + tools_with_array, + ) + check_single_todos(result, expected_output) + result = self.detector.detect_and_parse( + r"""todo_writetodos[{\"id\": \"1\", \"task\": \"Print \\\\n to see newline\",\"status\": \"pending\"}]""", + tools_with_array, + ) + check_single_todos(result, expected_output) + + +class TestJsonArrayParser(unittest.TestCase): + def setUp(self): + # Create sample tools for testing + self.tools = [ + Tool( + type="function", + function=Function( + name="get_weather", + description="Get weather information", + parameters={ + "properties": { + "location": { + "type": "string", + "description": "Location to get weather for", + }, + "unit": { + "type": "string", + "description": "Temperature unit", + "enum": ["celsius", "fahrenheit"], + }, + }, + "required": ["location"], + }, + ), + ), + Tool( + type="function", + function=Function( + name="search", + description="Search for information", + parameters={ + "properties": { + "query": { + "type": "string", + "description": "Search query", + }, + }, + "required": ["query"], + }, + ), + ), + ] + self.detector = JsonArrayParser() + + def test_json_detector_has_no_ebnf(self): + """JsonArrayParser no longer exposes EBNF generation helpers.""" + self.assertFalse( + hasattr(self.detector, "build_ebnf"), + "JsonArrayParser should not expose EBNF helpers after cleanup", + ) + + def test_parse_streaming_increment_malformed_json(self): + """Test parsing with malformed JSON""" + # Test with malformed JSON + text = '[{"name": "get_weather", "parameters": {"location": "Tokyo"' + result = self.detector.parse_streaming_increment(text, self.tools) + + # Should not crash and return a valid result + self.assertIsInstance(result, StreamingParseResult) + + text = "[{}}}]" + result = self.detector.parse_streaming_increment(text, self.tools) + + self.assertIsInstance(result, StreamingParseResult) + + def test_parse_streaming_increment_empty_input(self): + """Test parsing with empty input""" + result = self.detector.parse_streaming_increment("", self.tools) + self.assertEqual(len(result.calls), 0) + self.assertEqual(result.normal_text, "") + + def test_parse_streaming_increment_whitespace_handling(self): + """Test parsing with various whitespace scenarios""" + # Test with leading/trailing whitespace split across chunks + chunk1 = ' [{"name": "get_weather", "parameters": ' + result1 = self.detector.parse_streaming_increment(chunk1, self.tools) + self.assertIsInstance(result1, StreamingParseResult) + chunk2 = '{"location": "Tokyo"}}] ' + result2 = self.detector.parse_streaming_increment(chunk2, self.tools) + + # The base class should handle this + self.assertIsInstance(result2, StreamingParseResult) + + def test_parse_streaming_increment_nested_objects(self): + """Test parsing with nested JSON objects""" + chunk1 = '[{"name": "get_weather", "parameters": {"location": "Tokyo", ' + result1 = self.detector.parse_streaming_increment(chunk1, self.tools) + self.assertIsInstance(result1, StreamingParseResult) + chunk2 = '"nested": {"key": "value"}}}]' + result2 = self.detector.parse_streaming_increment(chunk2, self.tools) + + # The base class should handle this + self.assertIsInstance(result2, StreamingParseResult) + + def test_json_parsing_with_commas(self): + """Test that JSON parsing works correctly with comma separators""" + # Stream two complete objects, at least 2 chunks per tool call + chunk1 = '[{"name": "get_weather", "parameters": {"location": "Tok' + result1 = self.detector.parse_streaming_increment(chunk1, self.tools) + self.assertIsInstance(result1, StreamingParseResult) + chunk2 = 'yo"}},' + result2 = self.detector.parse_streaming_increment(chunk2, self.tools) + self.assertIsInstance(result2, StreamingParseResult) + + chunk3 = '{"name": "get_weather", "parameters": {"location": "Par' + result3 = self.detector.parse_streaming_increment(chunk3, self.tools) + self.assertIsInstance(result3, StreamingParseResult) + chunk4 = 'is"}}]' + result4 = self.detector.parse_streaming_increment(chunk4, self.tools) + self.assertIsInstance(result4, StreamingParseResult) + self.assertGreater( + len(result4.calls), 0, "Should parse tool calls from text with separators" + ) + + def test_braces_in_strings(self): + """Test that JSON with } characters inside strings works correctly""" + # Test case: JSON array with } inside string values - streamed across chunks + chunk1 = '[{"name": "get_weather", "parameters": {"location": "has } inside"' + result1 = self.detector.parse_streaming_increment(chunk1, self.tools) + self.assertIsInstance(result1, StreamingParseResult) + chunk2 = "}}" + result2 = self.detector.parse_streaming_increment(chunk2, self.tools) + self.assertIsInstance(result2, StreamingParseResult) + self.assertGreater( + len(result2.calls), 0, "Should parse tool call with } in string" + ) + + # Test with separator (streaming in progress) + chunk3 = '[{"name": "get_weather", "parameters": {"location": "has } inside"}' + result3 = self.detector.parse_streaming_increment(chunk3, self.tools) + self.assertIsInstance(result3, StreamingParseResult) + chunk4 = "}," + result4 = self.detector.parse_streaming_increment(chunk4, self.tools) + self.assertIsInstance(result4, StreamingParseResult) + chunk5 = '{"name": "get_weather"' + result5 = self.detector.parse_streaming_increment(chunk5, self.tools) + self.assertIsInstance(result5, StreamingParseResult) + self.assertGreater( + len(result5.calls), + 0, + "Should parse tool calls with separator and } in string", + ) + + def test_separator_in_same_chunk(self): + """Test that separator already present in chunk works correctly""" + # Test case: separator already in the chunk (streaming in progress) with 2+ chunks per tool call + chunk1 = '[{"name": "get_weather", "parameters": {"location": "Tokyo"' + result1 = self.detector.parse_streaming_increment(chunk1, self.tools) + self.assertIsInstance(result1, StreamingParseResult) + chunk2 = '}},{"name": "get_weather"' + result2 = self.detector.parse_streaming_increment(chunk2, self.tools) + self.assertIsInstance(result2, StreamingParseResult) + self.assertGreater( + len(result2.calls), + 0, + "Should parse tool calls with separator in same chunk", + ) + + def test_separator_in_separate_chunk(self): + """Test that separator in separate chunk works correctly""" + # Test case: separator in separate chunk - this tests streaming behavior + chunk1 = '[{"name": "get_weather", "parameters": {"location": "Tokyo"}}' + chunk2 = "," + chunk3 = '{"name": "get_weather", "parameters": {"location": "Paris"}}' + + # Process first chunk + result1 = self.detector.parse_streaming_increment(chunk1, self.tools) + self.assertIsInstance(result1, StreamingParseResult) + + # Process separator chunk + result2 = self.detector.parse_streaming_increment(chunk2, self.tools) + self.assertIsInstance(result2, StreamingParseResult) + + # Process second chunk (streaming in progress) + result3 = self.detector.parse_streaming_increment(chunk3, self.tools) + self.assertIsInstance(result3, StreamingParseResult) + + def test_incomplete_json_across_chunks(self): + """Test that incomplete JSON across chunks works correctly""" + # Test case: incomplete JSON across chunks - this tests streaming behavior + chunk1 = '[{"name": "get_weather", "parameters": {"location": "Tokyo"' + chunk2 = '}},{"name": "get_weather"' + + # Process first chunk (incomplete) + result1 = self.detector.parse_streaming_increment(chunk1, self.tools) + self.assertIsInstance(result1, StreamingParseResult) + + # Process second chunk (completes first object and starts second, streaming in progress) + result2 = self.detector.parse_streaming_increment(chunk2, self.tools) + self.assertIsInstance(result2, StreamingParseResult) + + def test_malformed_json_recovery(self): + """Test that malformed JSON recovers gracefully""" + # Test with malformed JSON - should handle gracefully + malformed_text = ( + '[{"name": "get_weather", "parameters": {"location": "unclosed string' + ) + + result1 = self.detector.parse_streaming_increment(malformed_text, self.tools) + self.assertIsInstance(result1, StreamingParseResult) + + # Test valid JSON after malformed - streamed across 2 chunks (streaming in progress) + valid_chunk1 = '[{"name": "get_weather", "parameters": {"location": "Tok' + result2 = self.detector.parse_streaming_increment(valid_chunk1, self.tools) + self.assertIsInstance(result2, StreamingParseResult) + valid_chunk2 = 'yo"}}' + result3 = self.detector.parse_streaming_increment(valid_chunk2, self.tools) + self.assertIsInstance(result3, StreamingParseResult) + + def test_nested_objects_with_commas(self): + """Test that nested objects with commas inside work correctly""" + # Test with nested objects that have commas - should work with json.loads() + chunk1 = '[{"name": "get_weather", "parameters": {"location": "Tok' + result1 = self.detector.parse_streaming_increment(chunk1, self.tools) + self.assertIsInstance(result1, StreamingParseResult) + chunk2 = 'yo", "unit": "celsius"}}' + result2 = self.detector.parse_streaming_increment(chunk2, self.tools) + self.assertIsInstance(result2, StreamingParseResult) + self.assertGreater( + len(result2.calls), 0, "Should parse tool call with nested objects" + ) + + def test_empty_objects(self): + """Test that empty objects work correctly""" + # Test with empty objects - should work with json.loads() + chunk1 = '[{"name": "get_weather", "parameters": ' + result1 = self.detector.parse_streaming_increment(chunk1, self.tools) + self.assertIsInstance(result1, StreamingParseResult) + chunk2 = "{}}" + result2 = self.detector.parse_streaming_increment(chunk2, self.tools) + self.assertIsInstance(result2, StreamingParseResult) + + def test_whitespace_handling(self): + """Test that various whitespace scenarios work correctly""" + # Test with various whitespace patterns - should work with json.loads() + chunk1 = ' \n\n [{"name": "get_weather", "parameters": ' + result1 = self.detector.parse_streaming_increment(chunk1, self.tools) + self.assertIsInstance(result1, StreamingParseResult) + chunk2 = '{"location": "Tokyo"}}' + result2 = self.detector.parse_streaming_increment(chunk2, self.tools) + self.assertIsInstance(result2, StreamingParseResult) + + def test_multiple_commas_in_chunk(self): + """Test that multiple commas in a single chunk work correctly""" + # Stream multiple tool calls ensuring at least 2 chunks per complete tool call + chunk1 = '[{"name": "get_weather", "parameters": {"location": "To' + result1 = self.detector.parse_streaming_increment(chunk1, self.tools) + self.assertIsInstance(result1, StreamingParseResult) + chunk2 = 'kyo"}},' + result2 = self.detector.parse_streaming_increment(chunk2, self.tools) + self.assertIsInstance(result2, StreamingParseResult) + + chunk3 = '{"name": "get_weather", "parameters": {"location": "Pa' + result3 = self.detector.parse_streaming_increment(chunk3, self.tools) + self.assertIsInstance(result3, StreamingParseResult) + chunk4 = 'ris"}},' + result4 = self.detector.parse_streaming_increment(chunk4, self.tools) + self.assertIsInstance(result4, StreamingParseResult) + + chunk5 = '{"name": "get_weather"' + result5 = self.detector.parse_streaming_increment(chunk5, self.tools) + self.assertIsInstance(result5, StreamingParseResult) + self.assertGreater( + len(result5.calls), 0, "Should parse tool calls with multiple commas" + ) + + def test_complete_tool_call_with_trailing_comma(self): + """Test that complete tool call with trailing comma parses correctly""" + # Test case: complete tool call followed by comma at end of chunk (split across 2 chunks) + chunk1 = '[{"name": "get_weather", "parameters": {"location": "Tokyo"}' + result1 = self.detector.parse_streaming_increment(chunk1, self.tools) + self.assertIsInstance(result1, StreamingParseResult) + chunk2 = "}, " + result2 = self.detector.parse_streaming_increment(chunk2, self.tools) + self.assertIsInstance(result2, StreamingParseResult) + self.assertGreater(len(result2.calls), 0, "Should parse complete tool call") + + # Test that next chunk with opening brace gets the separator prepended + next_chunk = '{"name": "get_weather", "parameters": {"location": "Paris"}}' + result_next = self.detector.parse_streaming_increment(next_chunk, self.tools) + self.assertIsInstance(result_next, StreamingParseResult) + self.assertGreater( + len(result_next.calls), 0, "Should parse subsequent tool call" + ) + + def test_three_tool_calls_separate_chunks_with_commas(self): + """Test parsing 3 tool calls in separate chunks with commas at the end""" + # First tool call: 2 chunks + chunk1_1 = '[{"name": "get_weather", "parameters": ' + result1_1 = self.detector.parse_streaming_increment(chunk1_1, self.tools) + chunk1_2 = '{"location": "Tokyo"}},' + result1_2 = self.detector.parse_streaming_increment(chunk1_2, self.tools) + self.assertIsInstance(result1_2, StreamingParseResult) + self.assertGreater(len(result1_2.calls), 0, "Should parse first tool call") + + # Second tool call: 2 chunks + chunk2_1 = '{"name": "search", "parameters": ' + result2_1 = self.detector.parse_streaming_increment(chunk2_1, self.tools) + chunk2_2 = '{"query": "restaurants"}},' + result2_2 = self.detector.parse_streaming_increment(chunk2_2, self.tools) + self.assertIsInstance(result2_2, StreamingParseResult) + self.assertGreater(len(result2_2.calls), 0, "Should parse second tool call") + + # Third tool call: 2 chunks + chunk3_1 = '{"name": "get_weather", "parameters": ' + result3_1 = self.detector.parse_streaming_increment(chunk3_1, self.tools) + chunk3_2 = '{"location": "Paris"}}]' + result3_2 = self.detector.parse_streaming_increment(chunk3_2, self.tools) + self.assertIsInstance(result3_2, StreamingParseResult) + self.assertGreater(len(result3_2.calls), 0, "Should parse third tool call") + # Verify all tool calls were parsed correctly + total_calls = len(result1_2.calls) + len(result2_2.calls) + len(result3_2.calls) + self.assertEqual(total_calls, 3, "Should have parsed exactly 3 tool calls") + + +class TestLfm2Detector(unittest.TestCase): + """Tests for LFM2 (Liquid Foundation Model 2) function call detector.""" + + def setUp(self): + """Set up test tools and detector.""" + self.tools = [ + Tool( + type="function", + function=Function( + name="get_weather", + description="Get weather information", + parameters={ + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "City name", + }, + "unit": { + "type": "string", + "description": "Temperature unit", + "enum": ["celsius", "fahrenheit"], + }, + }, + "required": ["city"], + }, + ), + ), + Tool( + type="function", + function=Function( + name="search", + description="Search for information", + parameters={ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Search query", + }, + }, + "required": ["query"], + }, + ), + ), + Tool( + type="function", + function=Function( + name="calculator", + description="Perform calculations", + parameters={ + "type": "object", + "properties": { + "expression": { + "type": "string", + "description": "Math expression", + }, + }, + "required": ["expression"], + }, + ), + ), + ] + self.detector = Lfm2Detector() + + # ==================== has_tool_call tests ==================== + + def test_has_tool_call_true(self): + """Test detection of tool call markers.""" + text = '<|tool_call_start|>[get_weather(city="Paris")]<|tool_call_end|>' + self.assertTrue(self.detector.has_tool_call(text)) + + def test_has_tool_call_false(self): + """Test no false positives for regular text.""" + text = "The weather in Paris is nice today." + self.assertFalse(self.detector.has_tool_call(text)) + + def test_has_tool_call_partial_marker(self): + """Test that partial markers are detected (start token present).""" + text = '<|tool_call_start|>[get_weather(city="Paris")' + self.assertTrue(self.detector.has_tool_call(text)) + + # ==================== detect_and_parse tests (Pythonic format) ==================== + + def test_detect_and_parse_pythonic_simple(self): + """Test parsing a simple Pythonic format tool call.""" + text = '<|tool_call_start|>[get_weather(city="Paris")]<|tool_call_end|>' + result = self.detector.detect_and_parse(text, self.tools) + + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].name, "get_weather") + self.assertEqual(result.calls[0].tool_index, 0) + + params = json.loads(result.calls[0].parameters) + self.assertEqual(params["city"], "Paris") + + def test_detect_and_parse_pythonic_multiple_args(self): + """Test parsing with multiple arguments.""" + text = '<|tool_call_start|>[get_weather(city="London", unit="celsius")]<|tool_call_end|>' + result = self.detector.detect_and_parse(text, self.tools) + + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].name, "get_weather") + + params = json.loads(result.calls[0].parameters) + self.assertEqual(params["city"], "London") + self.assertEqual(params["unit"], "celsius") + + def test_detect_and_parse_pythonic_no_args(self): + """Test parsing function with no arguments.""" + # Add a no-arg tool for this test + tools_with_noarg = self.tools + [ + Tool( + type="function", + function=Function( + name="get_time", + description="Get current time", + parameters={"type": "object", "properties": {}}, + ), + ), + ] + text = "<|tool_call_start|>[get_time()]<|tool_call_end|>" + result = self.detector.detect_and_parse(text, tools_with_noarg) + + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].name, "get_time") + + def test_detect_and_parse_pythonic_multiple_calls(self): + """Test parsing multiple tool calls in one block.""" + text = '<|tool_call_start|>[get_weather(city="Paris"), search(query="restaurants")]<|tool_call_end|>' + result = self.detector.detect_and_parse(text, self.tools) + + self.assertEqual(len(result.calls), 2) + self.assertEqual(result.calls[0].name, "get_weather") + self.assertEqual(result.calls[1].name, "search") + + params1 = json.loads(result.calls[0].parameters) + params2 = json.loads(result.calls[1].parameters) + self.assertEqual(params1["city"], "Paris") + self.assertEqual(params2["query"], "restaurants") + + def test_detect_and_parse_with_normal_text_before(self): + """Test parsing with normal text before the tool call.""" + text = 'Let me check the weather for you. <|tool_call_start|>[get_weather(city="Tokyo")]<|tool_call_end|>' + result = self.detector.detect_and_parse(text, self.tools) + + self.assertEqual(result.normal_text, "Let me check the weather for you.") + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].name, "get_weather") + + def test_detect_and_parse_special_characters_in_value(self): + """Test parsing with special characters in argument values.""" + text = ( + '<|tool_call_start|>[search(query="what\'s the weather?")]<|tool_call_end|>' + ) + result = self.detector.detect_and_parse(text, self.tools) + + self.assertEqual(len(result.calls), 1) + params = json.loads(result.calls[0].parameters) + self.assertIn("weather", params["query"]) + + def test_detect_and_parse_numeric_values(self): + """Test parsing with numeric argument values.""" + text = '<|tool_call_start|>[calculator(expression="5 * 7")]<|tool_call_end|>' + result = self.detector.detect_and_parse(text, self.tools) + + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].name, "calculator") + + # ==================== detect_and_parse tests (JSON format) ==================== + + def test_detect_and_parse_json_simple(self): + """Test parsing JSON format tool call.""" + text = '<|tool_call_start|>[{"name": "get_weather", "arguments": {"city": "Berlin"}}]<|tool_call_end|>' + result = self.detector.detect_and_parse(text, self.tools) + + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].name, "get_weather") + + params = json.loads(result.calls[0].parameters) + self.assertEqual(params["city"], "Berlin") + + def test_detect_and_parse_json_multiple_calls(self): + """Test parsing multiple JSON format tool calls.""" + text = '<|tool_call_start|>[{"name": "get_weather", "arguments": {"city": "Paris"}}, {"name": "search", "arguments": {"query": "hotels"}}]<|tool_call_end|>' + result = self.detector.detect_and_parse(text, self.tools) + + self.assertEqual(len(result.calls), 2) + self.assertEqual(result.calls[0].name, "get_weather") + self.assertEqual(result.calls[1].name, "search") + + def test_detect_and_parse_json_with_parameters_key(self): + """Test parsing JSON format with 'parameters' key instead of 'arguments'.""" + text = '<|tool_call_start|>[{"name": "get_weather", "parameters": {"city": "Madrid"}}]<|tool_call_end|>' + result = self.detector.detect_and_parse(text, self.tools) + + self.assertEqual(len(result.calls), 1) + params = json.loads(result.calls[0].parameters) + self.assertEqual(params["city"], "Madrid") + + # ==================== Edge cases ==================== + + def test_detect_and_parse_no_tool_call(self): + """Test parsing text with no tool calls.""" + text = "This is just regular text without any tool calls." + result = self.detector.detect_and_parse(text, self.tools) + + self.assertEqual(result.normal_text, text) + self.assertEqual(result.calls, []) + + def test_detect_and_parse_unknown_function(self): + """Test parsing with unknown function name - skipped by default (SGLANG_FORWARD_UNKNOWN_TOOLS=false).""" + text = '<|tool_call_start|>[unknown_function(arg="value")]<|tool_call_end|>' + result = self.detector.detect_and_parse(text, self.tools) + + # By default, unknown functions are skipped (consistent with other detectors) + self.assertEqual(len(result.calls), 0) + + def test_detect_and_parse_empty_content(self): + """Test parsing with empty content between markers.""" + text = "<|tool_call_start|><|tool_call_end|>" + result = self.detector.detect_and_parse(text, self.tools) + + self.assertEqual(result.calls, []) + + def test_detect_and_parse_multiple_blocks(self): + """Test parsing multiple separate tool call blocks.""" + text = '<|tool_call_start|>[get_weather(city="Paris")]<|tool_call_end|> Some text <|tool_call_start|>[search(query="food")]<|tool_call_end|>' + result = self.detector.detect_and_parse(text, self.tools) + + self.assertEqual(len(result.calls), 2) + self.assertEqual(result.calls[0].name, "get_weather") + self.assertEqual(result.calls[1].name, "search") + + # ==================== Streaming tests ==================== + # The LFM2 detector buffers until it sees complete <|tool_call_start|>...<|tool_call_end|> + # blocks, then parses the complete block. This allows proper handling of both + # JSON and Pythonic formats. + + def test_streaming_json_complete_in_one_chunk(self): + """Test streaming with complete JSON tool call in one chunk.""" + text = '<|tool_call_start|>{"name": "get_weather", "arguments": {"city": "Rome"}}<|tool_call_end|>' + result = self.detector.parse_streaming_increment(text, self.tools) + + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].name, "get_weather") + + def test_streaming_json_split_across_chunks(self): + """Test streaming with JSON tool call split across multiple chunks - waits for complete block.""" + # Reset detector state + self.detector = Lfm2Detector() + + # First chunk: start marker and partial JSON (no end token) + chunk1 = '<|tool_call_start|>{"name": "get_weather", "arguments": {"city": ' + result1 = self.detector.parse_streaming_increment(chunk1, self.tools) + + # Should buffer and not emit calls yet (waiting for complete block) + self.assertEqual(len(result1.calls), 0) + self.assertEqual(result1.normal_text, "") + + # Second chunk: complete the JSON and end token + chunk2 = '"Vienna"}}<|tool_call_end|>' + result2 = self.detector.parse_streaming_increment(chunk2, self.tools) + + # Now should have the complete tool call + self.assertEqual(len(result2.calls), 1) + self.assertEqual(result2.calls[0].name, "get_weather") + + def test_streaming_json_normal_text_before_tool_call(self): + """Test streaming with normal text before JSON tool call.""" + # Reset detector state + self.detector = Lfm2Detector() + + chunk1 = "I'll check the weather. " + result1 = self.detector.parse_streaming_increment(chunk1, self.tools) + + # Normal text should be returned + self.assertIn("check the weather", result1.normal_text) + + chunk2 = '<|tool_call_start|>{"name": "get_weather", "arguments": {"city": "Amsterdam"}}<|tool_call_end|>' + result2 = self.detector.parse_streaming_increment(chunk2, self.tools) + + self.assertEqual(len(result2.calls), 1) + + def test_streaming_eot_token_filtering(self): + """Test that end-of-turn token is filtered from normal text.""" + # Reset detector state + self.detector = Lfm2Detector() + + # Send text that ends with tool call end token (JSON format) + text = '<|tool_call_start|>{"name": "get_weather", "arguments": {"city": "Oslo"}}<|tool_call_end|>' + result = self.detector.parse_streaming_increment(text, self.tools) + + # The normal_text should not contain the eot_token + self.assertNotIn("<|tool_call_end|>", result.normal_text) + + # ==================== Pythonic streaming tests ==================== + + def test_streaming_pythonic_complete_in_one_chunk(self): + """Test streaming with complete Pythonic tool call in one chunk.""" + self.detector = Lfm2Detector() + text = '<|tool_call_start|>[get_weather(city="Berlin")]<|tool_call_end|>' + result = self.detector.parse_streaming_increment(text, self.tools) + + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].name, "get_weather") + self.assertEqual(json.loads(result.calls[0].parameters), {"city": "Berlin"}) + + def test_streaming_pythonic_split_across_chunks(self): + """Test streaming with Pythonic tool call split across multiple chunks.""" + self.detector = Lfm2Detector() + + # First chunk: start marker and partial call + chunk1 = '<|tool_call_start|>[get_weather(city="' + result1 = self.detector.parse_streaming_increment(chunk1, self.tools) + + # Should buffer and not emit calls yet + self.assertEqual(len(result1.calls), 0) + + # Second chunk: complete the call + chunk2 = 'Munich")]<|tool_call_end|>' + result2 = self.detector.parse_streaming_increment(chunk2, self.tools) + + # Now should have the complete tool call + self.assertEqual(len(result2.calls), 1) + self.assertEqual(result2.calls[0].name, "get_weather") + self.assertEqual(json.loads(result2.calls[0].parameters), {"city": "Munich"}) + + def test_streaming_pythonic_multiple_calls(self): + """Test streaming with multiple Pythonic tool calls.""" + self.detector = Lfm2Detector() + + text = '<|tool_call_start|>[get_weather(city="Paris"), search(query="hotels")]<|tool_call_end|>' + result = self.detector.parse_streaming_increment(text, self.tools) + + self.assertEqual(len(result.calls), 2) + self.assertEqual(result.calls[0].name, "get_weather") + self.assertEqual(result.calls[1].name, "search") + + # ==================== structure_info tests ==================== + + def test_supports_structural_tag(self): + """Test that LFM2 does not support structural tags (Pythonic format).""" + # LFM2 uses Pythonic format which is not JSON-compatible, + # so structural_tag constrained generation cannot be used + self.assertFalse(self.detector.supports_structural_tag()) + + def test_structure_info(self): + """Test structure info for constrained generation.""" + info_func = self.detector.structure_info() + info = info_func("get_weather") + + self.assertEqual(info.begin, "<|tool_call_start|>[get_weather(") + self.assertEqual(info.end, ")]<|tool_call_end|>") + self.assertEqual(info.trigger, "<|tool_call_start|>") + + +class TestGigaChat3Detector(unittest.TestCase): + def setUp(self): + self.tools = [ + Tool( + type="function", + function=Function( + name="manage_user_memory", + description="Create, update, or delete a user memory entry.", + parameters={ + "type": "object", + "properties": { + "content": { + "anyOf": [{"type": "string"}, {"type": "null"}], + "default": None, + }, + "action": { + "type": "string", + "enum": ["create", "update", "delete"], + "default": "create", + }, + "id": { + "anyOf": [ + {"type": "string", "format": "uuid"}, + {"type": "null"}, + ], + "default": None, + }, + }, + }, + ), + ), + Tool( + type="function", + function=Function( + name="get_weather", + description="Get weather information", + parameters={ + "type": "object", + "properties": { + "city": {"type": "string", "description": "City name"}, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + }, + }, + "required": ["city"], + }, + ), + ), + ] + self.detector = GigaChat3Detector() + + def test_has_tool_call(self): + """Test detection of tool call markers.""" + self.assertTrue(self.detector.has_tool_call("function call<|role_sep|>\n{}")) + self.assertFalse(self.detector.has_tool_call("No tool call here")) + + def test_detect_and_parse_no_tool_call(self): + """Test parsing text without tool calls.""" + text = "How can I help you today?" + result = self.detector.detect_and_parse(text, self.tools) + + self.assertEqual(result.normal_text, text) + self.assertEqual(len(result.calls), 0) + + def test_detect_and_parse_simple_tool_call(self): + """Test parsing a simple tool call without content.""" + text = '<|message_sep|>\n\nfunction call<|role_sep|>\n{"name": "manage_user_memory", "arguments": {"action": "create", "id": "preferences"}}' + result = self.detector.detect_and_parse(text, self.tools) + + # No content before tool call + self.assertEqual(result.normal_text, "") + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].name, "manage_user_memory") + + params = json.loads(result.calls[0].parameters) + self.assertEqual(params["action"], "create") + self.assertEqual(params["id"], "preferences") + + def test_detect_and_parse_parameterless_tool_call(self): + """Test parsing a tool call with empty arguments.""" + text = '<|message_sep|>\n\nfunction call<|role_sep|>\n{"name": "manage_user_memory", "arguments": {}}' + result = self.detector.detect_and_parse(text, self.tools) + + self.assertEqual(result.normal_text, "") + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].name, "manage_user_memory") + + params = json.loads(result.calls[0].parameters) + self.assertEqual(params, {}) + + def test_detect_and_parse_complex_tool_call(self): + """Test parsing a tool call with nested objects.""" + text = """<|message_sep|> + +function call<|role_sep|> +{"name": "manage_user_memory", "arguments": {"action": "create", "id": "preferences", "content": {"short_answers": true, "hate_emojis": true, "english_ui": false, "russian_math_explanations": true}}}""" + + result = self.detector.detect_and_parse(text, self.tools) + + self.assertEqual(result.normal_text, "") + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].name, "manage_user_memory") + + params = json.loads(result.calls[0].parameters) + self.assertEqual(params["action"], "create") + self.assertEqual(params["id"], "preferences") + self.assertIsInstance(params["content"], dict) + self.assertEqual(params["content"]["short_answers"], True) + self.assertEqual(params["content"]["hate_emojis"], True) + + def test_detect_and_parse_with_content_before(self): + """Test parsing tool call with text content before it.""" + text = 'I\'ll check that for you.<|message_sep|>\n\nfunction call<|role_sep|>\n{"name": "manage_user_memory", "arguments": {"action": "create", "id": "preferences"}}' + result = self.detector.detect_and_parse(text, self.tools) + + self.assertEqual(result.normal_text, "I'll check that for you.") + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].name, "manage_user_memory") + + def test_detect_and_parse_with_eos_token(self): + """Test parsing tool call with EOS token at the end.""" + text = '<|message_sep|>\n\nfunction call<|role_sep|>\n{"name": "manage_user_memory", "arguments": {"action": "create", "id": "preferences"}}' + result = self.detector.detect_and_parse(text, self.tools) + + self.assertEqual(result.normal_text, "") + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].name, "manage_user_memory") + + params = json.loads(result.calls[0].parameters) + self.assertEqual(params["action"], "create") + self.assertEqual(params["id"], "preferences") + + def test_detect_and_parse_with_content_and_eos(self): + """Test parsing tool call with content and EOS token.""" + text = 'I\'ll remember that.<|message_sep|>\n\nfunction call<|role_sep|>\n{"name": "manage_user_memory", "arguments": {"action": "create", "id": "test"}}' + result = self.detector.detect_and_parse(text, self.tools) + + self.assertEqual(result.normal_text, "I'll remember that.") + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].name, "manage_user_memory") + + def test_detect_and_parse_invalid_json(self): + """Test parsing with invalid JSON in function call.""" + text = '<|message_sep|>\n\nfunction call<|role_sep|>\n{"name": "manage_user_memory", "arguments": {invalid json}}' + result = self.detector.detect_and_parse(text, self.tools) + + # Should return the full text as content when JSON parsing fails + self.assertIn("function call", result.normal_text) + self.assertEqual(len(result.calls), 0) + + def test_detect_and_parse_missing_name(self): + """Test parsing with missing function name.""" + text = '<|message_sep|>\n\nfunction call<|role_sep|>\n{"arguments": {"action": "create"}}' + result = self.detector.detect_and_parse(text, self.tools) + + # Should not extract tool call if name is missing + self.assertEqual(len(result.calls), 0) + + def test_detect_and_parse_missing_arguments(self): + """Test parsing with missing arguments field.""" + text = '<|message_sep|>\n\nfunction call<|role_sep|>\n{"name": "manage_user_memory"}' + result = self.detector.detect_and_parse(text, self.tools) + + # Should not extract tool call if arguments is missing + self.assertEqual(len(result.calls), 0) + + def test_detect_and_parse_arguments_not_dict(self): + """Test parsing with arguments that is not a dict.""" + text = '<|message_sep|>\n\nfunction call<|role_sep|>\n{"name": "manage_user_memory", "arguments": "string_args"}' + result = self.detector.detect_and_parse(text, self.tools) + + # Should not extract tool call if arguments is not a dict + self.assertEqual(len(result.calls), 0) + + def test_streaming_no_tool_call(self): + """Test streaming text without tool calls.""" + chunks = ["How ", "can ", "I ", "help ", "you?"] + + accumulated_text = "" + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, self.tools) + accumulated_text += result.normal_text + + self.assertEqual(accumulated_text, "How can I help you?") + self.assertEqual(len(result.calls), 0) + + def test_streaming_simple_tool_call(self): + """Test streaming a simple tool call.""" + chunks = [ + "<|message_sep|>\n\n", + "function call", + "<|role_sep|>\n", + '{"name": "manage_user_memory", ', + '"arguments": {"action": "create"', + ', "id": "preferences"}}', + ] + + tool_calls_by_index = {} + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, self.tools) + + for call in result.calls: + if call.tool_index is not None: + if call.tool_index not in tool_calls_by_index: + tool_calls_by_index[call.tool_index] = { + "name": "", + "parameters": "", + } + + if call.name: + tool_calls_by_index[call.tool_index]["name"] = call.name + if call.parameters: + tool_calls_by_index[call.tool_index][ + "parameters" + ] += call.parameters + + self.assertEqual(len(tool_calls_by_index), 1) + self.assertEqual(tool_calls_by_index[0]["name"], "manage_user_memory") + + params = json.loads(tool_calls_by_index[0]["parameters"]) + self.assertEqual(params["action"], "create") + self.assertEqual(params["id"], "preferences") + + def test_streaming_with_content_before(self): + """Test streaming with content before tool call.""" + chunks = [ + "I'll ", + "help ", + "you.", + "<|message_sep|>\n\n", + "function call", + "<|role_sep|>\n", + '{"name": "get_weather", ', + '"arguments": {"city": "Tokyo"}}', + ] + + accumulated_text = "" + tool_calls_by_index = {} + + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, self.tools) + accumulated_text += result.normal_text + + for call in result.calls: + if call.tool_index is not None: + if call.tool_index not in tool_calls_by_index: + tool_calls_by_index[call.tool_index] = { + "name": "", + "parameters": "", + } + + if call.name: + tool_calls_by_index[call.tool_index]["name"] = call.name + if call.parameters: + tool_calls_by_index[call.tool_index][ + "parameters" + ] += call.parameters + + self.assertEqual(accumulated_text, "I'll help you.") + self.assertEqual(len(tool_calls_by_index), 1) + self.assertEqual(tool_calls_by_index[0]["name"], "get_weather") + + params = json.loads(tool_calls_by_index[0]["parameters"]) + self.assertEqual(params["city"], "Tokyo") + + def test_streaming_complex_arguments(self): + """Test streaming with complex nested arguments.""" + chunks = [ + "<|message_sep|>\n\n", + "functi", + "on call<|role_sep|>\n", + '{"name": "manage_user_memory", "arguments": ', + '{"action": "create", "id": "prefs", ', + '"content": {"likes": ["short", "clear"], ', + '"dislikes": ["emojis", "verbose"]}', + "}}", + ] + + tool_calls_by_index = {} + + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, self.tools) + + for call in result.calls: + if call.tool_index is not None: + if call.tool_index not in tool_calls_by_index: + tool_calls_by_index[call.tool_index] = { + "name": "", + "parameters": "", + } + + if call.name: + tool_calls_by_index[call.tool_index]["name"] = call.name + if call.parameters: + tool_calls_by_index[call.tool_index][ + "parameters" + ] += call.parameters + + self.assertEqual(len(tool_calls_by_index), 1) + self.assertEqual(tool_calls_by_index[0]["name"], "manage_user_memory") + + params = json.loads(tool_calls_by_index[0]["parameters"]) + self.assertEqual(params["action"], "create") + self.assertEqual(params["content"]["likes"], ["short", "clear"]) + self.assertEqual(params["content"]["dislikes"], ["emojis", "verbose"]) + + def test_streaming_with_eos_token(self): + """Test streaming with EOS token at the end.""" + chunks = [ + "<|message_sep|>\n\n", + "function c", + "all<|role_sep|>\n", + '{"name": "get_weather", ', + '"arguments": {"city": "Paris"}}', + "", + ] + + tool_calls_by_index = {} + + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, self.tools) + + for call in result.calls: + if call.tool_index is not None: + if call.tool_index not in tool_calls_by_index: + tool_calls_by_index[call.tool_index] = { + "name": "", + "parameters": "", + } + + if call.name: + tool_calls_by_index[call.tool_index]["name"] = call.name + if call.parameters: + tool_calls_by_index[call.tool_index][ + "parameters" + ] += call.parameters + + self.assertEqual(len(tool_calls_by_index), 1) + self.assertEqual(tool_calls_by_index[0]["name"], "get_weather") + + params = json.loads(tool_calls_by_index[0]["parameters"]) + self.assertEqual(params["city"], "Paris") + + def test_streaming_incomplete_json(self): + """Test streaming with incomplete JSON (no closing brace).""" + chunks = [ + "<|message_sep|>\n\n", + "fun", + "ction call<|role_sep|>\n", + '{"name": "get_weather", ', + '"arguments": {"city": "London"', + # Missing closing braces + ] + + tool_calls_by_index = {} + + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, self.tools) + + for call in result.calls: + if call.tool_index is not None: + if call.tool_index not in tool_calls_by_index: + tool_calls_by_index[call.tool_index] = { + "name": "", + "parameters": "", + } + + if call.name: + tool_calls_by_index[call.tool_index]["name"] = call.name + if call.parameters: + tool_calls_by_index[call.tool_index][ + "parameters" + ] += call.parameters + + # Should have name but incomplete parameters + self.assertEqual(len(tool_calls_by_index), 1) + self.assertEqual(tool_calls_by_index[0]["name"], "get_weather") + self.assertTrue(tool_calls_by_index[0]["parameters"].startswith('{"city":')) + + def test_streaming_large_steps(self): + """Test streaming with large chunks that complete in fewer steps.""" + chunks = [ + "I'll remember that.", + "<|message_sep|>\n\nfuncti", + "on call<|role_sep|>\n", + '{"name": "manage_user_memory", "arguments": {"action": "create", "id": "preferences", "content": {"short_answers": true, "hate_emojis": true, ', + '"english_ui": false, "russian_math_explanations": true}', + "}}", + ] + + accumulated_text = "" + tool_calls_by_index = {} + + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, self.tools) + accumulated_text += result.normal_text + + for call in result.calls: + if call.tool_index is not None: + if call.tool_index not in tool_calls_by_index: + tool_calls_by_index[call.tool_index] = { + "name": "", + "parameters": "", + } + + if call.name: + tool_calls_by_index[call.tool_index]["name"] = call.name + if call.parameters: + tool_calls_by_index[call.tool_index][ + "parameters" + ] += call.parameters + + self.assertEqual(accumulated_text, "I'll remember that.") + self.assertEqual(len(tool_calls_by_index), 1) + self.assertEqual(tool_calls_by_index[0]["name"], "manage_user_memory") + + params = json.loads(tool_calls_by_index[0]["parameters"]) + self.assertEqual(params["action"], "create") + self.assertEqual(params["content"]["short_answers"], True) + self.assertEqual(params["content"]["russian_math_explanations"], True) + + def test_streaming_very_small_chunks(self): + """Test streaming with very small chunks (character by character).""" + text = '{"name": "get_weather", "arguments": {"city": "NYC"}}' + + # Split into very small chunks (every 5 characters) + chunk_size = 5 + chunked_text = [ + text[i : i + chunk_size] for i in range(0, len(text), chunk_size) + ] + chunks = [ + "<|message_sep|>\n\n", + "func", + "tion call", + "<|role_sep|>\n", + *chunked_text, + ] + tool_calls_by_index = {} + + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, self.tools) + + for call in result.calls: + if call.tool_index is not None: + if call.tool_index not in tool_calls_by_index: + tool_calls_by_index[call.tool_index] = { + "name": "", + "parameters": "", + } + + if call.name: + tool_calls_by_index[call.tool_index]["name"] = call.name + if call.parameters: + tool_calls_by_index[call.tool_index][ + "parameters" + ] += call.parameters + + self.assertEqual(len(tool_calls_by_index), 1) + self.assertEqual(tool_calls_by_index[0]["name"], "get_weather") + + params = json.loads(tool_calls_by_index[0]["parameters"]) + self.assertEqual(params["city"], "NYC") + + def test_streaming_json_split_at_quotes(self): + """Test streaming when JSON is split at quote boundaries.""" + chunks = [ + "<|message_sep|>\n\nfunction call<|role_sep|>\n", + '{"name', + '": "', + "get_weather", + '", "arguments', + '": {"city', + '": "', + "Rome", + '"}}', + ] + + tool_calls_by_index = {} + + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, self.tools) + + for call in result.calls: + if call.tool_index is not None: + if call.tool_index not in tool_calls_by_index: + tool_calls_by_index[call.tool_index] = { + "name": "", + "parameters": "", + } + + if call.name: + tool_calls_by_index[call.tool_index]["name"] = call.name + if call.parameters: + tool_calls_by_index[call.tool_index][ + "parameters" + ] += call.parameters + + self.assertEqual(len(tool_calls_by_index), 1) + self.assertEqual(tool_calls_by_index[0]["name"], "get_weather") + + params = json.loads(tool_calls_by_index[0]["parameters"]) + self.assertEqual(params["city"], "Rome") + + +class TestGemma4Detector(unittest.TestCase): + def setUp(self): + self.tools = [ + Tool( + type="function", + function=Function( + name="get_weather", + description="Get weather information", + parameters={ + "type": "object", + "properties": { + "location": {"type": "string"}, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + }, + }, + "required": ["location"], + }, + ), + ) + ] + self.detector = Gemma4Detector() + + def test_detect_and_parse(self): + text = 'Some text before <|tool_call>call:get_weather{location:<|"|>Tokyo<|"|>}' + result = self.detector.detect_and_parse(text, self.tools) + + self.assertEqual(result.normal_text, "Some text before ") + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].name, "get_weather") + + params = json.loads(result.calls[0].parameters) + self.assertEqual(params["location"], "Tokyo") + + def test_parse_streaming_increment(self): + chunks = [ + "Some text ", + "before <|tool", + "_call>call:get_we", + "ather{location:<|", + '"|>Tokyo<|"|>} after", + ] + + all_results = [] + for chunk in chunks: + res = self.detector.parse_streaming_increment(chunk, self.tools) + all_results.append(res) + + combined_normal_text = "".join(r.normal_text for r in all_results) + self.assertEqual(combined_normal_text, "Some text before after") + + found_name = False + found_params = False + for res in all_results: + for call in res.calls: + if call.name == "get_weather": + found_name = True + if call.parameters: + params = json.loads(call.parameters) + if params == {"location": "Tokyo"}: + found_params = True + + self.assertTrue(found_name) + self.assertTrue(found_params) + + def test_nested_array_streaming(self): + # Additional coverage for complex structure + chunks = [ + '<|tool_call>call:get_weather{location:<|"', + '|>New York<|"|>,nested:[1, 2, {inner:<|"|>', + 'val<|"|>}]}', + ] + + all_results = [] + for chunk in chunks: + res = self.detector.parse_streaming_increment(chunk, self.tools) + all_results.append(res) + + found_params = False + for res in all_results: + for call in res.calls: + if call.parameters: + params = json.loads(call.parameters) + if "location" in params and params["location"] == "New York": + if "nested" in params and params["nested"] == [ + 1, + 2, + {"inner": "val"}, + ]: + found_params = True + + self.assertTrue(found_params) + + def test_has_tool_call(self): + self.assertTrue( + self.detector.has_tool_call( + '<|tool_call>call:get_weather{location:<|"|>Tokyo<|"|>}' + ) + ) + self.assertFalse(self.detector.has_tool_call("no tool call here")) + + def test_detect_and_parse_no_tool_call(self): + text = "This is plain text without any tool calls." + result = self.detector.detect_and_parse(text, self.tools) + self.assertEqual(result.normal_text, text) + self.assertEqual(len(result.calls), 0) + + def test_detect_and_parse_tool_index(self): + text = '<|tool_call>call:get_weather{location:<|"|>Tokyo<|"|>}' + result = self.detector.detect_and_parse(text, self.tools) + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].tool_index, 0) + self.assertEqual(result.calls[0].name, "get_weather") + + def test_detect_and_parse_unknown_tool_index(self): + text = '<|tool_call>call:unknown_func{arg:<|"|>val<|"|>}' + result = self.detector.detect_and_parse(text, self.tools) + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].tool_index, -1) + + def test_detect_and_parse_nested_object(self): + text = '<|tool_call>call:get_weather{location:<|"|>Tokyo<|"|>,details:{temp:25,unit:<|"|>celsius<|"|>}}' + result = self.detector.detect_and_parse(text, self.tools) + self.assertEqual(len(result.calls), 1) + params = json.loads(result.calls[0].parameters) + self.assertEqual(params["location"], "Tokyo") + self.assertIsInstance(params["details"], dict) + self.assertEqual(params["details"]["temp"], 25) + self.assertEqual(params["details"]["unit"], "celsius") + + def test_detect_and_parse_multiple_calls(self): + extra_tools = self.tools + [ + Tool( + type="function", + function=Function( + name="get_time", + description="Get current time", + parameters={ + "type": "object", + "properties": {"timezone": {"type": "string"}}, + }, + ), + ) + ] + text = ( + 'Some text <|tool_call>call:get_weather{location:<|"|>Tokyo<|"|>}' + ' more text <|tool_call>call:get_time{timezone:<|"|>UTC<|"|>}' + ) + result = self.detector.detect_and_parse(text, extra_tools) + self.assertEqual(len(result.calls), 2) + self.assertEqual(result.calls[0].name, "get_weather") + self.assertEqual(result.calls[1].name, "get_time") + self.assertEqual(result.normal_text, "Some text ") + + def test_parse_gemma4_args_empty(self): + self.assertEqual(_parse_gemma4_args(""), {}) + self.assertEqual(_parse_gemma4_args(" "), {}) + + def test_parse_gemma4_args_booleans(self): + result = _parse_gemma4_args("flag:true,other:false") + self.assertIs(result["flag"], True) + self.assertIs(result["other"], False) + + def test_parse_gemma4_args_numbers(self): + result = _parse_gemma4_args("count:42,ratio:3.14") + self.assertEqual(result["count"], 42) + self.assertAlmostEqual(result["ratio"], 3.14) + + def test_parse_gemma4_args_string_with_colon(self): + result = _parse_gemma4_args( + 'url:<|"|>http://example.com<|"|>' + ) + self.assertEqual(result["url"], "http://example.com") + + def test_parse_gemma4_args_nested_object(self): + result = _parse_gemma4_args( + 'outer:{inner:<|"|>val<|"|>,num:5}' + ) + self.assertIsInstance(result["outer"], dict) + self.assertEqual(result["outer"]["inner"], "val") + self.assertEqual(result["outer"]["num"], 5) + + def test_parse_gemma4_array_mixed_types(self): + result = _parse_gemma4_array( + '<|"|>hello<|"|>, 42, true, {key:<|"|>val<|"|>}' + ) + self.assertEqual(result[0], "hello") + self.assertEqual(result[1], 42) + self.assertIs(result[2], True) + self.assertIsInstance(result[3], dict) + self.assertEqual(result[3]["key"], "val") + + def test_parse_gemma4_value_types(self): + self.assertIs(_parse_gemma4_value("true"), True) + self.assertIs(_parse_gemma4_value("false"), False) + self.assertEqual(_parse_gemma4_value("42"), 42) + self.assertAlmostEqual(_parse_gemma4_value("3.14"), 3.14) + self.assertEqual(_parse_gemma4_value("hello"), "hello") + self.assertEqual(_parse_gemma4_value(""), "") + + +if __name__ == "__main__": + unittest.main() From f0cc8b2c28f4ad80128d4eba97ca9b9c466d618c Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Wed, 1 Apr 2026 19:46:05 +0000 Subject: [PATCH 092/112] default Gemma4 attention backend to triton Gemma4 requires triton attention backend (flashinfer trtllm_mha fails on cuda graph compilation for this architecture). Auto-select triton when no attention backend is explicitly set. Also add Gemma4 to the Gemma family block that disables hybrid SWA memory. --- python/sglang/srt/server_args.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index f2dfa1c456c6..9c68003affe7 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1857,6 +1857,7 @@ def _handle_model_specific_adjustments(self): "Gemma3ForConditionalGeneration", "Gemma3nForCausalLM", "Gemma3nForConditionalGeneration", + "Gemma4ForConditionalGeneration", ]: # FIXME: https://github.com/sgl-project/sglang/pull/7367 is not compatible with gemma2 model. # It failed at this test: https://github.com/sgl-project/sglang/actions/runs/16255155597/job/45890331952#step:4:736 @@ -1864,6 +1865,10 @@ def _handle_model_specific_adjustments(self): f"Disable hybrid SWA memory for {model_arch} as it is not yet supported." ) self.disable_hybrid_swa_memory = True + if model_arch == "Gemma4ForConditionalGeneration": + if self.is_attention_backend_not_set(): + self.attention_backend = "triton" + logger.info("Use triton as default attention backend for Gemma4") elif model_arch in ["Exaone4ForCausalLM", "ExaoneMoEForCausalLM"]: if hf_config.sliding_window_pattern is not None: logger.warning( From cf1ee553b3604d15725ccb73441dbc6aed56300e Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Wed, 1 Apr 2026 19:48:35 +0000 Subject: [PATCH 093/112] revert: fused RMSNorm kernel does not need +1 shift Gemma4 uses Gemma4RMSNorm (weight=ones, computes norm*w) not GemmaRMSNorm (weight=zeros, computes norm*(1+w)). The fused kernel replaces a plain RMSNorm path, so the original `x * rrms * w + r` was correct. Revert the erroneous +1 addition. --- python/sglang/srt/layers/gemma4_fused_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/gemma4_fused_ops.py b/python/sglang/srt/layers/gemma4_fused_ops.py index 3e4bd28ae314..5f227db82853 100644 --- a/python/sglang/srt/layers/gemma4_fused_ops.py +++ b/python/sglang/srt/layers/gemma4_fused_ops.py @@ -40,7 +40,7 @@ def _gemma_rmsnorm_residual_kernel( var = tl.sum(x * x, axis=0) / N rrms = tl.rsqrt(var + eps) - out = x * rrms * (w + 1.0) + r + out = x * rrms * w + r if HAS_SCALAR: scalar = tl.load(Scalar_ptr).to(tl.float32) From 776cd8f9e8671bf2b9716ad1948b03dfb93fe3d4 Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Wed, 1 Apr 2026 21:18:42 +0000 Subject: [PATCH 094/112] MoE weight and config name change --- .../kernels/fused_moe_triton/common_utils.py | 2 +- python/sglang/srt/models/gemma4_causal.py | 22 +++++++------------ python/sglang/srt/models/gemma4_mm.py | 12 +++++----- 3 files changed, 14 insertions(+), 22 deletions(-) diff --git a/benchmark/kernels/fused_moe_triton/common_utils.py b/benchmark/kernels/fused_moe_triton/common_utils.py index 01b75fd5e0f2..d08d2bb75d83 100644 --- a/benchmark/kernels/fused_moe_triton/common_utils.py +++ b/benchmark/kernels/fused_moe_triton/common_utils.py @@ -137,7 +137,7 @@ def get_model_config( elif architecture == "Gemma4ForConditionalGeneration": E = config.num_experts // ep_size topk = config.top_k_experts - intermediate_size = config.expert_intermediate_size + intermediate_size = config.moe_intermediate_size else: # Default: Mixtral E = config.num_local_experts // ep_size diff --git a/python/sglang/srt/models/gemma4_causal.py b/python/sglang/srt/models/gemma4_causal.py index 495e6642fd51..18a1b1336853 100644 --- a/python/sglang/srt/models/gemma4_causal.py +++ b/python/sglang/srt/models/gemma4_causal.py @@ -180,7 +180,7 @@ def routing_function( num_experts=config.num_experts + get_global_server_args().ep_num_redundant_experts, hidden_size=config.hidden_size, - intermediate_size=getattr(config, "expert_intermediate_size", config.moe_intermediate_size), + intermediate_size=config.moe_intermediate_size, layer_id=layer_id, top_k=config.top_k_experts, quant_config=quant_config, @@ -274,13 +274,10 @@ def __init__( if layer_type in config.rope_parameters: rope_parameters = dict(config.rope_parameters[layer_type]) - if layer_type == "full_attention": - global_prf = getattr(config, "global_partial_rotary_factor", 0.25) - rope_parameters["partial_rotary_factor"] = global_prf else: rope_parameters = dict( rope_type="default", - rope_theta=getattr(config, "rope_theta", 10000.0), + rope_theta=10000.0, ) # KV sharing logic @@ -470,9 +467,7 @@ def __init__( self.post_per_layer_input_norm = None # Parallel MoE - self.enable_moe_block = getattr(config, "enable_moe_block", False) or getattr( - config, "use_second_mlp_block", False - ) + self.enable_moe_block = getattr(config, "enable_moe_block", False) if self.enable_moe_block: self.router = Gemma4Router( config, @@ -923,12 +918,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): for name, loaded_weight in weights: name = name.replace("model.language_model.", "model.") - if ( - ".moe." in name - and "experts" not in name - and "per_expert_scale" not in name - ): - name = name.replace(".moe.", ".moe.experts.") + # HF has router.per_expert_scale and experts.* on the decoder layer; + # remap into our moe.* subtree since Gemma4MoE owns both. + name = name.replace(".router.per_expert_scale", ".moe.per_expert_scale") + if ".experts." in name and ".moe.experts." not in name: + name = name.replace(".experts.", ".moe.experts.") # attention_k_eq_v: full-attention layers have no v_proj in the # checkpoint (K and V share weights). When we see a k_proj weight diff --git a/python/sglang/srt/models/gemma4_mm.py b/python/sglang/srt/models/gemma4_mm.py index 85a801069fb8..4618129fab7a 100644 --- a/python/sglang/srt/models/gemma4_mm.py +++ b/python/sglang/srt/models/gemma4_mm.py @@ -741,13 +741,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): name = re.sub(r"^model\.", "", name) - # moe experts - if ( - ".moe." in name - and "experts" not in name - and "per_expert_scale" not in name - ): - name = name.replace(".moe.", ".moe.experts.") + # HF has router.per_expert_scale and experts.* on the decoder layer; + # remap into our moe.* subtree since Gemma4MoE owns both. + name = name.replace(".router.per_expert_scale", ".moe.per_expert_scale") + if ".experts." in name and ".moe.experts." not in name: + name = name.replace(".experts.", ".moe.experts.") # Remap audio tower checkpoint names to our module tree if "audio_tower." in name: From 683827c6483bc35b0b32cf6da0e5558cc3f03bc6 Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Wed, 1 Apr 2026 21:47:52 +0000 Subject: [PATCH 095/112] register gemma4 as hybrid SWA model --- python/sglang/srt/server_args.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 9c68003affe7..145d455c6e47 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1857,7 +1857,6 @@ def _handle_model_specific_adjustments(self): "Gemma3ForConditionalGeneration", "Gemma3nForCausalLM", "Gemma3nForConditionalGeneration", - "Gemma4ForConditionalGeneration", ]: # FIXME: https://github.com/sgl-project/sglang/pull/7367 is not compatible with gemma2 model. # It failed at this test: https://github.com/sgl-project/sglang/actions/runs/16255155597/job/45890331952#step:4:736 @@ -1865,10 +1864,10 @@ def _handle_model_specific_adjustments(self): f"Disable hybrid SWA memory for {model_arch} as it is not yet supported." ) self.disable_hybrid_swa_memory = True - if model_arch == "Gemma4ForConditionalGeneration": - if self.is_attention_backend_not_set(): - self.attention_backend = "triton" - logger.info("Use triton as default attention backend for Gemma4") + elif model_arch == "Gemma4ForConditionalGeneration": + if self.is_attention_backend_not_set(): + self.attention_backend = "triton" + logger.info("Use triton as default attention backend for Gemma4") elif model_arch in ["Exaone4ForCausalLM", "ExaoneMoEForCausalLM"]: if hf_config.sliding_window_pattern is not None: logger.warning( From aba771b6380adc7444750d69dc1f8d1894ee6681 Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Wed, 1 Apr 2026 22:45:19 +0000 Subject: [PATCH 096/112] lint --- .../function_call/test_function_call_parser.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/test/registered/function_call/test_function_call_parser.py b/test/registered/function_call/test_function_call_parser.py index 172e71610bae..056d9fcad42d 100644 --- a/test/registered/function_call/test_function_call_parser.py +++ b/test/registered/function_call/test_function_call_parser.py @@ -3901,23 +3901,17 @@ def test_parse_gemma4_args_numbers(self): self.assertAlmostEqual(result["ratio"], 3.14) def test_parse_gemma4_args_string_with_colon(self): - result = _parse_gemma4_args( - 'url:<|"|>http://example.com<|"|>' - ) + result = _parse_gemma4_args('url:<|"|>http://example.com<|"|>') self.assertEqual(result["url"], "http://example.com") def test_parse_gemma4_args_nested_object(self): - result = _parse_gemma4_args( - 'outer:{inner:<|"|>val<|"|>,num:5}' - ) + result = _parse_gemma4_args('outer:{inner:<|"|>val<|"|>,num:5}') self.assertIsInstance(result["outer"], dict) self.assertEqual(result["outer"]["inner"], "val") self.assertEqual(result["outer"]["num"], 5) def test_parse_gemma4_array_mixed_types(self): - result = _parse_gemma4_array( - '<|"|>hello<|"|>, 42, true, {key:<|"|>val<|"|>}' - ) + result = _parse_gemma4_array('<|"|>hello<|"|>, 42, true, {key:<|"|>val<|"|>}') self.assertEqual(result[0], "hello") self.assertEqual(result[1], 42) self.assertIs(result[2], True) From 4e061dc58fadf2a68dacaa879657ebddb2376241 Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Thu, 2 Apr 2026 01:10:30 +0000 Subject: [PATCH 097/112] init remove swa kv pool hack --- python/sglang/srt/mem_cache/swa_memory_pool.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/python/sglang/srt/mem_cache/swa_memory_pool.py b/python/sglang/srt/mem_cache/swa_memory_pool.py index 367caf30f02c..ab2cda35ee6f 100644 --- a/python/sglang/srt/mem_cache/swa_memory_pool.py +++ b/python/sglang/srt/mem_cache/swa_memory_pool.py @@ -105,11 +105,6 @@ def get_kv_size_bytes(self): k_size_swa, v_size_swa = self.swa_kv_pool.get_kv_size_bytes() return k_size + k_size_swa, v_size + v_size_swa - def get_v_head_dim(self): - swa_v_dim = self.swa_kv_pool.get_value_buffer(0).shape[-1] - full_v_dim = self.full_kv_pool.get_value_buffer(0).shape[-1] - return max(swa_v_dim, full_v_dim) - def get_contiguous_buf_infos(self): full_kv_data_ptrs, full_kv_data_lens, full_kv_item_lens = ( self.full_kv_pool.get_contiguous_buf_infos() From 595b7680e6daf5f612c8e3f2cfdda12866722b37 Mon Sep 17 00:00:00 2001 From: Xinyuan Tong <115166877+JustinTong0323@users.noreply.github.com> Date: Thu, 2 Apr 2026 18:15:04 +0800 Subject: [PATCH 098/112] Update python/sglang/srt/configs/model_config.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- python/sglang/srt/configs/model_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 98f81426dd55..922667924586 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -1493,7 +1493,7 @@ def get_hybrid_layer_ids( "Gemma4ForCausalLM" in model_architectures or "Gemma4ForConditionalGeneration" in model_architectures ): - layer_types = getattr(hf_text_config, "layer_types", None) + layer_types = getattr(hf_text_config, "layer_types", []) swa_attention_layer_ids = [ i for i, x in enumerate(layer_types) if x == "sliding_attention" ] From 2deb3de226d6da03280bfd9379cd7e70e44d4d5e Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Thu, 2 Apr 2026 10:22:35 +0000 Subject: [PATCH 099/112] Add 'ather' to codespell ignore list for chunked get_weather test strings --- .codespellrc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.codespellrc b/.codespellrc index 808a344b4e6f..5b14597698f4 100644 --- a/.codespellrc +++ b/.codespellrc @@ -1,3 +1,3 @@ [codespell] -ignore-words-list = ans, als, hel, boostrap, childs, te, vas, hsa, ment, cann, thi, makro, wil, rouge, PRIS +ignore-words-list = ans, als, hel, boostrap, childs, te, vas, hsa, ment, cann, thi, makro, wil, rouge, PRIS, ather skip = *.json,*.jsonl,*.patch,*.txt From ad7d0d2cdbb0f2914d4858c70997cfe94a4e47ea Mon Sep 17 00:00:00 2001 From: Xinyuan Tong <115166877+JustinTong0323@users.noreply.github.com> Date: Thu, 2 Apr 2026 18:23:03 +0800 Subject: [PATCH 100/112] Update python/sglang/bench_one_batch.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- python/sglang/bench_one_batch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/bench_one_batch.py b/python/sglang/bench_one_batch.py index df719b66d667..31244c8851c4 100644 --- a/python/sglang/bench_one_batch.py +++ b/python/sglang/bench_one_batch.py @@ -319,8 +319,8 @@ def prepare_inputs_for_correctness_test(bench_args, tokenizer, custom_prompts): if custom_prompts else [ "The capital of France is", - # "The capital of the United Kindom is", - # "Today is a sunny day and I like", + "The capital of the United Kindom is", + "Today is a sunny day and I like", ] ) input_ids = [tokenizer.encode(p) for p in prompts] From 9b56b56f4c3c2eaa8d1a441cdd32f33e15c711f4 Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Thu, 2 Apr 2026 10:36:14 +0000 Subject: [PATCH 101/112] Address PR review comments: default layer_types, precompute causal mask, avoid hasattr --- python/sglang/srt/configs/model_config.py | 2 +- python/sglang/srt/models/gemma4_audio.py | 50 ++++++++++++----------- python/sglang/srt/models/gemma4_causal.py | 3 +- 3 files changed, 30 insertions(+), 25 deletions(-) diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 922667924586..f7512518f33a 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -1456,7 +1456,7 @@ def get_hybrid_layer_ids( i for i in range(num_hidden_layers) if (i + 1) % 4 == 0 ] elif "GptOssForCausalLM" in model_architectures: - layer_types = getattr(hf_text_config, "layer_types", None) + layer_types = getattr(hf_text_config, "layer_types", []) swa_attention_layer_ids = [ i for i, x in enumerate(layer_types) if x == "sliding_attention" ] diff --git a/python/sglang/srt/models/gemma4_audio.py b/python/sglang/srt/models/gemma4_audio.py index 91dc52ae7f75..db825165fe29 100644 --- a/python/sglang/srt/models/gemma4_audio.py +++ b/python/sglang/srt/models/gemma4_audio.py @@ -807,6 +807,30 @@ def __init__( else: self.output_proj = None + # Precompute causal_valid_mask — depends only on static config values. + chunk_size = config.attention_chunk_size + max_future_horizon = config.attention_context_right + max_past_horizon = max(0, config.attention_context_left - 1) + upper_diagonal = max_past_horizon + max_future_horizon + context_size = chunk_size + max_past_horizon + max_future_horizon + + lower_causal_mask = torch.tril( + torch.ones((context_size, chunk_size), dtype=torch.bool), + diagonal=0, + ).T + upper_causal_mask = torch.tril( + torch.ones((chunk_size, context_size), dtype=torch.bool), + diagonal=upper_diagonal, + ) + local_causal_valid_mask = torch.ones( + (chunk_size, context_size), dtype=torch.bool + ) + self.register_buffer( + "causal_valid_mask", + local_causal_valid_mask * lower_causal_mask * upper_causal_mask, + persistent=False, + ) + @property def device(self): return next(self.parameters()).device @@ -828,30 +852,10 @@ def forward( audio_mel, audio_mel_mask ) - with torch.no_grad(): - chunk_size = self.config.attention_chunk_size - max_future_horizon = self.config.attention_context_right - max_past_horizon = max(0, self.config.attention_context_left - 1) - upper_diagonal = max_past_horizon + max_future_horizon - context_size = chunk_size + max_past_horizon + max_future_horizon - - lower_causal_mask = torch.tril( - torch.ones((context_size, chunk_size), dtype=torch.bool), - diagonal=0, - ).T - upper_causal_mask = torch.tril( - torch.ones((chunk_size, context_size), dtype=torch.bool), - diagonal=upper_diagonal, - ) - local_causal_valid_mask = torch.ones( - (chunk_size, context_size), dtype=torch.bool - ) - causal_valid_mask = ( - local_causal_valid_mask * lower_causal_mask * upper_causal_mask - ) - for block in self.conformer: - audio_encodings = block(audio_encodings, current_mask, causal_valid_mask) + audio_encodings = block( + audio_encodings, current_mask, self.causal_valid_mask + ) if self.output_proj is not None: audio_encodings, _ = self.output_proj(audio_encodings) diff --git a/python/sglang/srt/models/gemma4_causal.py b/python/sglang/srt/models/gemma4_causal.py index 18a1b1336853..544406119243 100644 --- a/python/sglang/srt/models/gemma4_causal.py +++ b/python/sglang/srt/models/gemma4_causal.py @@ -103,6 +103,7 @@ def __init__( quant_config=None, prefix=add_prefix("proj", prefix), ) + self._fused_scale: Optional[torch.Tensor] = None def fuse_scale(self): """Pre-compute scale * root_size. Call after weights are loaded.""" @@ -111,7 +112,7 @@ def fuse_scale(self): def forward(self, x: torch.Tensor) -> torch.Tensor: """Returns raw router logits [T, E].""" x = self.norm(x) - if not hasattr(self, "_fused_scale"): + if self._fused_scale is None: self.fuse_scale() x = x * self._fused_scale.to(x.dtype) router_logits, _ = self.proj(x) From 862be179f8232d95d146c930a6191a2ffde51fa6 Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Thu, 2 Apr 2026 11:45:27 +0000 Subject: [PATCH 102/112] Revert test/manual/test_vlm_accuracy.py to main (test model not ready) --- test/manual/test_vlm_accuracy.py | 397 +------------------------------ 1 file changed, 2 insertions(+), 395 deletions(-) diff --git a/test/manual/test_vlm_accuracy.py b/test/manual/test_vlm_accuracy.py index e35cb17233b6..6e26c012a7eb 100644 --- a/test/manual/test_vlm_accuracy.py +++ b/test/manual/test_vlm_accuracy.py @@ -1,16 +1,12 @@ -"""Multimodal encoder accuracy tests: compare HF vs SGLang encoder outputs.""" +""" """ -import os -import socket -import tempfile import unittest from typing import List, Optional import numpy as np import torch -import torch.multiprocessing as mp import torch.nn.functional as F -from transformers import AutoConfig, AutoModel, AutoProcessor, AutoTokenizer +from transformers import AutoModel, AutoProcessor, AutoTokenizer from sglang.srt.configs.model_config import ModelConfig from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest @@ -322,392 +318,3 @@ async def test_vlm_embedding_output(self): ) self.compare_outputs(sglang_output, hf_output) - - -# --------------------------------------------------------------------------- -# Gemma 4 encoder accuracy: vision tower + audio tower vs HF reference -# --------------------------------------------------------------------------- - - -def _make_patchified_vision_inputs( - device: torch.device, - dtype: torch.dtype = torch.bfloat16, - side_patches: int = 48, - patch_size: int = 16, -) -> tuple: - """Create synthetic patchified vision inputs matching the HF image-processor format. - - Returns (pixel_values, pixel_position_ids) with no padding. - """ - num_patches = side_patches * side_patches - patch_pixels = 3 * patch_size**2 - pixel_values = torch.randn(1, num_patches, patch_pixels, device=device, dtype=dtype) - ys, xs = torch.meshgrid( - torch.arange(side_patches), torch.arange(side_patches), indexing="ij" - ) - pixel_position_ids = ( - torch.stack([xs.flatten(), ys.flatten()], dim=-1).unsqueeze(0).to(device) - ) - return pixel_values, pixel_position_ids - - -class TestGemma4EncoderAccuracy(unittest.TestCase): - """Compare Gemma 4 vision and audio encoder outputs between HF and SGLang. - - For each encoder we compare: - 1. Raw tower output (before the multimodal embedder projection). - 2. Projected output (tower + ``embed_vision`` / ``embed_audio``). - - Inputs are random tensors so that the test is self-contained and does not - depend on image / audio files. - """ - - MODEL_PATH = "gg-hf-gg/gemma-4-e4b-it" - COSINE_THRESHOLD = 0.98 - - @classmethod - def setUpClass(cls): - cls.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - # -- HF model: extract encoder components, discard the rest ----------- - from transformers import ( - Gemma4ForConditionalGeneration as HFGemma4ForConditionalGeneration, - ) - - hf_full = HFGemma4ForConditionalGeneration.from_pretrained( - cls.MODEL_PATH, torch_dtype=torch.bfloat16 - ) - - cls.hf_vision_tower = hf_full.model.vision_tower.eval().to(cls.device) - cls.hf_embed_vision = hf_full.model.embed_vision.eval().to(cls.device) - - cls.hf_audio_tower = None - cls.hf_embed_audio = None - cls.mel_bins = None - if hf_full.model.audio_tower is not None: - cls.hf_audio_tower = hf_full.model.audio_tower.eval().to(cls.device) - cls.hf_embed_audio = hf_full.model.embed_audio.eval().to(cls.device) - config = AutoConfig.from_pretrained(cls.MODEL_PATH) - cls.mel_bins = 128 - - del hf_full - torch.cuda.empty_cache() - - # -- SGLang model via ModelRunner ------------------------------------- - cls.model_runner = ModelRunner( - model_config=ModelConfig(cls.MODEL_PATH, model_override_args="{}"), - mem_fraction_static=0.8, - gpu_id=0, - tp_rank=0, - tp_size=1, - moe_ep_rank=0, - moe_ep_size=1, - pp_rank=0, - pp_size=1, - nccl_port=12435, - server_args=ServerArgs( - model_path=cls.MODEL_PATH, - disable_cuda_graph=True, - mm_attention_backend="sdpa", - ), - ) - cls.sg_model = cls.model_runner.model - - # -- helpers -------------------------------------------------------------- - - @staticmethod - def _cosine_stats(a: torch.Tensor, b: torch.Tensor): - cos = F.cosine_similarity(a.float(), b.float()) - return cos.mean().item(), cos.min().item() - - def _assert_cosine_close(self, hf: torch.Tensor, sg: torch.Tensor, label: str): - mean_cos, min_cos = self._cosine_stats(hf, sg) - print(f" {label}: mean_cos={mean_cos:.6f} min_cos={min_cos:.6f}") - self.assertGreater( - min_cos, - self.COSINE_THRESHOLD, - f"{label} min cosine {min_cos:.6f} < {self.COSINE_THRESHOLD}", - ) - - # -- vision --------------------------------------------------------------- - - def test_vision_encoder(self): - """Vision tower + embed_vision should match HF on patchified pixels.""" - pixel_values, pixel_position_ids = _make_patchified_vision_inputs(self.device) - - with torch.no_grad(): - # HF: last_hidden_state contains only valid (non-padding) tokens - hf_out = self.hf_vision_tower(pixel_values, pixel_position_ids) - hf_tokens = hf_out.last_hidden_state - hf_projected = self.hf_embed_vision(hf_tokens.unsqueeze(0)).squeeze(0) - - # SGLang: returns (pooled, pooler_mask) with mask True = valid - sg_pooled, sg_mask = self.sg_model.vision_tower( - pixel_values, pixel_position_ids - ) - sg_tokens = torch.cat([hs[m] for hs, m in zip(sg_pooled, sg_mask)]) - sg_projected = self.sg_model.embed_vision(sg_tokens.unsqueeze(0)).squeeze(0) - - self.assertEqual(hf_tokens.shape, sg_tokens.shape) - print() - self._assert_cosine_close(hf_tokens, sg_tokens, "vision tower") - self._assert_cosine_close(hf_projected, sg_projected, "vision projected") - - # -- audio ---------------------------------------------------------------- - - def test_audio_encoder(self): - """Audio tower + embed_audio should match HF on random mel input.""" - if self.hf_audio_tower is None: - self.skipTest("Model does not have an audio tower") - - num_frames = 200 - audio_mel = torch.randn( - 1, num_frames, self.mel_bins, device=self.device, dtype=torch.bfloat16 - ) - audio_mel_mask = torch.zeros( - 1, num_frames, device=self.device, dtype=torch.bool - ) - - with torch.no_grad(): - # HF: attention_mask convention is True=valid. - # SGLang: audio_mel_mask convention is True=padding. - hf_attention_mask = ~audio_mel_mask - hf_out = self.hf_audio_tower(audio_mel, hf_attention_mask) - hf_enc = hf_out.last_hidden_state - hf_output_mask = hf_out.attention_mask # True=valid - hf_valid = hf_enc[hf_output_mask.unsqueeze(-1).expand_as(hf_enc)].reshape( - -1, hf_enc.shape[-1] - ) - hf_projected = self.hf_embed_audio(hf_valid.unsqueeze(0)).squeeze(0) - - # SGLang: returns (encodings, mask) where mask True=padding - sg_enc, sg_mask = self.sg_model.audio_tower(audio_mel, audio_mel_mask) - sg_valid_mask = ~sg_mask - sg_valid = sg_enc[sg_valid_mask.unsqueeze(-1).expand_as(sg_enc)].reshape( - -1, sg_enc.shape[-1] - ) - sg_projected = self.sg_model.embed_audio(sg_valid.unsqueeze(0)).squeeze(0) - - self.assertEqual(hf_valid.shape, sg_valid.shape) - print() - self._assert_cosine_close(hf_valid, sg_valid, "audio tower") - self._assert_cosine_close(hf_projected, sg_projected, "audio projected") - - -# --------------------------------------------------------------------------- -# Gemma 4 encoder accuracy at TP=2: compare SGLang (TP=2) vs HF reference -# --------------------------------------------------------------------------- - - -def _find_free_port() -> int: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(("localhost", 0)) - s.listen(1) - return s.getsockname()[1] - - -def _tp2_encoder_worker( - local_rank: int, - world_size: int, - nccl_port: int, - model_path: str, - mel_bins: int, - num_frames: int, - result_file: str, -): - """Worker spawned by mp.spawn — loads SGLang model with TP and runs encoders.""" - os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(i) for i in range(world_size)) - device = torch.device(f"cuda:{local_rank}") - torch.cuda.set_device(device) - - model_runner = ModelRunner( - model_config=ModelConfig(model_path, model_override_args="{}"), - mem_fraction_static=0.5, - gpu_id=local_rank, - tp_rank=local_rank, - tp_size=world_size, - moe_ep_rank=0, - moe_ep_size=1, - pp_rank=0, - pp_size=1, - nccl_port=nccl_port, - server_args=ServerArgs( - model_path=model_path, - disable_cuda_graph=True, - mm_attention_backend="sdpa", - mem_fraction_static=0.5, - ), - ) - sg_model = model_runner.model - - # Deterministic input — identical on every rank. - torch.manual_seed(42) - audio_mel = torch.randn( - 1, num_frames, mel_bins, device=device, dtype=torch.bfloat16 - ) - audio_mel_mask = torch.zeros(1, num_frames, device=device, dtype=torch.bool) - pixel_values, pixel_position_ids = _make_patchified_vision_inputs(device) - - with torch.no_grad(): - # Audio - sg_audio_enc, sg_audio_mask = sg_model.audio_tower(audio_mel, audio_mel_mask) - sg_audio_valid_mask = ~sg_audio_mask - sg_audio_valid = sg_audio_enc[ - sg_audio_valid_mask.unsqueeze(-1).expand_as(sg_audio_enc) - ].reshape(-1, sg_audio_enc.shape[-1]) - sg_audio_proj = sg_model.embed_audio(sg_audio_valid.unsqueeze(0)).squeeze(0) - - # Vision - sg_vis_pooled, sg_vis_mask = sg_model.vision_tower( - pixel_values, pixel_position_ids - ) - sg_vis_tokens = torch.cat([hs[m] for hs, m in zip(sg_vis_pooled, sg_vis_mask)]) - sg_vis_proj = sg_model.embed_vision(sg_vis_tokens.unsqueeze(0)).squeeze(0) - - if local_rank == 0: - torch.save( - { - "audio_valid": sg_audio_valid.cpu(), - "audio_projected": sg_audio_proj.cpu(), - "vision_tokens": sg_vis_tokens.cpu(), - "vision_projected": sg_vis_proj.cpu(), - }, - result_file, - ) - - -class TestGemma4EncoderAccuracyTP2(unittest.TestCase): - """Compare Gemma 4 vision + audio encoder outputs at TP=2 against HF. - - Uses ``mp.spawn`` to create 2 workers that jointly load the SGLang model - with tensor parallelism, then compares rank-0 output with the HF reference - computed in the parent process. - """ - - MODEL_PATH = "gg-hf-gg/gemma-4-e4b-it" - # TP=2 all-reduce introduces small bf16 rounding that compounds across - # 12 conformer blocks; 0.98 is the practical floor. - COSINE_THRESHOLD = 0.98 - NUM_FRAMES = 200 - - @classmethod - def setUpClass(cls): - if not torch.cuda.is_available() or torch.cuda.device_count() < 2: - raise unittest.SkipTest("Need >= 2 GPUs for TP=2 test") - - cls.device = torch.device("cuda:0") - config = AutoConfig.from_pretrained(cls.MODEL_PATH) - cls.mel_bins = 128 - - # -- HF reference (run on GPU 0, then free) ---------------------------- - from transformers import ( - Gemma4ForConditionalGeneration as HFGemma4ForConditionalGeneration, - ) - - hf_full = HFGemma4ForConditionalGeneration.from_pretrained( - cls.MODEL_PATH, torch_dtype=torch.bfloat16 - ) - hf_audio_tower = hf_full.model.audio_tower.eval().to(cls.device) - hf_embed_audio = hf_full.model.embed_audio.eval().to(cls.device) - hf_vision_tower = hf_full.model.vision_tower.eval().to(cls.device) - hf_embed_vision = hf_full.model.embed_vision.eval().to(cls.device) - del hf_full - torch.cuda.empty_cache() - - torch.manual_seed(42) - audio_mel = torch.randn( - 1, cls.NUM_FRAMES, cls.mel_bins, device=cls.device, dtype=torch.bfloat16 - ) - audio_mel_mask = torch.zeros( - 1, cls.NUM_FRAMES, device=cls.device, dtype=torch.bool - ) - pixel_values, pixel_position_ids = _make_patchified_vision_inputs(cls.device) - - with torch.no_grad(): - # HF attention_mask: True=valid; SGLang audio_mel_mask: True=padding - hf_attention_mask = ~audio_mel_mask - hf_out = hf_audio_tower(audio_mel, hf_attention_mask) - hf_enc = hf_out.last_hidden_state - hf_output_mask = hf_out.attention_mask # True=valid - cls.hf_audio_valid = ( - hf_enc[hf_output_mask.unsqueeze(-1).expand_as(hf_enc)] - .reshape(-1, hf_enc.shape[-1]) - .cpu() - ) - cls.hf_audio_proj = ( - hf_embed_audio(cls.hf_audio_valid.unsqueeze(0).to(cls.device)) - .squeeze(0) - .cpu() - ) - - hf_vis_out = hf_vision_tower(pixel_values, pixel_position_ids) - cls.hf_vis_tokens = hf_vis_out.last_hidden_state.cpu() - cls.hf_vis_proj = ( - hf_embed_vision(cls.hf_vis_tokens.unsqueeze(0).to(cls.device)) - .squeeze(0) - .cpu() - ) - - del hf_audio_tower, hf_embed_audio, hf_vision_tower, hf_embed_vision - import gc - - gc.collect() - torch.cuda.empty_cache() - - # -- Run SGLang at TP=2 via mp.spawn ----------------------------------- - nccl_port = _find_free_port() - cls._result_file = tempfile.NamedTemporaryFile(suffix=".pt", delete=False) - cls._result_file.close() - - mp.spawn( - _tp2_encoder_worker, - args=( - 2, - nccl_port, - cls.MODEL_PATH, - cls.mel_bins, - cls.NUM_FRAMES, - cls._result_file.name, - ), - nprocs=2, - join=True, - ) - - cls.sg_results = torch.load(cls._result_file.name, weights_only=True) - - @classmethod - def tearDownClass(cls): - if hasattr(cls, "_result_file"): - os.unlink(cls._result_file.name) - - @staticmethod - def _cosine_stats(a: torch.Tensor, b: torch.Tensor): - cos = F.cosine_similarity(a.float(), b.float()) - return cos.mean().item(), cos.min().item() - - def _assert_cosine_close(self, hf: torch.Tensor, sg: torch.Tensor, label: str): - mean_cos, min_cos = self._cosine_stats(hf, sg) - print(f" {label}: mean_cos={mean_cos:.6f} min_cos={min_cos:.6f}") - self.assertGreater( - min_cos, - self.COSINE_THRESHOLD, - f"{label} min cosine {min_cos:.6f} < {self.COSINE_THRESHOLD}", - ) - - def test_audio_encoder_tp2(self): - """Audio tower + embed_audio at TP=2 should match HF reference.""" - sg_valid = self.sg_results["audio_valid"] - sg_proj = self.sg_results["audio_projected"] - self.assertEqual(self.hf_audio_valid.shape, sg_valid.shape) - print() - self._assert_cosine_close(self.hf_audio_valid, sg_valid, "audio tower TP=2") - self._assert_cosine_close(self.hf_audio_proj, sg_proj, "audio projected TP=2") - - def test_vision_encoder_tp2(self): - """Vision tower + embed_vision at TP=2 should match HF reference.""" - sg_tokens = self.sg_results["vision_tokens"] - sg_proj = self.sg_results["vision_projected"] - self.assertEqual(self.hf_vis_tokens.shape, sg_tokens.shape) - print() - self._assert_cosine_close(self.hf_vis_tokens, sg_tokens, "vision tower TP=2") - self._assert_cosine_close(self.hf_vis_proj, sg_proj, "vision projected TP=2") From d88f7e83bf1d6415287c28f6988b632f01dfb8c0 Mon Sep 17 00:00:00 2001 From: adarshxs Date: Thu, 2 Apr 2026 13:56:11 +0000 Subject: [PATCH 103/112] fix reasoning parser --- python/sglang/srt/entrypoints/openai/serving_chat.py | 7 ++++++- python/sglang/srt/parser/reasoning_parser.py | 8 ++++---- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/entrypoints/openai/serving_chat.py b/python/sglang/srt/entrypoints/openai/serving_chat.py index 9d003599b365..23beb398b205 100644 --- a/python/sglang/srt/entrypoints/openai/serving_chat.py +++ b/python/sglang/srt/entrypoints/openai/serving_chat.py @@ -1275,12 +1275,17 @@ def _get_reasoning_from_request(self, request: ChatCompletionRequest) -> bool: if not self.reasoning_parser: return False - if self.reasoning_parser in ["deepseek-v3", "gemma4"]: + if self.reasoning_parser == "deepseek-v3": # Models that require explicit enable thinking (thinking=True) return ( request.chat_template_kwargs is not None and request.chat_template_kwargs.get("thinking") is True ) + if self.reasoning_parser == "gemma4": + return request.chat_template_kwargs is not None and ( + request.chat_template_kwargs.get("thinking") is True + or request.chat_template_kwargs.get("enable_thinking") is True + ) if self.reasoning_parser in ["kimi_k2"]: # Models that thinking by default, and can be disabled by setting thinking=False return ( diff --git a/python/sglang/srt/parser/reasoning_parser.py b/python/sglang/srt/parser/reasoning_parser.py index 8a856447f23c..8811c90b2ddc 100644 --- a/python/sglang/srt/parser/reasoning_parser.py +++ b/python/sglang/srt/parser/reasoning_parser.py @@ -114,8 +114,10 @@ def parse_streaming_increment(self, new_text: str) -> StreamingParseResult: self._buffer += new_text current_text = self._buffer + think_start_text = self.think_start_token + self.think_start_self_label + # If the current text is a prefix of the think token, keep buffering - tokens_to_check = [self.think_start_token, self.think_end_token] + tokens_to_check = [think_start_text, self.think_end_token] if self.tool_start_token: tokens_to_check.append(self.tool_start_token) if any( @@ -124,11 +126,9 @@ def parse_streaming_increment(self, new_text: str) -> StreamingParseResult: ): return StreamingParseResult() - think_start_text = self.think_start_token + self.think_start_self_label - # Strip `` token if present if not self.stripped_think_start and think_start_text in current_text: - current_text = current_text.replace(self.think_start_token, "") + current_text = current_text.replace(think_start_text, "", 1) self.stripped_think_start = True self._in_reasoning = True From f09b2fae44d5eb793b2576ac23f5e274a9e21d74 Mon Sep 17 00:00:00 2001 From: adarshxs Date: Thu, 2 Apr 2026 14:27:00 +0000 Subject: [PATCH 104/112] add HIP back to optimized store_cache gate --- python/sglang/srt/mem_cache/memory_pool.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 06e0801a98f1..df5d54223c22 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -96,7 +96,7 @@ def _set_kv_buffer_impl( same_kv_dim: bool = True, ) -> None: row_bytes = row_dim * store_dtype.itemsize - if _is_cuda and same_kv_dim and can_use_store_cache(row_bytes): + if (_is_cuda or _is_hip) and same_kv_dim and can_use_store_cache(row_bytes): return store_cache( k.view(-1, row_dim), v.view(-1, row_dim), From 39b37ca60a3be1a2eecf104f1d30eb3affde4c33 Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Thu, 2 Apr 2026 15:23:28 +0000 Subject: [PATCH 105/112] fix: use only enable_thinking for gemma4 reasoning parser to match chat template --- python/sglang/srt/entrypoints/openai/serving_chat.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/entrypoints/openai/serving_chat.py b/python/sglang/srt/entrypoints/openai/serving_chat.py index 23beb398b205..08a9ecd9ac02 100644 --- a/python/sglang/srt/entrypoints/openai/serving_chat.py +++ b/python/sglang/srt/entrypoints/openai/serving_chat.py @@ -1282,9 +1282,9 @@ def _get_reasoning_from_request(self, request: ChatCompletionRequest) -> bool: and request.chat_template_kwargs.get("thinking") is True ) if self.reasoning_parser == "gemma4": - return request.chat_template_kwargs is not None and ( - request.chat_template_kwargs.get("thinking") is True - or request.chat_template_kwargs.get("enable_thinking") is True + return ( + request.chat_template_kwargs is not None + and request.chat_template_kwargs.get("enable_thinking") is True ) if self.reasoning_parser in ["kimi_k2"]: # Models that thinking by default, and can be disabled by setting thinking=False From 0df626a40f1ac892e06c84d10074193d131f2f4f Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Sat, 4 Apr 2026 03:32:34 +0000 Subject: [PATCH 106/112] adapt to MultimodalProcessorOutput --- .../sglang/srt/multimodal/processors/gemma4.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/python/sglang/srt/multimodal/processors/gemma4.py b/python/sglang/srt/multimodal/processors/gemma4.py index e97885002342..80bb37061358 100644 --- a/python/sglang/srt/multimodal/processors/gemma4.py +++ b/python/sglang/srt/multimodal/processors/gemma4.py @@ -20,7 +20,7 @@ from sglang.srt.managers.multimodal_processor import ( BaseMultimodalProcessor as SGLangBaseProcessor, ) -from sglang.srt.managers.schedule_batch import Modality +from sglang.srt.managers.schedule_batch import Modality, MultimodalProcessorOutput from sglang.srt.models.gemma4_audio import _SSCP_CONV_STRIDE_SIZES from sglang.srt.models.gemma4_mm import Gemma4ForConditionalGeneration from sglang.srt.multimodal.processors.base_processor import MultimodalSpecialTokens @@ -136,10 +136,10 @@ async def process_mm_data_async( base_output, self.mm_tokens ) - return { - "input_ids": input_ids.tolist(), - "mm_items": mm_items, - "im_token_id": self.mm_tokens.image_token_id, - "video_token_id": self.mm_tokens.video_token_id, - "audio_token_id": self.mm_tokens.audio_token_id, - } + return MultimodalProcessorOutput( + input_ids=input_ids.tolist(), + mm_items=mm_items, + im_token_id=self.mm_tokens.image_token_id, + video_token_id=self.mm_tokens.video_token_id, + audio_token_id=self.mm_tokens.audio_token_id, + ) From d9d34c6555b5e7a67ad4144bf57a96a08ec3c358 Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Sat, 4 Apr 2026 04:19:04 +0000 Subject: [PATCH 107/112] address comments --- python/sglang/srt/parser/conversation.py | 1 - scripts/playground/reference_hf.py | 1 + .../unit/parser/test_reasoning_parser.py | 111 ++++++++++++++++++ 3 files changed, 112 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/parser/conversation.py b/python/sglang/srt/parser/conversation.py index 092b1bbd93cd..954cb168ba34 100644 --- a/python/sglang/srt/parser/conversation.py +++ b/python/sglang/srt/parser/conversation.py @@ -65,7 +65,6 @@ class SeparatorStyle(IntEnum): QWEN2_VL_EMBED = auto() QWEN2_AUDIO = auto() GEMMA3 = auto() - GEMMA4 = auto() MPT = auto() PADDLE_OCR = auto() diff --git a/scripts/playground/reference_hf.py b/scripts/playground/reference_hf.py index 48b5762106c2..538c31f7713d 100644 --- a/scripts/playground/reference_hf.py +++ b/scripts/playground/reference_hf.py @@ -167,6 +167,7 @@ def synthetic_tokens(args): prefill_logits = m.forward(torch.tensor([input_ids], device="cuda")).logits[ 0 ][-1] + if i == 0: print("prefill logits", prefill_logits) else: diff --git a/test/registered/unit/parser/test_reasoning_parser.py b/test/registered/unit/parser/test_reasoning_parser.py index 5b9d623d51b7..cdf7d7c9bd38 100644 --- a/test/registered/unit/parser/test_reasoning_parser.py +++ b/test/registered/unit/parser/test_reasoning_parser.py @@ -5,6 +5,7 @@ from sglang.srt.parser.reasoning_parser import ( BaseReasoningFormatDetector, DeepSeekR1Detector, + Gemma4Detector, Glm45Detector, KimiDetector, KimiK2Detector, @@ -586,6 +587,84 @@ def test_force_nonempty_content_no_thinking_tokens(self): self.assertEqual(result.reasoning_text, "") +class TestGemma4Detector(CustomTestCase): + def setUp(self): + self.detector = Gemma4Detector() + + def test_init(self): + """Test Gemma4Detector initialization.""" + self.assertEqual(self.detector.think_start_token, "<|channel>") + self.assertEqual(self.detector.think_end_token, "") + self.assertEqual(self.detector.think_start_self_label, "thought\n") + self.assertFalse(self.detector._in_reasoning) + self.assertTrue(self.detector.stream_reasoning) + + def test_detect_and_parse_complete_reasoning(self): + """Test parsing complete Gemma4 reasoning block (think_start_self_label is stripped).""" + text = "<|channel>thought\nLet me think about thisThe answer is 42." + result = self.detector.detect_and_parse(text) + self.assertEqual(result.reasoning_text, "Let me think about this") + self.assertEqual(result.normal_text, "The answer is 42.") + + def test_detect_and_parse_without_thinking(self): + """Test parsing without thinking (enable_thinking=False case).""" + text = "Direct answer without thinking." + result = self.detector.detect_and_parse(text) + self.assertEqual(result.normal_text, text) + self.assertEqual(result.reasoning_text, "") + + def test_detect_and_parse_reasoning_only(self): + """Test parsing when output is all reasoning (no end token yet).""" + text = "<|channel>thought\nStill thinking..." + result = self.detector.detect_and_parse(text) + self.assertEqual(result.reasoning_text, "Still thinking...") + self.assertEqual(result.normal_text, "") + + def test_streaming_complete_flow(self): + """Test streaming parse of Gemma4 reasoning flow.""" + chunks = [ + "<|channel>", + "thought\nreasoning content", + "", + "final answer", + ] + all_reasoning = "" + all_normal = "" + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk) + all_reasoning += result.reasoning_text + all_normal += result.normal_text + self.assertIn("reasoning content", all_reasoning) + self.assertIn("final answer", all_normal) + + def test_streaming_full_start_sequence(self): + """Test streaming with the full start sequence (token + self_label).""" + # Gemma4 start sequence is "<|channel>thought\n", not just "<|channel>" + result = self.detector.parse_streaming_increment("<|channel>thought\n") + self.assertEqual(result.normal_text, "") + self.assertEqual(result.reasoning_text, "") + self.assertTrue(self.detector._in_reasoning) + + result = self.detector.parse_streaming_increment("reasoning content") + self.assertEqual(result.reasoning_text, "reasoning content") + self.assertEqual(result.normal_text, "") + + def test_streaming_partial_start_buffered(self): + """Test that partial start sequence is buffered.""" + # "<|channel>" alone is a prefix of "<|channel>thought\n", so it's buffered + result = self.detector.parse_streaming_increment("<|channel>") + self.assertEqual(result.normal_text, "") + self.assertEqual(result.reasoning_text, "") + + def test_force_reasoning(self): + """Test Gemma4Detector with force_reasoning=True.""" + detector = Gemma4Detector(force_reasoning=True) + text = "This should be reasoningThe answer." + result = detector.detect_and_parse(text) + self.assertEqual(result.reasoning_text, "This should be reasoning") + self.assertEqual(result.normal_text, "The answer.") + + class TestReasoningParser(CustomTestCase): def test_init_valid_model(self): """Test initialization with valid model types.""" @@ -604,6 +683,9 @@ def test_init_valid_model(self): parser = ReasoningParser("glm45") self.assertIsInstance(parser.detector, Glm45Detector) + parser = ReasoningParser("gemma4") + self.assertIsInstance(parser.detector, Gemma4Detector) + def test_init_invalid_model(self): """Test initialization with invalid model type.""" with self.assertRaises(ValueError) as context: @@ -782,6 +864,35 @@ def test_kimi_streaming_scenario(self): self.assertIn("multiple factors", all_reasoning) self.assertIn("42", all_normal) + def test_gemma4_complete_response(self): + """Test complete Gemma4 response parsing (think_start_self_label stripped).""" + parser = ReasoningParser("gemma4") + text = "<|channel>thought\nI need to solve x + 2 = 5. Subtracting 2: x = 3.The answer is x = 3." + reasoning, normal = parser.parse_non_stream(text) + self.assertIn("x = 3", reasoning) + self.assertNotIn("thought\n", reasoning) + self.assertEqual(normal, "The answer is x = 3.") + + def test_gemma4_streaming_scenario(self): + """Test Gemma4 streaming scenario.""" + parser = ReasoningParser("gemma4") + chunks = [ + "<|channel>", + "thought\nLet me analyze.", + " Multiple factors.", + "", + "The solution is 42.", + ] + all_reasoning = "" + all_normal = "" + for chunk in chunks: + reasoning, normal = parser.parse_stream_chunk(chunk) + all_reasoning += reasoning + all_normal += normal + self.assertIn("analyze", all_reasoning) + self.assertIn("Multiple factors", all_reasoning) + self.assertIn("42", all_normal) + def test_empty_reasoning_blocks(self): """Test handling of empty reasoning blocks.""" parser = ReasoningParser("qwen3") From 71105ab9d7bd41773ebf28883f3812009f3a267d Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Sat, 4 Apr 2026 04:28:01 +0000 Subject: [PATCH 108/112] bring Gemma 4 function call parser test to unit/function_call/ --- .../test_function_call_parser.py | 3931 ----------------- .../test_function_call_parser.py | 108 +- 2 files changed, 107 insertions(+), 3932 deletions(-) delete mode 100644 test/registered/function_call/test_function_call_parser.py diff --git a/test/registered/function_call/test_function_call_parser.py b/test/registered/function_call/test_function_call_parser.py deleted file mode 100644 index 056d9fcad42d..000000000000 --- a/test/registered/function_call/test_function_call_parser.py +++ /dev/null @@ -1,3931 +0,0 @@ -import json -import unittest - -from sglang.srt.entrypoints.openai.protocol import Function, Tool -from sglang.srt.function_call.base_format_detector import BaseFormatDetector -from sglang.srt.function_call.core_types import StreamingParseResult -from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector -from sglang.srt.function_call.deepseekv32_detector import DeepSeekV32Detector -from sglang.srt.function_call.gemma4_detector import ( - Gemma4Detector, - _parse_gemma4_args, - _parse_gemma4_array, - _parse_gemma4_value, -) -from sglang.srt.function_call.gigachat3_detector import GigaChat3Detector -from sglang.srt.function_call.glm4_moe_detector import Glm4MoeDetector -from sglang.srt.function_call.glm47_moe_detector import Glm47MoeDetector -from sglang.srt.function_call.json_array_parser import JsonArrayParser -from sglang.srt.function_call.kimik2_detector import KimiK2Detector -from sglang.srt.function_call.lfm2_detector import Lfm2Detector -from sglang.srt.function_call.llama32_detector import Llama32Detector -from sglang.srt.function_call.mistral_detector import MistralDetector -from sglang.srt.function_call.pythonic_detector import PythonicDetector -from sglang.srt.function_call.qwen3_coder_detector import Qwen3CoderDetector -from sglang.test.ci.ci_register import register_cpu_ci - -register_cpu_ci(1.0, "default") - - -class TestPythonicDetector(unittest.TestCase): - def setUp(self): - # Create sample tools for testing - self.tools = [ - Tool( - type="function", - function=Function( - name="get_weather", - description="Get weather information", - parameters={ - "properties": { - "location": { - "type": "string", - "description": "Location to get weather for", - }, - "unit": { - "type": "string", - "description": "Temperature unit", - "enum": ["celsius", "fahrenheit"], - }, - }, - "required": ["location"], - }, - ), - ), - Tool( - type="function", - function=Function( - name="search", - description="Search for information", - parameters={ - "properties": { - "query": { - "type": "string", - "description": "Search query", - }, - }, - "required": ["query"], - }, - ), - ), - ] - self.detector = PythonicDetector() - - def test_parse_streaming_no_brackets(self): - """Test parsing text with no brackets (no tool calls).""" - text = "This is just normal text without any tool calls." - result = self.detector.parse_streaming_increment(text, self.tools) - - self.assertEqual(result.normal_text, text) - self.assertEqual(result.calls, []) - self.assertEqual(self.detector._buffer, "") # Buffer should be cleared - - def test_parse_streaming_complete_tool_call(self): - """Test parsing a complete tool call.""" - text = "Here's a tool call: [get_weather(location='New York', unit='celsius')]" - result = self.detector.parse_streaming_increment(text, self.tools) - - self.assertEqual(result.normal_text, "Here's a tool call: ") - self.assertEqual(len(result.calls), 1) - self.assertEqual(result.calls[0].name, "get_weather") - self.assertEqual( - self.detector._buffer, "" - ) # Buffer should be cleared after processing - - # Check the parameters - params = json.loads(result.calls[0].parameters) - self.assertEqual(params["location"], "New York") - self.assertEqual(params["unit"], "celsius") - - def test_parse_streaming_text_before_tool_call(self): - """Test parsing text that appears before a tool call.""" - text = "This is some text before [get_weather(location='London')]" - result = self.detector.parse_streaming_increment(text, self.tools) - - self.assertEqual(result.normal_text, "This is some text before ") - self.assertEqual(len(result.calls), 1) - self.assertEqual(result.calls[0].name, "get_weather") - - # Check the parameters - params = json.loads(result.calls[0].parameters) - self.assertEqual(params["location"], "London") - - def test_parse_streaming_partial_tool_call(self): - """Test parsing a partial tool call that spans multiple chunks.""" - # First chunk with opening bracket but no closing bracket - text1 = "Let me check the weather: [get_weather(location=" - result1 = self.detector.parse_streaming_increment(text1, self.tools) - - self.assertEqual(result1.normal_text, "Let me check the weather: ") - self.assertEqual(result1.calls, []) - self.assertEqual( - self.detector._buffer, "[get_weather(location=" - ) # Partial tool call remains in buffer - - # Second chunk completing the tool call - text2 = "'Paris')]" - result2 = self.detector.parse_streaming_increment(text2, self.tools) - - self.assertEqual(result2.normal_text, "") - self.assertEqual(len(result2.calls), 1) - self.assertEqual(result2.calls[0].name, "get_weather") - - # Check the parameters - params = json.loads(result2.calls[0].parameters) - self.assertEqual(params["location"], "Paris") - self.assertEqual( - self.detector._buffer, "" - ) # Buffer should be cleared after processing - - def test_parse_streaming_bracket_without_text_before(self): - """Test parsing a tool call that starts at the beginning of the text.""" - text = "[search(query='python programming')]" - result = self.detector.parse_streaming_increment(text, self.tools) - - self.assertEqual(result.normal_text, "") - self.assertEqual(len(result.calls), 1) - self.assertEqual(result.calls[0].name, "search") - - # Check the parameters - params = json.loads(result.calls[0].parameters) - self.assertEqual(params["query"], "python programming") - - def test_parse_streaming_text_after_tool_call(self): - """Test parsing text that appears after a tool call.""" - # First chunk with complete tool call and some text after - text = "[get_weather(location='Tokyo')] Here's the forecast:" - result = self.detector.parse_streaming_increment(text, self.tools) - - self.assertEqual(result.normal_text, "") - self.assertEqual(len(result.calls), 1) - self.assertEqual(result.calls[0].name, "get_weather") - self.assertEqual( - self.detector._buffer, " Here's the forecast:" - ) # Text after tool call remains in buffer - - # Process the remaining text in buffer - result2 = self.detector.parse_streaming_increment("", self.tools) - self.assertEqual(result2.normal_text, " Here's the forecast:") - self.assertEqual(result2.calls, []) - self.assertEqual(self.detector._buffer, "") # Buffer should be cleared - - def test_parse_streaming_multiple_tool_calls(self): - """Test parsing multiple tool calls in sequence.""" - text = "[get_weather(location='Berlin')] and [search(query='restaurants')]" - - # First tool call - result1 = self.detector.parse_streaming_increment(text, self.tools) - self.assertEqual(len(result1.calls), 1) - self.assertEqual(result1.calls[0].name, "get_weather") - self.assertEqual(self.detector._buffer, " and [search(query='restaurants')]") - - # Second tool call - result2 = self.detector.parse_streaming_increment("", self.tools) - self.assertEqual(result2.normal_text, " and ") - self.assertEqual(len(result2.calls), 1) - self.assertEqual(result2.calls[0].name, "search") - self.assertEqual(self.detector._buffer, "") - - def test_parse_streaming_opening_bracket_only(self): - """Test parsing text with only an opening bracket but no closing bracket.""" - text = "Let's try this: [" - result = self.detector.parse_streaming_increment(text, self.tools) - - self.assertEqual(result.normal_text, "Let's try this: ") - self.assertEqual(result.calls, []) - self.assertEqual( - self.detector._buffer, "[" - ) # Opening bracket remains in buffer - - def test_parse_streaming_nested_brackets(self): - """Test parsing tool calls with nested brackets in arguments.""" - # Test with list argument containing nested brackets - text = "[get_weather(location='New York', unit='celsius', data=[1, 2, 3])]" - result = self.detector.parse_streaming_increment(text, self.tools) - - self.assertEqual(result.normal_text, "") - self.assertEqual(len(result.calls), 1) - self.assertEqual(result.calls[0].name, "get_weather") - self.assertEqual(self.detector._buffer, "") - - # Check the parameters - params = json.loads(result.calls[0].parameters) - self.assertEqual(params["location"], "New York") - self.assertEqual(params["unit"], "celsius") - self.assertEqual(params["data"], [1, 2, 3]) - - def test_parse_streaming_nested_brackets_dict(self): - """Test parsing tool calls with nested dictionaries and lists.""" - # Test with nested dict and list arguments - text = "[search(query='test', config={'options': [1, 2], 'nested': {'key': 'value'}})]" - result = self.detector.parse_streaming_increment(text, self.tools) - - self.assertEqual(result.normal_text, "") - self.assertEqual(len(result.calls), 1) - self.assertEqual(result.calls[0].name, "search") - self.assertEqual(self.detector._buffer, "") - - # Check the parameters - params = json.loads(result.calls[0].parameters) - self.assertEqual(params["query"], "test") - self.assertEqual(params["config"]["options"], [1, 2]) - self.assertEqual(params["config"]["nested"]["key"], "value") - - def test_parse_streaming_multiple_tools_with_nested_brackets(self): - """Test parsing multiple tool calls with nested brackets.""" - text = "[get_weather(location='Paris', data=[10, 20]), search(query='test', filters=['a', 'b'])]" - result = self.detector.parse_streaming_increment(text, self.tools) - - self.assertEqual(result.normal_text, "") - self.assertEqual(len(result.calls), 2) - self.assertEqual(self.detector._buffer, "") - - # Check first tool call - params1 = json.loads(result.calls[0].parameters) - self.assertEqual(result.calls[0].name, "get_weather") - self.assertEqual(params1["location"], "Paris") - self.assertEqual(params1["data"], [10, 20]) - - # Check second tool call - params2 = json.loads(result.calls[1].parameters) - self.assertEqual(result.calls[1].name, "search") - self.assertEqual(params2["query"], "test") - self.assertEqual(params2["filters"], ["a", "b"]) - - def test_parse_streaming_partial_nested_brackets(self): - """Test parsing partial tool calls with nested brackets across chunks.""" - # First chunk with nested brackets but incomplete - text1 = "Here's a call: [get_weather(location='Tokyo', data=[1, 2" - result1 = self.detector.parse_streaming_increment(text1, self.tools) - - self.assertEqual(result1.normal_text, "Here's a call: ") - self.assertEqual(result1.calls, []) - self.assertEqual( - self.detector._buffer, "[get_weather(location='Tokyo', data=[1, 2" - ) - - # Second chunk completing the nested brackets - text2 = ", 3])]" - result2 = self.detector.parse_streaming_increment(text2, self.tools) - - self.assertEqual(result2.normal_text, "") - self.assertEqual(len(result2.calls), 1) - self.assertEqual(result2.calls[0].name, "get_weather") - self.assertEqual(self.detector._buffer, "") - - # Check the parameters - params = json.loads(result2.calls[0].parameters) - self.assertEqual(params["location"], "Tokyo") - self.assertEqual(params["data"], [1, 2, 3]) - - def test_parse_streaming_with_python_start_and_end_token(self): - """Test parsing a message that starts with <|python_start|> and <|python_end|> across chunks.""" - chunks = [ - "Here's a call: ", - "<|python_", - "start|>[get_weather(location=", - "'Tokyo', data=[1, 2", - ", 3])]<|python_end|>", - ] - - normal_text = "" - call_name = "" - parameters = "" - for chunk in chunks: - result = self.detector.parse_streaming_increment(chunk, self.tools) - if result.normal_text: - normal_text += result.normal_text - if result.calls: - call_name += result.calls[0].name - parameters += result.calls[0].parameters - - self.assertEqual(normal_text, "Here's a call: ") - self.assertEqual(call_name, "get_weather") - self.assertEqual(self.detector._buffer, "") - self.assertEqual( - result.normal_text, "", "Final result should have no normal text" - ) - - # Check the parameters - params = json.loads(parameters) - self.assertEqual(params["location"], "Tokyo") - self.assertEqual(params["data"], [1, 2, 3]) - - chunks = [ - "Here's a call: <|python_start|>[get_weather(location='Tokyo', data=[1, 2, 3])]<|python_end|>" - ] - - normal_text = "" - call_name = "" - parameters = "" - for chunk in chunks: - result = self.detector.parse_streaming_increment(chunk, self.tools) - if result.normal_text: - normal_text += result.normal_text - if result.calls: - call_name += result.calls[0].name - parameters += result.calls[0].parameters - - self.assertEqual(normal_text, "Here's a call: ") - self.assertEqual(call_name, "get_weather") - self.assertEqual(self.detector._buffer, "") - - # Check the parameters - params = json.loads(parameters) - self.assertEqual(params["location"], "Tokyo") - self.assertEqual(params["data"], [1, 2, 3]) - - def test_detect_and_parse_with_python_start_and_end_token(self): - """Test parsing a message that starts with <|python_start|> and contains a valid tool call.""" - text = "User wants to get the weather in Mars. <|python_start|>[get_weather(location='Mars', unit='celsius')]<|python_end|> In this way we will get the weather in Mars." - result = self.detector.detect_and_parse(text, self.tools) - - self.assertEqual( - result.normal_text, - "User wants to get the weather in Mars. In this way we will get the weather in Mars.", - ) - self.assertEqual(len(result.calls), 1) - self.assertEqual(result.calls[0].name, "get_weather") - self.assertEqual(self.detector._buffer, "") - - # Check the parameters - params = json.loads(result.calls[0].parameters) - self.assertEqual(params["location"], "Mars") - self.assertEqual(params["unit"], "celsius") - - -class TestMistralDetector(unittest.TestCase): - def setUp(self): - """Set up test tools and detector for Mistral format testing.""" - self.tools = [ - Tool( - type="function", - function=Function( - name="make_next_step_decision", - description="Test function for decision making", - parameters={ - "type": "object", - "properties": { - "decision": { - "type": "string", - "description": "The next step to take", - }, - "content": { - "type": "string", - "description": "The content of the next step", - }, - }, - "required": ["decision", "content"], - }, - ), - ), - ] - self.detector = MistralDetector() - - def test_detect_and_parse_with_nested_brackets_in_content(self): - """Test parsing Mistral format with nested brackets in JSON content. - - This test case specifically addresses the issue where the regex pattern - was incorrectly truncating JSON when it contained nested brackets like [City Name]. - """ - # This is the exact problematic text from the original test failure - test_text = '[TOOL_CALLS] [{"name":"make_next_step_decision", "arguments":{"decision":"","content":"```\\nTOOL: Access a weather API or service\\nOBSERVATION: Retrieve the current weather data for the top 5 populated cities in the US\\nANSWER: The weather in the top 5 populated cities in the US is as follows: [City Name] - [Weather Conditions] - [Temperature]\\n```"}}]' - - result = self.detector.detect_and_parse(test_text, self.tools) - - # Verify that the parsing was successful - self.assertEqual(len(result.calls), 1, "Should detect exactly one tool call") - - call = result.calls[0] - self.assertEqual( - call.name, - "make_next_step_decision", - "Should detect the correct function name", - ) - - # Verify that the parameters are valid JSON and contain the expected content - params = json.loads(call.parameters) - self.assertEqual( - params["decision"], "", "Decision parameter should be empty string" - ) - - # The content should contain the full text including the nested brackets [City Name] - expected_content = "```\nTOOL: Access a weather API or service\nOBSERVATION: Retrieve the current weather data for the top 5 populated cities in the US\nANSWER: The weather in the top 5 populated cities in the US is as follows: [City Name] - [Weather Conditions] - [Temperature]\n```" - self.assertEqual( - params["content"], - expected_content, - "Content should include nested brackets without truncation", - ) - - # Verify that normal text is empty (since the entire input is a tool call) - self.assertEqual( - result.normal_text, "", "Normal text should be empty for pure tool call" - ) - - def test_detect_and_parse_simple_case(self): - """Test parsing a simple Mistral format tool call without nested brackets.""" - test_text = '[TOOL_CALLS] [{"name":"make_next_step_decision", "arguments":{"decision":"TOOL", "content":"Use weather API"}}]' - - result = self.detector.detect_and_parse(test_text, self.tools) - - self.assertEqual(len(result.calls), 1) - call = result.calls[0] - self.assertEqual(call.name, "make_next_step_decision") - - params = json.loads(call.parameters) - self.assertEqual(params["decision"], "TOOL") - self.assertEqual(params["content"], "Use weather API") - - def test_detect_and_parse_no_tool_calls(self): - """Test parsing text without any tool calls.""" - test_text = "This is just normal text without any tool calls." - - result = self.detector.detect_and_parse(test_text, self.tools) - - self.assertEqual(len(result.calls), 0, "Should detect no tool calls") - self.assertEqual( - result.normal_text, - test_text, - "Should return the original text as normal text", - ) - - def test_detect_and_parse_with_text_before_tool_call(self): - """Test parsing text that has content before the tool call.""" - test_text = 'Here is some text before the tool call: [TOOL_CALLS] [{"name":"make_next_step_decision", "arguments":{"decision":"ANSWER", "content":"The answer is 42"}}]' - - result = self.detector.detect_and_parse(test_text, self.tools) - - self.assertEqual(len(result.calls), 1) - self.assertEqual(result.normal_text, "Here is some text before the tool call:") - - call = result.calls[0] - self.assertEqual(call.name, "make_next_step_decision") - - params = json.loads(call.parameters) - self.assertEqual(params["decision"], "ANSWER") - self.assertEqual(params["content"], "The answer is 42") - - def test_detect_and_parse_compact_args_format(self): - """Test parsing compact format: [TOOL_CALLS]name[ARGS]{...}.""" - test_text = '[TOOL_CALLS]make_next_step_decision[ARGS]{"decision":"TOOL", "content":"Use weather API"}' - - result = self.detector.detect_and_parse(test_text, self.tools) - self.assertEqual(len(result.calls), 1) - self.assertEqual(result.calls[0].name, "make_next_step_decision") - params = json.loads(result.calls[0].parameters) - self.assertEqual(params["decision"], "TOOL") - self.assertEqual(params["content"], "Use weather API") - - def test_streaming_compact_args_format_emits_tool_calls(self): - """Test streaming chunks for compact format produce tool_calls items.""" - chunks = [ - "[TOOL_CALLS]make_next_step_decision[ARGS]", - '{"decision":"TOOL", ', - '"content":"Use weather API"}', - ] - - emitted = [] - for chunk in chunks: - result = self.detector.parse_streaming_increment(chunk, self.tools) - if result.calls: - emitted.extend(result.calls) - - # Expect two items: name chunk + full args chunk - self.assertEqual(len(emitted), 2) - self.assertEqual(emitted[0].name, "make_next_step_decision") - self.assertEqual(emitted[0].parameters, "") - self.assertIsNone(emitted[1].name) - params = json.loads(emitted[1].parameters) - self.assertEqual(params["decision"], "TOOL") - self.assertEqual(params["content"], "Use weather API") - - -class TestBaseFormatDetector(unittest.TestCase): - """Test buffer management and sequential tool index assignment in BaseFormatDetector.""" - - def setUp(self): - """Set up test detector and tools.""" - - # Create a concrete implementation of BaseFormatDetector for testing - class TestFormatDetector(BaseFormatDetector): - def __init__(self): - super().__init__() - self.bot_token = "" - self.eot_token = "" - - def detect_and_parse(self, text, tools): - # Not used in streaming tests - pass - - def has_tool_call(self, text): - return "" in text - - def structure_info(self): - # Not used in streaming tests - pass - - self.detector = TestFormatDetector() - self.tools = [ - Tool( - type="function", - function=Function( - name="get_weather", - description="Get weather information", - parameters={ - "type": "object", - "properties": { - "city": { - "type": "string", - "description": "City name", - } - }, - "required": ["city"], - }, - ), - ), - Tool( - type="function", - function=Function( - name="get_tourist_attractions", - description="Get tourist attractions", - parameters={ - "type": "object", - "properties": { - "city": { - "type": "string", - "description": "City name", - } - }, - "required": ["city"], - }, - ), - ), - ] - - def test_sequential_tool_index_assignment(self): - """Test that multiple tool calls get sequential tool_index values (0, 1, 2, ...).""" - # Simulate streaming chunks for two consecutive tool calls - chunks = [ - "", - '{"name": "get_weather", ', - '"arguments": {"city": "Paris"}}', - ", ", - '{"name": "get_tourist_attractions", ', - '"arguments": {"city": "London"}}', - "", - ] - - tool_indices_seen = [] - - for chunk in chunks: - result = self.detector.parse_streaming_increment(chunk, self.tools) - - if result.calls: - for call in result.calls: - if call.tool_index is not None: - tool_indices_seen.append(call.tool_index) - - # Verify we got sequential tool indices - unique_indices = sorted(set(tool_indices_seen)) - self.assertEqual( - unique_indices, - [0, 1], - f"Expected sequential tool indices [0, 1], got {unique_indices}", - ) - - def test_buffer_content_preservation(self): - """Test that buffer correctly preserves unprocessed content when tool completes.""" - # Test simpler scenario: tool completion followed by new tool start - chunks = [ - "", - '{"name": "get_weather", ', - '"arguments": {"city": "Paris"}}', - ", ", - '{"name": "get_tourist_attractions", ', - '"arguments": {"city": "London"}} ', - ] - - tool_calls_seen = [] - - for chunk in chunks: - result = self.detector.parse_streaming_increment(chunk, self.tools) - if result.calls: - for call in result.calls: - if ( - call.name - ): # Only count calls with names (not just parameter updates) - tool_calls_seen.append(call.name) - - # Should see both tool names - self.assertIn("get_weather", tool_calls_seen, "Should process first tool") - self.assertIn( - "get_tourist_attractions", tool_calls_seen, "Should process second tool" - ) - - def test_current_tool_id_increment_on_completion(self): - """Test that current_tool_id increments when a tool completes.""" - # Initial state - self.assertEqual( - self.detector.current_tool_id, -1, "Should start with current_tool_id=-1" - ) - - # Process first tool completely - chunks = [ - "", - '{"name": "get_weather", ', - ] - - for chunk in chunks: - result = self.detector.parse_streaming_increment(chunk, self.tools) - - self.assertEqual( - self.detector.current_tool_id, 0, "current_tool_id should be 0" - ) - self.assertEqual( - result.calls[0].name, "get_weather", "The first tool should be get_weather" - ) - self.assertEqual( - result.calls[0].tool_index, 0, "The first tool index should be 0" - ) - - # Complete second tool name - this should show that current_tool_id is now 1 - result = self.detector.parse_streaming_increment( - '"arguments": {"city": "Paris"}}, {"name": "get_', self.tools - ) - self.assertEqual(result.calls[0].parameters, '{"city": "Paris"}') - - self.assertEqual( - self.detector.current_tool_id, - 1, - "current_tool_id should be 1 after first tool completes and second tool starts", - ) - - result = self.detector.parse_streaming_increment( - 'tourist_attractions", ', self.tools - ) - - # Second tool should have tool_index=1 - tourist_calls = [ - call for call in result.calls if call.name == "get_tourist_attractions" - ] - self.assertEqual( - tourist_calls[0].tool_index, 1, "Second tool should have tool_index=1" - ) - - def test_tool_name_streaming_with_correct_index(self): - """Test that tool names are streamed with correct tool_index values.""" - # Process first tool - self.detector.parse_streaming_increment("", self.tools) - result1 = self.detector.parse_streaming_increment( - '{"name": "get_weather", ', self.tools - ) - - # First tool name should have tool_index=0 - weather_calls = [call for call in result1.calls if call.name == "get_weather"] - self.assertEqual(len(weather_calls), 1, "Should have one weather call") - self.assertEqual( - weather_calls[0].tool_index, 0, "First tool should have tool_index=0" - ) - - # Complete first tool - self.detector.parse_streaming_increment( - '"arguments": {"city": "Paris"}}', self.tools - ) - - # Start second tool - self.detector.parse_streaming_increment(", ", self.tools) - result2 = self.detector.parse_streaming_increment( - '{"name": "get_tourist_attractions", ', self.tools - ) - - # Second tool name should have tool_index=1 - tourist_calls = [ - call for call in result2.calls if call.name == "get_tourist_attractions" - ] - self.assertEqual( - len(tourist_calls), 1, "Should have one tourist attractions call" - ) - self.assertEqual( - tourist_calls[0].tool_index, 1, "Second tool should have tool_index=1" - ) - - def test_buffer_reset_on_invalid_tool(self): - """Test that buffer and state are reset when an invalid tool name is encountered.""" - # Start fresh with an invalid tool name from the beginning - result = self.detector.parse_streaming_increment( - '{"name": "invalid_tool", ', self.tools - ) - - # Should return empty result and reset state - self.assertEqual(result.calls, [], "Should return no calls for invalid tool") - self.assertEqual( - self.detector.current_tool_id, - -1, - "current_tool_id should remain -1 for invalid tool", - ) - self.assertEqual( - self.detector._buffer, "", "Buffer should be cleared for invalid tool" - ) - - def test_chinese_characters_not_double_escaped(self): - """Test that Chinese characters in tool call parameters are not double-escaped.""" - # Test with Chinese city name "杭州" (Hangzhou) - chunks = [ - "", - '{"name": "get_weather", ', - '"arguments": {"city": "杭州"}}', - "", - ] - - accumulated_parameters = {} - for chunk in chunks: - result = self.detector.parse_streaming_increment(chunk, self.tools) - if result.calls: - for call in result.calls: - if call.parameters: - tool_idx = call.tool_index if call.tool_index is not None else 0 - if tool_idx not in accumulated_parameters: - accumulated_parameters[tool_idx] = "" - accumulated_parameters[tool_idx] += call.parameters - - # Verify that Chinese characters are preserved (not escaped as \uXXXX) - self.assertGreater( - len(accumulated_parameters), 0, "Should have parsed parameters" - ) - final_params_str = accumulated_parameters[0] - - # The parameters string should contain the actual Chinese characters, not escaped Unicode - self.assertIn( - "杭州", final_params_str, "Should contain actual Chinese characters" - ) - self.assertNotIn( - "\\u676d", final_params_str, "Should not contain escaped Unicode sequences" - ) - self.assertNotIn( - "\\u5dde", final_params_str, "Should not contain escaped Unicode sequences" - ) - - # Verify the JSON can be parsed and contains the correct value - params = json.loads(final_params_str) - self.assertEqual( - params["city"], "杭州", "Should correctly parse Chinese city name" - ) - - def test_chinese_characters_incremental_streaming(self): - """Test that Chinese characters work correctly with incremental streaming.""" - # Test incremental streaming with Chinese characters - chunks = [ - "", - '{"name": "get_weather", ', - '"arguments": {"city": "', - "杭州", - '"}}', - "", - ] - - accumulated_parameters = {} - for chunk in chunks: - result = self.detector.parse_streaming_increment(chunk, self.tools) - if result.calls: - for call in result.calls: - if call.parameters: - tool_idx = call.tool_index if call.tool_index is not None else 0 - if tool_idx not in accumulated_parameters: - accumulated_parameters[tool_idx] = "" - accumulated_parameters[tool_idx] += call.parameters - - # Verify Chinese characters are preserved throughout streaming - self.assertGreater( - len(accumulated_parameters), 0, "Should have parsed parameters" - ) - final_params_str = accumulated_parameters[0] - - # Should contain actual Chinese characters, not escaped - self.assertIn( - "杭州", final_params_str, "Should contain actual Chinese characters" - ) - - # Parse and verify - params = json.loads(final_params_str) - self.assertEqual( - params["city"], "杭州", "Should correctly parse Chinese city name" - ) - - def test_multiple_chinese_parameters(self): - """Test multiple tool calls with Chinese parameters.""" - # Test with multiple tool calls containing Chinese characters - chunks = [ - "", - '{"name": "get_weather", "arguments": {"city": "北京"}}, ', - '{"name": "get_tourist_attractions", "arguments": {"city": "上海"}}', - "", - ] - - accumulated_parameters = {} - for chunk in chunks: - result = self.detector.parse_streaming_increment(chunk, self.tools) - if result.calls: - for call in result.calls: - if call.parameters: - tool_idx = call.tool_index if call.tool_index is not None else 0 - if tool_idx not in accumulated_parameters: - accumulated_parameters[tool_idx] = "" - accumulated_parameters[tool_idx] += call.parameters - - # Verify both tool calls have correct Chinese characters - self.assertGreaterEqual( - len(accumulated_parameters), 1, "Should have parsed parameters" - ) - - # Check first tool call (北京 - Beijing) - if 0 in accumulated_parameters: - params0 = json.loads(accumulated_parameters[0]) - self.assertIn( - "北京", - accumulated_parameters[0], - "Should contain actual Chinese characters", - ) - self.assertEqual( - params0["city"], "北京", "Should correctly parse first Chinese city" - ) - - # Check second tool call (上海 - Shanghai) if present - if 1 in accumulated_parameters: - params1 = json.loads(accumulated_parameters[1]) - self.assertIn( - "上海", - accumulated_parameters[1], - "Should contain actual Chinese characters", - ) - self.assertEqual( - params1["city"], "上海", "Should correctly parse second Chinese city" - ) - - -class TestLlama32Detector(unittest.TestCase): - def setUp(self): - """Set up test tools and detector for Mistral format testing.""" - self.tools = [ - Tool( - type="function", - function=Function( - name="get_weather", - description="Get weather information", - parameters={ - "type": "object", - "properties": { - "city": { - "type": "string", - "description": "City name", - } - }, - "required": ["city"], - }, - ), - ), - Tool( - type="function", - function=Function( - name="get_tourist_attractions", - description="Get tourist attractions", - parameters={ - "type": "object", - "properties": { - "city": { - "type": "string", - "description": "City name", - } - }, - "required": ["city"], - }, - ), - ), - ] - self.detector = Llama32Detector() - - def test_single_json(self): - text = '{"name": "get_weather", "parameters": {"city": "Paris"}}' - result = self.detector.detect_and_parse(text, self.tools) - assert len(result.calls) == 1 - assert result.calls[0].name == "get_weather" - assert result.normal_text == "" - - def test_multiple_json_with_separator(self): - text = ( - '<|python_tag|>{"name": "get_weather", "parameters": {"city": "Paris"}};' - '{"name": "get_tourist_attractions", "parameters": {"city": "Paris"}}' - ) - result = self.detector.detect_and_parse(text, self.tools) - self.assertEqual(len(result.calls), 2) - self.assertEqual(result.calls[1].name, "get_tourist_attractions") - self.assertEqual(result.normal_text, "") - - def test_multiple_json_with_separator_customized(self): - text = ( - '<|python_tag|>{"name": "get_weather", "parameters": {}}' - '<|python_tag|>{"name": "get_tourist_attractions", "parameters": {}}' - ) - result = self.detector.detect_and_parse(text, self.tools) - self.assertEqual(len(result.calls), 2) - self.assertEqual(result.calls[1].name, "get_tourist_attractions") - self.assertEqual(result.normal_text, "") - - def test_json_with_trailing_text(self): - text = '{"name": "get_weather", "parameters": {}} Some follow-up text' - result = self.detector.detect_and_parse(text, self.tools) - self.assertEqual(len(result.calls), 1) - self.assertIn("follow-up", result.normal_text) - - def test_invalid_then_valid_json(self): - text = ( - '{"name": "get_weather", "parameters": {' # malformed - '{"name": "get_weather", "parameters": {}}' - ) - result = self.detector.detect_and_parse(text, self.tools) - self.assertEqual(len(result.calls), 1) - self.assertEqual(result.calls[0].name, "get_weather") - - def test_plain_text_only(self): - text = "This is just plain explanation text." - result = self.detector.detect_and_parse(text, self.tools) - self.assertEqual(result.calls, []) - self.assertEqual(result.normal_text, text) - - def test_with_python_tag_prefix(self): - text = 'Some intro. <|python_tag|>{"name": "get_weather", "parameters": {}}' - result = self.detector.detect_and_parse(text, self.tools) - self.assertEqual(len(result.calls), 1) - self.assertTrue(result.normal_text.strip().startswith("Some intro.")) - - -class TestKimiK2Detector(unittest.TestCase): - - def setUp(self): - """Set up test tools and detector.""" - self.tools = [ - Tool( - type="function", - function=Function( - name="get_weather", - description="Get weather information", - parameters={ - "type": "object", - "properties": { - "city": { - "type": "string", - "description": "City name", - } - }, - "required": ["city"], - }, - ), - ), - Tool( - type="function", - function=Function( - name="get_tourist_attractions", - description="Get tourist attractions", - parameters={ - "type": "object", - "properties": { - "city": { - "type": "string", - "description": "City name", - } - }, - "required": ["city"], - }, - ), - ), - ] - self.detector = KimiK2Detector() - - def test_single_tool_call(self): - """Test parsing a single tool call in a complete text.""" - text = '<|tool_calls_section_begin|><|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{"city": "Paris"}<|tool_call_end|><|tool_calls_section_end|>' - result = self.detector.detect_and_parse(text, self.tools) - self.assertEqual(len(result.calls), 1) - self.assertEqual(result.calls[0].name, "get_weather") - self.assertEqual(result.calls[0].parameters, '{"city": "Paris"}') - self.assertEqual(result.normal_text, "") - - def test_multiple_tool_calls(self): - """Test parsing multiple tool calls in a complete text.""" - text = '<|tool_calls_section_begin|><|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{"city": "Paris"}<|tool_call_end|><|tool_call_begin|>functions.get_tourist_attractions:1<|tool_call_argument_begin|>{"city": "London"}<|tool_call_end|><|tool_calls_section_end|>' - result = self.detector.detect_and_parse(text, self.tools) - self.assertEqual(len(result.calls), 2) - self.assertEqual(result.calls[0].name, "get_weather") - self.assertEqual(result.calls[0].parameters, '{"city": "Paris"}') - self.assertEqual(result.calls[1].name, "get_tourist_attractions") - self.assertEqual(result.calls[1].parameters, '{"city": "London"}') - self.assertEqual(result.normal_text, "") - - def test_streaming_tool_call(self): - """Test streaming incremental parsing of a tool call.""" - chunks = [ - "<|tool_calls_section_begin|><|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{", - '"city": "Paris"', - "}", - "<|tool_call_end|><|tool_calls_section_end|>", - ] - - tool_calls = [] - for chunk in chunks: - result = self.detector.parse_streaming_increment(chunk, self.tools) - for tool_call_chunk in result.calls: - if tool_call_chunk.tool_index is not None: - - while len(tool_calls) <= tool_call_chunk.tool_index: - tool_calls.append({"name": "", "parameters": ""}) - - tc = tool_calls[tool_call_chunk.tool_index] - - if tool_call_chunk.name: - tc["name"] += tool_call_chunk.name - if tool_call_chunk.parameters: - tc["parameters"] += tool_call_chunk.parameters - - self.assertEqual(len(tool_calls), 1) - self.assertEqual(tool_calls[0]["name"], "get_weather") - self.assertEqual(tool_calls[0]["parameters"], '{"city": "Paris"}') - - def test_streaming_multiple_tool_calls(self): - """Test streaming incremental parsing of multiple tool calls.""" - chunks = [ - "<|tool_calls_section_begin|><|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{", - '"city": "Paris"', - "}<|tool_call_end|>", - "<|tool_call_begin|>functions.get_tourist_attractions:1<|tool_call_argument_begin|>{", - '"city": "London"', - "}<|tool_call_end|>", - "<|tool_calls_section_end|>", - ] - - tool_calls = [] - for chunk in chunks: - result = self.detector.parse_streaming_increment(chunk, self.tools) - for tool_call_chunk in result.calls: - if tool_call_chunk.tool_index is not None: - - while len(tool_calls) <= tool_call_chunk.tool_index: - tool_calls.append({"name": "", "parameters": ""}) - - tc = tool_calls[tool_call_chunk.tool_index] - - if tool_call_chunk.name: - tc["name"] += tool_call_chunk.name - if tool_call_chunk.parameters: - tc["parameters"] += tool_call_chunk.parameters - - self.assertEqual(len(tool_calls), 2) - self.assertEqual(tool_calls[0]["name"], "get_weather") - self.assertEqual(tool_calls[0]["parameters"], '{"city": "Paris"}') - self.assertEqual(tool_calls[1]["name"], "get_tourist_attractions") - self.assertEqual(tool_calls[1]["parameters"], '{"city": "London"}') - - def test_tool_call_completion(self): - """Test that the buffer and state are reset after a tool call is completed.""" - chunks = [ - "<|tool_calls_section_begin|><|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{", - '"city": "Paris"', - "}", - "<|tool_call_end|>", - "<|tool_calls_section_end|>", - ] - - for chunk in chunks: - result = self.detector.parse_streaming_increment(chunk, self.tools) - - # After processing all chunks, the buffer should be empty and current_tool_id should be reset - self.assertEqual(self.detector._buffer, "") - self.assertEqual(self.detector.current_tool_id, 1) - - def test_tool_name_streaming(self): - """Test that tool names are streamed correctly with the right index.""" - chunks = [ - "<|tool_calls_section_begin|><|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{", - '"city": "Paris"', - "}", - "<|tool_call_end|>", - "<|tool_call_begin|>functions.get_tourist_attractions:1<|tool_call_argument_begin|>{", - ] - - tool_calls = [] - for chunk in chunks: - result = self.detector.parse_streaming_increment(chunk, self.tools) - for tool_call_chunk in result.calls: - if tool_call_chunk.tool_index is not None: - - while len(tool_calls) <= tool_call_chunk.tool_index: - tool_calls.append({"name": "", "parameters": ""}) - - tc = tool_calls[tool_call_chunk.tool_index] - - if tool_call_chunk.name: - tc["name"] += tool_call_chunk.name - if tool_call_chunk.parameters: - tc["parameters"] += tool_call_chunk.parameters - - self.assertEqual(len(tool_calls), 2) - self.assertEqual(tool_calls[0]["name"], "get_weather") - self.assertEqual(tool_calls[0]["parameters"], '{"city": "Paris"}') - self.assertEqual(tool_calls[1]["name"], "get_tourist_attractions") - - def test_invalid_tool_call(self): - """Test that invalid tool calls are handled correctly.""" - text = 'invalid_tool:0<|tool_call_argument_begin|>{"city": "Paris"}<|tool_call_end|><|tool_calls_section_end|>' - result = self.detector.detect_and_parse(text, self.tools) - self.assertEqual(len(result.calls), 0) - self.assertEqual(result.normal_text, text) - - def test_partial_tool_call(self): - """Test that partial tool calls are handled correctly in streaming mode.""" - chunks = [ - "<|tool_calls_section_begin|><|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{", - '"city": "Paris"', - ] - - tool_calls = [] - for chunk in chunks: - result = self.detector.parse_streaming_increment(chunk, self.tools) - for tool_call_chunk in result.calls: - if tool_call_chunk.tool_index is not None: - - while len(tool_calls) <= tool_call_chunk.tool_index: - tool_calls.append({"name": "", "parameters": ""}) - - tc = tool_calls[tool_call_chunk.tool_index] - - if tool_call_chunk.name: - tc["name"] += tool_call_chunk.name - if tool_call_chunk.parameters: - tc["parameters"] += tool_call_chunk.parameters - - self.assertEqual(len(tool_calls), 1) - self.assertEqual(tool_calls[0]["name"], "get_weather") - self.assertEqual(tool_calls[0]["parameters"], '{"city": "Paris"') - - -class TestDeepSeekV3Detector(unittest.TestCase): - def setUp(self): - """Set up test tools and detector for DeepSeekV3 format testing.""" - self.tools = [ - Tool( - type="function", - function=Function( - name="get_weather", - description="Get weather information", - parameters={ - "type": "object", - "properties": { - "city": { - "type": "string", - "description": "City name", - } - }, - "required": ["city"], - }, - ), - ), - Tool( - type="function", - function=Function( - name="get_tourist_attractions", - description="Get tourist attractions", - parameters={ - "type": "object", - "properties": { - "city": { - "type": "string", - "description": "City name", - } - }, - "required": ["city"], - }, - ), - ), - ] - self.detector = DeepSeekV3Detector() - - def test_parse_streaming_multiple_tool_calls_with_multi_token_chunk(self): - """Test parsing multiple tool calls when streaming chunks contains multi-tokens (e.g. DeepSeekV3 enable MTP)""" - # Simulate streaming chunks with multi-tokens for two consecutive tool calls - chunks = [ - "<|tool▁calls▁begin|>", - "<|tool▁call▁begin|>function", - "<|tool▁sep|>get", - "_weather\n", - "```json\n", - '{"city":', - '"Shanghai', - '"}\n```<|tool▁call▁end|>', - "\n<|tool▁call▁begin|>", - "function<|tool▁sep|>", - "get_tour", - "ist_att", - "ractions\n```" 'json\n{"', - 'city": "', - 'Beijing"}\n', - "```<|tool▁call▁end|>", - "<|tool▁calls▁end|>", - ] - - tool_calls_seen = [] - tool_calls_parameters = [] - - for chunk in chunks: - result = self.detector.parse_streaming_increment(chunk, self.tools) - if result.calls: - for call in result.calls: - if call.name: - tool_calls_seen.append(call.name) - if call.parameters: - tool_calls_parameters.append(call.parameters) - - # Should see both tool names - self.assertIn("get_weather", tool_calls_seen, "Should process first tool") - self.assertIn( - "get_tourist_attractions", tool_calls_seen, "Should process second tool" - ) - - # Verify that the parameters are valid JSON and contain the expected content - params1 = json.loads(tool_calls_parameters[0]) - params2 = json.loads(tool_calls_parameters[1]) - self.assertEqual(params1["city"], "Shanghai") - self.assertEqual(params2["city"], "Beijing") - - -class TestDeepSeekV32Detector(unittest.TestCase): - def setUp(self): - """Set up test tools and detector for DeepSeekV32 format testing.""" - self.tools = [ - Tool( - type="function", - function=Function( - name="search", - description="Searches for information related to query and displays topn results.", - parameters={ - "type": "object", - "properties": { - "query": { - "type": "string", - "description": "The search query string", - }, - "topn": { - "type": "integer", - "description": "Number of top results to display", - "default": 10, - }, - "source": { - "type": "string", - "description": "Source to search within", - "enum": ["web", "news"], - "default": "web", - }, - }, - "required": ["query"], - }, - ), - ), - Tool( - type="function", - function=Function( - name="get_favorite_tourist_spot", - description="Return the favorite tourist spot for a given city.", - parameters={ - "type": "object", - "properties": {"city": {"type": "string"}}, - "required": ["city"], - }, - ), - ), - ] - self.detector = DeepSeekV32Detector() - from transformers import AutoTokenizer - - self.tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-V3.2") - self.interval = 1 - - def test_detect_and_parse_xml_format(self): - """Test parsing standard XML format (DSML)""" - text = """I'll help you with information about San Francisco and get its favorite tourist spot for you.\n\n - <|DSML|function_calls>\n - <|DSML|invoke name="get_favorite_tourist_spot">\n - <|DSML|parameter name="city" string="true">San Francisco\n - \n - <|DSML|invoke name="search"> - <|DSML|parameter name="query" string="true">WebNav benchmark - <|DSML|parameter name="topn" string="false">10 - <|DSML|parameter name="source" string="true">web - - - """ - result = self.detector.detect_and_parse(text, self.tools) - - self.assertIn("I'll help you with information", result.normal_text) - self.assertEqual(len(result.calls), 2) - - # Check first call - call1 = result.calls[0] - self.assertEqual(call1.name, "get_favorite_tourist_spot") - params1 = json.loads(call1.parameters) - self.assertEqual(params1["city"], "San Francisco") - - # Check second call - call2 = result.calls[1] - self.assertEqual(call2.name, "search") - params2 = json.loads(call2.parameters) - self.assertEqual(params2["query"], "WebNav benchmark") - self.assertEqual(params2["topn"], 10) - self.assertEqual(params2["source"], "web") - - def test_detect_and_parse_json_format(self): - """Test parsing JSON format inside invoke tags""" - text = """I'll help you with information about San Francisco and get its favorite tourist spot for you. - - <|DSML|function_calls> - <|DSML|invoke name="get_favorite_tourist_spot"> - { - "city": "San Francisco" - } - - <|DSML|invoke name="search"> - { - "query": "WebNav benchmark", - "topn": 10, - "source": "web" - } - - - """ - result = self.detector.detect_and_parse(text, self.tools) - - self.assertIn("I'll help you with information", result.normal_text) - self.assertEqual(len(result.calls), 2) - - # Check first call - call1 = result.calls[0] - self.assertEqual(call1.name, "get_favorite_tourist_spot") - params1 = json.loads(call1.parameters) - self.assertEqual(params1["city"], "San Francisco") - - # Check second call - call2 = result.calls[1] - self.assertEqual(call2.name, "search") - params2 = json.loads(call2.parameters) - self.assertEqual(params2["query"], "WebNav benchmark") - self.assertEqual(params2["topn"], 10) - self.assertEqual(params2["source"], "web") - - def test_streaming_xml_format(self): - """Test streaming parsing of XML format""" - text = """<|DSML|function_calls> - <|DSML|invoke name="get_favorite_tourist_spot"> - <|DSML|parameter name="city" string="true">San Francisco - <|DSML|parameter name="another_city" string="true">London - <|DSML|parameter name="topn" string="false">10 - <|DSML|parameter name="obj" string="false">{"name": "John", "age": 30} - - """ - - input_ids = self.tokenizer.encode(text, add_special_tokens=False) - chunk_ids = [ - input_ids[i : i + self.interval] - for i in range(0, len(input_ids), self.interval) - ] - chunks = [self.tokenizer.decode(chunk_id) for chunk_id in chunk_ids] - - tool_calls_by_index = {} - - num_tool_call_chunks = 0 - for chunk in chunks: - result = self.detector.parse_streaming_increment(chunk, self.tools) - for call in result.calls: - num_tool_call_chunks += 1 - if call.tool_index is not None: - if call.tool_index not in tool_calls_by_index: - tool_calls_by_index[call.tool_index] = { - "name": "", - "parameters": "", - } - - if call.name: - tool_calls_by_index[call.tool_index]["name"] = call.name - if call.parameters: - tool_calls_by_index[call.tool_index][ - "parameters" - ] += call.parameters - - self.assertGreater(num_tool_call_chunks, 8) - - self.assertEqual(len(tool_calls_by_index), 1) - self.assertEqual(tool_calls_by_index[0]["name"], "get_favorite_tourist_spot") - params = json.loads(tool_calls_by_index[0]["parameters"]) - self.assertEqual(params["city"], "San Francisco") - self.assertEqual(params["another_city"], "London") - self.assertEqual(params["topn"], 10) - self.assertEqual(params["obj"]["name"], "John") - self.assertEqual(params["obj"]["age"], 30) - - def test_streaming_json_format(self): - """Test streaming parsing of JSON format""" - text = """<|DSML|function_calls> - <|DSML|invoke name="get_favorite_tourist_spot"> - { - "city": "San Francisco", - "another_city": "London", - "topn": 10, - "obj": { - "name": "John", - "age": 30 - } - } - - """ - - input_ids = self.tokenizer.encode(text, add_special_tokens=False) - chunk_ids = [ - input_ids[i : i + self.interval] - for i in range(0, len(input_ids), self.interval) - ] - chunks = [self.tokenizer.decode(chunk_id) for chunk_id in chunk_ids] - - tool_calls_by_index = {} - - num_tool_call_chunks = 0 - for chunk in chunks: - result = self.detector.parse_streaming_increment(chunk, self.tools) - for call in result.calls: - num_tool_call_chunks += 1 - if call.tool_index is not None: - if call.tool_index not in tool_calls_by_index: - tool_calls_by_index[call.tool_index] = { - "name": "", - "parameters": "", - } - - if call.name: - tool_calls_by_index[call.tool_index]["name"] = call.name - if call.parameters: - tool_calls_by_index[call.tool_index][ - "parameters" - ] += call.parameters - - self.assertGreater(num_tool_call_chunks, 8) - self.assertEqual(len(tool_calls_by_index), 1) - self.assertEqual(tool_calls_by_index[0]["name"], "get_favorite_tourist_spot") - - # Clean up parameters string if needed (trim whitespace) - params_str = tool_calls_by_index[0]["parameters"].strip() - params = json.loads(params_str) - self.assertEqual(params["city"], "San Francisco") - - def test_detect_and_parse_no_parameters(self): - """Test parsing function calls with no parameters (non-streaming)""" - # Add a no-parameter tool - tools_with_no_param = self.tools + [ - Tool( - type="function", - function=Function( - name="get_date", - description="Get the current date.", - parameters={"type": "object", "properties": {}}, - ), - ), - ] - - text = """Let me get the current date for you. - -<|DSML|function_calls> -<|DSML|invoke name="get_date"> - -""" - - result = self.detector.detect_and_parse(text, tools_with_no_param) - - self.assertIn("Let me get the current date", result.normal_text) - self.assertEqual(len(result.calls), 1) - - call = result.calls[0] - self.assertEqual(call.name, "get_date") - params = json.loads(call.parameters) - self.assertEqual(params, {}) - - def test_streaming_no_parameters(self): - """Test streaming parsing of function calls with no parameters. - - This test verifies the fix for the bug where functions with no parameters - were being silently skipped in streaming mode. - """ - # Add a no-parameter tool - tools_with_no_param = self.tools + [ - Tool( - type="function", - function=Function( - name="get_date", - description="Get the current date.", - parameters={"type": "object", "properties": {}}, - ), - ), - ] - - text = """<|DSML|function_calls> -<|DSML|invoke name="get_date"> - -""" - - # Reset detector state - self.detector = DeepSeekV32Detector() - - # Simulate streaming by splitting into small chunks - input_ids = self.tokenizer.encode(text, add_special_tokens=False) - chunk_ids = [ - input_ids[i : i + self.interval] - for i in range(0, len(input_ids), self.interval) - ] - chunks = [self.tokenizer.decode(chunk_id) for chunk_id in chunk_ids] - - tool_calls_by_index = {} - - for chunk in chunks: - result = self.detector.parse_streaming_increment(chunk, tools_with_no_param) - for call in result.calls: - if call.tool_index is not None: - if call.tool_index not in tool_calls_by_index: - tool_calls_by_index[call.tool_index] = { - "name": "", - "parameters": "", - } - - if call.name: - tool_calls_by_index[call.tool_index]["name"] = call.name - if call.parameters: - tool_calls_by_index[call.tool_index][ - "parameters" - ] += call.parameters - - # Verify that the no-parameter function was correctly parsed - self.assertEqual( - len(tool_calls_by_index), 1, "Should have exactly one tool call" - ) - self.assertEqual(tool_calls_by_index[0]["name"], "get_date") - - # Parameters should be empty JSON object - params_str = tool_calls_by_index[0]["parameters"].strip() - params = json.loads(params_str) - self.assertEqual(params, {}) - - def test_streaming_no_parameters_with_whitespace(self): - """Test streaming parsing when invoke content has only whitespace (newlines).""" - tools_with_no_param = self.tools + [ - Tool( - type="function", - function=Function( - name="get_date", - description="Get the current date.", - parameters={"type": "object", "properties": {}}, - ), - ), - ] - - # This format has newlines inside the invoke tag (common model output) - text = """<|DSML|function_calls> -<|DSML|invoke name="get_date"> - - -""" - - # Reset detector state - self.detector = DeepSeekV32Detector() - - input_ids = self.tokenizer.encode(text, add_special_tokens=False) - chunk_ids = [ - input_ids[i : i + self.interval] - for i in range(0, len(input_ids), self.interval) - ] - chunks = [self.tokenizer.decode(chunk_id) for chunk_id in chunk_ids] - - tool_calls_by_index = {} - - for chunk in chunks: - result = self.detector.parse_streaming_increment(chunk, tools_with_no_param) - for call in result.calls: - if call.tool_index is not None: - if call.tool_index not in tool_calls_by_index: - tool_calls_by_index[call.tool_index] = { - "name": "", - "parameters": "", - } - - if call.name: - tool_calls_by_index[call.tool_index]["name"] = call.name - if call.parameters: - tool_calls_by_index[call.tool_index][ - "parameters" - ] += call.parameters - - # Should still parse correctly even with whitespace-only content - self.assertEqual( - len(tool_calls_by_index), 1, "Should have exactly one tool call" - ) - self.assertEqual(tool_calls_by_index[0]["name"], "get_date") - params = json.loads(tool_calls_by_index[0]["parameters"]) - self.assertEqual(params, {}) - - -class TestQwen3CoderDetector(unittest.TestCase): - """Test suite for Qwen3CoderDetector.""" - - def setUp(self): - """Initialize test fixtures before each test method.""" - self.tools = [ - Tool( - type="function", - function=Function( - name="get_current_weather", - parameters={ - "type": "object", - "properties": { - "location": {"type": "string"}, - "unit": { - "type": "string", - "enum": ["celsius", "fahrenheit"], - }, - "days": {"type": "integer"}, - }, - "required": ["location"], - }, - ), - ), - Tool( - type="function", - function=Function( - name="sql_interpreter", - parameters={ - "type": "object", - "properties": { - "query": {"type": "string"}, - "dry_run": {"type": "boolean"}, - }, - }, - ), - ), - Tool( - type="function", - function=Function( - name="TodoWrite", - parameters={ - "type": "object", - "properties": { - "todos": { - "type": "array", - "items": { - "type": "object", - "properties": { - "content": {"type": "string"}, - "status": {"type": "string"}, - }, - "required": ["content", "status"], - }, - }, - }, - }, - ), - ), - ] - self.detector = Qwen3CoderDetector() - - # ==================== Basic Functionality Tests ==================== - - def test_plain_text_only(self): - """ - Test parsing of plain text without any tool calls. - - Scenario: Input contains only plain text, no tool call markers. - Purpose: Verify that plain text is correctly identified and no false tool calls are detected. - """ - text = "This is plain text without any tool calls." - result = self.detector.detect_and_parse(text, self.tools) - - self.assertEqual(result.normal_text, text) - self.assertEqual(len(result.calls), 0) - - def test_single_tool_call(self): - """ - Test parsing of a single tool call. - - Scenario: Input contains one complete tool call with parameters. - Purpose: Verify correct extraction of tool name and parameters. - """ - text = """ - -Boston -celsius -3 - -""" - result = self.detector.detect_and_parse(text, self.tools) - - self.assertEqual(len(result.calls), 1) - self.assertEqual(result.calls[0].name, "get_current_weather") - - params = json.loads(result.calls[0].parameters) - self.assertEqual(params["location"], "Boston") - self.assertEqual(params["unit"], "celsius") - self.assertEqual(params["days"], 3) - - def test_single_tool_call_with_text_prefix(self): - """ - Test parsing of tool call with preceding text. - - Scenario: Input has plain text followed by a tool call. - Purpose: Verify correct separation of text and tool call. - """ - text = """Let me check the weather for you. - - - -New York - -""" - result = self.detector.detect_and_parse(text, self.tools) - - self.assertTrue(result.normal_text.startswith("Let me check")) - self.assertEqual(len(result.calls), 1) - self.assertEqual(result.calls[0].name, "get_current_weather") - - def test_multiple_tool_calls(self): - """ - Test parsing of multiple consecutive tool calls. - - Scenario: Input contains two tool calls one after another. - Purpose: Verify that multiple tool calls are correctly identified and parsed. - """ - text = """ - -New York - - - - -SELECT * FROM users -True - -""" - result = self.detector.detect_and_parse(text, self.tools) - - self.assertEqual(len(result.calls), 2) - self.assertEqual(result.calls[0].name, "get_current_weather") - self.assertEqual(result.calls[1].name, "sql_interpreter") - - params1 = json.loads(result.calls[0].parameters) - self.assertEqual(params1["location"], "New York") - - params2 = json.loads(result.calls[1].parameters) - self.assertEqual(params2["query"], "SELECT * FROM users") - self.assertEqual(params2["dry_run"], True) - - # ==================== Streaming Tests ==================== - - def test_streaming_single_tool_call(self): - """ - Test streaming parsing of a single tool call. - - Scenario: Tool call is fed incrementally in chunks. - Purpose: Verify streaming parser correctly assembles tool call from chunks. - """ - chunks = [ - "", - "", - "", - "Boston", - "", - "celsius", - "", - "", - ] - - detector = Qwen3CoderDetector() - all_calls = [] - collected_params = "" - - for chunk in chunks: - result = detector.parse_streaming_increment(chunk, self.tools) - all_calls.extend(result.calls) - for call in result.calls: - if call.parameters: - collected_params += call.parameters - - # Verify we got the tool call - self.assertGreater(len(all_calls), 0) - - # Verify parameters were collected - if collected_params: - params = json.loads(collected_params) - self.assertEqual(params["location"], "Boston") - self.assertEqual(params["unit"], "celsius") - - def test_streaming_with_text_and_tool(self): - """ - Test streaming parsing with mixed text and tool call. - - Scenario: Stream contains plain text followed by a tool call. - Purpose: Verify correct separation in streaming mode. - """ - chunks = [ - "Let me ", - "help you.\n\n", - "", - "", - "Paris", - "", - "", - ] - - detector = Qwen3CoderDetector() - full_text = "" - all_calls = [] - - for chunk in chunks: - result = detector.parse_streaming_increment(chunk, self.tools) - if result.normal_text: - full_text += result.normal_text - all_calls.extend(result.calls) - - self.assertTrue(full_text.startswith("Let me")) - self.assertGreater(len(all_calls), 0) - - # ==================== Parameter Type Tests ==================== - - def test_integer_parameter_conversion(self): - """ - Test correct type conversion for integer parameters. - - Scenario: Tool call with integer parameter. - Purpose: Verify integer values are correctly parsed and typed. - """ - text = """ - -Tokyo -5 - -""" - result = self.detector.detect_and_parse(text, self.tools) - - params = json.loads(result.calls[0].parameters) - self.assertIsInstance(params["days"], int) - self.assertEqual(params["days"], 5) - - def test_boolean_parameter_conversion(self): - """ - Test correct type conversion for boolean parameters. - - Scenario: Tool call with boolean parameter. - Purpose: Verify boolean values are correctly parsed. - """ - text = """ - -SELECT 1 -True - -""" - result = self.detector.detect_and_parse(text, self.tools) - - params = json.loads(result.calls[0].parameters) - self.assertIsInstance(params["dry_run"], bool) - self.assertEqual(params["dry_run"], True) - - def test_complex_array_parameter(self): - """ - Test parsing of complex array parameters. - - Scenario: Tool call with array of objects as parameter. - Purpose: Verify complex nested structures are correctly parsed. - """ - text = """ - - -[ - {"content": "Buy groceries", "status": "pending"}, - {"content": "Finish report", "status": "completed"} -] - - -""" - result = self.detector.detect_and_parse(text, self.tools) - - params = json.loads(result.calls[0].parameters) - self.assertIsInstance(params["todos"], list) - self.assertEqual(len(params["todos"]), 2) - self.assertEqual(params["todos"][0]["content"], "Buy groceries") - self.assertEqual(params["todos"][1]["status"], "completed") - - # ==================== Edge Cases ==================== - - def test_empty_parameter_value(self): - """ - Test handling of empty parameter values. - - Scenario: Tool call with empty parameter value. - Purpose: Verify empty values are handled gracefully. - """ - text = """ - - - -""" - result = self.detector.detect_and_parse(text, self.tools) - - self.assertEqual(len(result.calls), 1) - params = json.loads(result.calls[0].parameters) - self.assertEqual(params["location"], "") - - def test_parameter_with_special_characters(self): - """ - Test handling of parameters with special characters. - - Scenario: Parameter value contains special characters like quotes, newlines. - Purpose: Verify special characters are correctly preserved. - """ - text = """ - -SELECT * FROM users WHERE name = 'John "Doe"' - -""" - result = self.detector.detect_and_parse(text, self.tools) - - params = json.loads(result.calls[0].parameters) - self.assertIn("John", params["query"]) - self.assertIn("Doe", params["query"]) - - def test_incomplete_tool_call(self): - """ - Test handling of incomplete tool call at end of stream. - - Scenario: Stream ends with an incomplete tool call (missing closing tag). - Purpose: Verify detector handles incomplete input gracefully without crashing. - """ - text = """ - -London""" - - # Should not crash - result = self.detector.detect_and_parse(text, self.tools) - self.assertIsInstance(result, StreamingParseResult) - - def test_has_tool_call_detection(self): - """ - Test the has_tool_call method for detecting tool call markers. - - Scenario: Various inputs with and without tool call markers. - Purpose: Verify correct detection of tool call presence. - """ - self.assertTrue(self.detector.has_tool_call("")) - self.assertTrue(self.detector.has_tool_call("text more")) - self.assertFalse(self.detector.has_tool_call("plain text only")) - self.assertFalse(self.detector.has_tool_call("")) - - -class TestGlm4MoeDetector(unittest.TestCase): - def setUp(self): - self.tools = [ - Tool( - type="function", - function=Function( - name="get_weather", - description="Get weather information", - parameters={ - "type": "object", - "properties": { - "city": {"type": "string", "description": "City name"}, - "date": {"type": "string", "description": "Date"}, - }, - "required": ["city", "date"], - }, - ), - ), - ] - self.detector = Glm4MoeDetector() - - def test_single_tool_call(self): - text = ( - "get_weather\n" - "city\nBeijing\n" - "date\n2024-06-27\n" - "" - ) - result = self.detector.detect_and_parse(text, self.tools) - self.assertEqual(len(result.calls), 1) - self.assertEqual(result.calls[0].name, "get_weather") - self.assertEqual( - result.calls[0].parameters, '{"city": "Beijing", "date": "2024-06-27"}' - ) - self.assertEqual(result.normal_text, "") - - def test_multiple_tool_calls(self): - text = ( - "get_weather\n" - "city\nBeijing\n" - "date\n2024-06-27\n" - "" - "get_weather\n" - "city\nShanghai\n" - "date\n2024-06-28\n" - "" - ) - result = self.detector.detect_and_parse(text, self.tools) - self.assertEqual(len(result.calls), 2) - self.assertEqual(result.calls[0].name, "get_weather") - self.assertEqual( - result.calls[0].parameters, '{"city": "Beijing", "date": "2024-06-27"}' - ) - self.assertEqual(result.calls[1].name, "get_weather") - self.assertEqual( - result.calls[1].parameters, '{"city": "Shanghai", "date": "2024-06-28"}' - ) - self.assertEqual(result.normal_text, "") - - def test_streaming_tool_call(self): - """Test streaming incremental parsing of a tool call.""" - chunks = [ - "get_weather\n", - "city\nBeijing\n", - "date\n2024-06-27\n", - "", - ] - tool_calls = [] - for chunk in chunks: - result = self.detector.parse_streaming_increment(chunk, self.tools) - for tool_call_chunk in result.calls: - if ( - hasattr(tool_call_chunk, "tool_index") - and tool_call_chunk.tool_index is not None - ): - while len(tool_calls) <= tool_call_chunk.tool_index: - tool_calls.append({"name": "", "parameters": ""}) - tc = tool_calls[tool_call_chunk.tool_index] - if tool_call_chunk.name: - tc["name"] = tool_call_chunk.name - if tool_call_chunk.parameters: - tc["parameters"] += tool_call_chunk.parameters - self.assertEqual(len(tool_calls), 1) - self.assertEqual(tool_calls[0]["name"], "get_weather") - self.assertEqual( - tool_calls[0]["parameters"], '{"city": "Beijing", "date": "2024-06-27"}' - ) - - def test_streaming_multiple_tool_calls(self): - """Test streaming incremental parsing of multiple tool calls.""" - chunks = [ - "get_weather\n", - "city\nBeijing\n", - "date\n2024-06-27\n", - "get_weather\n", - "city\nShanghai\n", - "date\n2024-06-28\n", - "", - ] - tool_calls = [] - for chunk in chunks: - result = self.detector.parse_streaming_increment(chunk, self.tools) - for tool_call_chunk in result.calls: - if ( - hasattr(tool_call_chunk, "tool_index") - and tool_call_chunk.tool_index is not None - ): - while len(tool_calls) <= tool_call_chunk.tool_index: - tool_calls.append({"name": "", "parameters": ""}) - tc = tool_calls[tool_call_chunk.tool_index] - if tool_call_chunk.name: - tc["name"] = tool_call_chunk.name - if tool_call_chunk.parameters: - tc["parameters"] += tool_call_chunk.parameters - self.assertEqual(len(tool_calls), 2) - self.assertEqual(tool_calls[0]["name"], "get_weather") - self.assertEqual( - tool_calls[0]["parameters"], '{"city": "Beijing", "date": "2024-06-27"}' - ) - self.assertEqual(tool_calls[1]["name"], "get_weather") - self.assertEqual( - tool_calls[1]["parameters"], '{"city": "Shanghai", "date": "2024-06-28"}' - ) - - def test_tool_call_id(self): - """Test that the buffer and state are reset after a tool call is completed.""" - chunks = [ - "get_weather\n", - "city\nBeijing\n", - "date\n2024-06-27\n", - "", - ] - for chunk in chunks: - result = self.detector.parse_streaming_increment(chunk, self.tools) - self.assertEqual(self.detector.current_tool_id, 1) - - def test_invalid_tool_call(self): - """Test that invalid tool calls are handled correctly.""" - text = "invalid_func\ncity\nBeijing\n" - result = self.detector.detect_and_parse(text, self.tools) - self.assertEqual(len(result.calls), 0) - - def test_partial_tool_call(self): - """Test parsing a partial tool call that spans multiple chunks.""" - chunks = [ - "get_weather\n", - "city\nBeijing\n", - "date\n2024-06-27\n", - ] - - tool_calls = [] - for chunk in chunks: - result = self.detector.parse_streaming_increment(chunk, self.tools) - for tool_call_chunk in result.calls: - if ( - hasattr(tool_call_chunk, "tool_index") - and tool_call_chunk.tool_index is not None - ): - while len(tool_calls) <= tool_call_chunk.tool_index: - tool_calls.append({"name": "", "parameters": ""}) - tc = tool_calls[tool_call_chunk.tool_index] - if tool_call_chunk.name: - tc["name"] = tool_call_chunk.name - if tool_call_chunk.parameters: - tc["parameters"] += tool_call_chunk.parameters - - self.assertEqual(len(tool_calls), 1) - self.assertEqual(tool_calls[0]["name"], "get_weather") - self.assertEqual( - tool_calls[0]["parameters"], '{"city": "Beijing", "date": "2024-06-27"}' - ) - - def test_array_argument_with_escaped_json(self): - """Test that array arguments with escaped JSON are properly handled without double-escaping.""" - # Add a tool with array parameter - tools_with_array = [ - Tool( - type="function", - function=Function( - name="todo_write", - description="Write todos", - parameters={ - "type": "object", - "properties": { - "todos": { - "type": "array", - "description": "The updated todo list", - } - }, - "required": ["todos"], - }, - ), - ), - ] - - def check_params(result): - self.assertEqual(1, len(result.calls)) - self.assertEqual("todo_write", result.calls[0].name) - params = json.loads(result.calls[0].parameters) - self.assertIsInstance(params["todos"], list) - self.assertEqual(4, len(params["todos"])) - self.assertEqual("1", params["todos"][0]["id"]) - self.assertEqual( - "Check for hard-coded issues in the backend code", - params["todos"][0]["task"], - ) - self.assertEqual("in_progress", params["todos"][0]["status"]) - self.assertEqual("2", params["todos"][1]["id"]) - self.assertEqual( - "Check for hard-coded issues in the frontend code", - params["todos"][1]["task"], - ) - self.assertEqual("pending", params["todos"][1]["status"]) - self.assertEqual("3", params["todos"][2]["id"]) - self.assertEqual( - "Check for code violating the Single Responsibility Principle", - params["todos"][2]["task"], - ) - self.assertEqual("pending", params["todos"][2]["status"]) - self.assertEqual("4", params["todos"][3]["id"]) - self.assertEqual( - "Generate a rectification proposal report", params["todos"][3]["task"] - ) - self.assertEqual("pending", params["todos"][3]["status"]) - - # Simulate the raw response from GLM-4.6 model with normal and escaped JSON in XML - result = self.detector.detect_and_parse( - """todo_write\ntodos\n[{\"id\": \"1\", \"task\": \"Check for hard-coded issues in the backend code\", \"status\": \"in_progress\"}, {\"id\": \"2\", \"task\": \"Check for hard-coded issues in the frontend code\", \"status\": \"pending\"}, {\"id\": \"3\", \"task\": \"Check for code violating the Single Responsibility Principle\", \"status\": \"pending\"}, {\"id\": \"4\", \"task\": \"Generate a rectification proposal report\", \"status\": \"pending\"}] -""", - tools_with_array, - ) - check_params(result) - result = self.detector.detect_and_parse( - r"""todo_write\ntodos\n[{\"id\": \"1\", \"task\": \"Check for hard-coded issues in the backend code\", \"status\": \"in_progress\"}, {\"id\": \"2\", \"task\": \"Check for hard-coded issues in the frontend code\", \"status\": \"pending\"}, {\"id\": \"3\", \"task\": \"Check for code violating the Single Responsibility Principle\", \"status\": \"pending\"}, {\"id\": \"4\", \"task\": \"Generate a rectification proposal report\", \"status\": \"pending\"}] -""", - tools_with_array, - ) - check_params(result) - - def check_single_todos(tool_result, expected): - self.assertEqual(1, len(tool_result.calls)) - self.assertEqual("todo_write", tool_result.calls[0].name) - params = json.loads(tool_result.calls[0].parameters) - self.assertIsInstance(params["todos"], list) - self.assertEqual(1, len(params["todos"])) - self.assertEqual("1", params["todos"][0]["id"]) - self.assertEqual(expected, params["todos"][0]["task"]) - self.assertEqual("pending", params["todos"][0]["status"]) - - # Test with escaped JSON containing backslashes in content (e.g., Windows paths) - expected_path = r"Check file at C:\Users\test.txt" - result = self.detector.detect_and_parse( - """todo_write\ntodos\n[{\"id\": \"1\", \"task\": \"Check file at C:\\\\Users\\\\test.txt\", \"status\": \"pending\"}]""", - tools_with_array, - ) - check_single_todos(result, expected_path) - result = self.detector.detect_and_parse( - r"""todo_write\ntodos\n[{\"id\": \"1\", \"task\": \"Check file at C:\\\\Users\\\\test.txt\", \"status\": \"pending\"}]""", - tools_with_array, - ) - check_single_todos(result, expected_path) - - # Should contain literal \n, not actual newline - expected_output = r"Print \n to see newline" - result = self.detector.detect_and_parse( - """todo_write\ntodos\n[{\"id\": \"1\", \"task\": \"Print \\\\n to see newline\",\"status\": \"pending\"}]""", - tools_with_array, - ) - check_single_todos(result, expected_output) - result = self.detector.detect_and_parse( - r"""todo_write\ntodos\n[{\"id\": \"1\", \"task\": \"Print \\\\n to see newline\",\"status\": \"pending\"}]""", - tools_with_array, - ) - check_single_todos(result, expected_output) - - def test_empty_function_name_handling(self): - """Test that empty function name is handled gracefully without assertion error.""" - # This test simulates the issue where the model outputs only the start token without a function name - chunks = [ - "", # Start token only, no function name yet - "\n", # More content without function name - ] - - for chunk in chunks: - # Should not raise AssertionError: func_name should not be empty - result = self.detector.parse_streaming_increment(chunk, self.tools) - # Should return empty calls without error - self.assertIsInstance(result, StreamingParseResult) - self.assertEqual(result.calls, []) - - -class TestGlm47MoeDetector(unittest.TestCase): - def setUp(self): - self.tools = [ - Tool( - type="function", - function=Function( - name="get_weather", - description="Get weather information", - parameters={ - "type": "object", - "properties": { - "city": {"type": "string", "description": "City name"}, - "date": {"type": "string", "description": "Date"}, - }, - "required": ["city", "date"], - }, - ), - ), - ] - self.detector = Glm47MoeDetector() - - def test_single_tool_call(self): - text = ( - "get_weather" - "cityBeijing" - "date2024-06-27" - "" - ) - result = self.detector.detect_and_parse(text, self.tools) - self.assertEqual(len(result.calls), 1) - self.assertEqual(result.calls[0].name, "get_weather") - self.assertEqual( - result.calls[0].parameters, '{"city": "Beijing", "date": "2024-06-27"}' - ) - self.assertEqual(result.normal_text, "") - - def test_multiple_tool_calls(self): - text = ( - "get_weather" - "cityBeijing" - "date2024-06-27" - "" - "get_weather" - "cityShanghai" - "date2024-06-28" - "" - ) - result = self.detector.detect_and_parse(text, self.tools) - self.assertEqual(len(result.calls), 2) - self.assertEqual(result.calls[0].name, "get_weather") - self.assertEqual( - result.calls[0].parameters, '{"city": "Beijing", "date": "2024-06-27"}' - ) - self.assertEqual(result.calls[1].name, "get_weather") - self.assertEqual( - result.calls[1].parameters, '{"city": "Shanghai", "date": "2024-06-28"}' - ) - self.assertEqual(result.normal_text, "") - - def test_streaming_tool_call(self): - """Test streaming incremental parsing of a tool call.""" - chunks = [ - "get_weather", - "cityBeijing", - "date2024-06-27", - "", - ] - tool_calls = [] - for chunk in chunks: - result = self.detector.parse_streaming_increment(chunk, self.tools) - for tool_call_chunk in result.calls: - if ( - hasattr(tool_call_chunk, "tool_index") - and tool_call_chunk.tool_index is not None - ): - while len(tool_calls) <= tool_call_chunk.tool_index: - tool_calls.append({"name": "", "parameters": ""}) - tc = tool_calls[tool_call_chunk.tool_index] - if tool_call_chunk.name: - tc["name"] = tool_call_chunk.name - if tool_call_chunk.parameters: - tc["parameters"] += tool_call_chunk.parameters - self.assertEqual(len(tool_calls), 1) - self.assertEqual(tool_calls[0]["name"], "get_weather") - self.assertEqual( - tool_calls[0]["parameters"], '{"city": "Beijing", "date": "2024-06-27"}' - ) - - def test_streaming_multiple_tool_calls(self): - """Test streaming incremental parsing of multiple tool calls.""" - chunks = [ - "get_weather", - "cityBeijing", - "date2024-06-27", - "get_weather", - "cityShanghai", - "date2024-06-28", - "", - ] - tool_calls = [] - for chunk in chunks: - result = self.detector.parse_streaming_increment(chunk, self.tools) - for tool_call_chunk in result.calls: - if ( - hasattr(tool_call_chunk, "tool_index") - and tool_call_chunk.tool_index is not None - ): - while len(tool_calls) <= tool_call_chunk.tool_index: - tool_calls.append({"name": "", "parameters": ""}) - tc = tool_calls[tool_call_chunk.tool_index] - if tool_call_chunk.name: - tc["name"] = tool_call_chunk.name - if tool_call_chunk.parameters: - tc["parameters"] += tool_call_chunk.parameters - self.assertEqual(len(tool_calls), 2) - self.assertEqual(tool_calls[0]["name"], "get_weather") - self.assertEqual( - tool_calls[0]["parameters"], '{"city": "Beijing", "date": "2024-06-27"}' - ) - self.assertEqual(tool_calls[1]["name"], "get_weather") - self.assertEqual( - tool_calls[1]["parameters"], '{"city": "Shanghai", "date": "2024-06-28"}' - ) - - def test_tool_call_id(self): - """Test that the buffer and state are reset after a tool call is completed.""" - chunks = [ - "get_weather", - "cityBeijing", - "date2024-06-27", - "", - ] - for chunk in chunks: - result = self.detector.parse_streaming_increment(chunk, self.tools) - self.assertEqual(self.detector.current_tool_id, 1) - - def test_invalid_tool_call(self): - """Test that invalid tool calls are handled correctly.""" - text = "invalid_funccityBeijing" - result = self.detector.detect_and_parse(text, self.tools) - self.assertEqual(len(result.calls), 0) - - def test_partial_tool_call(self): - """Test parsing a partial tool call that spans multiple chunks.""" - chunks = [ - "get_weather", - "cityBeijing", - "date2024-06-27", - ] - - tool_calls = [] - for chunk in chunks: - result = self.detector.parse_streaming_increment(chunk, self.tools) - for tool_call_chunk in result.calls: - if ( - hasattr(tool_call_chunk, "tool_index") - and tool_call_chunk.tool_index is not None - ): - while len(tool_calls) <= tool_call_chunk.tool_index: - tool_calls.append({"name": "", "parameters": ""}) - tc = tool_calls[tool_call_chunk.tool_index] - if tool_call_chunk.name: - tc["name"] = tool_call_chunk.name - if tool_call_chunk.parameters: - tc["parameters"] += tool_call_chunk.parameters - - self.assertEqual(len(tool_calls), 1) - self.assertEqual(tool_calls[0]["name"], "get_weather") - self.assertEqual( - tool_calls[0]["parameters"], '{"city": "Beijing", "date": "2024-06-27"}' - ) - - def test_array_argument_with_escaped_json(self): - """Test that array arguments with escaped JSON are properly handled without double-escaping.""" - # Add a tool with array parameter - tools_with_array = [ - Tool( - type="function", - function=Function( - name="todo_write", - description="Write todos", - parameters={ - "type": "object", - "properties": { - "todos": { - "type": "array", - "description": "The updated todo list", - } - }, - "required": ["todos"], - }, - ), - ), - ] - - def check_params(result): - self.assertEqual(1, len(result.calls)) - self.assertEqual("todo_write", result.calls[0].name) - params = json.loads(result.calls[0].parameters) - self.assertIsInstance(params["todos"], list) - self.assertEqual(4, len(params["todos"])) - self.assertEqual("1", params["todos"][0]["id"]) - self.assertEqual( - "Check for hard-coded issues in the backend code", - params["todos"][0]["task"], - ) - self.assertEqual("in_progress", params["todos"][0]["status"]) - self.assertEqual("2", params["todos"][1]["id"]) - self.assertEqual( - "Check for hard-coded issues in the frontend code", - params["todos"][1]["task"], - ) - self.assertEqual("pending", params["todos"][1]["status"]) - self.assertEqual("3", params["todos"][2]["id"]) - self.assertEqual( - "Check for code violating the Single Responsibility Principle", - params["todos"][2]["task"], - ) - self.assertEqual("pending", params["todos"][2]["status"]) - self.assertEqual("4", params["todos"][3]["id"]) - self.assertEqual( - "Generate a rectification proposal report", params["todos"][3]["task"] - ) - self.assertEqual("pending", params["todos"][3]["status"]) - - # Simulate the raw response from GLM-4.6 model with normal and escaped JSON in XML - result = self.detector.detect_and_parse( - """todo_writetodos[{\"id\": \"1\", \"task\": \"Check for hard-coded issues in the backend code\", \"status\": \"in_progress\"}, {\"id\": \"2\", \"task\": \"Check for hard-coded issues in the frontend code\", \"status\": \"pending\"}, {\"id\": \"3\", \"task\": \"Check for code violating the Single Responsibility Principle\", \"status\": \"pending\"}, {\"id\": \"4\", \"task\": \"Generate a rectification proposal report\", \"status\": \"pending\"}] -""", - tools_with_array, - ) - check_params(result) - result = self.detector.detect_and_parse( - r"""todo_writetodos[{\"id\": \"1\", \"task\": \"Check for hard-coded issues in the backend code\", \"status\": \"in_progress\"}, {\"id\": \"2\", \"task\": \"Check for hard-coded issues in the frontend code\", \"status\": \"pending\"}, {\"id\": \"3\", \"task\": \"Check for code violating the Single Responsibility Principle\", \"status\": \"pending\"}, {\"id\": \"4\", \"task\": \"Generate a rectification proposal report\", \"status\": \"pending\"}] -""", - tools_with_array, - ) - check_params(result) - - def check_single_todos(tool_result, expected): - self.assertEqual(1, len(tool_result.calls)) - self.assertEqual("todo_write", tool_result.calls[0].name) - params = json.loads(tool_result.calls[0].parameters) - self.assertIsInstance(params["todos"], list) - self.assertEqual(1, len(params["todos"])) - self.assertEqual("1", params["todos"][0]["id"]) - self.assertEqual(expected, params["todos"][0]["task"]) - self.assertEqual("pending", params["todos"][0]["status"]) - - # Test with escaped JSON containing backslashes in content (e.g., Windows paths) - expected_path = r"Check file at C:\Users\test.txt" - result = self.detector.detect_and_parse( - """todo_writetodos[{\"id\": \"1\", \"task\": \"Check file at C:\\\\Users\\\\test.txt\", \"status\": \"pending\"}]""", - tools_with_array, - ) - check_single_todos(result, expected_path) - result = self.detector.detect_and_parse( - r"""todo_writetodos[{\"id\": \"1\", \"task\": \"Check file at C:\\\\Users\\\\test.txt\", \"status\": \"pending\"}]""", - tools_with_array, - ) - check_single_todos(result, expected_path) - - # Should contain literal \n, not actual newline - expected_output = r"Print \n to see newline" - result = self.detector.detect_and_parse( - """todo_writetodos[{\"id\": \"1\", \"task\": \"Print \\\\n to see newline\",\"status\": \"pending\"}]""", - tools_with_array, - ) - check_single_todos(result, expected_output) - result = self.detector.detect_and_parse( - r"""todo_writetodos[{\"id\": \"1\", \"task\": \"Print \\\\n to see newline\",\"status\": \"pending\"}]""", - tools_with_array, - ) - check_single_todos(result, expected_output) - - -class TestJsonArrayParser(unittest.TestCase): - def setUp(self): - # Create sample tools for testing - self.tools = [ - Tool( - type="function", - function=Function( - name="get_weather", - description="Get weather information", - parameters={ - "properties": { - "location": { - "type": "string", - "description": "Location to get weather for", - }, - "unit": { - "type": "string", - "description": "Temperature unit", - "enum": ["celsius", "fahrenheit"], - }, - }, - "required": ["location"], - }, - ), - ), - Tool( - type="function", - function=Function( - name="search", - description="Search for information", - parameters={ - "properties": { - "query": { - "type": "string", - "description": "Search query", - }, - }, - "required": ["query"], - }, - ), - ), - ] - self.detector = JsonArrayParser() - - def test_json_detector_has_no_ebnf(self): - """JsonArrayParser no longer exposes EBNF generation helpers.""" - self.assertFalse( - hasattr(self.detector, "build_ebnf"), - "JsonArrayParser should not expose EBNF helpers after cleanup", - ) - - def test_parse_streaming_increment_malformed_json(self): - """Test parsing with malformed JSON""" - # Test with malformed JSON - text = '[{"name": "get_weather", "parameters": {"location": "Tokyo"' - result = self.detector.parse_streaming_increment(text, self.tools) - - # Should not crash and return a valid result - self.assertIsInstance(result, StreamingParseResult) - - text = "[{}}}]" - result = self.detector.parse_streaming_increment(text, self.tools) - - self.assertIsInstance(result, StreamingParseResult) - - def test_parse_streaming_increment_empty_input(self): - """Test parsing with empty input""" - result = self.detector.parse_streaming_increment("", self.tools) - self.assertEqual(len(result.calls), 0) - self.assertEqual(result.normal_text, "") - - def test_parse_streaming_increment_whitespace_handling(self): - """Test parsing with various whitespace scenarios""" - # Test with leading/trailing whitespace split across chunks - chunk1 = ' [{"name": "get_weather", "parameters": ' - result1 = self.detector.parse_streaming_increment(chunk1, self.tools) - self.assertIsInstance(result1, StreamingParseResult) - chunk2 = '{"location": "Tokyo"}}] ' - result2 = self.detector.parse_streaming_increment(chunk2, self.tools) - - # The base class should handle this - self.assertIsInstance(result2, StreamingParseResult) - - def test_parse_streaming_increment_nested_objects(self): - """Test parsing with nested JSON objects""" - chunk1 = '[{"name": "get_weather", "parameters": {"location": "Tokyo", ' - result1 = self.detector.parse_streaming_increment(chunk1, self.tools) - self.assertIsInstance(result1, StreamingParseResult) - chunk2 = '"nested": {"key": "value"}}}]' - result2 = self.detector.parse_streaming_increment(chunk2, self.tools) - - # The base class should handle this - self.assertIsInstance(result2, StreamingParseResult) - - def test_json_parsing_with_commas(self): - """Test that JSON parsing works correctly with comma separators""" - # Stream two complete objects, at least 2 chunks per tool call - chunk1 = '[{"name": "get_weather", "parameters": {"location": "Tok' - result1 = self.detector.parse_streaming_increment(chunk1, self.tools) - self.assertIsInstance(result1, StreamingParseResult) - chunk2 = 'yo"}},' - result2 = self.detector.parse_streaming_increment(chunk2, self.tools) - self.assertIsInstance(result2, StreamingParseResult) - - chunk3 = '{"name": "get_weather", "parameters": {"location": "Par' - result3 = self.detector.parse_streaming_increment(chunk3, self.tools) - self.assertIsInstance(result3, StreamingParseResult) - chunk4 = 'is"}}]' - result4 = self.detector.parse_streaming_increment(chunk4, self.tools) - self.assertIsInstance(result4, StreamingParseResult) - self.assertGreater( - len(result4.calls), 0, "Should parse tool calls from text with separators" - ) - - def test_braces_in_strings(self): - """Test that JSON with } characters inside strings works correctly""" - # Test case: JSON array with } inside string values - streamed across chunks - chunk1 = '[{"name": "get_weather", "parameters": {"location": "has } inside"' - result1 = self.detector.parse_streaming_increment(chunk1, self.tools) - self.assertIsInstance(result1, StreamingParseResult) - chunk2 = "}}" - result2 = self.detector.parse_streaming_increment(chunk2, self.tools) - self.assertIsInstance(result2, StreamingParseResult) - self.assertGreater( - len(result2.calls), 0, "Should parse tool call with } in string" - ) - - # Test with separator (streaming in progress) - chunk3 = '[{"name": "get_weather", "parameters": {"location": "has } inside"}' - result3 = self.detector.parse_streaming_increment(chunk3, self.tools) - self.assertIsInstance(result3, StreamingParseResult) - chunk4 = "}," - result4 = self.detector.parse_streaming_increment(chunk4, self.tools) - self.assertIsInstance(result4, StreamingParseResult) - chunk5 = '{"name": "get_weather"' - result5 = self.detector.parse_streaming_increment(chunk5, self.tools) - self.assertIsInstance(result5, StreamingParseResult) - self.assertGreater( - len(result5.calls), - 0, - "Should parse tool calls with separator and } in string", - ) - - def test_separator_in_same_chunk(self): - """Test that separator already present in chunk works correctly""" - # Test case: separator already in the chunk (streaming in progress) with 2+ chunks per tool call - chunk1 = '[{"name": "get_weather", "parameters": {"location": "Tokyo"' - result1 = self.detector.parse_streaming_increment(chunk1, self.tools) - self.assertIsInstance(result1, StreamingParseResult) - chunk2 = '}},{"name": "get_weather"' - result2 = self.detector.parse_streaming_increment(chunk2, self.tools) - self.assertIsInstance(result2, StreamingParseResult) - self.assertGreater( - len(result2.calls), - 0, - "Should parse tool calls with separator in same chunk", - ) - - def test_separator_in_separate_chunk(self): - """Test that separator in separate chunk works correctly""" - # Test case: separator in separate chunk - this tests streaming behavior - chunk1 = '[{"name": "get_weather", "parameters": {"location": "Tokyo"}}' - chunk2 = "," - chunk3 = '{"name": "get_weather", "parameters": {"location": "Paris"}}' - - # Process first chunk - result1 = self.detector.parse_streaming_increment(chunk1, self.tools) - self.assertIsInstance(result1, StreamingParseResult) - - # Process separator chunk - result2 = self.detector.parse_streaming_increment(chunk2, self.tools) - self.assertIsInstance(result2, StreamingParseResult) - - # Process second chunk (streaming in progress) - result3 = self.detector.parse_streaming_increment(chunk3, self.tools) - self.assertIsInstance(result3, StreamingParseResult) - - def test_incomplete_json_across_chunks(self): - """Test that incomplete JSON across chunks works correctly""" - # Test case: incomplete JSON across chunks - this tests streaming behavior - chunk1 = '[{"name": "get_weather", "parameters": {"location": "Tokyo"' - chunk2 = '}},{"name": "get_weather"' - - # Process first chunk (incomplete) - result1 = self.detector.parse_streaming_increment(chunk1, self.tools) - self.assertIsInstance(result1, StreamingParseResult) - - # Process second chunk (completes first object and starts second, streaming in progress) - result2 = self.detector.parse_streaming_increment(chunk2, self.tools) - self.assertIsInstance(result2, StreamingParseResult) - - def test_malformed_json_recovery(self): - """Test that malformed JSON recovers gracefully""" - # Test with malformed JSON - should handle gracefully - malformed_text = ( - '[{"name": "get_weather", "parameters": {"location": "unclosed string' - ) - - result1 = self.detector.parse_streaming_increment(malformed_text, self.tools) - self.assertIsInstance(result1, StreamingParseResult) - - # Test valid JSON after malformed - streamed across 2 chunks (streaming in progress) - valid_chunk1 = '[{"name": "get_weather", "parameters": {"location": "Tok' - result2 = self.detector.parse_streaming_increment(valid_chunk1, self.tools) - self.assertIsInstance(result2, StreamingParseResult) - valid_chunk2 = 'yo"}}' - result3 = self.detector.parse_streaming_increment(valid_chunk2, self.tools) - self.assertIsInstance(result3, StreamingParseResult) - - def test_nested_objects_with_commas(self): - """Test that nested objects with commas inside work correctly""" - # Test with nested objects that have commas - should work with json.loads() - chunk1 = '[{"name": "get_weather", "parameters": {"location": "Tok' - result1 = self.detector.parse_streaming_increment(chunk1, self.tools) - self.assertIsInstance(result1, StreamingParseResult) - chunk2 = 'yo", "unit": "celsius"}}' - result2 = self.detector.parse_streaming_increment(chunk2, self.tools) - self.assertIsInstance(result2, StreamingParseResult) - self.assertGreater( - len(result2.calls), 0, "Should parse tool call with nested objects" - ) - - def test_empty_objects(self): - """Test that empty objects work correctly""" - # Test with empty objects - should work with json.loads() - chunk1 = '[{"name": "get_weather", "parameters": ' - result1 = self.detector.parse_streaming_increment(chunk1, self.tools) - self.assertIsInstance(result1, StreamingParseResult) - chunk2 = "{}}" - result2 = self.detector.parse_streaming_increment(chunk2, self.tools) - self.assertIsInstance(result2, StreamingParseResult) - - def test_whitespace_handling(self): - """Test that various whitespace scenarios work correctly""" - # Test with various whitespace patterns - should work with json.loads() - chunk1 = ' \n\n [{"name": "get_weather", "parameters": ' - result1 = self.detector.parse_streaming_increment(chunk1, self.tools) - self.assertIsInstance(result1, StreamingParseResult) - chunk2 = '{"location": "Tokyo"}}' - result2 = self.detector.parse_streaming_increment(chunk2, self.tools) - self.assertIsInstance(result2, StreamingParseResult) - - def test_multiple_commas_in_chunk(self): - """Test that multiple commas in a single chunk work correctly""" - # Stream multiple tool calls ensuring at least 2 chunks per complete tool call - chunk1 = '[{"name": "get_weather", "parameters": {"location": "To' - result1 = self.detector.parse_streaming_increment(chunk1, self.tools) - self.assertIsInstance(result1, StreamingParseResult) - chunk2 = 'kyo"}},' - result2 = self.detector.parse_streaming_increment(chunk2, self.tools) - self.assertIsInstance(result2, StreamingParseResult) - - chunk3 = '{"name": "get_weather", "parameters": {"location": "Pa' - result3 = self.detector.parse_streaming_increment(chunk3, self.tools) - self.assertIsInstance(result3, StreamingParseResult) - chunk4 = 'ris"}},' - result4 = self.detector.parse_streaming_increment(chunk4, self.tools) - self.assertIsInstance(result4, StreamingParseResult) - - chunk5 = '{"name": "get_weather"' - result5 = self.detector.parse_streaming_increment(chunk5, self.tools) - self.assertIsInstance(result5, StreamingParseResult) - self.assertGreater( - len(result5.calls), 0, "Should parse tool calls with multiple commas" - ) - - def test_complete_tool_call_with_trailing_comma(self): - """Test that complete tool call with trailing comma parses correctly""" - # Test case: complete tool call followed by comma at end of chunk (split across 2 chunks) - chunk1 = '[{"name": "get_weather", "parameters": {"location": "Tokyo"}' - result1 = self.detector.parse_streaming_increment(chunk1, self.tools) - self.assertIsInstance(result1, StreamingParseResult) - chunk2 = "}, " - result2 = self.detector.parse_streaming_increment(chunk2, self.tools) - self.assertIsInstance(result2, StreamingParseResult) - self.assertGreater(len(result2.calls), 0, "Should parse complete tool call") - - # Test that next chunk with opening brace gets the separator prepended - next_chunk = '{"name": "get_weather", "parameters": {"location": "Paris"}}' - result_next = self.detector.parse_streaming_increment(next_chunk, self.tools) - self.assertIsInstance(result_next, StreamingParseResult) - self.assertGreater( - len(result_next.calls), 0, "Should parse subsequent tool call" - ) - - def test_three_tool_calls_separate_chunks_with_commas(self): - """Test parsing 3 tool calls in separate chunks with commas at the end""" - # First tool call: 2 chunks - chunk1_1 = '[{"name": "get_weather", "parameters": ' - result1_1 = self.detector.parse_streaming_increment(chunk1_1, self.tools) - chunk1_2 = '{"location": "Tokyo"}},' - result1_2 = self.detector.parse_streaming_increment(chunk1_2, self.tools) - self.assertIsInstance(result1_2, StreamingParseResult) - self.assertGreater(len(result1_2.calls), 0, "Should parse first tool call") - - # Second tool call: 2 chunks - chunk2_1 = '{"name": "search", "parameters": ' - result2_1 = self.detector.parse_streaming_increment(chunk2_1, self.tools) - chunk2_2 = '{"query": "restaurants"}},' - result2_2 = self.detector.parse_streaming_increment(chunk2_2, self.tools) - self.assertIsInstance(result2_2, StreamingParseResult) - self.assertGreater(len(result2_2.calls), 0, "Should parse second tool call") - - # Third tool call: 2 chunks - chunk3_1 = '{"name": "get_weather", "parameters": ' - result3_1 = self.detector.parse_streaming_increment(chunk3_1, self.tools) - chunk3_2 = '{"location": "Paris"}}]' - result3_2 = self.detector.parse_streaming_increment(chunk3_2, self.tools) - self.assertIsInstance(result3_2, StreamingParseResult) - self.assertGreater(len(result3_2.calls), 0, "Should parse third tool call") - # Verify all tool calls were parsed correctly - total_calls = len(result1_2.calls) + len(result2_2.calls) + len(result3_2.calls) - self.assertEqual(total_calls, 3, "Should have parsed exactly 3 tool calls") - - -class TestLfm2Detector(unittest.TestCase): - """Tests for LFM2 (Liquid Foundation Model 2) function call detector.""" - - def setUp(self): - """Set up test tools and detector.""" - self.tools = [ - Tool( - type="function", - function=Function( - name="get_weather", - description="Get weather information", - parameters={ - "type": "object", - "properties": { - "city": { - "type": "string", - "description": "City name", - }, - "unit": { - "type": "string", - "description": "Temperature unit", - "enum": ["celsius", "fahrenheit"], - }, - }, - "required": ["city"], - }, - ), - ), - Tool( - type="function", - function=Function( - name="search", - description="Search for information", - parameters={ - "type": "object", - "properties": { - "query": { - "type": "string", - "description": "Search query", - }, - }, - "required": ["query"], - }, - ), - ), - Tool( - type="function", - function=Function( - name="calculator", - description="Perform calculations", - parameters={ - "type": "object", - "properties": { - "expression": { - "type": "string", - "description": "Math expression", - }, - }, - "required": ["expression"], - }, - ), - ), - ] - self.detector = Lfm2Detector() - - # ==================== has_tool_call tests ==================== - - def test_has_tool_call_true(self): - """Test detection of tool call markers.""" - text = '<|tool_call_start|>[get_weather(city="Paris")]<|tool_call_end|>' - self.assertTrue(self.detector.has_tool_call(text)) - - def test_has_tool_call_false(self): - """Test no false positives for regular text.""" - text = "The weather in Paris is nice today." - self.assertFalse(self.detector.has_tool_call(text)) - - def test_has_tool_call_partial_marker(self): - """Test that partial markers are detected (start token present).""" - text = '<|tool_call_start|>[get_weather(city="Paris")' - self.assertTrue(self.detector.has_tool_call(text)) - - # ==================== detect_and_parse tests (Pythonic format) ==================== - - def test_detect_and_parse_pythonic_simple(self): - """Test parsing a simple Pythonic format tool call.""" - text = '<|tool_call_start|>[get_weather(city="Paris")]<|tool_call_end|>' - result = self.detector.detect_and_parse(text, self.tools) - - self.assertEqual(len(result.calls), 1) - self.assertEqual(result.calls[0].name, "get_weather") - self.assertEqual(result.calls[0].tool_index, 0) - - params = json.loads(result.calls[0].parameters) - self.assertEqual(params["city"], "Paris") - - def test_detect_and_parse_pythonic_multiple_args(self): - """Test parsing with multiple arguments.""" - text = '<|tool_call_start|>[get_weather(city="London", unit="celsius")]<|tool_call_end|>' - result = self.detector.detect_and_parse(text, self.tools) - - self.assertEqual(len(result.calls), 1) - self.assertEqual(result.calls[0].name, "get_weather") - - params = json.loads(result.calls[0].parameters) - self.assertEqual(params["city"], "London") - self.assertEqual(params["unit"], "celsius") - - def test_detect_and_parse_pythonic_no_args(self): - """Test parsing function with no arguments.""" - # Add a no-arg tool for this test - tools_with_noarg = self.tools + [ - Tool( - type="function", - function=Function( - name="get_time", - description="Get current time", - parameters={"type": "object", "properties": {}}, - ), - ), - ] - text = "<|tool_call_start|>[get_time()]<|tool_call_end|>" - result = self.detector.detect_and_parse(text, tools_with_noarg) - - self.assertEqual(len(result.calls), 1) - self.assertEqual(result.calls[0].name, "get_time") - - def test_detect_and_parse_pythonic_multiple_calls(self): - """Test parsing multiple tool calls in one block.""" - text = '<|tool_call_start|>[get_weather(city="Paris"), search(query="restaurants")]<|tool_call_end|>' - result = self.detector.detect_and_parse(text, self.tools) - - self.assertEqual(len(result.calls), 2) - self.assertEqual(result.calls[0].name, "get_weather") - self.assertEqual(result.calls[1].name, "search") - - params1 = json.loads(result.calls[0].parameters) - params2 = json.loads(result.calls[1].parameters) - self.assertEqual(params1["city"], "Paris") - self.assertEqual(params2["query"], "restaurants") - - def test_detect_and_parse_with_normal_text_before(self): - """Test parsing with normal text before the tool call.""" - text = 'Let me check the weather for you. <|tool_call_start|>[get_weather(city="Tokyo")]<|tool_call_end|>' - result = self.detector.detect_and_parse(text, self.tools) - - self.assertEqual(result.normal_text, "Let me check the weather for you.") - self.assertEqual(len(result.calls), 1) - self.assertEqual(result.calls[0].name, "get_weather") - - def test_detect_and_parse_special_characters_in_value(self): - """Test parsing with special characters in argument values.""" - text = ( - '<|tool_call_start|>[search(query="what\'s the weather?")]<|tool_call_end|>' - ) - result = self.detector.detect_and_parse(text, self.tools) - - self.assertEqual(len(result.calls), 1) - params = json.loads(result.calls[0].parameters) - self.assertIn("weather", params["query"]) - - def test_detect_and_parse_numeric_values(self): - """Test parsing with numeric argument values.""" - text = '<|tool_call_start|>[calculator(expression="5 * 7")]<|tool_call_end|>' - result = self.detector.detect_and_parse(text, self.tools) - - self.assertEqual(len(result.calls), 1) - self.assertEqual(result.calls[0].name, "calculator") - - # ==================== detect_and_parse tests (JSON format) ==================== - - def test_detect_and_parse_json_simple(self): - """Test parsing JSON format tool call.""" - text = '<|tool_call_start|>[{"name": "get_weather", "arguments": {"city": "Berlin"}}]<|tool_call_end|>' - result = self.detector.detect_and_parse(text, self.tools) - - self.assertEqual(len(result.calls), 1) - self.assertEqual(result.calls[0].name, "get_weather") - - params = json.loads(result.calls[0].parameters) - self.assertEqual(params["city"], "Berlin") - - def test_detect_and_parse_json_multiple_calls(self): - """Test parsing multiple JSON format tool calls.""" - text = '<|tool_call_start|>[{"name": "get_weather", "arguments": {"city": "Paris"}}, {"name": "search", "arguments": {"query": "hotels"}}]<|tool_call_end|>' - result = self.detector.detect_and_parse(text, self.tools) - - self.assertEqual(len(result.calls), 2) - self.assertEqual(result.calls[0].name, "get_weather") - self.assertEqual(result.calls[1].name, "search") - - def test_detect_and_parse_json_with_parameters_key(self): - """Test parsing JSON format with 'parameters' key instead of 'arguments'.""" - text = '<|tool_call_start|>[{"name": "get_weather", "parameters": {"city": "Madrid"}}]<|tool_call_end|>' - result = self.detector.detect_and_parse(text, self.tools) - - self.assertEqual(len(result.calls), 1) - params = json.loads(result.calls[0].parameters) - self.assertEqual(params["city"], "Madrid") - - # ==================== Edge cases ==================== - - def test_detect_and_parse_no_tool_call(self): - """Test parsing text with no tool calls.""" - text = "This is just regular text without any tool calls." - result = self.detector.detect_and_parse(text, self.tools) - - self.assertEqual(result.normal_text, text) - self.assertEqual(result.calls, []) - - def test_detect_and_parse_unknown_function(self): - """Test parsing with unknown function name - skipped by default (SGLANG_FORWARD_UNKNOWN_TOOLS=false).""" - text = '<|tool_call_start|>[unknown_function(arg="value")]<|tool_call_end|>' - result = self.detector.detect_and_parse(text, self.tools) - - # By default, unknown functions are skipped (consistent with other detectors) - self.assertEqual(len(result.calls), 0) - - def test_detect_and_parse_empty_content(self): - """Test parsing with empty content between markers.""" - text = "<|tool_call_start|><|tool_call_end|>" - result = self.detector.detect_and_parse(text, self.tools) - - self.assertEqual(result.calls, []) - - def test_detect_and_parse_multiple_blocks(self): - """Test parsing multiple separate tool call blocks.""" - text = '<|tool_call_start|>[get_weather(city="Paris")]<|tool_call_end|> Some text <|tool_call_start|>[search(query="food")]<|tool_call_end|>' - result = self.detector.detect_and_parse(text, self.tools) - - self.assertEqual(len(result.calls), 2) - self.assertEqual(result.calls[0].name, "get_weather") - self.assertEqual(result.calls[1].name, "search") - - # ==================== Streaming tests ==================== - # The LFM2 detector buffers until it sees complete <|tool_call_start|>...<|tool_call_end|> - # blocks, then parses the complete block. This allows proper handling of both - # JSON and Pythonic formats. - - def test_streaming_json_complete_in_one_chunk(self): - """Test streaming with complete JSON tool call in one chunk.""" - text = '<|tool_call_start|>{"name": "get_weather", "arguments": {"city": "Rome"}}<|tool_call_end|>' - result = self.detector.parse_streaming_increment(text, self.tools) - - self.assertEqual(len(result.calls), 1) - self.assertEqual(result.calls[0].name, "get_weather") - - def test_streaming_json_split_across_chunks(self): - """Test streaming with JSON tool call split across multiple chunks - waits for complete block.""" - # Reset detector state - self.detector = Lfm2Detector() - - # First chunk: start marker and partial JSON (no end token) - chunk1 = '<|tool_call_start|>{"name": "get_weather", "arguments": {"city": ' - result1 = self.detector.parse_streaming_increment(chunk1, self.tools) - - # Should buffer and not emit calls yet (waiting for complete block) - self.assertEqual(len(result1.calls), 0) - self.assertEqual(result1.normal_text, "") - - # Second chunk: complete the JSON and end token - chunk2 = '"Vienna"}}<|tool_call_end|>' - result2 = self.detector.parse_streaming_increment(chunk2, self.tools) - - # Now should have the complete tool call - self.assertEqual(len(result2.calls), 1) - self.assertEqual(result2.calls[0].name, "get_weather") - - def test_streaming_json_normal_text_before_tool_call(self): - """Test streaming with normal text before JSON tool call.""" - # Reset detector state - self.detector = Lfm2Detector() - - chunk1 = "I'll check the weather. " - result1 = self.detector.parse_streaming_increment(chunk1, self.tools) - - # Normal text should be returned - self.assertIn("check the weather", result1.normal_text) - - chunk2 = '<|tool_call_start|>{"name": "get_weather", "arguments": {"city": "Amsterdam"}}<|tool_call_end|>' - result2 = self.detector.parse_streaming_increment(chunk2, self.tools) - - self.assertEqual(len(result2.calls), 1) - - def test_streaming_eot_token_filtering(self): - """Test that end-of-turn token is filtered from normal text.""" - # Reset detector state - self.detector = Lfm2Detector() - - # Send text that ends with tool call end token (JSON format) - text = '<|tool_call_start|>{"name": "get_weather", "arguments": {"city": "Oslo"}}<|tool_call_end|>' - result = self.detector.parse_streaming_increment(text, self.tools) - - # The normal_text should not contain the eot_token - self.assertNotIn("<|tool_call_end|>", result.normal_text) - - # ==================== Pythonic streaming tests ==================== - - def test_streaming_pythonic_complete_in_one_chunk(self): - """Test streaming with complete Pythonic tool call in one chunk.""" - self.detector = Lfm2Detector() - text = '<|tool_call_start|>[get_weather(city="Berlin")]<|tool_call_end|>' - result = self.detector.parse_streaming_increment(text, self.tools) - - self.assertEqual(len(result.calls), 1) - self.assertEqual(result.calls[0].name, "get_weather") - self.assertEqual(json.loads(result.calls[0].parameters), {"city": "Berlin"}) - - def test_streaming_pythonic_split_across_chunks(self): - """Test streaming with Pythonic tool call split across multiple chunks.""" - self.detector = Lfm2Detector() - - # First chunk: start marker and partial call - chunk1 = '<|tool_call_start|>[get_weather(city="' - result1 = self.detector.parse_streaming_increment(chunk1, self.tools) - - # Should buffer and not emit calls yet - self.assertEqual(len(result1.calls), 0) - - # Second chunk: complete the call - chunk2 = 'Munich")]<|tool_call_end|>' - result2 = self.detector.parse_streaming_increment(chunk2, self.tools) - - # Now should have the complete tool call - self.assertEqual(len(result2.calls), 1) - self.assertEqual(result2.calls[0].name, "get_weather") - self.assertEqual(json.loads(result2.calls[0].parameters), {"city": "Munich"}) - - def test_streaming_pythonic_multiple_calls(self): - """Test streaming with multiple Pythonic tool calls.""" - self.detector = Lfm2Detector() - - text = '<|tool_call_start|>[get_weather(city="Paris"), search(query="hotels")]<|tool_call_end|>' - result = self.detector.parse_streaming_increment(text, self.tools) - - self.assertEqual(len(result.calls), 2) - self.assertEqual(result.calls[0].name, "get_weather") - self.assertEqual(result.calls[1].name, "search") - - # ==================== structure_info tests ==================== - - def test_supports_structural_tag(self): - """Test that LFM2 does not support structural tags (Pythonic format).""" - # LFM2 uses Pythonic format which is not JSON-compatible, - # so structural_tag constrained generation cannot be used - self.assertFalse(self.detector.supports_structural_tag()) - - def test_structure_info(self): - """Test structure info for constrained generation.""" - info_func = self.detector.structure_info() - info = info_func("get_weather") - - self.assertEqual(info.begin, "<|tool_call_start|>[get_weather(") - self.assertEqual(info.end, ")]<|tool_call_end|>") - self.assertEqual(info.trigger, "<|tool_call_start|>") - - -class TestGigaChat3Detector(unittest.TestCase): - def setUp(self): - self.tools = [ - Tool( - type="function", - function=Function( - name="manage_user_memory", - description="Create, update, or delete a user memory entry.", - parameters={ - "type": "object", - "properties": { - "content": { - "anyOf": [{"type": "string"}, {"type": "null"}], - "default": None, - }, - "action": { - "type": "string", - "enum": ["create", "update", "delete"], - "default": "create", - }, - "id": { - "anyOf": [ - {"type": "string", "format": "uuid"}, - {"type": "null"}, - ], - "default": None, - }, - }, - }, - ), - ), - Tool( - type="function", - function=Function( - name="get_weather", - description="Get weather information", - parameters={ - "type": "object", - "properties": { - "city": {"type": "string", "description": "City name"}, - "unit": { - "type": "string", - "enum": ["celsius", "fahrenheit"], - }, - }, - "required": ["city"], - }, - ), - ), - ] - self.detector = GigaChat3Detector() - - def test_has_tool_call(self): - """Test detection of tool call markers.""" - self.assertTrue(self.detector.has_tool_call("function call<|role_sep|>\n{}")) - self.assertFalse(self.detector.has_tool_call("No tool call here")) - - def test_detect_and_parse_no_tool_call(self): - """Test parsing text without tool calls.""" - text = "How can I help you today?" - result = self.detector.detect_and_parse(text, self.tools) - - self.assertEqual(result.normal_text, text) - self.assertEqual(len(result.calls), 0) - - def test_detect_and_parse_simple_tool_call(self): - """Test parsing a simple tool call without content.""" - text = '<|message_sep|>\n\nfunction call<|role_sep|>\n{"name": "manage_user_memory", "arguments": {"action": "create", "id": "preferences"}}' - result = self.detector.detect_and_parse(text, self.tools) - - # No content before tool call - self.assertEqual(result.normal_text, "") - self.assertEqual(len(result.calls), 1) - self.assertEqual(result.calls[0].name, "manage_user_memory") - - params = json.loads(result.calls[0].parameters) - self.assertEqual(params["action"], "create") - self.assertEqual(params["id"], "preferences") - - def test_detect_and_parse_parameterless_tool_call(self): - """Test parsing a tool call with empty arguments.""" - text = '<|message_sep|>\n\nfunction call<|role_sep|>\n{"name": "manage_user_memory", "arguments": {}}' - result = self.detector.detect_and_parse(text, self.tools) - - self.assertEqual(result.normal_text, "") - self.assertEqual(len(result.calls), 1) - self.assertEqual(result.calls[0].name, "manage_user_memory") - - params = json.loads(result.calls[0].parameters) - self.assertEqual(params, {}) - - def test_detect_and_parse_complex_tool_call(self): - """Test parsing a tool call with nested objects.""" - text = """<|message_sep|> - -function call<|role_sep|> -{"name": "manage_user_memory", "arguments": {"action": "create", "id": "preferences", "content": {"short_answers": true, "hate_emojis": true, "english_ui": false, "russian_math_explanations": true}}}""" - - result = self.detector.detect_and_parse(text, self.tools) - - self.assertEqual(result.normal_text, "") - self.assertEqual(len(result.calls), 1) - self.assertEqual(result.calls[0].name, "manage_user_memory") - - params = json.loads(result.calls[0].parameters) - self.assertEqual(params["action"], "create") - self.assertEqual(params["id"], "preferences") - self.assertIsInstance(params["content"], dict) - self.assertEqual(params["content"]["short_answers"], True) - self.assertEqual(params["content"]["hate_emojis"], True) - - def test_detect_and_parse_with_content_before(self): - """Test parsing tool call with text content before it.""" - text = 'I\'ll check that for you.<|message_sep|>\n\nfunction call<|role_sep|>\n{"name": "manage_user_memory", "arguments": {"action": "create", "id": "preferences"}}' - result = self.detector.detect_and_parse(text, self.tools) - - self.assertEqual(result.normal_text, "I'll check that for you.") - self.assertEqual(len(result.calls), 1) - self.assertEqual(result.calls[0].name, "manage_user_memory") - - def test_detect_and_parse_with_eos_token(self): - """Test parsing tool call with EOS token at the end.""" - text = '<|message_sep|>\n\nfunction call<|role_sep|>\n{"name": "manage_user_memory", "arguments": {"action": "create", "id": "preferences"}}' - result = self.detector.detect_and_parse(text, self.tools) - - self.assertEqual(result.normal_text, "") - self.assertEqual(len(result.calls), 1) - self.assertEqual(result.calls[0].name, "manage_user_memory") - - params = json.loads(result.calls[0].parameters) - self.assertEqual(params["action"], "create") - self.assertEqual(params["id"], "preferences") - - def test_detect_and_parse_with_content_and_eos(self): - """Test parsing tool call with content and EOS token.""" - text = 'I\'ll remember that.<|message_sep|>\n\nfunction call<|role_sep|>\n{"name": "manage_user_memory", "arguments": {"action": "create", "id": "test"}}' - result = self.detector.detect_and_parse(text, self.tools) - - self.assertEqual(result.normal_text, "I'll remember that.") - self.assertEqual(len(result.calls), 1) - self.assertEqual(result.calls[0].name, "manage_user_memory") - - def test_detect_and_parse_invalid_json(self): - """Test parsing with invalid JSON in function call.""" - text = '<|message_sep|>\n\nfunction call<|role_sep|>\n{"name": "manage_user_memory", "arguments": {invalid json}}' - result = self.detector.detect_and_parse(text, self.tools) - - # Should return the full text as content when JSON parsing fails - self.assertIn("function call", result.normal_text) - self.assertEqual(len(result.calls), 0) - - def test_detect_and_parse_missing_name(self): - """Test parsing with missing function name.""" - text = '<|message_sep|>\n\nfunction call<|role_sep|>\n{"arguments": {"action": "create"}}' - result = self.detector.detect_and_parse(text, self.tools) - - # Should not extract tool call if name is missing - self.assertEqual(len(result.calls), 0) - - def test_detect_and_parse_missing_arguments(self): - """Test parsing with missing arguments field.""" - text = '<|message_sep|>\n\nfunction call<|role_sep|>\n{"name": "manage_user_memory"}' - result = self.detector.detect_and_parse(text, self.tools) - - # Should not extract tool call if arguments is missing - self.assertEqual(len(result.calls), 0) - - def test_detect_and_parse_arguments_not_dict(self): - """Test parsing with arguments that is not a dict.""" - text = '<|message_sep|>\n\nfunction call<|role_sep|>\n{"name": "manage_user_memory", "arguments": "string_args"}' - result = self.detector.detect_and_parse(text, self.tools) - - # Should not extract tool call if arguments is not a dict - self.assertEqual(len(result.calls), 0) - - def test_streaming_no_tool_call(self): - """Test streaming text without tool calls.""" - chunks = ["How ", "can ", "I ", "help ", "you?"] - - accumulated_text = "" - for chunk in chunks: - result = self.detector.parse_streaming_increment(chunk, self.tools) - accumulated_text += result.normal_text - - self.assertEqual(accumulated_text, "How can I help you?") - self.assertEqual(len(result.calls), 0) - - def test_streaming_simple_tool_call(self): - """Test streaming a simple tool call.""" - chunks = [ - "<|message_sep|>\n\n", - "function call", - "<|role_sep|>\n", - '{"name": "manage_user_memory", ', - '"arguments": {"action": "create"', - ', "id": "preferences"}}', - ] - - tool_calls_by_index = {} - for chunk in chunks: - result = self.detector.parse_streaming_increment(chunk, self.tools) - - for call in result.calls: - if call.tool_index is not None: - if call.tool_index not in tool_calls_by_index: - tool_calls_by_index[call.tool_index] = { - "name": "", - "parameters": "", - } - - if call.name: - tool_calls_by_index[call.tool_index]["name"] = call.name - if call.parameters: - tool_calls_by_index[call.tool_index][ - "parameters" - ] += call.parameters - - self.assertEqual(len(tool_calls_by_index), 1) - self.assertEqual(tool_calls_by_index[0]["name"], "manage_user_memory") - - params = json.loads(tool_calls_by_index[0]["parameters"]) - self.assertEqual(params["action"], "create") - self.assertEqual(params["id"], "preferences") - - def test_streaming_with_content_before(self): - """Test streaming with content before tool call.""" - chunks = [ - "I'll ", - "help ", - "you.", - "<|message_sep|>\n\n", - "function call", - "<|role_sep|>\n", - '{"name": "get_weather", ', - '"arguments": {"city": "Tokyo"}}', - ] - - accumulated_text = "" - tool_calls_by_index = {} - - for chunk in chunks: - result = self.detector.parse_streaming_increment(chunk, self.tools) - accumulated_text += result.normal_text - - for call in result.calls: - if call.tool_index is not None: - if call.tool_index not in tool_calls_by_index: - tool_calls_by_index[call.tool_index] = { - "name": "", - "parameters": "", - } - - if call.name: - tool_calls_by_index[call.tool_index]["name"] = call.name - if call.parameters: - tool_calls_by_index[call.tool_index][ - "parameters" - ] += call.parameters - - self.assertEqual(accumulated_text, "I'll help you.") - self.assertEqual(len(tool_calls_by_index), 1) - self.assertEqual(tool_calls_by_index[0]["name"], "get_weather") - - params = json.loads(tool_calls_by_index[0]["parameters"]) - self.assertEqual(params["city"], "Tokyo") - - def test_streaming_complex_arguments(self): - """Test streaming with complex nested arguments.""" - chunks = [ - "<|message_sep|>\n\n", - "functi", - "on call<|role_sep|>\n", - '{"name": "manage_user_memory", "arguments": ', - '{"action": "create", "id": "prefs", ', - '"content": {"likes": ["short", "clear"], ', - '"dislikes": ["emojis", "verbose"]}', - "}}", - ] - - tool_calls_by_index = {} - - for chunk in chunks: - result = self.detector.parse_streaming_increment(chunk, self.tools) - - for call in result.calls: - if call.tool_index is not None: - if call.tool_index not in tool_calls_by_index: - tool_calls_by_index[call.tool_index] = { - "name": "", - "parameters": "", - } - - if call.name: - tool_calls_by_index[call.tool_index]["name"] = call.name - if call.parameters: - tool_calls_by_index[call.tool_index][ - "parameters" - ] += call.parameters - - self.assertEqual(len(tool_calls_by_index), 1) - self.assertEqual(tool_calls_by_index[0]["name"], "manage_user_memory") - - params = json.loads(tool_calls_by_index[0]["parameters"]) - self.assertEqual(params["action"], "create") - self.assertEqual(params["content"]["likes"], ["short", "clear"]) - self.assertEqual(params["content"]["dislikes"], ["emojis", "verbose"]) - - def test_streaming_with_eos_token(self): - """Test streaming with EOS token at the end.""" - chunks = [ - "<|message_sep|>\n\n", - "function c", - "all<|role_sep|>\n", - '{"name": "get_weather", ', - '"arguments": {"city": "Paris"}}', - "", - ] - - tool_calls_by_index = {} - - for chunk in chunks: - result = self.detector.parse_streaming_increment(chunk, self.tools) - - for call in result.calls: - if call.tool_index is not None: - if call.tool_index not in tool_calls_by_index: - tool_calls_by_index[call.tool_index] = { - "name": "", - "parameters": "", - } - - if call.name: - tool_calls_by_index[call.tool_index]["name"] = call.name - if call.parameters: - tool_calls_by_index[call.tool_index][ - "parameters" - ] += call.parameters - - self.assertEqual(len(tool_calls_by_index), 1) - self.assertEqual(tool_calls_by_index[0]["name"], "get_weather") - - params = json.loads(tool_calls_by_index[0]["parameters"]) - self.assertEqual(params["city"], "Paris") - - def test_streaming_incomplete_json(self): - """Test streaming with incomplete JSON (no closing brace).""" - chunks = [ - "<|message_sep|>\n\n", - "fun", - "ction call<|role_sep|>\n", - '{"name": "get_weather", ', - '"arguments": {"city": "London"', - # Missing closing braces - ] - - tool_calls_by_index = {} - - for chunk in chunks: - result = self.detector.parse_streaming_increment(chunk, self.tools) - - for call in result.calls: - if call.tool_index is not None: - if call.tool_index not in tool_calls_by_index: - tool_calls_by_index[call.tool_index] = { - "name": "", - "parameters": "", - } - - if call.name: - tool_calls_by_index[call.tool_index]["name"] = call.name - if call.parameters: - tool_calls_by_index[call.tool_index][ - "parameters" - ] += call.parameters - - # Should have name but incomplete parameters - self.assertEqual(len(tool_calls_by_index), 1) - self.assertEqual(tool_calls_by_index[0]["name"], "get_weather") - self.assertTrue(tool_calls_by_index[0]["parameters"].startswith('{"city":')) - - def test_streaming_large_steps(self): - """Test streaming with large chunks that complete in fewer steps.""" - chunks = [ - "I'll remember that.", - "<|message_sep|>\n\nfuncti", - "on call<|role_sep|>\n", - '{"name": "manage_user_memory", "arguments": {"action": "create", "id": "preferences", "content": {"short_answers": true, "hate_emojis": true, ', - '"english_ui": false, "russian_math_explanations": true}', - "}}", - ] - - accumulated_text = "" - tool_calls_by_index = {} - - for chunk in chunks: - result = self.detector.parse_streaming_increment(chunk, self.tools) - accumulated_text += result.normal_text - - for call in result.calls: - if call.tool_index is not None: - if call.tool_index not in tool_calls_by_index: - tool_calls_by_index[call.tool_index] = { - "name": "", - "parameters": "", - } - - if call.name: - tool_calls_by_index[call.tool_index]["name"] = call.name - if call.parameters: - tool_calls_by_index[call.tool_index][ - "parameters" - ] += call.parameters - - self.assertEqual(accumulated_text, "I'll remember that.") - self.assertEqual(len(tool_calls_by_index), 1) - self.assertEqual(tool_calls_by_index[0]["name"], "manage_user_memory") - - params = json.loads(tool_calls_by_index[0]["parameters"]) - self.assertEqual(params["action"], "create") - self.assertEqual(params["content"]["short_answers"], True) - self.assertEqual(params["content"]["russian_math_explanations"], True) - - def test_streaming_very_small_chunks(self): - """Test streaming with very small chunks (character by character).""" - text = '{"name": "get_weather", "arguments": {"city": "NYC"}}' - - # Split into very small chunks (every 5 characters) - chunk_size = 5 - chunked_text = [ - text[i : i + chunk_size] for i in range(0, len(text), chunk_size) - ] - chunks = [ - "<|message_sep|>\n\n", - "func", - "tion call", - "<|role_sep|>\n", - *chunked_text, - ] - tool_calls_by_index = {} - - for chunk in chunks: - result = self.detector.parse_streaming_increment(chunk, self.tools) - - for call in result.calls: - if call.tool_index is not None: - if call.tool_index not in tool_calls_by_index: - tool_calls_by_index[call.tool_index] = { - "name": "", - "parameters": "", - } - - if call.name: - tool_calls_by_index[call.tool_index]["name"] = call.name - if call.parameters: - tool_calls_by_index[call.tool_index][ - "parameters" - ] += call.parameters - - self.assertEqual(len(tool_calls_by_index), 1) - self.assertEqual(tool_calls_by_index[0]["name"], "get_weather") - - params = json.loads(tool_calls_by_index[0]["parameters"]) - self.assertEqual(params["city"], "NYC") - - def test_streaming_json_split_at_quotes(self): - """Test streaming when JSON is split at quote boundaries.""" - chunks = [ - "<|message_sep|>\n\nfunction call<|role_sep|>\n", - '{"name', - '": "', - "get_weather", - '", "arguments', - '": {"city', - '": "', - "Rome", - '"}}', - ] - - tool_calls_by_index = {} - - for chunk in chunks: - result = self.detector.parse_streaming_increment(chunk, self.tools) - - for call in result.calls: - if call.tool_index is not None: - if call.tool_index not in tool_calls_by_index: - tool_calls_by_index[call.tool_index] = { - "name": "", - "parameters": "", - } - - if call.name: - tool_calls_by_index[call.tool_index]["name"] = call.name - if call.parameters: - tool_calls_by_index[call.tool_index][ - "parameters" - ] += call.parameters - - self.assertEqual(len(tool_calls_by_index), 1) - self.assertEqual(tool_calls_by_index[0]["name"], "get_weather") - - params = json.loads(tool_calls_by_index[0]["parameters"]) - self.assertEqual(params["city"], "Rome") - - -class TestGemma4Detector(unittest.TestCase): - def setUp(self): - self.tools = [ - Tool( - type="function", - function=Function( - name="get_weather", - description="Get weather information", - parameters={ - "type": "object", - "properties": { - "location": {"type": "string"}, - "unit": { - "type": "string", - "enum": ["celsius", "fahrenheit"], - }, - }, - "required": ["location"], - }, - ), - ) - ] - self.detector = Gemma4Detector() - - def test_detect_and_parse(self): - text = 'Some text before <|tool_call>call:get_weather{location:<|"|>Tokyo<|"|>}' - result = self.detector.detect_and_parse(text, self.tools) - - self.assertEqual(result.normal_text, "Some text before ") - self.assertEqual(len(result.calls), 1) - self.assertEqual(result.calls[0].name, "get_weather") - - params = json.loads(result.calls[0].parameters) - self.assertEqual(params["location"], "Tokyo") - - def test_parse_streaming_increment(self): - chunks = [ - "Some text ", - "before <|tool", - "_call>call:get_we", - "ather{location:<|", - '"|>Tokyo<|"|>} after", - ] - - all_results = [] - for chunk in chunks: - res = self.detector.parse_streaming_increment(chunk, self.tools) - all_results.append(res) - - combined_normal_text = "".join(r.normal_text for r in all_results) - self.assertEqual(combined_normal_text, "Some text before after") - - found_name = False - found_params = False - for res in all_results: - for call in res.calls: - if call.name == "get_weather": - found_name = True - if call.parameters: - params = json.loads(call.parameters) - if params == {"location": "Tokyo"}: - found_params = True - - self.assertTrue(found_name) - self.assertTrue(found_params) - - def test_nested_array_streaming(self): - # Additional coverage for complex structure - chunks = [ - '<|tool_call>call:get_weather{location:<|"', - '|>New York<|"|>,nested:[1, 2, {inner:<|"|>', - 'val<|"|>}]}', - ] - - all_results = [] - for chunk in chunks: - res = self.detector.parse_streaming_increment(chunk, self.tools) - all_results.append(res) - - found_params = False - for res in all_results: - for call in res.calls: - if call.parameters: - params = json.loads(call.parameters) - if "location" in params and params["location"] == "New York": - if "nested" in params and params["nested"] == [ - 1, - 2, - {"inner": "val"}, - ]: - found_params = True - - self.assertTrue(found_params) - - def test_has_tool_call(self): - self.assertTrue( - self.detector.has_tool_call( - '<|tool_call>call:get_weather{location:<|"|>Tokyo<|"|>}' - ) - ) - self.assertFalse(self.detector.has_tool_call("no tool call here")) - - def test_detect_and_parse_no_tool_call(self): - text = "This is plain text without any tool calls." - result = self.detector.detect_and_parse(text, self.tools) - self.assertEqual(result.normal_text, text) - self.assertEqual(len(result.calls), 0) - - def test_detect_and_parse_tool_index(self): - text = '<|tool_call>call:get_weather{location:<|"|>Tokyo<|"|>}' - result = self.detector.detect_and_parse(text, self.tools) - self.assertEqual(len(result.calls), 1) - self.assertEqual(result.calls[0].tool_index, 0) - self.assertEqual(result.calls[0].name, "get_weather") - - def test_detect_and_parse_unknown_tool_index(self): - text = '<|tool_call>call:unknown_func{arg:<|"|>val<|"|>}' - result = self.detector.detect_and_parse(text, self.tools) - self.assertEqual(len(result.calls), 1) - self.assertEqual(result.calls[0].tool_index, -1) - - def test_detect_and_parse_nested_object(self): - text = '<|tool_call>call:get_weather{location:<|"|>Tokyo<|"|>,details:{temp:25,unit:<|"|>celsius<|"|>}}' - result = self.detector.detect_and_parse(text, self.tools) - self.assertEqual(len(result.calls), 1) - params = json.loads(result.calls[0].parameters) - self.assertEqual(params["location"], "Tokyo") - self.assertIsInstance(params["details"], dict) - self.assertEqual(params["details"]["temp"], 25) - self.assertEqual(params["details"]["unit"], "celsius") - - def test_detect_and_parse_multiple_calls(self): - extra_tools = self.tools + [ - Tool( - type="function", - function=Function( - name="get_time", - description="Get current time", - parameters={ - "type": "object", - "properties": {"timezone": {"type": "string"}}, - }, - ), - ) - ] - text = ( - 'Some text <|tool_call>call:get_weather{location:<|"|>Tokyo<|"|>}' - ' more text <|tool_call>call:get_time{timezone:<|"|>UTC<|"|>}' - ) - result = self.detector.detect_and_parse(text, extra_tools) - self.assertEqual(len(result.calls), 2) - self.assertEqual(result.calls[0].name, "get_weather") - self.assertEqual(result.calls[1].name, "get_time") - self.assertEqual(result.normal_text, "Some text ") - - def test_parse_gemma4_args_empty(self): - self.assertEqual(_parse_gemma4_args(""), {}) - self.assertEqual(_parse_gemma4_args(" "), {}) - - def test_parse_gemma4_args_booleans(self): - result = _parse_gemma4_args("flag:true,other:false") - self.assertIs(result["flag"], True) - self.assertIs(result["other"], False) - - def test_parse_gemma4_args_numbers(self): - result = _parse_gemma4_args("count:42,ratio:3.14") - self.assertEqual(result["count"], 42) - self.assertAlmostEqual(result["ratio"], 3.14) - - def test_parse_gemma4_args_string_with_colon(self): - result = _parse_gemma4_args('url:<|"|>http://example.com<|"|>') - self.assertEqual(result["url"], "http://example.com") - - def test_parse_gemma4_args_nested_object(self): - result = _parse_gemma4_args('outer:{inner:<|"|>val<|"|>,num:5}') - self.assertIsInstance(result["outer"], dict) - self.assertEqual(result["outer"]["inner"], "val") - self.assertEqual(result["outer"]["num"], 5) - - def test_parse_gemma4_array_mixed_types(self): - result = _parse_gemma4_array('<|"|>hello<|"|>, 42, true, {key:<|"|>val<|"|>}') - self.assertEqual(result[0], "hello") - self.assertEqual(result[1], 42) - self.assertIs(result[2], True) - self.assertIsInstance(result[3], dict) - self.assertEqual(result[3]["key"], "val") - - def test_parse_gemma4_value_types(self): - self.assertIs(_parse_gemma4_value("true"), True) - self.assertIs(_parse_gemma4_value("false"), False) - self.assertEqual(_parse_gemma4_value("42"), 42) - self.assertAlmostEqual(_parse_gemma4_value("3.14"), 3.14) - self.assertEqual(_parse_gemma4_value("hello"), "hello") - self.assertEqual(_parse_gemma4_value(""), "") - - -if __name__ == "__main__": - unittest.main() diff --git a/test/registered/unit/function_call/test_function_call_parser.py b/test/registered/unit/function_call/test_function_call_parser.py index bd0e6fd3dfac..5ed54d3e9e56 100644 --- a/test/registered/unit/function_call/test_function_call_parser.py +++ b/test/registered/unit/function_call/test_function_call_parser.py @@ -6,7 +6,12 @@ from sglang.srt.function_call.core_types import StreamingParseResult from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector from sglang.srt.function_call.deepseekv32_detector import DeepSeekV32Detector -from sglang.srt.function_call.gemma4_detector import Gemma4Detector +from sglang.srt.function_call.gemma4_detector import ( + Gemma4Detector, + _parse_gemma4_args, + _parse_gemma4_array, + _parse_gemma4_value, +) from sglang.srt.function_call.gigachat3_detector import GigaChat3Detector from sglang.srt.function_call.glm4_moe_detector import Glm4MoeDetector from sglang.srt.function_call.glm47_moe_detector import Glm47MoeDetector @@ -3949,6 +3954,107 @@ def test_nested_array_streaming(self): self.assertTrue(found_params) + def test_has_tool_call(self): + self.assertTrue( + self.detector.has_tool_call( + '<|tool_call>call:get_weather{location:<|"|>Tokyo<|"|>}' + ) + ) + self.assertFalse(self.detector.has_tool_call("no tool call here")) + + def test_detect_and_parse_no_tool_call(self): + text = "This is plain text without any tool calls." + result = self.detector.detect_and_parse(text, self.tools) + self.assertEqual(result.normal_text, text) + self.assertEqual(len(result.calls), 0) + + def test_detect_and_parse_tool_index(self): + text = '<|tool_call>call:get_weather{location:<|"|>Tokyo<|"|>}' + result = self.detector.detect_and_parse(text, self.tools) + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].tool_index, 0) + self.assertEqual(result.calls[0].name, "get_weather") + + def test_detect_and_parse_unknown_tool_index(self): + text = '<|tool_call>call:unknown_func{arg:<|"|>val<|"|>}' + result = self.detector.detect_and_parse(text, self.tools) + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].tool_index, -1) + + def test_detect_and_parse_nested_object(self): + text = '<|tool_call>call:get_weather{location:<|"|>Tokyo<|"|>,details:{temp:25,unit:<|"|>celsius<|"|>}}' + result = self.detector.detect_and_parse(text, self.tools) + self.assertEqual(len(result.calls), 1) + params = json.loads(result.calls[0].parameters) + self.assertEqual(params["location"], "Tokyo") + self.assertIsInstance(params["details"], dict) + self.assertEqual(params["details"]["temp"], 25) + self.assertEqual(params["details"]["unit"], "celsius") + + def test_detect_and_parse_multiple_calls(self): + extra_tools = self.tools + [ + Tool( + type="function", + function=Function( + name="get_time", + description="Get current time", + parameters={ + "type": "object", + "properties": {"timezone": {"type": "string"}}, + }, + ), + ) + ] + text = ( + 'Some text <|tool_call>call:get_weather{location:<|"|>Tokyo<|"|>}' + ' more text <|tool_call>call:get_time{timezone:<|"|>UTC<|"|>}' + ) + result = self.detector.detect_and_parse(text, extra_tools) + self.assertEqual(len(result.calls), 2) + self.assertEqual(result.calls[0].name, "get_weather") + self.assertEqual(result.calls[1].name, "get_time") + self.assertEqual(result.normal_text, "Some text ") + + def test_parse_gemma4_args_empty(self): + self.assertEqual(_parse_gemma4_args(""), {}) + self.assertEqual(_parse_gemma4_args(" "), {}) + + def test_parse_gemma4_args_booleans(self): + result = _parse_gemma4_args("flag:true,other:false") + self.assertIs(result["flag"], True) + self.assertIs(result["other"], False) + + def test_parse_gemma4_args_numbers(self): + result = _parse_gemma4_args("count:42,ratio:3.14") + self.assertEqual(result["count"], 42) + self.assertAlmostEqual(result["ratio"], 3.14) + + def test_parse_gemma4_args_string_with_colon(self): + result = _parse_gemma4_args('url:<|"|>http://example.com<|"|>') + self.assertEqual(result["url"], "http://example.com") + + def test_parse_gemma4_args_nested_object(self): + result = _parse_gemma4_args('outer:{inner:<|"|>val<|"|>,num:5}') + self.assertIsInstance(result["outer"], dict) + self.assertEqual(result["outer"]["inner"], "val") + self.assertEqual(result["outer"]["num"], 5) + + def test_parse_gemma4_array_mixed_types(self): + result = _parse_gemma4_array('<|"|>hello<|"|>, 42, true, {key:<|"|>val<|"|>}') + self.assertEqual(result[0], "hello") + self.assertEqual(result[1], 42) + self.assertIs(result[2], True) + self.assertIsInstance(result[3], dict) + self.assertEqual(result[3]["key"], "val") + + def test_parse_gemma4_value_types(self): + self.assertIs(_parse_gemma4_value("true"), True) + self.assertIs(_parse_gemma4_value("false"), False) + self.assertEqual(_parse_gemma4_value("42"), 42) + self.assertAlmostEqual(_parse_gemma4_value("3.14"), 3.14) + self.assertEqual(_parse_gemma4_value("hello"), "hello") + self.assertEqual(_parse_gemma4_value(""), "") + if __name__ == "__main__": unittest.main() From 8fac51b5487ea348eed693c5378b1c234180e08f Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Sat, 4 Apr 2026 04:33:27 +0000 Subject: [PATCH 109/112] restore Qwen25 detector --- .../test_function_call_parser.py | 155 ++++++++++++++++++ 1 file changed, 155 insertions(+) diff --git a/test/registered/unit/function_call/test_function_call_parser.py b/test/registered/unit/function_call/test_function_call_parser.py index 5ed54d3e9e56..9e805741513e 100644 --- a/test/registered/unit/function_call/test_function_call_parser.py +++ b/test/registered/unit/function_call/test_function_call_parser.py @@ -3859,6 +3859,161 @@ def test_streaming_function_call_marker_json_split_at_quotes(self): self.assertEqual(params["city"], "Rome") +class TestQwen25Detector(unittest.TestCase): + """Test Qwen25Detector streaming and non-streaming multi-tool-call parsing.""" + + def setUp(self): + from sglang.srt.function_call.qwen25_detector import Qwen25Detector + + self.detector = Qwen25Detector() + self.tools = [ + Tool( + type="function", + function=Function( + name="get_current_weather", + description="Get the current weather in a given location", + parameters={ + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The city name", + }, + "state": { + "type": "string", + "description": "Two-letter state abbreviation", + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + }, + }, + "required": ["city", "state", "unit"], + }, + ), + ), + ] + + # -- Non-streaming tests -- + + def test_detect_and_parse_single_tool_call(self): + text = '\n{"name": "get_current_weather", "arguments": {"city": "NYC", "state": "NY", "unit": "fahrenheit"}}\n' + result = self.detector.detect_and_parse(text, self.tools) + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].name, "get_current_weather") + params = json.loads(result.calls[0].parameters) + self.assertEqual(params["city"], "NYC") + + def test_detect_and_parse_multiple_tool_calls(self): + text = ( + '\n{"name": "get_current_weather", "arguments": {"city": "NYC", "state": "NY", "unit": "fahrenheit"}}\n\n' + '\n{"name": "get_current_weather", "arguments": {"city": "Baltimore", "state": "MD", "unit": "fahrenheit"}}\n\n' + '\n{"name": "get_current_weather", "arguments": {"city": "Minneapolis", "state": "MN", "unit": "fahrenheit"}}\n\n' + '\n{"name": "get_current_weather", "arguments": {"city": "Los Angeles", "state": "CA", "unit": "fahrenheit"}}\n' + ) + result = self.detector.detect_and_parse(text, self.tools) + self.assertEqual(len(result.calls), 4) + cities = [json.loads(c.parameters)["city"] for c in result.calls] + self.assertEqual(cities, ["NYC", "Baltimore", "Minneapolis", "Los Angeles"]) + + def test_detect_and_parse_with_normal_text_prefix(self): + text = ( + "Sure, let me check the weather.\n" + '\n{"name": "get_current_weather", "arguments": {"city": "NYC", "state": "NY", "unit": "celsius"}}\n' + ) + result = self.detector.detect_and_parse(text, self.tools) + self.assertEqual(len(result.calls), 1) + self.assertIn("let me check", result.normal_text) + + # -- Streaming tests -- + + def _collect_streaming_tool_calls(self, chunks): + """Helper: feed chunks through streaming parser and collect tool calls by index.""" + tool_calls_by_index = {} + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, self.tools) + for call in result.calls: + if call.tool_index is not None: + if call.tool_index not in tool_calls_by_index: + tool_calls_by_index[call.tool_index] = { + "name": "", + "parameters": "", + } + if call.name: + tool_calls_by_index[call.tool_index]["name"] = call.name + if call.parameters: + tool_calls_by_index[call.tool_index][ + "parameters" + ] += call.parameters + return tool_calls_by_index + + def test_streaming_single_tool_call(self): + chunks = [ + "\n", + '{"name": "get_current_weather",', + ' "arguments": {"city": "NYC",', + ' "state": "NY",', + ' "unit": "fahrenheit"}}', + "\n", + ] + result = self._collect_streaming_tool_calls(chunks) + self.assertEqual(len(result), 1) + self.assertEqual(result[0]["name"], "get_current_weather") + params = json.loads(result[0]["parameters"]) + self.assertEqual(params["city"], "NYC") + + def test_streaming_multiple_tool_calls(self): + """Core regression test: multiple tool calls must all be parsed in streaming mode.""" + chunks = [ + "\n", + '{"name": "get_current_weather",', + ' "arguments": {"city": "NYC", "state": "NY", "unit": "fahrenheit"}}', + "\n\n", + "\n", + '{"name": "get_current_weather",', + ' "arguments": {"city": "Baltimore", "state": "MD", "unit": "fahrenheit"}}', + "\n\n", + "\n", + '{"name": "get_current_weather",', + ' "arguments": {"city": "LA", "state": "CA", "unit": "fahrenheit"}}', + "\n", + ] + result = self._collect_streaming_tool_calls(chunks) + self.assertEqual(len(result), 3, f"Expected 3 tool calls, got {len(result)}") + cities = [json.loads(result[i]["parameters"])["city"] for i in sorted(result)] + self.assertEqual(cities, ["NYC", "Baltimore", "LA"]) + + def test_streaming_multiple_tool_calls_fused_chunks(self): + """Test when separator and next bot_token arrive in a single chunk.""" + chunks = [ + '\n{"name": "get_current_weather", "arguments": {"city": "NYC", "state": "NY", "unit": "fahrenheit"}}', + '\n\n\n{"name": "get_current_weather",', + ' "arguments": {"city": "LA", "state": "CA", "unit": "fahrenheit"}}', + "\n", + ] + result = self._collect_streaming_tool_calls(chunks) + self.assertEqual(len(result), 2, f"Expected 2 tool calls, got {len(result)}") + cities = [json.loads(result[i]["parameters"])["city"] for i in sorted(result)] + self.assertEqual(cities, ["NYC", "LA"]) + + def test_streaming_multiple_tool_calls_char_by_char_separator(self): + """Test when the separator between tool calls arrives character by character.""" + call1 = '{"name": "get_current_weather", "arguments": {"city": "NYC", "state": "NY", "unit": "fahrenheit"}}' + call2 = '{"name": "get_current_weather", "arguments": {"city": "LA", "state": "CA", "unit": "celsius"}}' + separator = "\n\n\n" + + chunks = ["\n", call1] + for ch in separator: + chunks.append(ch) + chunks.append(call2) + chunks.append("\n") + + result = self._collect_streaming_tool_calls(chunks) + self.assertEqual(len(result), 2, f"Expected 2 tool calls, got {len(result)}") + cities = [json.loads(result[i]["parameters"])["city"] for i in sorted(result)] + self.assertEqual(cities, ["NYC", "LA"]) + + class TestGemma4Detector(unittest.TestCase): def setUp(self): self.tools = [ From a6cef69dc2ff12a0dcae40eecc1a074c579a6b74 Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Sat, 4 Apr 2026 04:40:38 +0000 Subject: [PATCH 110/112] single line removal --- python/sglang/srt/managers/detokenizer_manager.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index 8104929f89e8..ce27113845c7 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -315,6 +315,7 @@ def _decode_batch_token_id_output(self, recv_obj: BatchTokenIDOutput): incremental_output = output_str[s.sent_offset :] s.sent_offset = len(output_str) output_strs.append(incremental_output) + return output_strs def handle_batch_token_id_out(self, recv_obj: BatchTokenIDOutput): From 71b907ffb3675aa1bef8342744cdfeb12894049c Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Sat, 4 Apr 2026 04:46:36 +0000 Subject: [PATCH 111/112] hardening gemma 4 tool call and reasoning parser tests --- .../test_function_call_parser.py | 138 ++++++++++++++++++ .../unit/parser/test_reasoning_parser.py | 57 ++++++++ 2 files changed, 195 insertions(+) diff --git a/test/registered/unit/function_call/test_function_call_parser.py b/test/registered/unit/function_call/test_function_call_parser.py index 9e805741513e..01aa99904072 100644 --- a/test/registered/unit/function_call/test_function_call_parser.py +++ b/test/registered/unit/function_call/test_function_call_parser.py @@ -4210,6 +4210,144 @@ def test_parse_gemma4_value_types(self): self.assertEqual(_parse_gemma4_value("hello"), "hello") self.assertEqual(_parse_gemma4_value(""), "") + def _collect_streaming(self, chunks): + """Helper: feed chunks and collect normal text + tool calls by index.""" + normal_text = "" + tool_calls_by_index = {} + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, self.tools) + normal_text += result.normal_text + for call in result.calls: + if call.tool_index is not None: + if call.tool_index not in tool_calls_by_index: + tool_calls_by_index[call.tool_index] = { + "name": "", + "parameters": "", + } + if call.name: + tool_calls_by_index[call.tool_index]["name"] = call.name + if call.parameters: + tool_calls_by_index[call.tool_index][ + "parameters" + ] += call.parameters + return normal_text, tool_calls_by_index + + def test_streaming_multiple_tool_calls(self): + """Test streaming with two consecutive tool calls.""" + extra_tools = self.tools + [ + Tool( + type="function", + function=Function( + name="get_time", + description="Get current time", + parameters={ + "type": "object", + "properties": {"timezone": {"type": "string"}}, + }, + ), + ) + ] + chunks = [ + '<|tool_call>call:get_weather{location:<|"|>', + 'Tokyo<|"|>}', + ' <|tool_call>call:get_time{timezone:<|"|>', + 'UTC<|"|>}', + ] + normal_text = "" + tool_calls_by_index = {} + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, extra_tools) + normal_text += result.normal_text + for call in result.calls: + if call.tool_index is not None: + if call.tool_index not in tool_calls_by_index: + tool_calls_by_index[call.tool_index] = { + "name": "", + "parameters": "", + } + if call.name: + tool_calls_by_index[call.tool_index]["name"] = call.name + if call.parameters: + tool_calls_by_index[call.tool_index][ + "parameters" + ] += call.parameters + + self.assertEqual(len(tool_calls_by_index), 2) + self.assertEqual(tool_calls_by_index[0]["name"], "get_weather") + self.assertEqual(tool_calls_by_index[1]["name"], "get_time") + params0 = json.loads(tool_calls_by_index[0]["parameters"]) + params1 = json.loads(tool_calls_by_index[1]["parameters"]) + self.assertEqual(params0["location"], "Tokyo") + self.assertEqual(params1["timezone"], "UTC") + + def test_streaming_very_small_chunks(self): + """Test streaming with character-by-character chunks.""" + full_text = '<|tool_call>call:get_weather{location:<|"|>Rome<|"|>}' + chunks = list(full_text) + + normal_text, tool_calls = self._collect_streaming(chunks) + + self.assertEqual(len(tool_calls), 1) + self.assertEqual(tool_calls[0]["name"], "get_weather") + params = json.loads(tool_calls[0]["parameters"]) + self.assertEqual(params["location"], "Rome") + + def test_streaming_empty_args(self): + """Test streaming a tool call with no arguments.""" + chunks = ["<|tool_call>call:get_weather{}", ""] + normal_text, tool_calls = self._collect_streaming(chunks) + self.assertEqual(len(tool_calls), 1) + self.assertEqual(tool_calls[0]["name"], "get_weather") + + def test_streaming_text_between_tool_calls(self): + """Test streaming with normal text interleaved between two different tool calls.""" + extra_tools = self.tools + [ + Tool( + type="function", + function=Function( + name="get_time", + description="Get current time", + parameters={ + "type": "object", + "properties": {"timezone": {"type": "string"}}, + }, + ), + ) + ] + chunks = [ + "Hello! ", + '<|tool_call>call:get_weather{location:<|"|>Paris<|"|>}', + " Let me also check ", + '<|tool_call>call:get_time{timezone:<|"|>UTC<|"|>}', + ] + normal_text = "" + tool_calls_by_index = {} + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, extra_tools) + normal_text += result.normal_text + for call in result.calls: + if call.tool_index is not None: + if call.tool_index not in tool_calls_by_index: + tool_calls_by_index[call.tool_index] = { + "name": "", + "parameters": "", + } + if call.name: + tool_calls_by_index[call.tool_index]["name"] = call.name + if call.parameters: + tool_calls_by_index[call.tool_index][ + "parameters" + ] += call.parameters + self.assertIn("Hello!", normal_text) + self.assertIn("Let me also check", normal_text) + self.assertEqual(len(tool_calls_by_index), 2) + self.assertEqual(tool_calls_by_index[0]["name"], "get_weather") + self.assertEqual(tool_calls_by_index[1]["name"], "get_time") + params0 = json.loads(tool_calls_by_index[0]["parameters"]) + params1 = json.loads(tool_calls_by_index[1]["parameters"]) + self.assertEqual(params0["location"], "Paris") + self.assertEqual(params1["timezone"], "UTC") + if __name__ == "__main__": unittest.main() diff --git a/test/registered/unit/parser/test_reasoning_parser.py b/test/registered/unit/parser/test_reasoning_parser.py index cdf7d7c9bd38..8f05d7903e9b 100644 --- a/test/registered/unit/parser/test_reasoning_parser.py +++ b/test/registered/unit/parser/test_reasoning_parser.py @@ -656,6 +656,63 @@ def test_streaming_partial_start_buffered(self): self.assertEqual(result.normal_text, "") self.assertEqual(result.reasoning_text, "") + def test_streaming_end_token_mid_chunk(self): + """Test end token arriving in the same chunk as reasoning content.""" + self.detector.parse_streaming_increment("<|channel>thought\n") + result = self.detector.parse_streaming_increment( + "some reasoningthe answer" + ) + self.assertEqual(result.reasoning_text, "some reasoning") + self.assertEqual(result.normal_text, "the answer") + self.assertFalse(self.detector._in_reasoning) + + def test_streaming_split_end_token(self): + """Test end token split across two chunks.""" + self.detector.parse_streaming_increment("<|channel>thought\n") + self.detector.parse_streaming_increment("reasoning content") + + result1 = self.detector.parse_streaming_increment("final answer") + self.assertFalse(self.detector._in_reasoning) + self.assertIn("final answer", result2.normal_text) + + def test_streaming_self_label_split_across_chunks(self): + """Test self_label ('thought\\n') arriving separately from start token.""" + result1 = self.detector.parse_streaming_increment("<|channel>") + self.assertEqual(result1.reasoning_text, "") + self.assertEqual(result1.normal_text, "") + + result2 = self.detector.parse_streaming_increment("thought\n") + self.assertTrue(self.detector._in_reasoning) + + result3 = self.detector.parse_streaming_increment("reasoning here") + self.assertEqual(result3.reasoning_text, "reasoning here") + + def test_streaming_force_reasoning(self): + """Test streaming with force_reasoning=True (no start token needed).""" + detector = Gemma4Detector(force_reasoning=True) + + result1 = detector.parse_streaming_increment("reasoning content") + self.assertEqual(result1.reasoning_text, "reasoning content") + self.assertEqual(result1.normal_text, "") + + result2 = detector.parse_streaming_increment("the answer") + self.assertFalse(detector._in_reasoning) + self.assertIn("the answer", result2.normal_text) + + def test_streaming_multiple_reasoning_chunks(self): + """Test reasoning content arriving in many small chunks.""" + self.detector.parse_streaming_increment("<|channel>thought\n") + + all_reasoning = "" + for chunk in ["Think", "ing ", "step ", "by ", "step."]: + result = self.detector.parse_streaming_increment(chunk) + all_reasoning += result.reasoning_text + self.assertEqual(result.normal_text, "") + self.assertEqual(all_reasoning, "Thinking step by step.") + def test_force_reasoning(self): """Test Gemma4Detector with force_reasoning=True.""" detector = Gemma4Detector(force_reasoning=True) From 54ac982967c3a1ee764bd4eba439050c49928ab8 Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Sat, 4 Apr 2026 06:09:06 +0000 Subject: [PATCH 112/112] nit --- python/sglang/srt/mem_cache/swa_memory_pool.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/sglang/srt/mem_cache/swa_memory_pool.py b/python/sglang/srt/mem_cache/swa_memory_pool.py index ab2cda35ee6f..80f24e6dfa56 100644 --- a/python/sglang/srt/mem_cache/swa_memory_pool.py +++ b/python/sglang/srt/mem_cache/swa_memory_pool.py @@ -50,7 +50,6 @@ def __init__( self.device = device self.swa_layer_nums = len(swa_attention_layer_ids) self.full_layer_nums = len(full_attention_layer_ids) - self.start_layer = 0 self.page_size = page_size self.swa_loc = None @@ -165,7 +164,6 @@ def set_kv_buffer( layer_id = layer.layer_id layer_id_pool, is_swa_layer = self.layers_mapping[layer_id] - if is_swa_layer: if self.swa_loc is not None: loc = self.swa_loc