diff --git a/tests/models/vllm/test_pallas_torchax.py b/tests/models/vllm/test_pallas_torchax.py index a207d12bd9..18b9ec6f46 100644 --- a/tests/models/vllm/test_pallas_torchax.py +++ b/tests/models/vllm/test_pallas_torchax.py @@ -214,21 +214,6 @@ def test_init_with_fp8_kv_cache_raises_error(self): attn_type=AttentionType.DECODER, ) - def test_init_with_blocksparse_raises_error(self): - with pytest.raises(ValueError, - match="does not support block-sparse attention"): - PallasAttentionBackendImpl( - num_heads=32, - head_size=128, - scale=0.088, - num_kv_heads=8, - alibi_slopes=None, - sliding_window=None, - kv_cache_dtype="auto", - blocksparse_params={"block_size": 16}, - attn_type=AttentionType.DECODER, - ) - def test_init_with_encoder_attention_raises_error(self): with pytest.raises(NotImplementedError, match="Encoder self-attention"): @@ -326,21 +311,6 @@ def test_init_with_irope_warning(self): "Using irope in Pallas is not supported yet, it will fall back " "to global attention for long context.") - def test_init_with_blocksparse_constructor_error(self): - with pytest.raises(ValueError, - match="does not support block-sparse attention"): - PallasAttentionBackendImpl( - num_heads=32, - head_size=128, - scale=0.088, - num_kv_heads=8, - alibi_slopes=None, - sliding_window=None, - kv_cache_dtype="auto", - blocksparse_params={"block_size": 16}, - attn_type=AttentionType.DECODER, - ) - @patch( 'tpu_commons.attention.backends.pallas_torchax.ragged_paged_attention') @patch('tpu_commons.attention.backends.pallas_torchax.get_forward_context') diff --git a/tpu_commons/attention/backends/pallas_torchax.py b/tpu_commons/attention/backends/pallas_torchax.py index 54cc3f7f23..a88b99705d 100644 --- a/tpu_commons/attention/backends/pallas_torchax.py +++ b/tpu_commons/attention/backends/pallas_torchax.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import Any, Optional +from typing import Optional import torch from jax.tree_util import register_pytree_node_class @@ -132,7 +132,6 @@ def __init__( alibi_slopes: Optional[list[float]], sliding_window: Optional[int], kv_cache_dtype: str, - blocksparse_params: Optional[dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, attn_type: str = AttentionType.DECODER, kv_sharing_target_layer_name: Optional[int] = None, @@ -142,9 +141,6 @@ def __init__( logger.warning_once( "Using irope in Pallas is not supported yet, it will fall back " "to global attention for long context.") - if blocksparse_params is not None: - raise ValueError("Paged attention Pallas kernel does " - "not support block-sparse attention.") self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) diff --git a/tpu_commons/worker/tpu_worker_torchax.py b/tpu_commons/worker/tpu_worker_torchax.py index c665af5c6d..4204c565c3 100644 --- a/tpu_commons/worker/tpu_worker_torchax.py +++ b/tpu_commons/worker/tpu_worker_torchax.py @@ -292,5 +292,5 @@ def _init_tpu_worker_distributed_environment( backend="gloo", ) ensure_model_parallel_initialized( - parallel_config.tensor_parallel_size, - parallel_config.pipeline_parallel_size) + self.parallel_config.tensor_parallel_size, + self.parallel_config.pipeline_parallel_size)