plugin: hybrid Gemma4 wiring — block_table padding + model registration#390
Merged
Conversation
viktorpusTT
approved these changes
May 13, 2026
|
The comment and explanation regarding model registration seems really suspicious and at least partially wrong - other TT models don't get away without explicit registration. |
Two changes both required for vLLM to serve Gemma4 through the hybrid kv-cache-groups path: ``model_runner.py`` — pad legacy ``block_tables_per_group[0]`` to ``max_num_blocks_per_req``. vLLM's per-block-byte unifier produces per-group block_table native widths of ``cdiv(max_model_len, group_block_size)``. For Gemma4-E2B with ``cache_config.block_size=64``, sliding ends up at block_size=128 (width=cdiv(131072,128)=1024) and full at block_size=64 (width=2048). The legacy single-tensor view (``block_tables_per_group[0]``) is sliced to ``:max_num_blocks_per_req`` but was not padded, so on single-card runs where sliding wins group 0 the runtime tensor was narrower than ``warmup_model_decode``'s page_table. ``copy_host_to _device`` then tripped its shape assertion on the second decode trace replay. ``platform.py`` — register ``Gemma4ForCausalLM`` / ``Gemma4ForConditionalGeneration`` arch names so vLLM resolves them to the TT bridge in ``models/demos/gemma4/tt/generator_vllm.py``. Other TT models get away without explicit registration because upstream vLLM has a torch impl that the platform layer finds first; Gemma4 has none yet. The bridge itself lands in tt-metal (separate PR). Without that PR this plugin change is a no-op (no Gemma4 model code to dispatch to); without this plugin change the bridge can't be reached from vLLM. They go in together. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
baa503d to
17b75df
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Two changes required for vLLM to serve Gemma4 through the hybrid kv-cache-groups path. Companion to tt-metal PR vllm-project#44265 (the bridge itself); these go in together.
1.
model_runner.py— padblock_tables_per_group[i]tomax_num_blocks_per_reqvLLM's per-block-byte unifier (
unify_kv_cache_spec_page_size) produces per-group block_table native widths ofcdiv(max_model_len, group_block_size). For Gemma4-E2B withcache_config.block_size=64the unifier doubles sliding's block_size to 128 so per-block bytes match full's: sliding ends up with widthcdiv(131072,128)=1024, full with widthcdiv(131072,64)=2048.The legacy single-tensor view (
block_tables_per_group[0]) was sliced to:max_num_blocks_per_reqbut never padded. On single-card runs where sliding wins group 0 the runtime tensor was narrower than the page_tablewarmup_model_decodecaptured the trace against, andcopy_host_to_devicethen tripped its shape assertion on the second decode trace replay. The fix pads each group's block_table up tomax_num_blocks_per_reqwidth with zeros — zeros are safe because the kernel only reads up to each layer's active block count.The per-layer expansion in
_block_tables_per_layeralready padded; the legacy single-tensor view was the gap.2.
platform.py— register Gemma4 arch namesGemma4ForCausalLMandGemma4ForConditionalGenerationare added to the TT model registry so vLLM dispatches them to the bridge inmodels/demos/gemma4/tt/generator_vllm.py. Other TT models (Gemma3, GptOss, ...) get away without explicit registration because upstream vLLM has a torch impl that the platform layer finds first via the inferred registry path; Gemma4 has no upstream vLLM implementation yet, so the inspection has to land on the TT class directly.Dependencies
Test plan
vllm servewith--model google/gemma-4-E2B-it --max_num_seqs 1produces coherent chat completions ("Paris", working jokes/haiku).block_tablepad is a no-op when per-group native widths already equalmax_num_blocks_per_req, and the registration additions don't touch existing entries.🤖 Generated with Claude Code