diff --git a/.buildkite/scripts/setup_docker_env.sh b/.buildkite/scripts/setup_docker_env.sh index 738d8a37dc..f9560dd819 100644 --- a/.buildkite/scripts/setup_docker_env.sh +++ b/.buildkite/scripts/setup_docker_env.sh @@ -81,9 +81,6 @@ setup_environment() { TPU_INFERENCE_HASH="$BUILDKITE_COMMIT" fi - # TODO (ranlihao): unpin after the upstream is stable. - VLLM_COMMIT_HASH="0dd5dee9b9bc88453f5f3eacfde751e6b9ba4871" - docker build \ --build-arg VLLM_COMMIT_HASH="${VLLM_COMMIT_HASH}" \ --build-arg IS_FOR_V7X="${IS_FOR_V7X:-false}" \ diff --git a/docker/Dockerfile b/docker/Dockerfile index 21714ad6e3..4a149e745e 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -14,8 +14,7 @@ ARG BASE_IMAGE="python:3.12-slim-bookworm" # The latest main will be used if arg unspecified -# TODO (ranlihao): unpin after the upstream is stable. -ARG VLLM_COMMIT_HASH="0dd5dee9b9bc88453f5f3eacfde751e6b9ba4871" +ARG VLLM_COMMIT_HASH="" FROM $BASE_IMAGE diff --git a/tests/e2e/test_speculative_decoding.py b/tests/e2e/test_speculative_decoding.py index 8be5797765..af31adbf36 100644 --- a/tests/e2e/test_speculative_decoding.py +++ b/tests/e2e/test_speculative_decoding.py @@ -207,7 +207,7 @@ def _test_performance_helper( del ref_llm # Waiting for TPUs to be released - time.sleep(10) + time.sleep(30) # Test speculative LLM timing with max_num_seqs=1 spec_llm = LLM(model=model_name, @@ -223,7 +223,7 @@ def _test_performance_helper( del spec_llm # Waiting for TPUs to be released - time.sleep(10) + time.sleep(30) speedup = ref_time / spec_time print(f"Reference LLM time: {ref_time:.2f}s") diff --git a/tests/layers/vllm/test_attention.py b/tests/layers/vllm/test_attention.py index 4275b37642..9e4b14288e 100644 --- a/tests/layers/vllm/test_attention.py +++ b/tests/layers/vllm/test_attention.py @@ -116,7 +116,7 @@ def mesh(): class TestPallasAttentionBackend: def test_get_name(self): - assert PallasAttentionBackend.get_name() == "PALLAS" + assert PallasAttentionBackend.get_name() == "FLASH_ATTN" def test_get_impl_cls(self): assert PallasAttentionBackend.get_impl_cls( diff --git a/tests/worker/tpu_worker_test.py b/tests/worker/tpu_worker_test.py index 54d7710a7b..185c6955f0 100644 --- a/tests/worker/tpu_worker_test.py +++ b/tests/worker/tpu_worker_test.py @@ -174,9 +174,7 @@ def test_init_device_autodetects_devices( expected_is_last_rank) @patch('tpu_inference.worker.tpu_worker.utils') - @patch('tpu_inference.worker.tpu_worker.jax') - def test_determine_available_memory(self, mock_jax, mock_utils, - mock_vllm_config): + def test_determine_available_memory(self, mock_utils, mock_vllm_config): """Tests the available HBM memory calculation.""" # Setup mock return for hbm_usage_bytes: [(used_bytes, limit_bytes), ...] mock_utils.hbm_usage_bytes.return_value = [ @@ -188,11 +186,9 @@ def test_determine_available_memory(self, mock_jax, mock_utils, rank=0, distributed_init_method="test_method", devices=mock_devices) - mock_jax.local_devices.return_value = mock_devices available_mem = worker.determine_available_memory() - mock_jax.local_devices.assert_called_once() mock_utils.hbm_usage_bytes.assert_called_once_with(mock_devices) # Total limit: 1000 + 1000 = 2000 GiB # Total cap: 2000 * 0.9 = 1800 GiB diff --git a/tpu_inference/layers/vllm/attention.py b/tpu_inference/layers/vllm/attention.py index 7709757182..707bb81180 100644 --- a/tpu_inference/layers/vllm/attention.py +++ b/tpu_inference/layers/vllm/attention.py @@ -11,6 +11,10 @@ from torchax.ops.mappings import t2j from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionLayer, AttentionType) +from vllm.attention.backends.registry import (AttentionBackendEnum, + register_backend) +from vllm.config import VllmConfig +from vllm.utils.math_utils import cdiv, next_power_of_2 from tpu_inference import utils from tpu_inference.layers.common.attention_interface import attention @@ -22,17 +26,79 @@ logger = init_logger(__name__) +# TPU requires the head size to be a multiple of 128. +TPU_HEAD_SIZE_ALIGNMENT = 128 + +@register_backend(AttentionBackendEnum.FLASH_ATTN) class PallasAttentionBackend(AttentionBackend): @staticmethod def get_name() -> str: - return "PALLAS" + return "FLASH_ATTN" @staticmethod def get_impl_cls() -> type["PallasAttentionBackendImpl"]: return PallasAttentionBackendImpl + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + cache_dtype_str: str = "auto", + ) -> tuple[int, ...]: + padded_head_size = (cdiv(head_size, TPU_HEAD_SIZE_ALIGNMENT) * + TPU_HEAD_SIZE_ALIGNMENT) + return (num_blocks, block_size, num_kv_heads * 2, padded_head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + raise RuntimeError("swap_blocks is not used for the TPU backend.") + + # In recent TPU generations, up to v6e, the SMEM size is 1MB. The + # block_tables within the PallasMetadata constitute almost the entire SMEM + # requirement. Its size is max_num_seqs * num_page_per_seq * 4 (Int). Here + # we simply make sure that the size is smaller than half of SMEM capacity. + @staticmethod + def get_min_page_size(vllm_config: VllmConfig) -> int: + max_num_page_per_req = (1024 * 1024 // 2 // + vllm_config.scheduler_config.max_num_seqs // 4) + min_page_size = cdiv(vllm_config.model_config.max_model_len, + max_num_page_per_req) + min_page_size = 1 << (min_page_size - 1).bit_length() + return min_page_size + + @staticmethod + def get_max_num_seqs(model_len: int, page_size: int) -> int: + num_page_per_req = cdiv(model_len, page_size) + return 1024 * 1024 // 2 // num_page_per_req // 4 + + # TPU has limited SREGs (scalar registers), if page_size is too small, we + # can spill SREGs easily which leads to bad performance. The strategy we + # apply here is trying to split max-model-len to 16 pages which make the + # spill less likely. Meanwhile we make sure the page size is in [16, 256]. + @staticmethod + def get_page_size(vllm_config: VllmConfig) -> int: + # TODO: This is a temporary fix for vmem OOM. + # For long model length, we use 16 page-size to avoid too much + # VMEM spill. A more robust solution should be implemented to + # handle VREG spills. + if vllm_config.model_config.max_model_len > 8192: + return 16 + page_size = next_power_of_2( + vllm_config.model_config.max_model_len) // 16 + if page_size <= 16: + return 16 + if page_size >= 256: + return 256 + return page_size + class PallasAttentionBackendImpl(AttentionImpl): diff --git a/tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py b/tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py index 1caf311a5c..44682ed7a1 100644 --- a/tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py @@ -19,7 +19,8 @@ from torch.nn.parameter import Parameter from torchax.interop import jax_view, torch_view from torchax.ops.mappings import t2j -from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEConfig +from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig, + FusedMoERouter) from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa: E501 CompressedTensorsMoEMethod, CompressedTensorsW8A8Fp8MoEMethod) @@ -185,7 +186,8 @@ def process_fp8_moe_weights( def apply( self, - layer: torch.nn.Module, + layer: FusedMoE, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor: diff --git a/tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py b/tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py index 7e5aa796c5..3ccc2a5a2b 100644 --- a/tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +++ b/tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Callable from typing import Optional import jax @@ -26,6 +27,10 @@ CompressedTensorsW8A8Int8 from vllm.model_executor.layers.quantization.utils.w8a8_utils import \ convert_to_channelwise +from vllm.model_executor.parameter import (BasevLLMParameter, + ChannelQuantScaleParameter, + ModelWeightParameter, + PerTensorScaleParameter) from tpu_inference.layers.common.utils import \ slice_sharded_tensor_for_concatenation @@ -50,6 +55,62 @@ def __init__(self, strategy: str, is_static_input_scheme: bool, self.linear_config = linear_config self.is_channelwise = (self.strategy == QuantizationStrategy.CHANNEL) + def create_weights( + self, + layer: torch.nn.Module, + output_partition_sizes: list[int], + input_size_per_partition: int, + params_dtype: torch.dtype, + weight_loader: Callable, + **kwargs, + ): + layer.logical_widths = output_partition_sizes + + # WEIGHT + weight = ModelWeightParameter( + data=torch.empty(sum(output_partition_sizes), + input_size_per_partition, + dtype=torch.int8), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + + layer.register_parameter("weight", weight) + + # WEIGHT SCALE + if self.strategy == QuantizationStrategy.CHANNEL: + weight_scale = ChannelQuantScaleParameter( + data=torch.empty((sum(output_partition_sizes), 1), + dtype=torch.float32), + output_dim=0, + weight_loader=weight_loader, + ) + else: + assert self.strategy == QuantizationStrategy.TENSOR + weight_scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), + dtype=torch.float32), + weight_loader=weight_loader, + ) + layer.register_parameter("weight_scale", weight_scale) + + # INPUT SCALE + if self.is_static_input_scheme: + input_scale = BasevLLMParameter(data=torch.empty( + 1, dtype=torch.float32), + weight_loader=weight_loader) + layer.register_parameter("input_scale", input_scale) + + if not self.input_symmetric: + # Note: compressed-tensors stores the zp using the same dtype + # as the weights + # AZP loaded as int8 but used as int32 + input_zero_point = BasevLLMParameter( + data=torch.empty(1, dtype=torch.int8), + weight_loader=weight_loader) + layer.register_parameter("input_zero_point", input_zero_point) + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: weight = t2j(layer.weight, use_dlpack=False) delattr(layer, "weight") diff --git a/tpu_inference/layers/vllm/quantization/fp8.py b/tpu_inference/layers/vllm/quantization/fp8.py index 9a34aac406..09b3b794a9 100644 --- a/tpu_inference/layers/vllm/quantization/fp8.py +++ b/tpu_inference/layers/vllm/quantization/fp8.py @@ -22,7 +22,7 @@ from torchax.interop import jax_view, torch_view from torchax.ops.mappings import t2j from vllm.attention.layer import Attention -from vllm.model_executor.layers.fused_moe.layer import FusedMoE +from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoERouter from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization import \ register_quantization_config @@ -329,7 +329,8 @@ def process_fp8_moe_weights( def apply( self, - layer: torch.nn.Module, + layer: FusedMoE, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor: diff --git a/tpu_inference/layers/vllm/quantization/mxfp4.py b/tpu_inference/layers/vllm/quantization/mxfp4.py index 0d9c2b7bbd..2ebcd783f5 100644 --- a/tpu_inference/layers/vllm/quantization/mxfp4.py +++ b/tpu_inference/layers/vllm/quantization/mxfp4.py @@ -25,7 +25,8 @@ from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEQuantConfig, mxfp4_w4a16_moe_quant_config) from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, - FusedMoEMethodBase) + FusedMoEMethodBase, + FusedMoERouter) from vllm.model_executor.layers.linear import LinearBase from vllm.model_executor.layers.quantization import \ register_quantization_config @@ -210,7 +211,8 @@ def process_mxfp4_moe_weights( def apply( self, - layer: torch.nn.Module, + layer: FusedMoE, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor: diff --git a/tpu_inference/layers/vllm/quantization/unquantized.py b/tpu_inference/layers/vllm/quantization/unquantized.py index decfd4a3dc..ee61ba5e9c 100644 --- a/tpu_inference/layers/vllm/quantization/unquantized.py +++ b/tpu_inference/layers/vllm/quantization/unquantized.py @@ -23,7 +23,7 @@ from torchax.ops.mappings import t2j from vllm.attention.layer import Attention from vllm.model_executor.layers.fused_moe.layer import ( - FusedMoE, FusedMoEConfig, UnquantizedFusedMoEMethod) + FusedMoE, FusedMoEConfig, FusedMoERouter, UnquantizedFusedMoEMethod) from vllm.model_executor.layers.linear import (LinearBase, UnquantizedLinearMethod) from vllm.model_executor.layers.quantization import \ @@ -283,7 +283,8 @@ def process_unquantized_moe_weights( def apply( self, - layer: torch.nn.Module, + layer: FusedMoE, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor: diff --git a/tpu_inference/platforms/tpu_platform.py b/tpu_inference/platforms/tpu_platform.py index b9bc52d2b6..50ea86a78e 100644 --- a/tpu_inference/platforms/tpu_platform.py +++ b/tpu_inference/platforms/tpu_platform.py @@ -56,11 +56,14 @@ def get_attn_backend_cls(cls, selected_backend: "AttentionBackendEnum", **kwargs) -> str: from vllm.attention.backends.registry import AttentionBackendEnum - if selected_backend != AttentionBackendEnum.PALLAS: - logger.info("Cannot use %s backend on TPU.", selected_backend) - - logger.info("Using Pallas V1 backend.") - return "tpu_inference.layers.vllm.attention.PallasAttentionBackend" + # Invoke @register_backend in the module. + import tpu_inference.layers.vllm.attention # noqa: F401 + if selected_backend != AttentionBackendEnum.FLASH_ATTN: + logger.info("Cannot use %s backend on TPU. Setting to FLASH_ATTN.", + selected_backend) + selected_backend = AttentionBackendEnum.FLASH_ATTN + logger.info("Using %s backend.", selected_backend.name) + return selected_backend.get_path() @classmethod def get_device_name(cls, device_id: int = 0) -> str: @@ -144,9 +147,8 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: if compilation_config.backend == "": compilation_config.backend = "openxla" - # TODO(cuiq): remove this dependency. if vllm_config.model_config: - from vllm.v1.attention.backends.pallas import \ + from tpu_inference.layers.vllm.attention import \ PallasAttentionBackend cache_config.block_size = PallasAttentionBackend.get_page_size( vllm_config) # type: ignore[assignment] diff --git a/tpu_inference/runner/input_batch.py b/tpu_inference/runner/input_batch.py index 79f28ebb0c..839a38f642 100644 --- a/tpu_inference/runner/input_batch.py +++ b/tpu_inference/runner/input_batch.py @@ -11,7 +11,6 @@ from vllm.sampling_params import SamplingType from vllm.utils.collection_utils import swap_dict_values from vllm.v1.core.sched.output import NewRequestData -from vllm.v1.spec_decode.utils import is_spec_decode_unsupported from tpu_inference.runner.block_table import MultiGroupBlockTable @@ -177,10 +176,6 @@ def add_request( sampling_params = request.sampling_params - if (self.is_spec_decode - and is_spec_decode_unsupported(sampling_params)): - self.spec_decode_unsupported_reqs.add(req_id) - if sampling_params.sampling_type == SamplingType.GREEDY: # Avoid later division by zero. self.temperature_cpu[req_index] = -1.0 diff --git a/tpu_inference/runner/speculative_decoding_manager.py b/tpu_inference/runner/speculative_decoding_manager.py index be9b77ad24..08ee78e9f9 100644 --- a/tpu_inference/runner/speculative_decoding_manager.py +++ b/tpu_inference/runner/speculative_decoding_manager.py @@ -72,10 +72,8 @@ def propose_draft_token_ids( assert isinstance(self.runner.drafter, NgramProposer) self._draft_token_ids = self.runner.drafter.propose( sampled_token_ids[:self.runner.input_batch.num_reqs], - self.runner.input_batch.req_ids, self.runner.input_batch.num_tokens_no_spec, - self.runner.input_batch.token_ids_cpu, - self.runner.input_batch.spec_decode_unsupported_reqs) + self.runner.input_batch.token_ids_cpu) elif self.runner.speculative_config.method == "eagle3": self._draft_token_ids = self.propose_eagle3_draft_token_ids( sampled_token_ids, diff --git a/tpu_inference/worker/tpu_worker.py b/tpu_inference/worker/tpu_worker.py index ac0aa5f80d..9bd823d734 100644 --- a/tpu_inference/worker/tpu_worker.py +++ b/tpu_inference/worker/tpu_worker.py @@ -266,7 +266,7 @@ def initialize_pp_transfer_connect(self): def determine_available_memory(self) -> int: gpu_memory_utilization = self.cache_config.gpu_memory_utilization - hbm_usage = utils.hbm_usage_bytes(jax.local_devices()) + hbm_usage = utils.hbm_usage_bytes(self.devices) total_hbm_limit = total_hbm_used = 0 for used, limit in hbm_usage: total_hbm_used += used