Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions .buildkite/scripts/setup_docker_env.sh
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,6 @@ setup_environment() {
TPU_INFERENCE_HASH="$BUILDKITE_COMMIT"
fi

# TODO (ranlihao): unpin after the upstream is stable.
VLLM_COMMIT_HASH="0dd5dee9b9bc88453f5f3eacfde751e6b9ba4871"

docker build \
--build-arg VLLM_COMMIT_HASH="${VLLM_COMMIT_HASH}" \
--build-arg IS_FOR_V7X="${IS_FOR_V7X:-false}" \
Expand Down
3 changes: 1 addition & 2 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@

ARG BASE_IMAGE="python:3.12-slim-bookworm"
# The latest main will be used if arg unspecified
# TODO (ranlihao): unpin after the upstream is stable.
ARG VLLM_COMMIT_HASH="0dd5dee9b9bc88453f5f3eacfde751e6b9ba4871"
ARG VLLM_COMMIT_HASH=""

FROM $BASE_IMAGE

Expand Down
4 changes: 2 additions & 2 deletions tests/e2e/test_speculative_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def _test_performance_helper(
del ref_llm

# Waiting for TPUs to be released
time.sleep(10)
time.sleep(30)

# Test speculative LLM timing with max_num_seqs=1
spec_llm = LLM(model=model_name,
Expand All @@ -223,7 +223,7 @@ def _test_performance_helper(

del spec_llm
# Waiting for TPUs to be released
time.sleep(10)
time.sleep(30)

speedup = ref_time / spec_time
print(f"Reference LLM time: {ref_time:.2f}s")
Expand Down
2 changes: 1 addition & 1 deletion tests/layers/vllm/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def mesh():
class TestPallasAttentionBackend:

def test_get_name(self):
assert PallasAttentionBackend.get_name() == "PALLAS"
assert PallasAttentionBackend.get_name() == "FLASH_ATTN"

def test_get_impl_cls(self):
assert PallasAttentionBackend.get_impl_cls(
Expand Down
6 changes: 1 addition & 5 deletions tests/worker/tpu_worker_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,7 @@ def test_init_device_autodetects_devices(
expected_is_last_rank)

@patch('tpu_inference.worker.tpu_worker.utils')
@patch('tpu_inference.worker.tpu_worker.jax')
def test_determine_available_memory(self, mock_jax, mock_utils,
mock_vllm_config):
def test_determine_available_memory(self, mock_utils, mock_vllm_config):
"""Tests the available HBM memory calculation."""
# Setup mock return for hbm_usage_bytes: [(used_bytes, limit_bytes), ...]
mock_utils.hbm_usage_bytes.return_value = [
Expand All @@ -188,11 +186,9 @@ def test_determine_available_memory(self, mock_jax, mock_utils,
rank=0,
distributed_init_method="test_method",
devices=mock_devices)
mock_jax.local_devices.return_value = mock_devices

available_mem = worker.determine_available_memory()

mock_jax.local_devices.assert_called_once()
mock_utils.hbm_usage_bytes.assert_called_once_with(mock_devices)
# Total limit: 1000 + 1000 = 2000 GiB
# Total cap: 2000 * 0.9 = 1800 GiB
Expand Down
68 changes: 67 additions & 1 deletion tpu_inference/layers/vllm/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
from torchax.ops.mappings import t2j
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer, AttentionType)
from vllm.attention.backends.registry import (AttentionBackendEnum,
register_backend)
from vllm.config import VllmConfig
from vllm.utils.math_utils import cdiv, next_power_of_2

from tpu_inference import utils
from tpu_inference.layers.common.attention_interface import attention
Expand All @@ -22,17 +26,79 @@

logger = init_logger(__name__)

# TPU requires the head size to be a multiple of 128.
TPU_HEAD_SIZE_ALIGNMENT = 128


@register_backend(AttentionBackendEnum.FLASH_ATTN)
class PallasAttentionBackend(AttentionBackend):

@staticmethod
def get_name() -> str:
return "PALLAS"
return "FLASH_ATTN"

@staticmethod
def get_impl_cls() -> type["PallasAttentionBackendImpl"]:
return PallasAttentionBackendImpl

@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
padded_head_size = (cdiv(head_size, TPU_HEAD_SIZE_ALIGNMENT) *
TPU_HEAD_SIZE_ALIGNMENT)
return (num_blocks, block_size, num_kv_heads * 2, padded_head_size)

@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: torch.Tensor,
) -> None:
raise RuntimeError("swap_blocks is not used for the TPU backend.")

# In recent TPU generations, up to v6e, the SMEM size is 1MB. The
# block_tables within the PallasMetadata constitute almost the entire SMEM
# requirement. Its size is max_num_seqs * num_page_per_seq * 4 (Int). Here
# we simply make sure that the size is smaller than half of SMEM capacity.
@staticmethod
def get_min_page_size(vllm_config: VllmConfig) -> int:
max_num_page_per_req = (1024 * 1024 // 2 //
vllm_config.scheduler_config.max_num_seqs // 4)
min_page_size = cdiv(vllm_config.model_config.max_model_len,
max_num_page_per_req)
min_page_size = 1 << (min_page_size - 1).bit_length()
return min_page_size

@staticmethod
def get_max_num_seqs(model_len: int, page_size: int) -> int:
num_page_per_req = cdiv(model_len, page_size)
return 1024 * 1024 // 2 // num_page_per_req // 4

# TPU has limited SREGs (scalar registers), if page_size is too small, we
# can spill SREGs easily which leads to bad performance. The strategy we
# apply here is trying to split max-model-len to 16 pages which make the
# spill less likely. Meanwhile we make sure the page size is in [16, 256].
@staticmethod
def get_page_size(vllm_config: VllmConfig) -> int:
# TODO: This is a temporary fix for vmem OOM.
# For long model length, we use 16 page-size to avoid too much
# VMEM spill. A more robust solution should be implemented to
# handle VREG spills.
if vllm_config.model_config.max_model_len > 8192:
return 16
page_size = next_power_of_2(
vllm_config.model_config.max_model_len) // 16
if page_size <= 16:
return 16
if page_size >= 256:
return 256
return page_size


class PallasAttentionBackendImpl(AttentionImpl):

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
from torch.nn.parameter import Parameter
from torchax.interop import jax_view, torch_view
from torchax.ops.mappings import t2j
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEConfig
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
FusedMoERouter)
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa: E501
CompressedTensorsMoEMethod, CompressedTensorsW8A8Fp8MoEMethod)

Expand Down Expand Up @@ -185,7 +186,8 @@ def process_fp8_moe_weights(

def apply(
self,
layer: torch.nn.Module,
layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from collections.abc import Callable
from typing import Optional

import jax
Expand All @@ -26,6 +27,10 @@
CompressedTensorsW8A8Int8
from vllm.model_executor.layers.quantization.utils.w8a8_utils import \
convert_to_channelwise
from vllm.model_executor.parameter import (BasevLLMParameter,
ChannelQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter)

from tpu_inference.layers.common.utils import \
slice_sharded_tensor_for_concatenation
Expand All @@ -50,6 +55,62 @@ def __init__(self, strategy: str, is_static_input_scheme: bool,
self.linear_config = linear_config
self.is_channelwise = (self.strategy == QuantizationStrategy.CHANNEL)

def create_weights(
Comment thread
weiyu0824 marked this conversation as resolved.
self,
layer: torch.nn.Module,
output_partition_sizes: list[int],
input_size_per_partition: int,
params_dtype: torch.dtype,
weight_loader: Callable,
**kwargs,
):
layer.logical_widths = output_partition_sizes

# WEIGHT
weight = ModelWeightParameter(
data=torch.empty(sum(output_partition_sizes),
input_size_per_partition,
dtype=torch.int8),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)

layer.register_parameter("weight", weight)

# WEIGHT SCALE
if self.strategy == QuantizationStrategy.CHANNEL:
weight_scale = ChannelQuantScaleParameter(
data=torch.empty((sum(output_partition_sizes), 1),
dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader,
)
else:
assert self.strategy == QuantizationStrategy.TENSOR
weight_scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes),
dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("weight_scale", weight_scale)

# INPUT SCALE
if self.is_static_input_scheme:
input_scale = BasevLLMParameter(data=torch.empty(
1, dtype=torch.float32),
weight_loader=weight_loader)
layer.register_parameter("input_scale", input_scale)

if not self.input_symmetric:
# Note: compressed-tensors stores the zp using the same dtype
# as the weights
# AZP loaded as int8 but used as int32
input_zero_point = BasevLLMParameter(
data=torch.empty(1, dtype=torch.int8),
weight_loader=weight_loader)
layer.register_parameter("input_zero_point", input_zero_point)

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
weight = t2j(layer.weight, use_dlpack=False)
delattr(layer, "weight")
Expand Down
5 changes: 3 additions & 2 deletions tpu_inference/layers/vllm/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from torchax.interop import jax_view, torch_view
from torchax.ops.mappings import t2j
from vllm.attention.layer import Attention
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoERouter
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import \
register_quantization_config
Expand Down Expand Up @@ -329,7 +329,8 @@ def process_fp8_moe_weights(

def apply(
self,
layer: torch.nn.Module,
layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor:
Expand Down
6 changes: 4 additions & 2 deletions tpu_inference/layers/vllm/quantization/mxfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEQuantConfig, mxfp4_w4a16_moe_quant_config)
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
FusedMoEMethodBase)
FusedMoEMethodBase,
FusedMoERouter)
from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.quantization import \
register_quantization_config
Expand Down Expand Up @@ -210,7 +211,8 @@ def process_mxfp4_moe_weights(

def apply(
self,
layer: torch.nn.Module,
layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor:
Expand Down
5 changes: 3 additions & 2 deletions tpu_inference/layers/vllm/quantization/unquantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from torchax.ops.mappings import t2j
from vllm.attention.layer import Attention
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEConfig, UnquantizedFusedMoEMethod)
FusedMoE, FusedMoEConfig, FusedMoERouter, UnquantizedFusedMoEMethod)
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import \
Expand Down Expand Up @@ -283,7 +283,8 @@ def process_unquantized_moe_weights(

def apply(
self,
layer: torch.nn.Module,
layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor:
Expand Down
16 changes: 9 additions & 7 deletions tpu_inference/platforms/tpu_platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,14 @@ def get_attn_backend_cls(cls, selected_backend: "AttentionBackendEnum",
**kwargs) -> str:
from vllm.attention.backends.registry import AttentionBackendEnum

if selected_backend != AttentionBackendEnum.PALLAS:
Comment thread
weiyu0824 marked this conversation as resolved.
logger.info("Cannot use %s backend on TPU.", selected_backend)

logger.info("Using Pallas V1 backend.")
return "tpu_inference.layers.vllm.attention.PallasAttentionBackend"
# Invoke @register_backend in the module.
import tpu_inference.layers.vllm.attention # noqa: F401
if selected_backend != AttentionBackendEnum.FLASH_ATTN:
logger.info("Cannot use %s backend on TPU. Setting to FLASH_ATTN.",
selected_backend)
selected_backend = AttentionBackendEnum.FLASH_ATTN
logger.info("Using %s backend.", selected_backend.name)
return selected_backend.get_path()

@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
Expand Down Expand Up @@ -144,9 +147,8 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
if compilation_config.backend == "":
compilation_config.backend = "openxla"

# TODO(cuiq): remove this dependency.
if vllm_config.model_config:
from vllm.v1.attention.backends.pallas import \
from tpu_inference.layers.vllm.attention import \
PallasAttentionBackend
cache_config.block_size = PallasAttentionBackend.get_page_size(
vllm_config) # type: ignore[assignment]
Expand Down
5 changes: 0 additions & 5 deletions tpu_inference/runner/input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from vllm.sampling_params import SamplingType
from vllm.utils.collection_utils import swap_dict_values
from vllm.v1.core.sched.output import NewRequestData
from vllm.v1.spec_decode.utils import is_spec_decode_unsupported

from tpu_inference.runner.block_table import MultiGroupBlockTable

Expand Down Expand Up @@ -177,10 +176,6 @@ def add_request(

sampling_params = request.sampling_params

if (self.is_spec_decode
and is_spec_decode_unsupported(sampling_params)):
self.spec_decode_unsupported_reqs.add(req_id)

if sampling_params.sampling_type == SamplingType.GREEDY:
# Avoid later division by zero.
self.temperature_cpu[req_index] = -1.0
Expand Down
4 changes: 1 addition & 3 deletions tpu_inference/runner/speculative_decoding_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,8 @@ def propose_draft_token_ids(
assert isinstance(self.runner.drafter, NgramProposer)
self._draft_token_ids = self.runner.drafter.propose(
sampled_token_ids[:self.runner.input_batch.num_reqs],
self.runner.input_batch.req_ids,
self.runner.input_batch.num_tokens_no_spec,
self.runner.input_batch.token_ids_cpu,
self.runner.input_batch.spec_decode_unsupported_reqs)
self.runner.input_batch.token_ids_cpu)
elif self.runner.speculative_config.method == "eagle3":
self._draft_token_ids = self.propose_eagle3_draft_token_ids(
sampled_token_ids,
Expand Down
2 changes: 1 addition & 1 deletion tpu_inference/worker/tpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def initialize_pp_transfer_connect(self):

def determine_available_memory(self) -> int:
gpu_memory_utilization = self.cache_config.gpu_memory_utilization
hbm_usage = utils.hbm_usage_bytes(jax.local_devices())
hbm_usage = utils.hbm_usage_bytes(self.devices)
total_hbm_limit = total_hbm_used = 0
for used, limit in hbm_usage:
total_hbm_used += used
Expand Down