Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 40 additions & 23 deletions plugins/vllm-tt-plugin/src/vllm_tt_plugin/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,22 +365,20 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
# group structure during input prep / forward.
self.kv_cache_config = kv_cache_config

# Build the persistent input batch. Upstream's hybrid kv cache
# manager enforces a uniform page_size across groups, and to keep
# the existing block-table plumbing intact we additionally require
# a single block_size across groups for now. Different block_sizes
# per group will require the input batch + block table consumers
# to learn to address per-group block tables with mixed block
# sizes.
block_size = kv_cache_groups[0].kv_cache_spec.block_size
assert all(g.kv_cache_spec.block_size == block_size for g in kv_cache_groups), (
"Mixed block sizes across kv_cache_groups not yet supported"
)

# One block table per group, all with the same block_size. The
# MultiGroupBlockTable then has ``len(kv_cache_groups)`` entries
# — single-group models keep the previous shape (length 1).
per_group_block_sizes = [block_size] * len(kv_cache_groups)
# Upstream's hybrid kv cache manager equalises *page size*
# (block_size × num_kv_heads × head_size × dtype_bytes) across
# groups, not block_size itself: when groups have different
# ``num_kv_heads × head_size`` (e.g. Gemma4's full layers use
# head_dim=512 vs sliding head_dim=256), upstream's
# ``unify_kv_cache_spec_page_size`` adjusts ``block_size`` per
# spec instead. Use each group's own ``block_size`` here; the
# input batch / MultiGroupBlockTable already takes a per-group
# list. ``self.cache_config.block_size`` (the user-specified
# value) is still used elsewhere for per-request bounds — that's
# the smaller of the unified sizes, which conservatively
# overestimates for the larger-block groups (extra block-table
# rows allocated, never indexed).
per_group_block_sizes = [g.kv_cache_spec.block_size for g in kv_cache_groups]

max_num_reqs = self.scheduler_config.max_num_seqs
max_model_len = self.model_config.max_model_len
Expand Down Expand Up @@ -841,13 +839,32 @@ def _prepare_model_inputs(
assert num_reqs > 0

# Second dim of each block table is (ceil(max_model_len / block_size)).
# Slice to self.max_num_blocks_per_req which also takes into
# account max num blocks in KV cache in case max KV blocks is smaller.
# Constant shape is required for ttnn tracing to work.
block_tables_per_group = [
bt.get_cpu_tensor()[:num_reqs, : self.max_num_blocks_per_req]
for bt in input_batch.block_table.block_tables
]
# Slice/pad to self.max_num_blocks_per_req: slicing handles
# over-wide tables (a group's native width can exceed the global
# cap when ``max_num_blocks_per_req`` is bound by total KV cache
# size rather than max_model_len), and padding handles
# under-wide ones (hybrid kv-cache-groups with unified page sizes
# produce per-group block_tables of different native widths —
# e.g. Gemma4-E2B with ``cache_config.block_size=64`` ends up
# with sliding's group at 128 block_size and full's at 64,
# giving widths cdiv(max_model_len, 128) and
# cdiv(max_model_len, 64) respectively). The TT side captures
# decode traces against ``max_num_blocks_per_req`` (see
# ``warmup_model_decode``) and ``copy_host_to_device`` asserts
# shape-equality on replay, so runtime block_tables must match
# that width even when their underlying group is narrower.
target_width = self.max_num_blocks_per_req
block_tables_per_group = []
for bt in input_batch.block_table.block_tables:
bt_cpu = bt.get_cpu_tensor()[:num_reqs, :target_width]
if bt_cpu.shape[1] < target_width:
pad = torch.zeros(
bt_cpu.shape[0],
target_width - bt_cpu.shape[1],
dtype=bt_cpu.dtype,
)
bt_cpu = torch.cat([bt_cpu, pad], dim=1)
block_tables_per_group.append(bt_cpu)

# DP optimization: don't send padding blocks if possible to reduce
# overhead from gathering inputs to rank 0 and rely on DP concat
Expand Down
27 changes: 27 additions & 0 deletions plugins/vllm-tt-plugin/src/vllm_tt_plugin/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,33 @@ def register_tt_models(register_test_models=False) -> None:
"models.tt_transformers.tt.generator_vllm:Gemma3ForConditionalGeneration",
)

# Gemma4 — text-only TT bridge.
#
# Gemma4 isn't in vLLM's upstream registry, so without an entry here
# the upstream architecture resolver falls back to
# ``TransformersMultiModalForCausalLM`` (because ``hf_config !=
# hf_text_config`` for Gemma4's nested config — see
# ``ModelConfig._get_transformers_backend_cls``) and crashes on the
# ``_processor_factory`` assertion in the multimodal registry. The
# plugin's later ``TT``-prefix logic runs after that resolution, so
# it can't help.
#
# We register the plain HF arch names directly so upstream resolution
# finds our class. Since ``Gemma4ForCausalLM`` (the TT class) does not
# use ``SupportsMultiModal``, vLLM's ``_model_info.supports_multimodal``
# is False, ``multimodal_config`` is not populated, and the request
# path stays text-only — which matches what the TT model implements.
# The ``TT``-prefixed aliases satisfy the plugin's later validation
# in ``check_and_update_config`` so no override is needed.
_gemma4_target = "models.demos.gemma4.tt.generator_vllm:Gemma4ForCausalLM"
for arch in (
"Gemma4ForCausalLM",
"Gemma4ForConditionalGeneration",
"TTGemma4ForCausalLM",
"TTGemma4ForConditionalGeneration",
):
_register_model_if_missing(ModelRegistry, arch, _gemma4_target)

# DeepseekV3
_register_model_if_missing(
ModelRegistry,
Expand Down
Loading