Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions vllm_gaudi/v1/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,14 @@
except ImportError:
LMCacheConnectorMetadata = None

_GDN_MAMBA_TYPES: tuple[object, ...] = ("gdn_attention", "linear_attention")
try:
from vllm.v1.attention.backends.registry import MambaAttentionBackendEnum
_GDN_MAMBA_TYPES = (MambaAttentionBackendEnum.GDN_ATTN, MambaAttentionBackendEnum.LINEAR, "gdn_attention",
"linear_attention")
except (ImportError, AttributeError):
pass

_TYPE_CACHE: dict[str, dict[str, Any]] = {}

HPU_TORCH_DTYPE_TO_STR_DTYPE = {
Expand Down Expand Up @@ -5880,8 +5888,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
if self.num_mamba_like_layers > 0 and self._compact_gdn_enabled:
self._num_gdn_groups = sum(
1 for g in kv_cache_config.kv_cache_groups
if isinstance(g.kv_cache_spec, MambaSpec) and g.kv_cache_spec.mamba_type in ("gdn_attention",
"linear_attention"))
if isinstance(g.kv_cache_spec, MambaSpec) and g.kv_cache_spec.mamba_type in _GDN_MAMBA_TYPES)
# Profiling may request more sequences than max_num_seqs
# (e.g. VLLM_PROFILE_DECODE=16,64 with max_num_seqs=1).
# Ensure GDN compact tensors and free-list are large enough.
Expand Down Expand Up @@ -5916,7 +5923,7 @@ def _needs_raw_buffer(kv_cache_tensor) -> bool:
if isinstance(spec, FullAttentionSpec):
continue
if isinstance(spec, MambaSpec) and \
spec.mamba_type in ("gdn_attention", "linear_attention"):
spec.mamba_type in _GDN_MAMBA_TYPES:
continue
# Standard Mamba2 or unknown spec — needs raw buffer
return True
Expand Down Expand Up @@ -5964,7 +5971,7 @@ def _needs_raw_buffer(kv_cache_tensor) -> bool:
vc = torch.zeros(kv_cache_shape, dtype=kv_cache_spec.dtype, device=self.device)
kv_caches[layer_name] = (kc, vc, None, None)
elif isinstance(kv_cache_spec, MambaSpec) and \
kv_cache_spec.mamba_type in ("gdn_attention", "linear_attention") and \
kv_cache_spec.mamba_type in _GDN_MAMBA_TYPES and \
self._compact_gdn_enabled:
# GDN/linear_attention: compact allocation.
# All GDN groups share the same state tensor, so each
Expand All @@ -5991,7 +5998,7 @@ def _needs_raw_buffer(kv_cache_tensor) -> bool:
kv_caches[shared_layer] = tuple(state_tensors)
break
elif isinstance(kv_cache_spec, MambaSpec) and \
kv_cache_spec.mamba_type in ("gdn_attention", "linear_attention"):
kv_cache_spec.mamba_type in _GDN_MAMBA_TYPES:
# GDN/linear_attention: non-compact (baseline) allocation
# using contiguous tensors with num_blocks+1 slots.
if isinstance(kv_caches.get(layer_name), tuple):
Expand Down Expand Up @@ -6056,7 +6063,7 @@ def _needs_raw_buffer(kv_cache_tensor) -> bool:
vc = torch.zeros(kv_cache_shape, dtype=kv_cache_spec.dtype, device=self.device)
kv_caches[layer_name] = (kc, vc, None, None)
elif isinstance(kv_cache_spec, MambaSpec) and \
kv_cache_spec.mamba_type in ("gdn_attention", "linear_attention") and \
kv_cache_spec.mamba_type in _GDN_MAMBA_TYPES and \
self._compact_gdn_enabled:
# GDN/linear_attention: compact allocation.
self._compact_gdn_group_ids.add(group_idx)
Expand All @@ -6076,7 +6083,7 @@ def _needs_raw_buffer(kv_cache_tensor) -> bool:
kv_caches[shared_layer] = tuple(state_tensors)
break
elif isinstance(kv_cache_spec, MambaSpec) and \
kv_cache_spec.mamba_type in ("gdn_attention", "linear_attention"):
kv_cache_spec.mamba_type in _GDN_MAMBA_TYPES:
# GDN/linear_attention: non-compact (baseline) allocation.
if isinstance(kv_caches.get(layer_name), tuple):
continue
Expand Down
15 changes: 5 additions & 10 deletions vllm_gaudi/v1/worker/hpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from vllm.v1.outputs import (DraftTokenIds, AsyncModelRunnerOutput, ModelRunnerOutput)
from vllm.v1.worker.utils import bind_kv_cache
from vllm_gaudi.utils import is_fake_hpu
from vllm_gaudi.v1.worker.hpu_model_runner import HPUModelRunner
from vllm_gaudi.v1.worker.hpu_model_runner import HPUModelRunner, _GDN_MAMBA_TYPES
from vllm.v1.worker.worker_base import CompilationTimes, WorkerBase

