diff --git a/requirements.txt b/requirements.txt index 9491c6f308..0fc09d7c96 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,6 @@ ray>=2.48.0 pandas>=2.2.3 numba>=0.58.0 numpy>=1.26.0 -transformers>= 4.56.0, <5 +transformers >= 4.56.0, != 5.0.*, != 5.1.*, != 5.2.*, != 5.3.*, != 5.4.*, != 5.5.0, != 5.6.* kaldi-native-fbank >= 1.18.7 tblib==3.1.0 diff --git a/tests/models/language/generation/test_common.py b/tests/models/language/generation/test_common.py index 7c66003572..da03369475 100644 --- a/tests/models/language/generation/test_common.py +++ b/tests/models/language/generation/test_common.py @@ -26,7 +26,6 @@ def launch_lm_eval(eval_config): 'async_scheduling': async_scheduling, 'enforce_eager': enforce_eager, 'enable_prefix_caching': enable_apc, - 'add_bos_token': True, 'dtype': dtype, 'max_model_len': max_model_len, 'max_num_seqs': max_num_seqs, diff --git a/tests/unit_tests/kv_offload/offloading_connector/test_scheduler.py b/tests/unit_tests/kv_offload/offloading_connector/test_scheduler.py index 10595437e7..13e531e766 100644 --- a/tests/unit_tests/kv_offload/offloading_connector/test_scheduler.py +++ b/tests/unit_tests/kv_offload/offloading_connector/test_scheduler.py @@ -91,20 +91,19 @@ def test_offloading_connector(request_runner, async_scheduling: bool): runner.new_request(token_ids=[1] * offloaded_block_size) runner.manager.prepare_store.side_effect = (lambda block_hashes, req_context: generate_store_output([])) runner.run(decoded_tokens=[EOS_TOKEN_ID]) - runner.manager.lookup.assert_called() - assert len(list(runner.manager.lookup.call_args.args[0])) == 1 + runner.manager.lookup.assert_called_once() # single block lookup with a hit runner.scheduler.reset_prefix_cache() runner.new_request(token_ids=[0] * offloaded_block_size) runner.manager.prepare_store.side_effect = (lambda block_hashes, req_context: generate_store_output([])) - runner.manager.lookup.return_value = 1 + runner.connector_scheduler._maximal_prefix_lookup = lambda key, req_context: 1 runner.run(decoded_tokens=[EOS_TOKEN_ID], expected_loaded_gpu_block_indexes=(0, 1, 2)) # single block lookup with a hit in a middle block runner.new_request(token_ids=[0] * offloaded_block_size * 2 + [1] * offloaded_block_size) runner.manager.prepare_store.side_effect = (lambda block_hashes, req_context: generate_store_output([])) - runner.manager.lookup.return_value = 1 + runner.connector_scheduler._maximal_prefix_lookup = lambda key, req_context: 1 runner.run(decoded_tokens=[EOS_TOKEN_ID], expected_loaded_gpu_block_indexes=(3, 4, 5)) # test take_events @@ -182,7 +181,7 @@ def test_request_preemption(request_runner, async_scheduling: bool): # request should now return from preemption # re-load [0, ..., 8] from the CPU and store [9, 10, 11] - runner.manager.lookup.return_value = 3 + runner.connector_scheduler._maximal_prefix_lookup = lambda key, req_context: 3 runner.manager.prepare_store.side_effect = (lambda block_hashes, req_context: generate_store_output(block_hashes)) runner.run( decoded_tokens=[0] * gpu_block_size, @@ -219,7 +218,7 @@ def test_concurrent_lookups_of_the_same_prefix(request_runner, async_scheduling: # start a request to load the first block, but don't complete runner.scheduler.reset_prefix_cache() runner.new_request(token_ids=[0] * offloaded_block_size) - runner.manager.lookup.return_value = 1 + runner.connector_scheduler._maximal_prefix_lookup = lambda key, req_context: 1 runner.run( decoded_tokens=[], complete_transfers=False, @@ -231,7 +230,7 @@ def test_concurrent_lookups_of_the_same_prefix(request_runner, async_scheduling: # start a new request to load the same first block runner.new_request(token_ids=[0] * offloaded_block_size) - runner.manager.lookup.return_value = 1 + runner.connector_scheduler._maximal_prefix_lookup = lambda key, req_context: 1 runner.run( decoded_tokens=[], complete_transfers=False, @@ -275,7 +274,7 @@ def test_abort_loading_requests(request_runner, async_scheduling: bool): # start a request to load the first block, but don't complete runner.scheduler.reset_prefix_cache() runner.new_request(token_ids=[0] * offloaded_block_size) - runner.manager.lookup.return_value = 1 + runner.connector_scheduler._maximal_prefix_lookup = lambda key, req_context: 1 runner.run( decoded_tokens=[], complete_transfers=False, diff --git a/tests/unit_tests/kv_offload/offloading_connector/utils.py b/tests/unit_tests/kv_offload/offloading_connector/utils.py index 54f287e658..83e7e2c8f5 100644 --- a/tests/unit_tests/kv_offload/offloading_connector/utils.py +++ b/tests/unit_tests/kv_offload/offloading_connector/utils.py @@ -210,14 +210,14 @@ def __init__(self, self.scheduler_connector: OffloadingConnector = scheduler_connector # extract mocked OffloadingManager of scheduler connector - connector_scheduler = scheduler_connector.connector_scheduler - assert connector_scheduler is not None - manager = connector_scheduler.manager + self.connector_scheduler = scheduler_connector.connector_scheduler + assert self.connector_scheduler is not None + manager = self.connector_scheduler.manager assert isinstance(manager, MagicMock) self.manager: MagicMock = manager - assert len(connector_scheduler.config.kv_group_configs) == 1 - kv_group_config = connector_scheduler.config.kv_group_configs[0] + assert len(self.connector_scheduler.config.kv_group_configs) == 1 + kv_group_config = self.connector_scheduler.config.kv_group_configs[0] assert kv_group_config.gpu_block_size == gpu_block_size assert kv_group_config.offloaded_block_size == offloaded_block_size diff --git a/tests/unit_tests/ops/test_hpu_compressed_tensors.py b/tests/unit_tests/ops/test_hpu_compressed_tensors.py index 439049e90a..d6d71584f9 100644 --- a/tests/unit_tests/ops/test_hpu_compressed_tensors.py +++ b/tests/unit_tests/ops/test_hpu_compressed_tensors.py @@ -390,7 +390,7 @@ def test_compressed_tensors_wna16_moe_method(default_vllm_config: None, dist_ini mock_ctx = MagicMock(spec=["dp_metadata"]) mock_ctx.dp_metadata = None with override_forward_context(mock_ctx): - out = oot_op.runner.forward_dispatch(oot_op, hidden_states, router_logits, hidden_states) + out = oot_op.runner._forward_dispatch(oot_op, hidden_states, router_logits, hidden_states) # Check correctness torch.testing.assert_close(ref_output, out, atol=1e-4, rtol=1e-4) diff --git a/tests/unit_tests/ops/test_hpu_fused_moe.py b/tests/unit_tests/ops/test_hpu_fused_moe.py index b0c67d5f20..6aab174f02 100644 --- a/tests/unit_tests/ops/test_hpu_fused_moe.py +++ b/tests/unit_tests/ops/test_hpu_fused_moe.py @@ -41,7 +41,7 @@ def test_unquantized_fused_moe_method(default_vllm_config: None, dist_init): mock_ctx = MagicMock(spec=["dp_metadata"]) mock_ctx.dp_metadata = None with override_forward_context(mock_ctx): - out = oot_op.runner.forward_dispatch(oot_op, hidden_states, router_logits, hidden_states) + out = oot_op.runner._forward_dispatch(oot_op, hidden_states, router_logits, hidden_states) # Check correctness torch.testing.assert_close(ref_output, out, atol=1e-4, rtol=1e-4) diff --git a/tests/unit_tests/ops/utils.py b/tests/unit_tests/ops/utils.py index b05cef308d..3d252c64aa 100644 --- a/tests/unit_tests/ops/utils.py +++ b/tests/unit_tests/ops/utils.py @@ -53,7 +53,6 @@ def create_fused_moe(quant_config=None): hidden_size=512, intermediate_size=256, params_dtype=torch.bfloat16, - reduce_results=True, renormalize=True, use_grouped_topk=False, num_expert_group=None, diff --git a/vllm_gaudi/__init__.py b/vllm_gaudi/__init__.py index d382a7e515..a81d324d84 100755 --- a/vllm_gaudi/__init__.py +++ b/vllm_gaudi/__init__.py @@ -72,11 +72,15 @@ def register(): def register_utils(): """Register utility functions for the HPU platform.""" import vllm_gaudi.utils # noqa: F401 + + vllm_gaudi.utils.patch_nixl_utils_for_hpu() # Install the in-process EngineCore reconfigure hook only when # multi-model mode is requested, to avoid heavy imports for all users. import os + if os.environ.get("VLLM_HPU_MULTI_MODEL_CONFIG"): from vllm_gaudi.v1.engine.core_patch import install_engine_core_patch + install_engine_core_patch() @@ -86,7 +90,8 @@ def register_ops(): """Register custom ops for the HPU platform.""" import vllm_gaudi.v1.sample.hpu_rejection_sampler # noqa: F401 import vllm_gaudi.distributed.kv_transfer.kv_connector.v1.hpu_nixl_connector # noqa: F401 - if os.getenv('VLLM_HPU_HETERO_KV_LAYOUT', 'false').lower() == 'true': + + if os.getenv("VLLM_HPU_HETERO_KV_LAYOUT", "false").lower() == "true": import vllm_gaudi.distributed.kv_transfer.kv_connector.v1.hetero_hpu_nixl_connector # noqa: F401 import vllm_gaudi.v1.kv_offload.worker.cpu_hpu # noqa: F401 import vllm_gaudi.ops.hpu_attention # noqa: F401 @@ -107,16 +112,18 @@ def register_ops(): # Conditionally register HPURowParallelLinear only when chunking is enabled from vllm_gaudi.ops.hpu_row_parallel_linear import register as register_row_parallel + register_row_parallel() # Register HPU LoRA layers only when row parallel chunking is active - env_value = os.environ.get('VLLM_ROW_PARALLEL_CHUNKS', '1') + env_value = os.environ.get("VLLM_ROW_PARALLEL_CHUNKS", "1") try: row_parallel_chunks = int(env_value) except ValueError: row_parallel_chunks = 1 if row_parallel_chunks > 1: from vllm_gaudi.lora.layers.hpu_row_parallel_linear import register_hpu_lora_layers + register_hpu_lora_layers() @@ -125,4 +132,5 @@ def register_models(): import vllm_gaudi.models.interfaces # noqa: F401 import vllm_gaudi.models.bert # noqa: F401 from .models import register_model + register_model() diff --git a/vllm_gaudi/distributed/kv_transfer/kv_connector/v1/hetero_hpu_nixl_connector.py b/vllm_gaudi/distributed/kv_transfer/kv_connector/v1/hetero_hpu_nixl_connector.py index 9174142ed6..9adf8b468d 100644 --- a/vllm_gaudi/distributed/kv_transfer/kv_connector/v1/hetero_hpu_nixl_connector.py +++ b/vllm_gaudi/distributed/kv_transfer/kv_connector/v1/hetero_hpu_nixl_connector.py @@ -47,15 +47,14 @@ from vllm.distributed.kv_transfer.kv_connector.utils import ( EngineId, - TpKVTopology, + TransferTopology, get_current_attn_backend, yield_req_data, ) from vllm.v1.attention.backends.utils import get_kv_cache_layout from typing import Any -from nixl._api import nixl_agent as NixlWrapper -from nixl._api import nixl_agent_config +from vllm.distributed.nixl_utils import NixlWrapper, nixl_agent_config from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.request import Request @@ -122,7 +121,7 @@ def kv_postprocess_layout_on_save(cache, indices): blocks_to_update = cache.index_select(0, indices) target_shape = blocks_to_update.shape # NHD => HND - blocks_processed = (blocks_to_update.permute(0, 2, 1, 3).contiguous().view(target_shape)) + blocks_processed = blocks_to_update.permute(0, 2, 1, 3).contiguous().view(target_shape) cache.index_copy_(0, indices, blocks_processed) @@ -136,13 +135,11 @@ def if_postprocess_kvcache_on_save(vllm_config, current_block_size, current_kv_c agreed_block_size = int( vllm_config.kv_transfer_config.get_from_extra_config("agreed_block_size", current_block_size)) # Only allow save to smaller block size (larger required additional allocation) - block_size_on_save = (agreed_block_size if agreed_block_size <= current_block_size else current_block_size) - if (kv_cache_layout_on_save != current_kv_cache_layout or block_size_on_save != current_block_size): + block_size_on_save = agreed_block_size if agreed_block_size <= current_block_size else current_block_size + if kv_cache_layout_on_save != current_kv_cache_layout or block_size_on_save != current_block_size: postprocess_kv_caches_on_save = True logger.info( - "KV cache postprocess on save is enabled. " - "Local kv cache layout: %s -> %s, " - "block size: %d -> %d", + "KV cache postprocess on save is enabled. Local kv cache layout: %s -> %s, block size: %d -> %d", current_kv_cache_layout, kv_cache_layout_on_save, current_block_size, @@ -189,12 +186,12 @@ def NixlConnectorScheduler_init_(self, vllm_config: VllmConfig, engine_id: str): self.kv_cache_layout = get_kv_cache_layout() self.engine_id: EngineId = engine_id # type: ignore[misc] self.side_channel_host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST - self.side_channel_port = (envs.VLLM_NIXL_SIDE_CHANNEL_PORT + vllm_config.parallel_config.data_parallel_index) + self.side_channel_port = envs.VLLM_NIXL_SIDE_CHANNEL_PORT + vllm_config.parallel_config.data_parallel_index assert vllm_config.kv_transfer_config is not None if current_platform.device_type == "cpu": self.use_host_buffer = False else: - self.use_host_buffer = (vllm_config.kv_transfer_config.kv_buffer_device == "cpu") + self.use_host_buffer = vllm_config.kv_transfer_config.kv_buffer_device == "cpu" self.postprocess_kv_caches_on_save = False self.kv_cache_layout_on_save = self.kv_cache_layout @@ -232,8 +229,7 @@ def NixlConnectorScheduler_init_(self, vllm_config: VllmConfig, engine_id: str): def update_state_after_alloc(self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int): params = request.kv_transfer_params logger.debug( - "NIXLConnector update_state_after_alloc: " - "num_external_tokens=%s, kv_transfer_params=%s", + "NIXLConnector update_state_after_alloc: num_external_tokens=%s, kv_transfer_params=%s", num_external_tokens, params, ) @@ -258,7 +254,7 @@ def update_state_after_alloc(self, request: "Request", blocks: "KVCacheBlocks", # If remote_blocks and num_external_tokens = 0, we have # a full prefix cache hit on the D worker. We need to call # send_notif in _read_blocks to free the memory on the P. - local_block_ids = (blocks.get_unhashed_block_ids() if num_external_tokens > 0 else []) + local_block_ids = blocks.get_unhashed_block_ids() if num_external_tokens > 0 else [] # Get unhashed blocks to pull from remote. self._reqs_need_recv[request.request_id] = ( request, @@ -267,8 +263,7 @@ def update_state_after_alloc(self, request: "Request", blocks: "KVCacheBlocks", else: logger.warning( - "Got invalid KVTransferParams: %s. This " - "request will not utilize KVTransfer", + "Got invalid KVTransferParams: %s. This request will not utilize KVTransfer", params, ) else: @@ -355,8 +350,7 @@ def request_finished( params = request.kv_transfer_params logger.debug( - "NIXLConnector request_finished(%s), request_status=%s, " - "kv_transfer_params=%s", + "NIXLConnector request_finished(%s), request_status=%s, kv_transfer_params=%s", request.request_id, request.status, params, @@ -394,12 +388,11 @@ def request_finished( if delay_free_blocks: # Prefill request on remote. It will be read from D upon completion logger.debug( - "NIXLConnector request_finished(%s) waiting for %d seconds " - "for remote decode to fetch blocks", + "NIXLConnector request_finished(%s) waiting for %d seconds for remote decode to fetch blocks", request.request_id, envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT, ) - self._reqs_need_send[request.request_id] = (time.perf_counter() + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT) + self._reqs_need_send[request.request_id] = time.perf_counter() + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT block_size_ratio = self.block_size // self.block_size_on_save block_ids_on_save = block_ids @@ -477,8 +470,7 @@ def NixlConnectorWorker_init_(self, vllm_config: VllmConfig, engine_id: str): if self.device_type not in _NIXL_SUPPORTED_DEVICE: raise RuntimeError(f"{self.device_type} is not supported.") elif self.kv_buffer_device not in _NIXL_SUPPORTED_DEVICE[self.device_type]: - raise RuntimeError(f"{self.device_type} with {self.kv_buffer_device} kv_buffer " - "is not supported.") + raise RuntimeError(f"{self.device_type} with {self.kv_buffer_device} kv_buffer is not supported.") self.device_kv_caches: dict[str, torch.Tensor] = {} # type: ignore[misc] # cpu kv buffer for xfer @@ -498,8 +490,7 @@ def NixlConnectorWorker_init_(self, vllm_config: VllmConfig, engine_id: str): elif self.kv_buffer_device == "cpu": nixl_memory_type = "DRAM" if nixl_memory_type is None: - raise RuntimeError(f"{self.device_type} with {self.kv_buffer_device} kv_buffer " - "is not supported.") + raise RuntimeError(f"{self.device_type} with {self.kv_buffer_device} kv_buffer is not supported.") self.nixl_memory_type = nixl_memory_type # Note: host xfer buffer ops when use_host_buffer is True @@ -593,17 +584,18 @@ def NixlConnectorWorker_init_(self, vllm_config: VllmConfig, engine_id: str): self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int) self.xfer_stats = NixlKVConnectorStats() - self.kv_topo = TpKVTopology( + self.transfer_topo = TransferTopology( tp_rank=self.tp_rank, + tp_size=self.world_size, + block_size=self.block_size, engine_id=self.engine_id, - remote_tp_size=self._tp_size, # shared state - remote_block_size=self._block_size, # shared state is_mla=self.use_mla, + is_mamba=False, total_num_kv_heads=self.model_config.get_total_num_kv_heads(), - attn_backend=backend, + attn_backends=[backend], ) self.compat_hash = compute_nixl_compatibility_hash(self.vllm_config, self.backend_name, - self.kv_topo.cross_layers_blocks) + self.transfer_topo.cross_layers_blocks) self._physical_blocks_per_logical_kv_block = 1 @@ -612,17 +604,16 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): if self.use_host_buffer: self.initialize_host_xfer_buffer(kv_caches=kv_caches) - assert len(self.host_xfer_buffers) == len(kv_caches), (f"host_buffer: {len(self.host_xfer_buffers)}, " - f"kv_caches: {len(kv_caches)}") + assert len(self.host_xfer_buffers) == len(kv_caches), ( + f"host_buffer: {len(self.host_xfer_buffers)}, kv_caches: {len(kv_caches)}") xfer_buffers = self.host_xfer_buffers else: xfer_buffers = kv_caches - assert not self.host_xfer_buffers, ("host_xfer_buffer should not be initialized when " - f"kv_buffer_device is {self.kv_buffer_device}") + assert not self.host_xfer_buffers, ( + f"host_xfer_buffer should not be initialized when kv_buffer_device is {self.kv_buffer_device}") logger.info( - "Registering KV_Caches. use_mla: %s, kv_buffer_device: %s, " - "use_host_buffer: %s", + "Registering KV_Caches. use_mla: %s, kv_buffer_device: %s, use_host_buffer: %s", self.use_mla, self.kv_buffer_device, self.use_host_buffer, @@ -639,7 +630,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # (roughly 8KB vs 5KB). # Conversely for FlashInfer, K and V are registered in the same region # to better exploit the memory layout (ie num_blocks is the first dim). - split_k_and_v = self.kv_topo.split_k_and_v + split_k_and_v = self.transfer_topo.split_k_and_v tensor_size_bytes = None # TODO (NickLucche): Get kernel_block_size in a cleaner way @@ -667,7 +658,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self.block_size, kernel_block_size, ) - self._physical_blocks_per_logical_kv_block = (self.block_size // kernel_block_size) + self._physical_blocks_per_logical_kv_block = self.block_size // kernel_block_size self.block_size = kernel_block_size self._block_size[self.engine_id] = kernel_block_size @@ -678,18 +669,18 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): tensor_size_bytes = curr_tensor_size_bytes self.num_blocks = cache.shape[0] - assert cache.shape[0] == self.num_blocks, ("All kv cache tensors must have the same number of blocks") + assert cache.shape[0] == self.num_blocks, "All kv cache tensors must have the same number of blocks" block_size_ratio_on_save = self.block_size // self.block_size_on_save self.block_len_per_layer.append(curr_tensor_size_bytes // self.num_blocks) self.slot_size_per_layer.append(self.block_len_per_layer[-1] // self.block_size) block_len_per_layer_on_save.append(curr_tensor_size_bytes // self.num_blocks // block_size_ratio_on_save) - num_blocks_on_save = (curr_tensor_size_bytes // block_len_per_layer_on_save[-1]) + num_blocks_on_save = curr_tensor_size_bytes // block_len_per_layer_on_save[-1] if not self.use_mla: # Different kv cache shape is not supported by HeteroTP - assert tensor_size_bytes == curr_tensor_size_bytes, ("All kv cache tensors must have the same size") + assert tensor_size_bytes == curr_tensor_size_bytes, "All kv cache tensors must have the same size" # Need to make sure the device ID is non-negative for NIXL, # Torch uses -1 to indicate CPU tensors. self.device_id = max(cache.get_device(), 0) @@ -711,7 +702,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self.device_kv_caches = kv_caches self.dst_num_blocks[self.engine_id] = self.num_blocks - if self.kv_topo.is_kv_layout_blocks_first: + if self.transfer_topo.is_kv_layout_blocks_first: for i in range(len(self.slot_size_per_layer)): assert self.slot_size_per_layer[i] % 2 == 0 self.slot_size_per_layer[i] //= 2 @@ -790,16 +781,16 @@ def register_local_xfer_handler( for i, base_addr in enumerate(self.seen_base_addresses): # The new block_len is using prefill block_len; # and num_blocks is multiple with N - kv_block_len = (self.get_backend_aware_kv_block_len(layer_idx=i) // block_size_ratio) + kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i) // block_size_ratio block_len_per_layer = self.block_len_per_layer[i] // block_size_ratio - num_blocks = (self.num_blocks * self.block_len_per_layer[i] // block_len_per_layer) + num_blocks = self.num_blocks * self.block_len_per_layer[i] // block_len_per_layer for block_id in range(num_blocks): block_offset = block_id * block_len_per_layer addr = base_addr + block_offset # (addr, len, device id) blocks_data.append((addr, kv_block_len, self.device_id)) - if self.kv_topo.is_kv_layout_blocks_first: + if self.transfer_topo.is_kv_layout_blocks_first: # Separate and interleave K/V regions to maintain the same # descs ordering. This is needed for selecting contiguous heads # when split across TP ranks. @@ -847,14 +838,14 @@ def post_process_device_kv_on_save(self, block_ids: list[int]): if len(block_ids) == 0: return target_block_size = self.block_size_on_save - split_k_and_v = self.kv_topo.split_k_and_v + split_k_and_v = self.transfer_topo.split_k_and_v sample_cache = list(self.device_kv_caches.values())[0][0] indices = torch.tensor(block_ids, device=sample_cache.device) for _, cache_or_caches in self.device_kv_caches.items(): cache_list = cache_or_caches if split_k_and_v else [cache_or_caches] for cache in cache_list: - if (self.kv_cache_layout_on_save != self.kv_cache_layout and self.block_size_on_save != self.block_size): + if self.kv_cache_layout_on_save != self.kv_cache_layout and self.block_size_on_save != self.block_size: kv_postprocess_layout_and_blksize_on_save(cache, indices, target_block_size) elif self.kv_cache_layout_on_save != self.kv_cache_layout: kv_postprocess_layout_on_save(cache, indices) @@ -877,7 +868,7 @@ def _read_blocks( Post a READ point-to-point xfer request from a single local worker to a single remote worker. """ - block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(dst_engine_id) + block_size_ratio = self.transfer_topo.block_size_ratio(self._block_size[dst_engine_id]) if block_size_ratio > 1: # NOTE: # get_mapped_blocks will always expand block_ids for n times. @@ -912,8 +903,7 @@ def _read_blocks( except Exception as e: self._log_failure( failure_type="notification_failed", - msg="P worker blocks will be freed after timeout. " - "This may indicate network issues.", + msg="P worker blocks will be freed after timeout. This may indicate network issues.", req_id=request_id, error=e, dst_engine_id=dst_engine_id, diff --git a/vllm_gaudi/distributed/kv_transfer/kv_connector/v1/hpu_nixl_connector.py b/vllm_gaudi/distributed/kv_transfer/kv_connector/v1/hpu_nixl_connector.py index e616e4e758..2461ee8a78 100644 --- a/vllm_gaudi/distributed/kv_transfer/kv_connector/v1/hpu_nixl_connector.py +++ b/vllm_gaudi/distributed/kv_transfer/kv_connector/v1/hpu_nixl_connector.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import torch from vllm.distributed.kv_transfer.kv_connector.v1.nixl import NixlConnectorWorker -from vllm.distributed.kv_transfer.kv_connector.utils import TpKVTopology +from vllm.distributed.kv_transfer.kv_connector.utils import TransferTopology from vllm_gaudi.platform import logger import habana_frameworks.torch.utils.experimental as htexp @@ -68,7 +68,7 @@ def initialize_host_xfer_buffer(self, kv_caches: dict[str, torch.Tensor]) -> Non NixlConnectorWorker.initialize_host_xfer_buffer = initialize_host_xfer_buffer # ── HPU cross-layer-block false-positive fix ───────────────────────────────── # -# TpKVTopology.__post_init__ infers _cross_layers_blocks from tensor shape: +# TransferTopology.__post_init__ infers _cross_layers_blocks from tensor shape: # _cross_layers_blocks = (len(tensor_shape) == len(kv_cache_shape) + 1) # On HPU, get_kv_cache_shape() returns a 3-D shape instead of the 5-D shape # expected by CUDA FlashAttn. For DeepSeek MLA, the host buffer is 4-D, so @@ -76,16 +76,16 @@ def initialize_host_xfer_buffer(self, kv_caches: dict[str, torch.Tensor]) -> Non # by the number of attention layers (~27× for DeepSeek-V2-Lite-Chat), # producing out-of-bounds NIXL transfers and KV cache corruption. # MLA models never use cross-layer layout, so guard the heuristic with is_mla. -_original_tpkvtopo_post_init = TpKVTopology.__post_init__ +_original_transfer_topo_post_init = TransferTopology.__post_init__ -def _hpu_tpkvtopo_post_init(self): - _original_tpkvtopo_post_init(self) +def _hpu_transfer_topo_post_init(self): + _original_transfer_topo_post_init(self) if self.is_mla and self._cross_layers_blocks: - logger.warning("[HPU] TpKVTopology: overriding false-positive _cross_layers_blocks=True " + logger.warning("[HPU] TransferTopology: overriding false-positive _cross_layers_blocks=True " "for MLA model. HPU get_kv_cache_shape() returns 3-D tensors, causing " "the dim-count heuristic to misfire. Forcing _cross_layers_blocks=False.") self._cross_layers_blocks = False -TpKVTopology.__post_init__ = _hpu_tpkvtopo_post_init +TransferTopology.__post_init__ = _hpu_transfer_topo_post_init diff --git a/vllm_gaudi/models/ernie45_vl.py b/vllm_gaudi/models/ernie45_vl.py index 7d3757a1b7..90c3ccf81d 100644 --- a/vllm_gaudi/models/ernie45_vl.py +++ b/vllm_gaudi/models/ernie45_vl.py @@ -1,4 +1,5 @@ import torch +from vllm.distributed import tensor_model_parallel_all_reduce from vllm.model_executor.models.ernie45_vl import ( Ernie4_5VLMultiModalProcessor, Ernie4_5_VLProcessingInfo, @@ -74,7 +75,7 @@ def ernie4_5_vlmoemoe_forward_hpu( final_hidden_states = final_hidden_states[1] if self.tp_size > 1: - final_hidden_states = (self.text_experts.maybe_all_reduce_tensor_model_parallel(final_hidden_states)) + final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) return final_hidden_states.view(orig_shape) diff --git a/vllm_gaudi/models/qwen3_moe.py b/vllm_gaudi/models/qwen3_moe.py index fc40f6ab90..a6b3fd4d37 100644 --- a/vllm_gaudi/models/qwen3_moe.py +++ b/vllm_gaudi/models/qwen3_moe.py @@ -50,11 +50,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if is_seq_parallel: out = tensor_model_parallel_all_gather(out, 0) out = out[:num_tokens] - else: - # from upstream : TP>1 may require a reduction here. - tp_size = getattr(self, "tp_size", 1) - if tp_size > 1 and hasattr(self.experts, "maybe_all_reduce_tensor_model_parallel"): - out = self.experts.maybe_all_reduce_tensor_model_parallel(out) return out.reshape(*orig_shape[:-1], hidden_dim) diff --git a/vllm_gaudi/models/qwen3_next.py b/vllm_gaudi/models/qwen3_next.py index 279cf3e01f..5abbbba8fd 100644 --- a/vllm_gaudi/models/qwen3_next.py +++ b/vllm_gaudi/models/qwen3_next.py @@ -88,14 +88,9 @@ def _hpu_qwen3next_sparse_moe_forward( router_logits, _ = self.gate(hidden_states) final_hidden_states = self.experts(hidden_states=hidden_states, router_logits=router_logits) - if self.shared_expert is not None: - final_hidden_states = (final_hidden_states[0] + final_hidden_states[1]) - if self.is_sequence_parallel: final_hidden_states = tensor_model_parallel_all_gather(final_hidden_states, 0) final_hidden_states = final_hidden_states[:num_tokens] - elif self.tp_size > 1: - final_hidden_states = (self.experts.maybe_all_reduce_tensor_model_parallel(final_hidden_states)) return final_hidden_states.reshape(orig_shape) diff --git a/vllm_gaudi/ops/hpu_fused_moe.py b/vllm_gaudi/ops/hpu_fused_moe.py index 28937c804f..fd980fe88f 100755 --- a/vllm_gaudi/ops/hpu_fused_moe.py +++ b/vllm_gaudi/ops/hpu_fused_moe.py @@ -4,8 +4,8 @@ import os from typing import Union -from vllm.model_executor.layers.fused_moe.runner.default_moe_runner import ( - DefaultMoERunner, ) +from vllm.model_executor.layers.fused_moe.runner.moe_runner_base import ( + MoERunnerBase, ) import torch import vllm import vllm.envs as envs @@ -261,13 +261,6 @@ def forward_oot( return output.view(*input_shape) -def reduce_output(self, states: torch.Tensor) -> torch.Tensor: - if (not self.moe_config.is_sequence_parallel and not self.use_dp_chunking and self.reduce_results - and (self.moe_config.tp_size > 1 or self.moe_config.ep_size > 1)): - states = self.maybe_all_reduce_tensor_model_parallel(states) - return states - - def patched_fused_moe_forward( self, hidden_states: torch.Tensor, @@ -279,6 +272,11 @@ def patched_fused_moe_forward( ensure_moe_quant_config_init, and _sequence_parallel_context — all of which access ForwardContext and cause torch.compile graph breaks), we cache the layer reference and call _forward_impl directly. + + The post-forward reduction sequence mirrors upstream + MoERunnerBase.forward (vllm/model_executor/layers/fused_moe/runner/ + moe_runner_base.py) so we stay in sync with the new shared/fused + output combination logic introduced by upstream PR #35949. """ hidden_states, shared_experts_input = self.apply_routed_input_transform(hidden_states) hidden_states, og_hidden_dims = self._maybe_pad_hidden_states(shared_experts_input, hidden_states) @@ -296,11 +294,24 @@ def patched_fused_moe_forward( if self.gate is not None: router_logits, _ = self.gate(hidden_states) - fused_output = self._forward_impl(self._hpu_cached_layer, hidden_states, router_logits, shared_experts_input) + result = self._forward_impl(self._hpu_cached_layer, hidden_states, router_logits, shared_experts_input) else: - fused_output = self.forward_entry(hidden_states, router_logits, shared_experts_input, self._encode_layer_name()) + result = self._forward_entry(hidden_states, router_logits, shared_experts_input, self._encode_layer_name()) + + # Mirror upstream MoERunnerBase.forward post-_forward_entry pipeline. + if isinstance(result, tuple): + shared_output, fused_output = result + else: + shared_output, fused_output = None, result + + shared_output = self._maybe_reduce_shared_expert_output(shared_output) + shared_output, fused_output = self._maybe_apply_routed_scale_to_output(shared_output, fused_output) + fused_output = self.apply_routed_output_transform(fused_output) + + combined = (shared_output + fused_output) if shared_output is not None else fused_output - return self._maybe_reduce_output(fused_output, og_hidden_dims) + combined = self._maybe_reduce_final_output(combined, og_hidden_dims) + return self._maybe_add_zero_expert_output(combined) def get_compressed_expert_map(expert_map: torch.Tensor) -> str: @@ -489,8 +500,8 @@ def create_fused_moe_router( # Apply patches # Keep runner forward patch compatible with upstream layer_name-based dispatch. -_orig_default_moe_runner_init = DefaultMoERunner.__init__ -_orig_default_moe_runner_forward = DefaultMoERunner.forward +_orig_default_moe_runner_init = MoERunnerBase.__init__ +_orig_default_moe_runner_forward = MoERunnerBase.forward # When enabled, bypasses the opaque torch.ops.vllm.moe_forward_shared custom # op wrapper so that torch.ops.hpu.mixture_of_experts is captured directly in @@ -510,9 +521,9 @@ def _patched_default_moe_runner_forward(self, *args, **kwargs): return _orig_default_moe_runner_forward(self, *args, **kwargs) -DefaultMoERunner.__init__ = _patched_default_moe_runner_init +MoERunnerBase.__init__ = _patched_default_moe_runner_init -DefaultMoERunner.forward = _patched_default_moe_runner_forward +MoERunnerBase.forward = _patched_default_moe_runner_forward vllm.model_executor.layers.fused_moe.layer.get_compressed_expert_map = \ get_compressed_expert_map diff --git a/vllm_gaudi/utils.py b/vllm_gaudi/utils.py index 035ce4eecb..ecceb0e4bf 100644 --- a/vllm_gaudi/utils.py +++ b/vllm_gaudi/utils.py @@ -5,7 +5,7 @@ from vllm_gaudi.extension.runtime import get_config import vllm.v1.core.sched.async_scheduler as _async_sched_module from vllm_gaudi.v1.core.sched.hpu_async_scheduler import HPUAsyncScheduler -from typing import (Any, Optional, TypeVar, Union) +from typing import Any, Optional, TypeVar, Union import torch import habana_frameworks.torch as htorch import numpy as np @@ -18,18 +18,18 @@ @cache def is_fake_hpu() -> bool: - return os.environ.get('VLLM_USE_FAKE_HPU', '0') != '0' + return os.environ.get("VLLM_USE_FAKE_HPU", "0") != "0" @cache def hpu_device_string(): - device_string = 'hpu' if not is_fake_hpu() else 'cpu' + device_string = "hpu" if not is_fake_hpu() else "cpu" return device_string @cache def hpu_backend_string(): - backend_string = 'hccl' if not is_fake_hpu() else 'gloo' + backend_string = "hccl" if not is_fake_hpu() else "gloo" return backend_string @@ -37,7 +37,7 @@ def has_quant_config(model_config: ModelConfig) -> bool: return model_config.quantization == "inc" or os.getenv("QUANT_CONFIG", None) is not None -def async_h2d_copy(source, dest_tensor=None, dtype=None, device='hpu'): +def async_h2d_copy(source, dest_tensor=None, dtype=None, device="hpu"): """ Asynchronously transfer data from host to device. @@ -55,18 +55,17 @@ def async_h2d_copy(source, dest_tensor=None, dtype=None, device='hpu'): # Copy into pre-allocated destination tensor return dest_tensor.copy_(source, non_blocking=True) # Create new device tensor and copy - assert source.device.type == 'cpu', \ - "Source tensor must be on CPU for asynchronous transfer" + assert source.device.type == "cpu", "Source tensor must be on CPU for asynchronous transfer" target = torch.empty_like(source, device=device) return target.copy_(source, non_blocking=True) # Create tensor from data and transfer to device if dtype is None: raise ValueError("dtype must be specified when source is not a tensor") - cpu_tensor = torch.tensor(source, dtype=dtype, device='cpu') + cpu_tensor = torch.tensor(source, dtype=dtype, device="cpu") return cpu_tensor.to(device, non_blocking=True) -def async_h2d_update(source: torch.Tensor, dest: torch.Tensor, indices: list[int], device='hpu'): +def async_h2d_update(source: torch.Tensor, dest: torch.Tensor, indices: list[int], device="hpu"): """ Asynchronously update specific rows of a device tensor from a CPU tensor. @@ -170,7 +169,7 @@ def make_tensor_with_pad_align( """ Make a padded tensor from 2D inputs. The padding is applied to the end of each inner list until it reaches - max_len_aligned, max_len_aligned is max_len rounding to the nearest + max_len_aligned, max_len_aligned is max_len rounding to the nearest `max_len_align`. """ np_dtype = torch_utils.TORCH_DTYPE_TO_NUMPY_DTYPE[dtype] @@ -243,7 +242,7 @@ def make_mrope_positions_tensor_with_pad(input_positions: list[list[int]], input max_len=max_prompt_len, pad=0, dtype=torch.long, - device='cpu').flatten() + device="cpu").flatten() # Otherwise, Qwen2.5-VL expects positions in a (3, seq_len) # we are going to pad each seq_data in the list # using either MRope values or regular position @@ -253,10 +252,9 @@ def make_mrope_positions_tensor_with_pad(input_positions: list[list[int]], input positions = input_mrope_position[idx] if input_mrope_position is not None else input_positions[b_idx] padding_size = max_prompt_len - len(positions) assert padding_size >= 0 - padded_positions = positions \ - + (max_prompt_len - len(positions)) * [pad] + padded_positions = positions + (max_prompt_len - len(positions)) * [pad] mrope_input_positions[idx].extend(padded_positions) - return torch.tensor(mrope_input_positions, dtype=torch.long, device='cpu') + return torch.tensor(mrope_input_positions, dtype=torch.long, device="cpu") class HPUCompileConfig: @@ -272,10 +270,8 @@ def __init__(self, fullgraph: Optional[bool] = None, dynamic: Optional[bool] = N Env variables should not be overwritten when it comes to compilation of the whole model. """ - self.fullgraph = fullgraph if fullgraph is not None else \ - get_config().fullgraph_compilation - self.dynamic = dynamic if dynamic is not None else \ - get_config().dynamic_shapes_compilation + self.fullgraph = fullgraph if fullgraph is not None else get_config().fullgraph_compilation + self.dynamic = dynamic if dynamic is not None else get_config().dynamic_shapes_compilation self.regional_compilation = get_config().regional_compilation def get_compile_args(self) -> dict[str, Any]: @@ -284,9 +280,38 @@ def get_compile_args(self) -> dict[str, Any]: with torch.compile method or decorator """ if self.dynamic: - return {'backend': 'hpu_backend', 'fullgraph': self.fullgraph, 'options': {"force_static_compile": True}} + return {"backend": "hpu_backend", "fullgraph": self.fullgraph, "options": {"force_static_compile": True}} else: - return {'backend': 'hpu_backend', 'fullgraph': self.fullgraph, 'dynamic': False} + return {"backend": "hpu_backend", "fullgraph": self.fullgraph, "dynamic": False} _async_sched_module.AsyncScheduler = HPUAsyncScheduler # type: ignore[misc] + + +def patch_nixl_utils_for_hpu(): + """Patch vllm.distributed.nixl_utils to use nixl._api instead of rixl._api. + + Upstream vLLM gates NIXL imports on is_cuda(), falling back to rixl._api + for all other platforms. HPU needs nixl._api (same as CUDA), so we + monkey-patch the module-level symbols before anything else imports them. + """ + import logging + + logger = logging.getLogger(__name__) + try: + from nixl._api import nixl_agent as _NixlWrapper + from nixl._api import nixl_agent_config as _nixl_agent_config + except ImportError: + return + try: + from nixl._bindings import nixlXferTelemetry as _nixlXferTelemetry + except ImportError: + _nixlXferTelemetry = None # type: ignore[assignment] + + import vllm.distributed.nixl_utils as _nixl_mod + + _nixl_mod.NixlWrapper = _NixlWrapper # type: ignore[attr-defined] + _nixl_mod.nixl_agent_config = _nixl_agent_config # type: ignore[attr-defined] + if _nixlXferTelemetry is not None: + _nixl_mod.nixlXferTelemetry = _nixlXferTelemetry # type: ignore[attr-defined] + logger.info("Patched vllm.distributed.nixl_utils for HPU (nixl._api)")