Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 0 additions & 30 deletions tests/models/vllm/test_pallas_torchax.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,21 +214,6 @@ def test_init_with_fp8_kv_cache_raises_error(self):
attn_type=AttentionType.DECODER,
)

def test_init_with_blocksparse_raises_error(self):
with pytest.raises(ValueError,
match="does not support block-sparse attention"):
PallasAttentionBackendImpl(
num_heads=32,
head_size=128,
scale=0.088,
num_kv_heads=8,
alibi_slopes=None,
sliding_window=None,
kv_cache_dtype="auto",
blocksparse_params={"block_size": 16},
attn_type=AttentionType.DECODER,
)

def test_init_with_encoder_attention_raises_error(self):
with pytest.raises(NotImplementedError,
match="Encoder self-attention"):
Expand Down Expand Up @@ -326,21 +311,6 @@ def test_init_with_irope_warning(self):
"Using irope in Pallas is not supported yet, it will fall back "
"to global attention for long context.")

def test_init_with_blocksparse_constructor_error(self):
with pytest.raises(ValueError,
match="does not support block-sparse attention"):
PallasAttentionBackendImpl(
num_heads=32,
head_size=128,
scale=0.088,
num_kv_heads=8,
alibi_slopes=None,
sliding_window=None,
kv_cache_dtype="auto",
blocksparse_params={"block_size": 16},
attn_type=AttentionType.DECODER,
)

@patch(
'tpu_commons.attention.backends.pallas_torchax.ragged_paged_attention')
@patch('tpu_commons.attention.backends.pallas_torchax.get_forward_context')
Expand Down
6 changes: 1 addition & 5 deletions tpu_commons/attention/backends/pallas_torchax.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Any, Optional
from typing import Optional

import torch
from jax.tree_util import register_pytree_node_class
Expand Down Expand Up @@ -132,7 +132,6 @@ def __init__(
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[int] = None,
Expand All @@ -142,9 +141,6 @@ def __init__(
logger.warning_once(
"Using irope in Pallas is not supported yet, it will fall back "
"to global attention for long context.")
if blocksparse_params is not None:
raise ValueError("Paged attention Pallas kernel does "
"not support block-sparse attention.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
Expand Down
4 changes: 2 additions & 2 deletions tpu_commons/worker/tpu_worker_torchax.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,5 +292,5 @@ def _init_tpu_worker_distributed_environment(
backend="gloo",
)
ensure_model_parallel_initialized(
parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
self.parallel_config.tensor_parallel_size,
self.parallel_config.pipeline_parallel_size)
Loading