from vllm_gaudi.extension.logger import logger as init_logger
Expand Down Expand Up @@ -441,12 +441,9 @@ def determine_available_memory(self) -> int:
# Reduce reported memory so the scheduler computes fewer
# num_blocks that fit the HPU separate-allocation model.
has_attn = any(isinstance(s, FullAttentionSpec) for s in kv_cache_spec.values())
has_gdn = any(
isinstance(s, MambaSpec) and s.mamba_type in ("gdn_attention", "linear_attention")
for s in kv_cache_spec.values())
has_gdn = any(isinstance(s, MambaSpec) and s.mamba_type in _GDN_MAMBA_TYPES for s in kv_cache_spec.values())
has_standard_mamba = any(
isinstance(s, MambaSpec) and s.mamba_type not in ("gdn_attention", "linear_attention")
for s in kv_cache_spec.values())
isinstance(s, MambaSpec) and s.mamba_type not in _GDN_MAMBA_TYPES for s in kv_cache_spec.values())
compact_gdn = os.environ.get("VLLM_COMPACT_GDN", "0").strip().lower() in ("1", "true")
if has_attn and has_gdn and not compact_gdn:
# When compact GDN is OFF, GDN state scales with num_blocks
Expand All @@ -462,8 +459,7 @@ def determine_available_memory(self) -> int:
real_attn = next(s.real_page_size_bytes for s in kv_cache_spec.values() if isinstance(s, FullAttentionSpec))
real_mamba = next(
sum(math.prod(sh) * get_dtype_size(dt) for sh, dt in zip(s.shapes, s.dtypes))
for s in kv_cache_spec.values()
if isinstance(s, MambaSpec) and s.mamba_type in ("gdn_attention", "linear_attention"))
for s in kv_cache_spec.values() if isinstance(s, MambaSpec) and s.mamba_type in _GDN_MAMBA_TYPES)
total_real = real_attn + real_mamba
if total_real > padded_page:
factor = padded_page / total_real
Expand All @@ -484,8 +480,7 @@ def determine_available_memory(self) -> int:
attn_page_size = next(s.page_size_bytes for s in kv_cache_spec.values() if isinstance(s, FullAttentionSpec))
mamba_state_per_block = next(
sum(math.prod(sh) * get_dtype_size(dt) for sh, dt in zip(s.shapes, s.dtypes))
for s in kv_cache_spec.values()
if isinstance(s, MambaSpec) and s.mamba_type not in ("gdn_attention", "linear_attention"))
for s in kv_cache_spec.values() if isinstance(s, MambaSpec) and s.mamba_type not in _GDN_MAMBA_TYPES)
if attn_page_size > 0:
ratio = attn_page_size / (attn_page_size + mamba_state_per_block)
adjusted = int(available * ratio)
Expand Down
Loading