diff --git a/tests/distributed/omni_connectors/test_kv_flow.py b/tests/distributed/omni_connectors/test_kv_flow.py index 2bb06d4e00d..8c7ff79ca54 100644 --- a/tests/distributed/omni_connectors/test_kv_flow.py +++ b/tests/distributed/omni_connectors/test_kv_flow.py @@ -1,194 +1,251 @@ -import unittest -from types import SimpleNamespace -from unittest.mock import MagicMock, patch - +import pytest import torch -from vllm_omni.diffusion.models.bagel.pipeline_bagel import BagelPipeline +from tests.utils import hardware_test from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.distributed.omni_connectors.kv_transfer_manager import ( + OmniKVCacheConfig, + OmniKVTransferManager, +) from vllm_omni.inputs.data import OmniDiffusionSamplingParams -from vllm_omni.worker.gpu_ar_model_runner import GPUARModelRunner - - -class MockGPUARModelRunner(GPUARModelRunner): - """Subclass to bypass heavy initialization.""" - - def __init__(self, kv_caches, input_batch): - self.kv_caches = kv_caches - self.input_batch = input_batch - self.device = "cpu" - self.cache_config = MagicMock() - self.cache_config.block_size = 16 - self.cache_config.cache_dtype = "auto" - self.logger = MagicMock() - -class MockBagelPipeline(BagelPipeline): - """Subclass to bypass heavy initialization.""" +class MockConnector: def __init__(self): - self.device = "cpu" - self.od_config = MagicMock() - self.bagel = MagicMock() - self.tokenizer = MagicMock() - self.language_model = MagicMock() - self.new_token_ids = {} - - -class TestKVFlow(unittest.TestCase): - def setUp(self): - # Common constants - self.num_layers = 2 - self.num_heads = 4 - self.head_dim = 16 - self.block_size = 8 - self.x = 8 # PagedAttention factor - self.seq_len = 20 - self.req_id = "req_test_1" - - def test_sender_extraction_logic(self): - """Test extraction logic in GPUARModelRunner.""" - if GPUARModelRunner is object: - self.skipTest("vLLM not installed") - - # 1. Setup KV Cache (List of tuples (K, V)) - # Shape: [num_blocks, block_size, num_heads, head_dim] matching 4D expectation - num_blocks = 10 - kv_caches = [] - for _ in range(self.num_layers): - k_cache = torch.randn(num_blocks, self.block_size, self.num_heads, self.head_dim) - v_cache = torch.randn(num_blocks, self.block_size, self.num_heads, self.head_dim) - # Stack K and V to create [2, num_blocks, block_size, n_heads, head_dim] - layer_cache = torch.stack([k_cache, v_cache], dim=0) - kv_caches.append(layer_cache) - - # 2. Setup Input Batch Mock - block_ids = [1, 3, 5] - mock_input_batch = MagicMock() - mock_input_batch.req_id_to_index = {self.req_id: 0} - mock_input_batch.block_table.get_row.return_value = block_ids - - # 3. Instantiate Runner - runner = MockGPUARModelRunner(kv_caches, mock_input_batch) - - # 4. Run Extraction - # calling _extract_kv_cache directly: (req_id, block_ids, seq_len) - result = runner._extract_kv_cache(self.req_id, block_ids, self.seq_len) - - # 5. Verify Result - self.assertIsNotNone(result) - self.assertEqual(result.request_id, self.req_id) - - # Check keys "key_cache" and "value_cache" exist (length 2) - self.assertEqual(len(result.layer_blocks["key_cache"]), 2) - self.assertEqual(len(result.layer_blocks["value_cache"]), 2) - - # Check Tensor Shape: [seq_len, num_heads, head_dim] - # result.layer_blocks["key_cache"] is a list of tensors - expected_shape = (self.seq_len, self.num_heads, self.head_dim) - self.assertEqual(result.layer_blocks["key_cache"][0].shape, expected_shape) - - return result # Return for use in next test - - def test_receiver_injection_logic(self): - """Test injection logic in BagelPipeline.""" - if BagelPipeline is object: - self.skipTest("vLLM not installed") - - # 1. Get Data from Sender Test (simulate transfer) - # Re-create data manually to be independent - key_cache = [] - value_cache = [] - expected_shape = (self.seq_len, self.num_heads, self.head_dim) - for i in range(self.num_layers): - key_cache.append(torch.randn(expected_shape)) - value_cache.append(torch.randn(expected_shape)) - - layer_blocks = {"key_cache": key_cache, "value_cache": value_cache} - - transfer_data = MagicMock() # Mock KVCacheTransferData - transfer_data.layer_blocks = layer_blocks - transfer_data.metadata = {"kv_lens": [self.seq_len], "ropes": [0]} - - # 2. Setup Request with Injected Data - sp = OmniDiffusionSamplingParams( - past_key_values=SimpleNamespace(**layer_blocks), - kv_metadata=transfer_data.metadata, - ) - - req = OmniDiffusionRequest(["test"], sp) - - # 3. Setup Pipeline - pipeline = MockBagelPipeline() - - # Mock Bagel's NaiveCache - class RealNaiveCache: # Minimal impl - def __init__(self, n): - self.key_cache = {i: None for i in range(n)} - self.value_cache = {i: None for i in range(n)} - - pipeline.bagel.config.llm_config.num_hidden_layers = self.num_layers - - captured_context = {} - - def mock_prepare_prompts(curr_kvlens, curr_rope, **kwargs): - # Capture the state passed to prepare_prompts - captured_context["kv_lens"] = curr_kvlens - captured_context["ropes"] = curr_rope - return {}, [0], [0] - - pipeline.bagel.prepare_prompts = MagicMock(side_effect=mock_prepare_prompts) - with patch("vllm_omni.diffusion.models.bagel.pipeline_bagel.NaiveCache") as MockNaiveCacheCls: - # Setup the instance returned by the constructor - mock_cache_instance = RealNaiveCache(self.num_layers) - MockNaiveCacheCls.return_value = mock_cache_instance - - # Verification of Injection Logic (Simulation) - current_cache = RealNaiveCache(self.num_layers) - - # --- Logic from Source Code --- - injected_kv = req.sampling_params.past_key_values - if isinstance(current_cache, RealNaiveCache) and hasattr(injected_kv, "key_cache"): - # Assuming injected_kv is SimpleNamespace or object with list attrs - for layer_idx in range(len(injected_kv.key_cache)): - if injected_kv.key_cache[layer_idx] is not None: - k_tensor = injected_kv.key_cache[layer_idx] - v_tensor = injected_kv.value_cache[layer_idx] - - if k_tensor.device != pipeline.device: - k_tensor = k_tensor.to(pipeline.device) - if v_tensor.device != pipeline.device: - v_tensor = v_tensor.to(pipeline.device) - - current_cache.key_cache[layer_idx] = k_tensor - current_cache.value_cache[layer_idx] = v_tensor - - self.assertTrue(torch.allclose(current_cache.key_cache[0], layer_blocks["key_cache"][0])) - self.assertTrue(torch.allclose(current_cache.value_cache[1], layer_blocks["value_cache"][1])) - - def test_integration(self): - """Simulate the flow from Sender -> Connector (Dict) -> Receiver.""" - if GPUARModelRunner is object or BagelPipeline is object: - self.skipTest("vLLM not installed") - - # 1. Sender (Extraction) - runner_test_result = self.test_sender_extraction_logic() # Get KVCacheTransferData - - # 2. Connector (Serialize/Deserialize Simulation) - # KVCacheTransferData has to_dict - data_dict = runner_test_result.to_dict() - - # 3. Receiver (Request Setup) - sp = OmniDiffusionSamplingParams( - past_key_values=data_dict["layer_blocks"], - kv_metadata=data_dict["metadata"], - ) - req = OmniDiffusionRequest(["integration_test"], sp) # noqa: F841 - - # 4. Receiver (Injection Simulation) - # Use the logic verification again - pass - - -if __name__ == "__main__": - unittest.main() + self.store = {} + + def put(self, from_stage, to_stage, put_key, data): + # The manager now passes full key as put_key + key = f"{from_stage}->{to_stage}:{put_key}" + self.store[key] = data + return True, len(str(data)), None # (success, size, metadata) + + def get(self, from_stage, to_stage, get_key, metadata=None): + # The manager now passes full key as get_key + key = f"{from_stage}->{to_stage}:{get_key}" + if key in self.store: + return self.store[key], len(str(self.store[key])) + return None + + +@pytest.fixture +def mock_connector(): + return MockConnector() + + +@pytest.fixture +def kv_config(): + return OmniKVCacheConfig( + connector_config={"type": "mock"}, + from_stage="stage1", + to_stage="stage2", + stage_id="stage2", # Acting as receiver for some tests + need_recv_cache=True, + need_send_cache=True, + recv_timeout=1.0, # Short timeout for tests + ) + + +@pytest.fixture +def common_constants(): + return { + "num_layers": 2, + "num_heads": 4, + "head_dim": 16, + "block_size": 8, + "seq_len": 20, + "req_id": "req_test_1", + } + + +@pytest.mark.cache +@hardware_test( + res={"cuda": "L4"}, + num_cards=2, +) +def test_manager_extraction(kv_config, mock_connector, common_constants): + """Test extraction and sending logic in OmniKVTransferManager.""" + num_layers = common_constants["num_layers"] + block_size = common_constants["block_size"] + num_heads = common_constants["num_heads"] + head_dim = common_constants["head_dim"] + seq_len = common_constants["seq_len"] + req_id = common_constants["req_id"] + + num_blocks = 10 + kv_caches = [] + for _ in range(num_layers): + k_cache = torch.randn(num_blocks, block_size, num_heads, head_dim) + v_cache = torch.randn(num_blocks, block_size, num_heads, head_dim) + # Stack K and V to create [2, num_blocks, block_size, n_heads, head_dim] + layer_cache = torch.stack([k_cache, v_cache], dim=0) + kv_caches.append(layer_cache) + + block_ids = [1, 3, 5] + finished_reqs = {req_id: {"block_ids": block_ids, "seq_len": seq_len}} + + manager = OmniKVTransferManager(kv_config) + # Mock the connector factory or injection + manager._connector = mock_connector + + processed = manager.handle_finished_requests_kv_transfer(finished_reqs, kv_caches, block_size, "float32") + + assert req_id in processed + + # Check if data was put into connector + # Manager builds full key: omni_{from}_to_{to}_kv_cache_{req_id} + full_request_id = f"omni_stage1_to_stage2_kv_cache_{req_id}" + expected_key = f"stage1->stage2:{full_request_id}" + assert expected_key in mock_connector.store + + data = mock_connector.store[expected_key] + assert data["request_id"] == req_id + assert "layer_blocks" in data + assert len(data["layer_blocks"]["key_cache"]) == num_layers + + # Verify shape of extracted tensor: [seq_len, heads, dim] + # Note: Manager detaches and moves to CPU + expected_shape = (seq_len, num_heads, head_dim) + assert data["layer_blocks"]["key_cache"][0].shape == expected_shape + + +@pytest.mark.cache +@hardware_test( + res={"cuda": "L4"}, + num_cards=2, +) +def test_manager_reception(kv_config, mock_connector, common_constants): + """Test reception and injection logic in OmniKVTransferManager.""" + num_layers = common_constants["num_layers"] + block_size = common_constants["block_size"] + num_heads = common_constants["num_heads"] + head_dim = common_constants["head_dim"] + seq_len = common_constants["seq_len"] + req_id = common_constants["req_id"] + + expected_shape = (seq_len, num_heads, head_dim) + key_cache = [torch.randn(expected_shape) for _ in range(num_layers)] + value_cache = [torch.randn(expected_shape) for _ in range(num_layers)] + + layer_blocks = {"key_cache": key_cache, "value_cache": value_cache} + metadata = { + "block_size": block_size, + "num_layers": num_layers, + "dtype": "float32", + "seq_len": seq_len, + } + + data_to_receive = { + "request_id": req_id, + "layer_blocks": layer_blocks, + "metadata": metadata, + "block_ids": [], + } + + # In setUp, from_stage="stage1", stage_id="stage2". recv_stages=("stage1", "stage2") + + manager = OmniKVTransferManager(kv_config) + manager._connector = mock_connector + + # Pre-populate connector with data + # Manager builds full key: omni_{from}_to_{to}_kv_cache_{req_id} + full_request_id = f"omni_stage1_to_stage2_kv_cache_{req_id}" + store_key = f"stage1->stage2:{full_request_id}" + mock_connector.store[store_key] = data_to_receive + + req = OmniDiffusionRequest( + prompts=["test_recv"], + sampling_params=OmniDiffusionSamplingParams(), + request_ids=[req_id], + ) + # req.need_kv_receive = True # Implicitly handled by receive_kv_cache check? No, manager doesn't check it, runner does. + # But receive_kv_cache in manager checks request_id. Which we need to fix in manager next. + success = manager.receive_kv_cache(req, target_device=torch.device("cpu")) + + assert success + assert hasattr(req, "past_key_values") + assert hasattr(req, "kv_metadata") + + assert len(req.past_key_values.key_cache) == num_layers + assert torch.allclose(req.past_key_values.key_cache[0], key_cache[0]) + assert req.kv_metadata["seq_len"] == seq_len + + +@pytest.mark.cache +@hardware_test( + res={"cuda": "L4"}, + num_cards=2, +) +def test_integration_flow(common_constants): + """Simulate extraction -> connector -> reception.""" + num_layers = common_constants["num_layers"] + block_size = common_constants["block_size"] + num_heads = common_constants["num_heads"] + head_dim = common_constants["head_dim"] + req_id = common_constants["req_id"] + + sender_config = OmniKVCacheConfig( + connector_config={"type": "mock"}, from_stage="sender", to_stage="receiver", need_send_cache=True + ) + sender_manager = OmniKVTransferManager(sender_config) + connector = MockConnector() + sender_manager._connector = connector # Shared connector + + # Create Data + num_blocks = 5 + kv_caches = [] + for _ in range(num_layers): + layer = torch.randn(2, num_blocks, block_size, num_heads, head_dim) + kv_caches.append(layer) + + finished_reqs = {req_id: {"block_ids": [0, 1], "seq_len": 10}} + + # Send + sender_manager.handle_finished_requests_kv_transfer(finished_reqs, kv_caches, block_size, "float32") + + receiver_config = OmniKVCacheConfig( + connector_config={"type": "mock"}, + from_stage="sender", + stage_id="receiver", + need_recv_cache=True, + recv_timeout=1.0, + ) + receiver_manager = OmniKVTransferManager(receiver_config) + receiver_manager._connector = connector # Share the same mock connector instance + + req = OmniDiffusionRequest( + prompts=["test_integ"], + sampling_params=OmniDiffusionSamplingParams(), + request_ids=[req_id], + ) + + # Receive + success = receiver_manager.receive_kv_cache(req) + + # Verify + assert success + assert req.past_key_values is not None + assert req.kv_metadata["seq_len"] == 10 + + +@pytest.mark.cache +@hardware_test( + res={"cuda": "L4"}, + num_cards=2, +) +def test_manager_extraction_no_connector(kv_config, common_constants): + """Test extraction when connector is unavailable (should still return IDs).""" + block_size = common_constants["block_size"] + req_id = common_constants["req_id"] + + manager = OmniKVTransferManager(kv_config) + # Force connector to be None + manager._connector = None + manager.config.connector_config = None + finished_reqs = {req_id: {"block_ids": [1, 2], "seq_len": 10}} + + processed = manager.handle_finished_requests_kv_transfer( + finished_reqs, kv_caches=[], block_size=block_size, cache_dtype="float32" + ) + + assert req_id in processed diff --git a/vllm_omni/diffusion/worker/diffusion_model_runner.py b/vllm_omni/diffusion/worker/diffusion_model_runner.py index 7335162b0ee..b60d2c2f5b9 100644 --- a/vllm_omni/diffusion/worker/diffusion_model_runner.py +++ b/vllm_omni/diffusion/worker/diffusion_model_runner.py @@ -28,8 +28,7 @@ from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.offload import apply_offload_hooks from vllm_omni.diffusion.request import OmniDiffusionRequest -from vllm_omni.distributed.omni_connectors.factory import OmniConnectorFactory -from vllm_omni.distributed.omni_connectors.utils.config import ConnectorSpec +from vllm_omni.distributed.omni_connectors.kv_transfer_manager import OmniKVTransferManager logger = init_logger(__name__) @@ -62,10 +61,9 @@ def __init__( self.device = device self.pipeline = None self.cache_backend = None - self.connector = None - # Initialize OmniConnector after vllm_config is available (via init_device_and_model) - self._init_omni_connector() + # Initialize KV cache manager for connector management + self.kv_transfer_manager = OmniKVTransferManager.from_od_config(od_config) def load_model( self, @@ -138,126 +136,6 @@ def get_memory_context(): logger.info("Model runner: Initialization complete.") - def _init_omni_connector(self) -> None: - # TODO(wzliu)! get real connector from yaml file instead of hardcode - """Initialize OmniConnector for KV cache transfer.""" - try: - connector_config = None - - # 1. Try to get from omni_kv_config (injected from YAML) - # Use self.od_config because self.vllm_config is a dummy VllmConfig without model_config - if self.od_config.omni_kv_config: - connector_config = self.od_config.omni_kv_config.get("connector_config") - - if not connector_config: - logger.warning("No OmniConnector config found, skipping initialization") - return - - logger.info(f"Initializing OmniConnector with config: {connector_config}") - - c_type = connector_config.get("type") - if not c_type: - logger.error("Connector config missing 'type'") - return - - c_extra = {k: v for k, v in connector_config.items() if k != "type"} - connector_spec = ConnectorSpec(name=c_type, extra=c_extra) - - self.connector = OmniConnectorFactory.create_connector(connector_spec) - - except Exception as e: - logger.error(f"Failed to initialize OmniConnector: {e}") - import traceback - - traceback.print_exc() - - def _receive_kv_cache_for_request(self, req: OmniDiffusionRequest) -> None: - """Receive KV cache for a request via OmniConnector.""" - # TODO(wzliu)! must get control info from stage queue instead of hardcode - if not req.request_ids: - logger.warning("Request has no ID, cannot receive KV cache") - return - request_id = req.request_ids[0] - - try: - logger.info(f"Attempting to receive KV cache for request {request_id}") - - # TODO: Key used for transfer (must match sender side) - # key = f"kv_cache_{req.request_id}" - - # Get data from connector - # Determine from_stage and to_stage dynamically - omni_kv_config = self.od_config.omni_kv_config - stage_id = omni_kv_config.get("stage_id") - engine_input_source = omni_kv_config.get("engine_input_source", []) - - to_stage = stage_id - # Default to stage_id - 1 if input source is not explicit - if engine_input_source: - from_stage = engine_input_source[0] - elif isinstance(stage_id, int): - from_stage = stage_id - 1 - else: - raise ValueError("Invalid stage id") - logger.info(f"Wait for KV cache for request {request_id} from stage {from_stage} to {to_stage}...") - - # Check if we should receive KV cache based on config - need_recv_cache = omni_kv_config.get("need_recv_cache", False) - if need_recv_cache: - # Default timeout 30 seconds to prevent infinite hanging - timeout = omni_kv_config.get("recv_timeout", 30.0) - start_time = time.time() - - while True: - get_key = f"omni_{from_stage}_to_{to_stage}_kv_cache_{request_id}" - result = self.connector.get( - from_stage=from_stage, - to_stage=to_stage, - get_key=get_key, - ) - if result: - break - - if time.time() - start_time > timeout: - logger.error(f"Timeout waiting for KV cache for request {request_id} after {timeout}s") - result = None - break - - time.sleep(0.5) - else: - logger.info(f"Skip receiving KV cache for {request_id} (need_recv_cache=False)") - result = None - - if result: - data, size = result - logger.info(f"Successfully received KV cache for {request_id}") - - # Assume data structure matches KVCacheTransferData.to_dict() - if isinstance(data, dict) and "layer_blocks" in data: - # Get layer blocks and ensure they are on the correct device - layer_blocks = data["layer_blocks"] - - # Move tensors to GPU if needed (OmniSerializer should handle tensor reconstruction) - for cache_list in [layer_blocks["key_cache"], layer_blocks["value_cache"]]: - for i, tensor in enumerate(cache_list): - if isinstance(tensor, torch.Tensor) and tensor.device != self.pipeline.device: - cache_list[i] = tensor.to(self.pipeline.device).contiguous() - from types import SimpleNamespace - - req.sampling_params.past_key_values = SimpleNamespace(**layer_blocks) - - if "metadata" in data: - req.sampling_params.kv_metadata = data["metadata"] - - else: - logger.warning(f"No KV cache received for {request_id} (timeout or empty)") - - except Exception as e: - logger.error(f"Error receiving KV cache for {request_id}: {e}") - import traceback - - traceback.print_exc() - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: """Load weights into the pipeline.""" return self.pipeline.load_weights(weights) @@ -277,9 +155,8 @@ def execute_model(self, req: OmniDiffusionRequest) -> DiffusionOutput: if len(req.prompts) == 0: raise ValueError("Cannot execute model with empty request list") - # [Omni] KV Cache Receiving Logic - if req.sampling_params.need_kv_receive and self.connector is not None: - self._receive_kv_cache_for_request(req) + # The manager handles the check for need_recv_cache internally + self.kv_transfer_manager.receive_kv_cache(req, target_device=getattr(self.pipeline, "device", None)) if req.sampling_params.generator is None and req.sampling_params.seed is not None: req.sampling_params.generator = torch.Generator(device=self.device).manual_seed(req.sampling_params.seed) diff --git a/vllm_omni/distributed/omni_connectors/kv_transfer_manager.py b/vllm_omni/distributed/omni_connectors/kv_transfer_manager.py new file mode 100644 index 00000000000..82c06fdafee --- /dev/null +++ b/vllm_omni/distributed/omni_connectors/kv_transfer_manager.py @@ -0,0 +1,459 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unified OmniConnector and KV cache transfer management.""" + +import time +from collections.abc import Callable +from dataclasses import asdict, dataclass +from typing import Any + +import torch +from vllm.logger import init_logger + +from .factory import OmniConnectorFactory +from .utils.config import ConnectorSpec + +logger = init_logger(__name__) + + +@dataclass +class OmniKVCacheConfig: + """Configuration for OmniKVTransferManager.""" + + connector_config: dict[str, Any] | None = None + from_stage: str | None = None + to_stage: str | None = None + stage_id: str | int | None = None + engine_input_source: list[str | int] | None = None + need_recv_cache: bool = False + need_send_cache: bool = False + recv_timeout: float = 30.0 + + +@dataclass +class KVCacheTransferData: + """Container for KV cache transfer data.""" + + request_id: str + layer_blocks: dict[str, Any] + block_ids: list[int] + metadata: dict[str, Any] + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for serialization.""" + return asdict(self) + + +class OmniKVTransferManager: + """Unified management for OmniConnector and KV cache transfer. + + This class encapsulates all KV cache related operations: + - Connector initialization and lazy creation + - KV cache extraction from GPU blocks + - KV cache transfer with retry logic + - KV cache receiving with timeout + """ + + def __init__(self, config: OmniKVCacheConfig): + self.config = config + self._connector = None + + # Pre-calculate send stages (from_stage, to_stage) + self.send_stages = ( + (str(config.from_stage), str(config.to_stage)) if config.from_stage and config.to_stage else (None, None) + ) + + # Pre-calculate receive stages (from_stage, to_stage) + recv_from = config.from_stage + if config.engine_input_source: + recv_from = config.engine_input_source[0] + elif isinstance(config.stage_id, int): + recv_from = config.stage_id - 1 + + self.recv_stages = ( + (str(recv_from), str(config.stage_id)) + if recv_from is not None and config.stage_id is not None + else (None, None) + ) + + @classmethod + def _create(cls, cfg: dict | None) -> "OmniKVTransferManager": + """Create manager from raw config dict.""" + if not cfg or not isinstance(cfg, dict): + return cls(OmniKVCacheConfig()) + return cls( + OmniKVCacheConfig( + connector_config=cfg.get("connector_config"), + from_stage=cfg.get("omni_from_stage"), + to_stage=cfg.get("omni_to_stage"), + stage_id=cfg.get("stage_id"), + engine_input_source=cfg.get("engine_input_source", []), + need_recv_cache=cfg.get("need_recv_cache", False), + need_send_cache=cfg.get("need_send_cache", False), + recv_timeout=cfg.get("recv_timeout", 30.0), + ) + ) + + @classmethod + def from_model_config(cls, config: Any) -> "OmniKVTransferManager": + """Create from model config (for AR model runner).""" + return cls._create(getattr(config, "omni_kv_config", None)) + + @classmethod + def from_od_config(cls, config: Any) -> "OmniKVTransferManager": + """Create from OmniDiffusion config (for diffusion runner).""" + return cls._create(getattr(config, "omni_kv_config", None)) + + @classmethod + def from_vllm_config(cls, vllm_config: Any, model_config: Any) -> "OmniKVTransferManager": + """Create from vllm config with fallback to kv_transfer_config.""" + # Primary: omni_kv_config from model_config + omni_kv = getattr(model_config, "omni_kv_config", None) + if isinstance(omni_kv, dict): + return cls._create(omni_kv) + + # Fallback: check kv_transfer_config + kv_cfg = getattr(vllm_config, "kv_transfer_config", None) + if kv_cfg: + direct = getattr(kv_cfg, "omni_connector_config", None) + if isinstance(direct, dict) and direct: + return cls._create({"connector_config": direct}) + extra = getattr(kv_cfg, "kv_connector_extra_config", None) + if isinstance(extra, dict): + omni = extra.get("omni_connector_config") + if isinstance(omni, dict) and omni: + return cls._create({"connector_config": omni}) + + return cls(OmniKVCacheConfig()) + + @property + def connector(self): + """Lazy initialization of connector.""" + # If a previous initialization attempt failed, don't retry on every access. + if self._connector is False: + return None + + if self._connector is None: + cfg = self.config.connector_config + if cfg and (c_type := cfg.get("type")): + try: + logger.info(f"Initializing OmniConnector with config: {cfg}") + c_extra = {k: v for k, v in cfg.items() if k != "type"} + self._connector = OmniConnectorFactory.create_connector(ConnectorSpec(name=c_type, extra=c_extra)) + except Exception as e: + logger.error(f"Failed to initialize OmniConnector: {e}") + import traceback + + traceback.print_exc() + # Cache failure sentinel to avoid repeated initialization attempts in hot paths. + self._connector = False + + return self._connector if self._connector else None + + def get_connector(self): + """Get connector (compatibility wrapper for existing code).""" + return self.connector + + def handle_finished_requests_kv_transfer( + self, + finished_reqs: dict[str, dict[str, Any]], + kv_caches: list[torch.Tensor], + block_size: int, + cache_dtype: str, + request_id_resolver: Callable[[str], str] | None = None, + ) -> list[str]: + """Handle KV cache transfer for finished requests. + + This method extracts KV cache from GPU blocks and transfers them + to the downstream stage via the connector. + + Args: + finished_reqs: Dict mapping request_id to {block_ids, seq_len} + kv_caches: List of KV cache tensors per layer + block_size: Size of each cache block + cache_dtype: Data type of the cache + request_id_resolver: Optional function to resolve global request ID + + Returns: + List of request IDs that were processed + """ + if not finished_reqs: + return [] + + if not self.config.need_send_cache: + return list(finished_reqs.keys()) + + if not self.connector: + logger.warning("No connector available, skipping KV transfer but freeing resources") + return list(finished_reqs.keys()) + + logger.debug(f"Processing KV transfer for {len(finished_reqs)} requests") + + extracted_ids = [] + for req_id, data in finished_reqs.items(): + try: + seq_len = data.get("seq_len", 0) + block_ids = data.get("block_ids", []) + if not block_ids: + logger.warning(f"Request {req_id} has no block IDs, skipping") + continue + + # Extract KV cache from GPU blocks -> CPU tensors + kv_data = self._extract_kv_cache(req_id, block_ids, seq_len, kv_caches, block_size, cache_dtype) + if kv_data: + # Resolve global request ID if available + transfer_req_id = request_id_resolver(req_id) if request_id_resolver else req_id + + # Transfer to downstream stage via connector + self._transfer_kv_cache(kv_data, transfer_req_id) + + except Exception as e: + logger.error(f"Failed KV transfer for {req_id}: {e}") + finally: + extracted_ids.append(req_id) + + return extracted_ids + + def _extract_kv_cache( + self, + req_id: str, + block_ids: list[int], + seq_len: int, + kv_caches: list[torch.Tensor], + block_size: int, + cache_dtype: str, + ) -> KVCacheTransferData | None: + """Extract KV cache from GPU blocks for a single request. + + Args: + req_id: Request identifier + block_ids: List of block IDs to extract + seq_len: Sequence length + kv_caches: List of KV cache tensors per layer + block_size: Size of each cache block + cache_dtype: Data type of the cache + + Returns: + KVCacheTransferData if extraction successful, None otherwise + """ + num_layers = len(kv_caches) + key_cache: list[torch.Tensor | None] = [None] * num_layers + value_cache: list[torch.Tensor | None] = [None] * num_layers + + for layer_idx, kv_tensor in enumerate(kv_caches): + # Validate block IDs - shape: [2, num_blocks, block_size, n_heads, head_dim] + max_block = kv_tensor.shape[1] - 1 + valid_ids = [bid for bid in block_ids if 0 <= bid <= max_block] + if not valid_ids: + continue + + # Extract and reshape: [2, n_blocks, block_size, n_heads, head_dim] + # -> [2, seq_len, n_heads, head_dim] + selected = kv_tensor[:, valid_ids] # [2, n_valid, block_size, n_heads, head_dim] + n_kv, n_blks, blk_sz, n_heads, d_head = selected.shape + flat = selected.reshape(n_kv, n_blks * blk_sz, n_heads, d_head) + if seq_len < flat.shape[1]: + flat = flat[:, :seq_len] + + # Move to CPU + flat_cpu = flat.detach().cpu().contiguous() + key_cache[layer_idx] = flat_cpu[0] + value_cache[layer_idx] = flat_cpu[1] + + if not any(k is not None for k in key_cache): + return None + + return KVCacheTransferData( + request_id=req_id, + layer_blocks={"key_cache": key_cache, "value_cache": value_cache}, + block_ids=block_ids, + metadata={ + "block_size": block_size, + "num_layers": num_layers, + "dtype": str(cache_dtype), + "seq_len": seq_len, + }, + ) + + def _transfer_kv_cache(self, kv_data: KVCacheTransferData, transfer_req_id: str) -> None: + """Transfer KV cache data to downstream stage via OmniConnector. + + Args: + kv_data: The extracted KV cache data + transfer_req_id: The request ID to use for transfer + """ + from_stage, to_stage = self.send_stages + if not from_stage or not to_stage: + raise ValueError("Transfer stages (omni_from_stage, omni_to_stage) not configured") + + # Prepare data and transfer with retry + data_dict = kv_data.to_dict() + data_dict["request_id"] = transfer_req_id + + success, size, _ = self._transfer_with_retry(from_stage, to_stage, f"kv_cache_{transfer_req_id}", data_dict) + + if success: + logger.info(f"KV transfer OK: {transfer_req_id}, {size} bytes") + else: + logger.error(f"KV transfer FAILED: {transfer_req_id}") + + def _transfer_with_retry( + self, + from_stage: str, + to_stage: str, + request_id: str, + data: dict[str, Any], + max_retries: int = 3, + ) -> tuple[bool, int, dict[str, Any] | None]: + """Transfer data with retry and exponential backoff. + + Args: + from_stage: Source stage identifier + to_stage: Target stage identifier + request_id: Request identifier for the key + data: Data to transfer + max_retries: Maximum number of retry attempts + + Returns: + Tuple of (success, size, metadata) + """ + for attempt in range(max_retries): + try: + # Build the full key for connector + full_request_id = f"omni_{from_stage}_to_{to_stage}_{request_id}" + success, size, metadata = self.connector.put( + from_stage=from_stage, to_stage=to_stage, put_key=full_request_id, data=data + ) + if success: + return success, size, metadata + logger.warning(f"Transfer attempt {attempt + 1} failed for {request_id}") + except Exception as e: + logger.warning(f"Transfer attempt {attempt + 1} exception: {e}") + + if attempt < max_retries - 1: + time.sleep(0.1 * (2**attempt)) + + return False, 0, None + + @torch.inference_mode() + def receive_kv_cache_for_request( + self, + request_id: str, + target_device: torch.device | None = None, + ) -> tuple[dict[str, Any] | None, int]: + """Receive KV cache for a specific request. + + This implements the receiving logic from gpu_diffusion_model_runner.py. + + Args: + request_id: The request ID to receive KV cache for + target_device: Optional device to move tensors to + + Returns: + Tuple of (data dict, size) if successful, (None, 0) otherwise + """ + if not self.connector: + logger.warning("No connector available for receiving KV cache") + return None, 0 + + from_stage, to_stage = self.recv_stages + if not from_stage or not to_stage: + logger.warning("Receive stages not configured") + return None, 0 + + # Check if we should receive KV cache based on config + if not self.config.need_recv_cache: + logger.info(f"Skip receiving KV cache for {request_id} (need_recv_cache=False)") + return None, 0 + + timeout = self.config.recv_timeout + start_time = time.time() + + logger.info(f"Wait for KV cache for request {request_id} from stage {from_stage} to {to_stage}...") + + try: + while True: + # Build the full key for connector + full_request_id = f"omni_{from_stage}_to_{to_stage}_kv_cache_{request_id}" + result = self.connector.get( + from_stage=from_stage, + to_stage=to_stage, + get_key=full_request_id, + ) + if result: + data, size = result + logger.info(f"Successfully received KV cache for {request_id}, {size} bytes") + + # Move tensors to target device if specified + if target_device is not None and isinstance(data, dict) and "layer_blocks" in data: + layer_blocks = data["layer_blocks"] + for cache_list in [ + layer_blocks.get("key_cache", []), + layer_blocks.get("value_cache", []), + ]: + for i, tensor in enumerate(cache_list): + if isinstance(tensor, torch.Tensor) and tensor.device != target_device: + cache_list[i] = tensor.to(target_device).contiguous() + + return data, size + + if time.time() - start_time > timeout: + logger.error(f"Timeout waiting for KV cache for request {request_id} after {timeout}s") + return None, 0 + + time.sleep(0.5) + + except Exception as e: + logger.error(f"Error receiving KV cache for {request_id}: {e}") + import traceback + + traceback.print_exc() + return None, 0 + + def apply_kv_cache_to_request(self, req: Any, data: dict[str, Any]) -> None: + """Apply received KV cache data to a request object. + + Args: + req: The request object to apply KV cache to + data: The received KV cache data dictionary + """ + if isinstance(data, dict) and "layer_blocks" in data: + layer_blocks = data["layer_blocks"] + from types import SimpleNamespace + + kv_obj = SimpleNamespace(**layer_blocks) + req.past_key_values = kv_obj + + # [Omni] Also attach to sampling_params for BagelPipeline compatibility + # BagelPipeline checks req.sampling_params.past_key_values + if hasattr(req, "sampling_params") and req.sampling_params is not None: + req.sampling_params.past_key_values = kv_obj + + if "metadata" in data: + req.kv_metadata = data["metadata"] + + # Legacy compatibility method + def receive_kv_cache(self, req: Any, target_device: torch.device | None = None) -> bool: + """Receive KV cache and populate request object (legacy interface). + + Args: + req: Request object with request_id attribute + target_device: Optional device to move tensors to + + Returns: + True if successful, False otherwise + """ + request_id = getattr(req, "request_id", None) + if not request_id and hasattr(req, "request_ids") and req.request_ids: + # Adaptation for new OmniDiffusionRequest which has list of prompts/ids + request_id = req.request_ids[0] + + if not request_id: + logger.warning("Request has no ID, cannot receive KV cache") + return False + + data, size = self.receive_kv_cache_for_request(request_id, target_device) + if data: + self.apply_kv_cache_to_request(req, data) + return True + return False diff --git a/vllm_omni/model_executor/stage_configs/bagel_multiconnector.yaml b/vllm_omni/model_executor/stage_configs/bagel_multiconnector.yaml index 86c0330a6b9..e378fe6752d 100644 --- a/vllm_omni/model_executor/stage_configs/bagel_multiconnector.yaml +++ b/vllm_omni/model_executor/stage_configs/bagel_multiconnector.yaml @@ -61,13 +61,7 @@ stage_args: final_output_type: image is_comprehension: false default_sampling_params: - temperature: 0.0 - top_p: 1.0 - top_k: -1 - max_tokens: 2048 seed: 52 - detokenize: True - repetition_penalty: 1.0 input_connectors: from_stage_0: mooncake_connector diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py index ecc839fd3a2..8440cb5b8ec 100644 --- a/vllm_omni/worker/gpu_ar_model_runner.py +++ b/vllm_omni/worker/gpu_ar_model_runner.py @@ -35,7 +35,7 @@ from vllm.v1.worker.ubatch_utils import maybe_create_ubatch_slices from vllm.v1.worker.utils import is_residual_scattered_for_sp -from vllm_omni.core.sched.omni_ar_scheduler import KVCacheTransferData +from vllm_omni.distributed.omni_connectors.kv_transfer_manager import OmniKVTransferManager from vllm_omni.outputs import OmniModelRunnerOutput from vllm_omni.worker.gpu_model_runner import OmniGPUModelRunner @@ -70,7 +70,8 @@ def __init__(self, *args, **kwargs): # each model stage has their own hidden size self.hidden_size = self.model_config.hf_text_config.hidden_size self.inputs_embeds = self._make_buffer(self.max_num_tokens, self.hidden_size, dtype=self.dtype, numpy=False) - self.omni_connector = None + # Initialize KV cache manager (preserve vllm_config fallback behavior) + self.kv_transfer_manager = OmniKVTransferManager.from_vllm_config(self.vllm_config, self.model_config) def _make_buffer(self, *size, dtype, numpy=True): # Prevent ray from pinning the buffer due to large size @@ -95,7 +96,13 @@ def execute_model( raise RuntimeError("State error: sample_tokens() must be called after execute_model() returns None.") # [Omni] Handle KV transfer BEFORE updating states (which removes finished requests) - self.kv_extracted_req_ids = self._handle_finished_requests_kv_transfer(scheduler_output) + self.kv_extracted_req_ids = self.kv_transfer_manager.handle_finished_requests_kv_transfer( + finished_reqs=getattr(scheduler_output, "finished_requests_needing_kv_transfer", {}), + kv_caches=self.kv_caches, + block_size=self.cache_config.block_size, + cache_dtype=str(self.cache_config.cache_dtype), + request_id_resolver=self._resolve_global_request_id, + ) if self.vllm_config.model_config.enable_return_routed_experts: capturer = RoutedExpertsCapturer.get_instance() @@ -577,161 +584,6 @@ def propose_draft_token_ids(sampled_token_ids): return async_output - def _handle_finished_requests_kv_transfer(self, scheduler_output: SchedulerOutput) -> list[str]: - """Handle KV cache transfer for finished requests. - - Returns list of request IDs that were processed (for Scheduler to free blocks). - """ - finished_reqs = getattr(scheduler_output, "finished_requests_needing_kv_transfer", {}) - if not finished_reqs: - return [] - - logger.debug(f"Processing KV transfer for {len(finished_reqs)} requests") - - extracted_ids = [] - for req_id, data in finished_reqs.items(): - try: - seq_len = data.get("seq_len", 0) - block_ids = data.get("block_ids", []) - if not block_ids: - logger.warning(f"Request {req_id} has no block IDs, skipping") - continue - - # Extract KV cache from GPU blocks -> CPU tensors - kv_data = self._extract_kv_cache(req_id, block_ids, seq_len) - if kv_data: - # Transfer to downstream stage via connector - self._transfer_kv_cache(kv_data) - - except Exception as e: - logger.error(f"Failed KV transfer for {req_id}: {e}") - finally: - extracted_ids.append(req_id) - - return extracted_ids - - def _extract_kv_cache(self, req_id: str, block_ids: list[int], seq_len: int) -> KVCacheTransferData | None: - """Extract KV cache from GPU blocks for a single request.""" - num_layers = len(self.kv_caches) - key_cache = [None] * num_layers - value_cache = [None] * num_layers - - for layer_idx, kv_tensor in enumerate(self.kv_caches): - # Validate block IDs - max_block = kv_tensor.shape[1] - 1 - valid_ids = [bid for bid in block_ids if 0 <= bid <= max_block] - if not valid_ids: - continue - - # Extract and reshape: [2, n_blocks, block_size, n_heads, head_dim] - # -> [2, seq_len, n_heads, head_dim] - selected = kv_tensor[:, valid_ids] # [2, n_valid, block_size, n_heads, head_dim] - n_kv, n_blks, blk_sz, n_heads, d_head = selected.shape - flat = selected.reshape(n_kv, n_blks * blk_sz, n_heads, d_head) - if seq_len < flat.shape[1]: - flat = flat[:, :seq_len] - - # Move to CPU - flat_cpu = flat.detach().cpu().contiguous() - key_cache[layer_idx] = flat_cpu[0] - value_cache[layer_idx] = flat_cpu[1] - - if not any(k is not None for k in key_cache): - return None - - return KVCacheTransferData( - request_id=req_id, - layer_blocks={"key_cache": key_cache, "value_cache": value_cache}, - block_ids=block_ids, - metadata={ - "block_size": self.cache_config.block_size, - "num_layers": num_layers, - "dtype": str(self.cache_config.cache_dtype), - "seq_len": seq_len, - }, - ) - - def _transfer_kv_cache(self, kv_data: KVCacheTransferData) -> None: - """Transfer KV cache data to downstream stage via OmniConnector.""" - connector = self._get_or_create_connector() - if not connector: - return - - # Resolve global request ID if available - transfer_req_id = self._resolve_global_request_id(kv_data.request_id) - from_stage, to_stage = self._detect_transfer_stages() - - # Prepare data and transfer with retry - data_dict = kv_data.to_dict() - data_dict["request_id"] = transfer_req_id - - success, size, _ = self._transfer_with_retry( - connector, from_stage, to_stage, f"kv_cache_{transfer_req_id}", data_dict - ) - - if success: - logger.info(f"KV transfer OK: {transfer_req_id}, {size} bytes") - else: - logger.error(f"KV transfer FAILED: {transfer_req_id}") - - def _get_or_create_connector(self) -> Any | None: - """Get existing connector or create one from config.""" - if self.omni_connector: - return self.omni_connector - - from vllm_omni.distributed.omni_connectors.factory import OmniConnectorFactory - from vllm_omni.distributed.omni_connectors.utils.config import ConnectorSpec - - config = self._get_omni_connector_config() - if not config or not isinstance(config, dict): - logger.warning("No valid OmniConnector config found") - return None - - c_type = config.get("type") - if not c_type: - logger.error("OmniConnector config missing 'type' field") - return None - - c_extra = {k: v for k, v in config.items() if k != "type"} - self.omni_connector = OmniConnectorFactory.create_connector(ConnectorSpec(name=c_type, extra=c_extra)) - return self.omni_connector - - def _get_omni_connector_config(self) -> dict[str, Any] | None: - """Get OmniConnector configuration from model config.""" - # Primary: omni_kv_config from YAML - omni_kv = getattr(self.model_config, "omni_kv_config", None) - if isinstance(omni_kv, dict): - cfg = omni_kv.get("connector_config") - if isinstance(cfg, dict) and cfg: - return cfg - - # Fallback: kv_transfer_config - kv_cfg = getattr(self.vllm_config, "kv_transfer_config", None) - if kv_cfg: - direct = getattr(kv_cfg, "omni_connector_config", None) - if isinstance(direct, dict) and direct: - return direct - extra = getattr(kv_cfg, "kv_connector_extra_config", None) - if isinstance(extra, dict): - omni = extra.get("omni_connector_config") - if isinstance(omni, dict) and omni: - return omni - - return None - - def _detect_transfer_stages(self) -> tuple[str, str]: - """Detect source and target stages for KV transfer.""" - omni_kv = getattr(self.model_config, "omni_kv_config", None) - if isinstance(omni_kv, dict): - from_s = omni_kv.get("omni_from_stage") - to_s = omni_kv.get("omni_to_stage") - if from_s and to_s: - return str(from_s), str(to_s) - - raise ValueError( - "KV transfer stages not configured. Please set 'omni_from_stage' and 'omni_to_stage' in omni_kv_config." - ) - def _resolve_global_request_id(self, req_id: str) -> str: """Resolve global request ID from request state.""" req_state = self.requests.get(req_id) @@ -747,32 +599,3 @@ def _resolve_global_request_id(self, req_id: str) -> str: return global_id.decode("utf-8") return str(global_id) return req_id - - def _transfer_with_retry( - self, - connector: Any, - from_stage: str, - to_stage: str, - request_id: str, - data: dict[str, Any], - max_retries: int = 3, - ) -> tuple[bool, int, dict[str, Any] | None]: - """Transfer data with retry and exponential backoff.""" - import time - - for attempt in range(max_retries): - try: - put_key = f"omni_{from_stage}_to_{to_stage}_{request_id}" - success, size, metadata = connector.put( - from_stage=from_stage, to_stage=to_stage, put_key=put_key, data=data - ) - if success: - return success, size, metadata - logger.warning(f"Transfer attempt {attempt + 1} failed for {request_id}") - except Exception as e: - logger.warning(f"Transfer attempt {attempt + 1} exception: {e}") - - if attempt < max_retries - 1: - time.sleep(0.1 * (2**attempt)) - - return False, 0, None