diff --git a/.gitignore b/.gitignore index e3ae3c9d0b8..953f0e21638 100644 --- a/.gitignore +++ b/.gitignore @@ -91,3 +91,11 @@ generated-values.yaml .build/ **/.devcontainer/.env TensorRT-LLM + + +# START Ruler Generated Files +/.cursor/instructions.md +/.cursor/instructions.md.bak +/CLAUDE.md +/CLAUDE.md.bak +# END Ruler Generated Files diff --git a/dynamo.code-workspace b/dynamo.code-workspace index e3b65d3d097..4935096f65b 100644 --- a/dynamo.code-workspace +++ b/dynamo.code-workspace @@ -7,7 +7,6 @@ "settings": { "rust-analyzer.linkedProjects": [ "Cargo.toml", - "launch/dynamo-run/Cargo.toml", "lib/bindings/python/Cargo.toml" ], "rust-analyzer.procMacro.enable": true, diff --git a/lib/bindings/python/Cargo.toml b/lib/bindings/python/Cargo.toml index 9905c8dedec..23a046d0999 100644 --- a/lib/bindings/python/Cargo.toml +++ b/lib/bindings/python/Cargo.toml @@ -34,7 +34,7 @@ name = "_core" crate-type = ["cdylib", "rlib"] [features] -default = [] +default = ["block-manager"] block-manager = ["dynamo-llm/block-manager", "dep:dlpark", "dep:cudarc"] [dependencies] diff --git a/lib/bindings/python/rust/lib.rs b/lib/bindings/python/rust/lib.rs index 05239efabf4..c3d92e74f6a 100644 --- a/lib/bindings/python/rust/lib.rs +++ b/lib/bindings/python/rust/lib.rs @@ -87,6 +87,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; @@ -120,6 +121,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> { engine::add_to_module(m)?; parsers::add_to_module(m)?; + llm::scheduler_connector::register_module(m)?; #[cfg(feature = "block-manager")] llm::block_manager::add_to_module(m)?; diff --git a/lib/bindings/python/rust/llm.rs b/lib/bindings/python/rust/llm.rs index ceeed50fd15..a109c5b7731 100644 --- a/lib/bindings/python/rust/llm.rs +++ b/lib/bindings/python/rust/llm.rs @@ -34,6 +34,8 @@ pub mod local_model; pub mod model_card; pub mod nats; pub mod preprocessor; +pub mod scheduler_connector; +pub mod vllm_scheduler; #[cfg(feature = "block-manager")] pub mod block_manager; diff --git a/lib/bindings/python/rust/llm/block_manager/vllm.rs b/lib/bindings/python/rust/llm/block_manager/vllm.rs index 56bd6755585..2fdbe3d2319 100644 --- a/lib/bindings/python/rust/llm/block_manager/vllm.rs +++ b/lib/bindings/python/rust/llm/block_manager/vllm.rs @@ -31,6 +31,7 @@ use crate::to_pyerr; mod block_list; mod connector; mod request; +// mod scheduler; // TODO: Fix PyO3 bindings mod slot; pub use block_list::{BlockListType, BlockState, BlockStates, KvbmBlockList}; @@ -53,6 +54,15 @@ fn _vllm_integration(m: &Bound<'_, PyModule>) -> PyResult<()> { // TODO: use TRTLLM own integration module m.add_class::()?; m.add_class::()?; + + // Add scheduler recorder and conversion functions + // TODO: Fix PyO3 bindings for these types + // m.add_class::()?; + // m.add_function(wrap_pyfunction!(scheduler::scheduler_types::convert_scheduler_output, m)?)?; + // m.add_function(wrap_pyfunction!(scheduler::scheduler_types::convert_model_runner_output, m)?)?; + // m.add_function(wrap_pyfunction!(scheduler::scheduler_types::convert_engine_core_outputs, m)?)?; + // m.add_function(wrap_pyfunction!(scheduler::recorder_bindings::load_scheduler_trace, m)?)?; + Ok(()) } diff --git a/lib/bindings/python/rust/llm/block_manager/vllm/scheduler/mod.rs b/lib/bindings/python/rust/llm/block_manager/vllm/scheduler/mod.rs new file mode 100644 index 00000000000..9aa1b3e0d44 --- /dev/null +++ b/lib/bindings/python/rust/llm/block_manager/vllm/scheduler/mod.rs @@ -0,0 +1,7 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +pub mod recorder_bindings; +pub mod scheduler_types; + +use pyo3::prelude::*; \ No newline at end of file diff --git a/lib/bindings/python/rust/llm/block_manager/vllm/scheduler/recorder_bindings.rs b/lib/bindings/python/rust/llm/block_manager/vllm/scheduler/recorder_bindings.rs new file mode 100644 index 00000000000..13fb6b0fe68 --- /dev/null +++ b/lib/bindings/python/rust/llm/block_manager/vllm/scheduler/recorder_bindings.rs @@ -0,0 +1,83 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Python bindings for the SchedulerRecorder + +use dynamo_llm::integrations::vllm::recorder::SchedulerRecorder as RustRecorder; +use dynamo_llm::integrations::vllm::types::*; +use pyo3::prelude::*; +use std::path::PathBuf; + +/// Python-accessible SchedulerRecorder +#[pyclass(name = "SchedulerRecorder")] +pub struct PySchedulerRecorder { + inner: RustRecorder, +} + +#[pymethods] +impl PySchedulerRecorder { + /// Create a new SchedulerRecorder + #[new] + #[pyo3(signature = (model, vllm_version))] + fn new(model: String, vllm_version: String) -> Self { + Self { + inner: RustRecorder::new(model, vllm_version), + } + } + + /// Record a scheduler output (already converted to Rust) + fn record_schedule_output(&mut self, output: SchedulerOutput) -> PyResult<()> { + self.inner.record_schedule_output(output); + Ok(()) + } + + /// Record a model runner output (already converted to Rust) + fn record_model_runner_output(&mut self, output: ModelRunnerOutput) -> PyResult<()> { + self.inner.record_model_runner_output(output); + Ok(()) + } + + /// Record engine core outputs (already converted to Rust) + fn record_engine_core_outputs(&mut self, outputs: EngineCoreOutputs) -> PyResult<()> { + self.inner.record_engine_core_outputs(outputs); + Ok(()) + } + + /// Move to the next iteration + fn next_iteration(&mut self) -> PyResult<()> { + self.inner.next_iteration(); + Ok(()) + } + + /// Get the current iteration number + fn current_iteration(&self) -> u64 { + self.inner.current_iteration() + } + + /// Save the recording to a JSON file + fn save_to_file(&mut self, path: String) -> PyResult<()> { + let path = PathBuf::from(path); + self.inner + .save_to_file(&path) + .map_err(|e| PyErr::new::(format!("{}", e))) + } + + /// Clear all recordings + fn clear(&mut self) -> PyResult<()> { + self.inner.clear(); + Ok(()) + } + + /// Get the number of recorded iterations + fn num_iterations(&self) -> usize { + self.inner.get_trace().iterations.len() + } +} + +/// Load a recording from a JSON file +#[pyfunction] +pub fn load_scheduler_trace(path: String) -> PyResult { + let path = PathBuf::from(path); + RustRecorder::load_from_file(&path) + .map_err(|e| PyErr::new::(format!("{}", e))) +} \ No newline at end of file diff --git a/lib/bindings/python/rust/llm/block_manager/vllm/scheduler/scheduler_types.rs b/lib/bindings/python/rust/llm/block_manager/vllm/scheduler/scheduler_types.rs new file mode 100644 index 00000000000..cb5894860f6 --- /dev/null +++ b/lib/bindings/python/rust/llm/block_manager/vllm/scheduler/scheduler_types.rs @@ -0,0 +1,416 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Python-to-Rust converters for vLLM scheduler types + +use dynamo_llm::integrations::vllm::types::*; +use pyo3::prelude::*; +use pyo3::types::{PyDict, PyList}; +use std::collections::HashMap; + +/// Convert Python SchedulerOutput to Rust +#[pyfunction] +pub fn convert_scheduler_output(py: Python, obj: &Bound<'_, PyAny>) -> PyResult { + // Extract scheduled_new_reqs + let new_reqs_py = obj.getattr("scheduled_new_reqs")?; + let mut scheduled_new_reqs = Vec::new(); + + for item in new_reqs_py.iter()? { + let item = item?; + let req_id = item.getattr("req_id")?.extract::()?; + let prompt_token_ids = item.getattr("prompt_token_ids")?.extract::>()?; + + // Extract block_ids (tuple of lists) + let block_ids_tuple = item.getattr("block_ids")?; + let mut block_ids = Vec::new(); + for block_list in block_ids_tuple.iter()? { + let block_list = block_list?; + let blocks = block_list.extract::>()?; + block_ids.push(blocks); + } + + let num_computed_tokens = item.getattr("num_computed_tokens")?.extract::()?; + + // Extract mm_hashes and mm_positions + let mm_hashes = if let Ok(hashes) = item.getattr("mm_hashes") { + hashes.extract::>().unwrap_or_default() + } else { + Vec::new() + }; + + let mm_positions = if let Ok(positions) = item.getattr("mm_positions") { + let mut ranges = Vec::new(); + for pos in positions.iter()? { + let pos = pos?; + if let Ok(start) = pos.getattr("start") { + let start = start.extract::()?; + let end = pos.getattr("end")?.extract::()?; + ranges.push(PlaceholderRange { start, end }); + } + } + ranges + } else { + Vec::new() + }; + + scheduled_new_reqs.push(NewRequestData { + req_id, + prompt_token_ids, + block_ids, + num_computed_tokens, + mm_hashes, + mm_positions, + }); + } + + // Extract scheduled_cached_reqs + let cached_reqs_py = obj.getattr("scheduled_cached_reqs")?; + let scheduled_cached_reqs = CachedRequestData { + req_ids: cached_reqs_py.getattr("req_ids")?.extract::>()?, + resumed_from_preemption: cached_reqs_py + .getattr("resumed_from_preemption")? + .extract::>()?, + new_token_ids: cached_reqs_py + .getattr("new_token_ids")? + .extract::>>()?, + new_block_ids: { + let new_blocks = cached_reqs_py.getattr("new_block_ids")?; + let mut result = Vec::new(); + for item in new_blocks.iter()? { + let item = item?; + if item.is_none() { + result.push(None); + } else { + let mut block_ids = Vec::new(); + for block_list in item.iter()? { + let block_list = block_list?; + let blocks = block_list.extract::>()?; + block_ids.push(blocks); + } + result.push(Some(block_ids)); + } + } + result + }, + num_computed_tokens: cached_reqs_py + .getattr("num_computed_tokens")? + .extract::>()?, + }; + + // Extract num_scheduled_tokens + let num_scheduled_tokens_py = obj.getattr("num_scheduled_tokens")?; + let num_scheduled_tokens = if let Ok(dict) = num_scheduled_tokens_py.downcast::() { + let mut map = HashMap::new(); + for (key, value) in dict.iter() { + let key = key.extract::()?; + let value = value.extract::()?; + map.insert(key, value); + } + map + } else { + HashMap::new() + }; + + // Extract other fields + let total_num_scheduled_tokens = obj + .getattr("total_num_scheduled_tokens")? + .extract::()?; + + let scheduled_spec_decode_tokens = if let Ok(spec_tokens) = obj.getattr("scheduled_spec_decode_tokens") { + if let Ok(dict) = spec_tokens.downcast::() { + let mut map = HashMap::new(); + for (key, value) in dict.iter() { + let key = key.extract::()?; + let value = value.extract::>()?; + map.insert(key, value); + } + map + } else { + HashMap::new() + } + } else { + HashMap::new() + }; + + let scheduled_encoder_inputs = if let Ok(encoder_inputs) = obj.getattr("scheduled_encoder_inputs") { + if let Ok(dict) = encoder_inputs.downcast::() { + let mut map = HashMap::new(); + for (key, value) in dict.iter() { + let key = key.extract::()?; + let value = value.extract::>()?; + map.insert(key, value); + } + map + } else { + HashMap::new() + } + } else { + HashMap::new() + }; + + let num_common_prefix_blocks = obj + .getattr("num_common_prefix_blocks")? + .extract::>() + .unwrap_or_default(); + + let finished_req_ids = if let Ok(finished) = obj.getattr("finished_req_ids") { + // Convert set to vec + let mut ids = Vec::new(); + for item in finished.iter()? { + let item = item?; + ids.push(item.extract::()?); + } + ids + } else { + Vec::new() + }; + + let free_encoder_mm_hashes = obj + .getattr("free_encoder_mm_hashes")? + .extract::>() + .unwrap_or_default(); + + Ok(SchedulerOutput { + scheduled_new_reqs, + scheduled_cached_reqs, + num_scheduled_tokens, + total_num_scheduled_tokens, + scheduled_spec_decode_tokens, + scheduled_encoder_inputs, + num_common_prefix_blocks, + finished_req_ids, + free_encoder_mm_hashes, + }) +} + +/// Convert Python ModelRunnerOutput to Rust +#[pyfunction] +pub fn convert_model_runner_output(py: Python, obj: &Bound<'_, PyAny>) -> PyResult { + let req_ids = obj.getattr("req_ids")?.extract::>()?; + + let req_id_to_index_py = obj.getattr("req_id_to_index")?; + let req_id_to_index = if let Ok(dict) = req_id_to_index_py.downcast::() { + let mut map = HashMap::new(); + for (key, value) in dict.iter() { + let key = key.extract::()?; + let value = value.extract::()?; + map.insert(key, value); + } + map + } else { + HashMap::new() + }; + + let sampled_token_ids = obj + .getattr("sampled_token_ids")? + .extract::>>()?; + + let logprobs = if let Ok(logprobs_py) = obj.getattr("logprobs") { + if !logprobs_py.is_none() { + Some(LogprobsLists { + logprob_token_ids: logprobs_py + .getattr("logprob_token_ids")? + .extract::>>()?, + logprobs: logprobs_py + .getattr("logprobs")? + .extract::>>()?, + sampled_token_ranks: logprobs_py + .getattr("sampled_token_ranks")? + .extract::>()?, + }) + } else { + None + } + } else { + None + }; + + let prompt_logprobs_dict = if let Ok(prompt_dict) = obj.getattr("prompt_logprobs_dict") { + if let Ok(dict) = prompt_dict.downcast::() { + let mut map = HashMap::new(); + for (key, value) in dict.iter() { + let key = key.extract::()?; + if !value.is_none() { + let logprobs = Some(LogprobsLists { + logprob_token_ids: value + .getattr("logprob_token_ids")? + .extract::>>()?, + logprobs: value.getattr("logprobs")?.extract::>>()?, + sampled_token_ranks: value + .getattr("selected_token_ranks")? + .extract::>()?, + }); + map.insert(key, logprobs); + } else { + map.insert(key, None); + } + } + map + } else { + HashMap::new() + } + } else { + HashMap::new() + }; + + let num_nans_in_logits = if let Ok(nans) = obj.getattr("num_nans_in_logits") { + if !nans.is_none() { + if let Ok(dict) = nans.downcast::() { + let mut map = HashMap::new(); + for (key, value) in dict.iter() { + let key = key.extract::()?; + let value = value.extract::()?; + map.insert(key, value); + } + Some(map) + } else { + None + } + } else { + None + } + } else { + None + }; + + Ok(ModelRunnerOutput { + req_ids, + req_id_to_index, + sampled_token_ids, + logprobs, + prompt_logprobs_dict, + num_nans_in_logits, + }) +} + +/// Convert Python EngineCoreOutputs to Rust +#[pyfunction] +pub fn convert_engine_core_outputs(py: Python, outputs_dict: &Bound<'_, PyDict>) -> PyResult { + let mut all_outputs = Vec::new(); + + // The dict is keyed by engine_index + for (engine_idx, outputs_list) in outputs_dict.iter() { + let engine_index = engine_idx.extract::()?; + + for output in outputs_list.iter()? { + let output = output?; + + let request_id = output.getattr("request_id")?.extract::()?; + let new_token_ids = output.getattr("new_token_ids")?.extract::>()?; + + let new_logprobs = if let Ok(logprobs_py) = output.getattr("new_logprobs") { + if !logprobs_py.is_none() { + Some(LogprobsLists { + logprob_token_ids: logprobs_py + .getattr("logprob_token_ids")? + .extract::>>()?, + logprobs: logprobs_py + .getattr("logprobs")? + .extract::>>()?, + sampled_token_ranks: logprobs_py + .getattr("sampled_token_ranks")? + .extract::>()?, + }) + } else { + None + } + } else { + None + }; + + let finish_reason = if let Ok(reason) = output.getattr("finish_reason") { + if !reason.is_none() { + let val = reason.extract::()?; + Some(match val { + 0 => FinishReason::Stop, + 1 => FinishReason::Length, + 2 => FinishReason::Abort, + _ => FinishReason::Abort, + }) + } else { + None + } + } else { + None + }; + + let stop_reason = if let Ok(reason) = output.getattr("stop_reason") { + if !reason.is_none() { + if let Ok(s) = reason.extract::() { + Some(StopReason::String(s)) + } else if let Ok(i) = reason.extract::() { + Some(StopReason::Int(i)) + } else { + None + } + } else { + None + } + } else { + None + }; + + let events = if let Ok(events_py) = output.getattr("events") { + if !events_py.is_none() { + let mut events = Vec::new(); + for event in events_py.iter()? { + let event = event?; + let event_type = event.getattr("type")?.extract::()?; + let timestamp = event.getattr("timestamp")?.extract::()?; + + let event_type = match event_type { + 1 => EngineCoreEventType::Queued, + 2 => EngineCoreEventType::Scheduled, + 3 => EngineCoreEventType::Preempted, + _ => EngineCoreEventType::Queued, + }; + + events.push(EngineCoreEvent { + event_type, + timestamp, + }); + } + Some(events) + } else { + None + } + } else { + None + }; + + let num_cached_tokens = output + .getattr("num_cached_tokens")? + .extract::() + .unwrap_or(0); + + all_outputs.push(EngineCoreOutput { + request_id, + new_token_ids, + new_logprobs, + finish_reason, + stop_reason, + events, + num_cached_tokens, + }); + } + + // Use first engine index (typically 0) + let timestamp = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs_f64(); + + return Ok(EngineCoreOutputs { + engine_index, + outputs: all_outputs, + timestamp, + }); + } + + // Return empty if no outputs + Ok(EngineCoreOutputs { + engine_index: 0, + outputs: Vec::new(), + timestamp: 0.0, + }) +} \ No newline at end of file diff --git a/lib/bindings/python/rust/llm/scheduler_connector.rs b/lib/bindings/python/rust/llm/scheduler_connector.rs new file mode 100644 index 00000000000..6b199ad7580 --- /dev/null +++ b/lib/bindings/python/rust/llm/scheduler_connector.rs @@ -0,0 +1,202 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Python bindings for scheduler worker device blocks. + +use std::sync::Arc; + +use pyo3::prelude::*; + +use dynamo_llm::integrations::vllm::scheduler::worker::WorkerDeviceBlocks as RustWorkerDeviceBlocks; +use dynamo_llm::block_manager::storage::torch::{TorchDevice, TorchTensor}; + +use crate::to_pyerr; + +/// A wrapper around a Torch tensor for scheduler connector. +/// We hold onto the py object to ensure it doesn't get GCed. +#[derive(Clone, Debug)] +pub struct SchedulerTensor { + _py_tensor: Py, + device: TorchDevice, + data_ptr: u64, + size_bytes: usize, + shape: Vec, + stride: Vec, +} + +impl SchedulerTensor { + pub fn new(py_tensor: Py) -> anyhow::Result { + Python::with_gil(|py| { + let device = py_tensor.getattr(py, "device")?; + let device_type = device.getattr(py, "type")?.extract::(py)?; + + let device = if device_type == "cuda" { + TorchDevice::Cuda(device.getattr(py, "index")?.extract::(py)?) + } else { + TorchDevice::Other(device_type) + }; + + let data_ptr = py_tensor.call_method0(py, "data_ptr")?.extract::(py)?; + let size_bytes = py_tensor.getattr(py, "nbytes")?.extract::(py)?; + let shape = py_tensor.getattr(py, "shape")?.extract::>(py)?; + let stride = py_tensor + .call_method0(py, "stride")? + .extract::>(py)?; + + Ok(Self { + _py_tensor: py_tensor, + device, + data_ptr, + size_bytes, + shape, + stride, + }) + }) + } +} + +impl TorchTensor for SchedulerTensor { + fn device(&self) -> TorchDevice { + self.device.clone() + } + + fn data_ptr(&self) -> u64 { + self.data_ptr + } + + fn size_bytes(&self) -> usize { + self.size_bytes + } + + fn shape(&self) -> Vec { + self.shape.clone() + } + + fn stride(&self) -> Vec { + self.stride.clone() + } +} + +/// Python wrapper for WorkerDeviceBlocks. +/// +/// This class provides worker device block construction for the scheduler +/// without requiring leader/worker synchronization. +#[pyclass] +pub struct WorkerDeviceBlocks { + inner: Arc, +} + +#[pymethods] +impl WorkerDeviceBlocks { + /// Create local blocks from KV cache tensors. + /// + /// Args: + /// tensors: List of torch tensors (one per layer) + /// num_device_blocks: Number of device blocks + /// page_size: Page size (typically 16) + /// device_id: CUDA device ID + /// dtype_width_bytes: Bytes per dtype element (e.g., 2 for fp16) + /// is_fully_contiguous: Whether layout is fully contiguous + #[new] + #[pyo3(signature = (tensors, num_device_blocks, page_size, device_id=0, dtype_width_bytes=2, is_fully_contiguous=false))] + fn new( + tensors: Vec>, + num_device_blocks: usize, + page_size: usize, + device_id: usize, + dtype_width_bytes: usize, + is_fully_contiguous: bool, + ) -> PyResult { + // Convert Python tensors to Rust tensors + let mut rust_tensors: Vec> = Vec::with_capacity(tensors.len()); + + for tensor in tensors { + let scheduler_tensor = SchedulerTensor::new(tensor).map_err(to_pyerr)?; + rust_tensors.push(Arc::new(scheduler_tensor)); + } + + // Build worker device blocks + let worker_blocks = RustWorkerDeviceBlocks::from_tensors( + rust_tensors, + num_device_blocks, + page_size, + device_id, + dtype_width_bytes, + is_fully_contiguous, + ) + .map_err(to_pyerr)?; + + Ok(Self { + inner: Arc::new(worker_blocks), + }) + } + + /// Get the number of device blocks. + #[getter] + fn num_device_blocks(&self) -> usize { + self.inner.num_device_blocks + } + + /// Get the number of layers. + #[getter] + fn num_layers(&self) -> usize { + self.inner.num_layers + } + + /// Get the outer dimension. + #[getter] + fn outer_dim(&self) -> usize { + self.inner.outer_dim + } + + /// Get the page size. + #[getter] + fn page_size(&self) -> usize { + self.inner.page_size + } + + /// Get the inner dimension. + #[getter] + fn inner_dim(&self) -> usize { + self.inner.inner_dim + } + + /// Get the dtype width in bytes. + #[getter] + fn dtype_width_bytes(&self) -> usize { + self.inner.dtype_width_bytes + } + + /// Get the total bytes per block. + #[getter] + fn bytes_per_block(&self) -> usize { + self.inner.bytes_per_block + } + + /// Get the number of blocks that were created. + fn num_blocks(&self) -> usize { + self.inner.device_blocks.len() + } + + /// String representation for debugging. + fn __repr__(&self) -> String { + format!( + "WorkerDeviceBlocks(num_blocks={}, num_layers={}, outer_dim={}, page_size={}, inner_dim={}, dtype_width_bytes={}, bytes_per_block={})", + self.inner.device_blocks.len(), + self.inner.num_layers, + self.inner.outer_dim, + self.inner.page_size, + self.inner.inner_dim, + self.inner.dtype_width_bytes, + self.inner.bytes_per_block + ) + } +} + +/// Register the module with Python. +pub fn register_module(parent_module: &Bound<'_, PyModule>) -> PyResult<()> { + let m = PyModule::new(parent_module.py(), "scheduler_connector")?; + m.add_class::()?; + parent_module.add_submodule(&m)?; + Ok(()) +} \ No newline at end of file diff --git a/lib/bindings/python/rust/llm/vllm_scheduler.rs b/lib/bindings/python/rust/llm/vllm_scheduler.rs new file mode 100644 index 00000000000..8268e5e470e --- /dev/null +++ b/lib/bindings/python/rust/llm/vllm_scheduler.rs @@ -0,0 +1,84 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Python bindings for the Rust vLLM scheduler integration. +//! +//! This module provides a thin PyO3 wrapper around the core scheduler +//! implementation in dynamo_llm::integrations::vllm::scheduler. + +use pyo3::prelude::*; +use dynamo_llm::integrations::vllm::scheduler; + +/// PyO3 wrapper around the core RustSchedulerState. +/// This is a thin wrapper that forwards all calls to the core implementation. +#[pyclass] +pub struct RustSchedulerState { + /// The actual scheduler state from the core crate + inner: scheduler::RustSchedulerState, +} + +#[pymethods] +impl RustSchedulerState { + #[new] + pub fn new() -> Self { + Self { + inner: scheduler::RustSchedulerState::new(), + } + } + + /// Add a new request to the Rust scheduler state. + /// Called from Python when DynamoScheduler.add_request() is invoked. + /// Note: cache_salt is now passed as a string and converted to hash in Rust. + #[pyo3(signature = (request_id, prompt_token_ids, cache_salt=None, lora_int_id=None, priority=0, arrival_time=0.0))] + pub fn add_request( + &self, + request_id: String, + prompt_token_ids: Vec, + cache_salt: Option, + lora_int_id: Option, + priority: i32, + arrival_time: f64, + ) -> PyResult<()> { + self.inner + .add_request( + request_id, + prompt_token_ids, + cache_salt, + lora_int_id, + priority, + arrival_time, + ) + .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e)) + } + + /// Mark requests as finished without removing them. + /// Called from Python when finish_requests is invoked externally. + pub fn mark_as_finished(&self, request_ids: Vec) -> PyResult<()> { + self.inner + .mark_as_finished(request_ids) + .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e)) + } + + /// Remove finished requests from the Rust scheduler state. + /// Called from Python when processing finished_req_ids in update_from_output. + pub fn remove_finished_requests(&self, request_ids: Vec) -> PyResult<()> { + self.inner + .remove_finished_requests(request_ids) + .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e)) + } + + /// Get the current number of tracked requests. + pub fn num_requests(&self) -> usize { + self.inner.num_requests() + } + + /// Check if a request is being tracked. + pub fn has_request(&self, request_id: &str) -> bool { + self.inner.has_request(request_id) + } + + /// Get all currently tracked request IDs (for debugging). + pub fn get_request_ids(&self) -> Vec { + self.inner.get_request_ids() + } +} \ No newline at end of file diff --git a/lib/bindings/python/src/dynamo/llm/vllm_integration/__init__.py b/lib/bindings/python/src/dynamo/llm/vllm_integration/__init__.py index 0a6ded9a9ce..83b3429e9a4 100644 --- a/lib/bindings/python/src/dynamo/llm/vllm_integration/__init__.py +++ b/lib/bindings/python/src/dynamo/llm/vllm_integration/__init__.py @@ -3,8 +3,14 @@ # Import connector classes to make them available at the expected paths for vLLM from .connector.dynamo_connector import DynamoConnector, DynamoConnectorMetadata +from .scheduler import DynamoScheduler # Create module-level alias for backward compatibility dynamo_connector = DynamoConnector -__all__ = ["DynamoConnector", "DynamoConnectorMetadata", "dynamo_connector"] +__all__ = [ + "DynamoConnector", + "DynamoConnectorMetadata", + "dynamo_connector", + "DynamoScheduler", +] diff --git a/lib/bindings/python/src/dynamo/llm/vllm_integration/connector/__init__.py b/lib/bindings/python/src/dynamo/llm/vllm_integration/connector/__init__.py index 419f3011c30..769289db9c9 100644 --- a/lib/bindings/python/src/dynamo/llm/vllm_integration/connector/__init__.py +++ b/lib/bindings/python/src/dynamo/llm/vllm_integration/connector/__init__.py @@ -2,5 +2,14 @@ # SPDX-License-Identifier: Apache-2.0 from .dynamo_connector import DynamoConnector, DynamoConnectorMetadata +from .dynamo_scheduler_connector import ( + DynamoSchedulerConnector, + DynamoSchedulerConnectorMetadata, +) -__all__ = ["DynamoConnector", "DynamoConnectorMetadata"] +__all__ = [ + "DynamoConnector", + "DynamoConnectorMetadata", + "DynamoSchedulerConnector", + "DynamoSchedulerConnectorMetadata", +] diff --git a/lib/bindings/python/src/dynamo/llm/vllm_integration/connector/dynamo_scheduler_connector.py b/lib/bindings/python/src/dynamo/llm/vllm_integration/connector/dynamo_scheduler_connector.py new file mode 100644 index 00000000000..483fceae99f --- /dev/null +++ b/lib/bindings/python/src/dynamo/llm/vllm_integration/connector/dynamo_scheduler_connector.py @@ -0,0 +1,165 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Dynamo Scheduler Connector implementation for vLLM. + +This connector uses minimal scheduler-specific implementations that provide +no-op responses, used for scheduler integration testing without KV transfer. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Optional + +import torch +from typing_extensions import override +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, + KVConnectorMetadata, + KVConnectorRole, +) +from vllm.v1.core.sched.output import SchedulerOutput + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionMetadata + from vllm.config import VllmConfig + from vllm.forward_context import ForwardContext + from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.request import Request + +# Import our minimal scheduler connector implementations +from dynamo.llm.vllm_integration.scheduler_conn_leader import SchedulerConnectorLeader +from dynamo.llm.vllm_integration.scheduler_conn_worker import SchedulerConnectorWorker + +EngineId = str + + +class DynamoSchedulerConnectorMetadata(KVConnectorMetadata): + """Minimal metadata container for scheduler connector.""" + + def __init__(self, metadata: bytes): + assert isinstance(metadata, bytes) + self.metadata = metadata + + +class DynamoSchedulerConnector(KVConnectorBase_V1): + """ + Dynamo Scheduler Connector that uses minimal no-op implementations. + + This connector is specifically for scheduler integration testing and + provides no actual KV transfer functionality. + """ + + def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): + super().__init__(vllm_config=vllm_config, role=role) + + assert vllm_config.kv_transfer_config is not None + assert vllm_config.kv_transfer_config.engine_id is not None + self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id + + if role == KVConnectorRole.SCHEDULER: + self._scheduler = SchedulerConnectorLeader( + vllm_config=vllm_config, engine_id=self.engine_id + ) + self._worker = None + elif role == KVConnectorRole.WORKER: + self._worker = SchedulerConnectorWorker( + vllm_config=vllm_config, engine_id=self.engine_id + ) + self._scheduler = None + else: + # KV_BOTH role - create both scheduler and worker + self._scheduler = SchedulerConnectorLeader( + vllm_config=vllm_config, engine_id=self.engine_id + ) + self._worker = SchedulerConnectorWorker( + vllm_config=vllm_config, engine_id=self.engine_id + ) + + # Scheduler/Leader methods + + def get_num_new_matched_tokens( + self, + request: "Request", + num_computed_tokens: int, + ) -> tuple[int, bool]: + """Always returns (0, False) - no external tokens available.""" + return self._scheduler.get_num_new_matched_tokens(request, num_computed_tokens) + + def update_state_after_alloc( + self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int + ): + """No-op since we never have external tokens.""" + self._scheduler.update_state_after_alloc(request, blocks, num_external_tokens) + + def build_connector_meta( + self, scheduler_output: SchedulerOutput + ) -> KVConnectorMetadata: + """Build minimal connector metadata (empty bytes).""" + data = self._scheduler.build_connector_meta(scheduler_output) + return DynamoSchedulerConnectorMetadata(data) + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, Optional[dict[str, Any]]]: + """Never delays block freeing - returns (False, None).""" + return self._scheduler.request_finished(request, block_ids) + + # Worker methods + + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + """Register KV caches - no-op for scheduler connector.""" + if self._worker: + self._worker.register_kv_caches(kv_caches) + + def bind_connector_metadata( + self, connector_metadata: DynamoSchedulerConnectorMetadata + ) -> None: + """Bind connector metadata - no-op.""" + if self._worker: + assert isinstance(connector_metadata.metadata, bytes) + self._worker.bind_connector_metadata(connector_metadata.metadata) + + def clear_connector_metadata(self) -> None: + """Clear connector metadata - no-op.""" + if self._worker: + self._worker.clear_connector_metadata() + + @override + def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: + """Start loading KV cache - no-op for scheduler connector.""" + if self._worker: + self._worker.start_load_kv(forward_context, **kwargs) + + @override + def wait_for_layer_load(self, layer_name: str) -> None: + """Wait for layer load - no-op.""" + pass + + @override + def save_kv_layer( + self, + layer_name: str, + kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", + **kwargs, + ) -> None: + """Save KV layer - no-op for scheduler connector.""" + if self._worker: + self._worker.save_kv_layer(layer_name, kv_layer, attn_metadata, **kwargs) + + @override + def wait_for_save(self): + """Wait for save - no-op.""" + pass + + def get_finished( + self, finished_req_ids: set[str] + ) -> tuple[Optional[set[str]], Optional[set[str]]]: + """Get finished request IDs - always returns (None, None).""" + if self._worker: + return self._worker.get_finished(finished_req_ids) + return (None, None) diff --git a/lib/bindings/python/src/dynamo/llm/vllm_integration/recording_scheduler.py b/lib/bindings/python/src/dynamo/llm/vllm_integration/recording_scheduler.py new file mode 100644 index 00000000000..6005b69a294 --- /dev/null +++ b/lib/bindings/python/src/dynamo/llm/vllm_integration/recording_scheduler.py @@ -0,0 +1,282 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Recording scheduler that captures vLLM scheduler behavior for Rust implementation. + +This scheduler wraps the DynamoScheduler and records all inputs/outputs +in a format suitable for replay by a Rust scheduler implementation. +""" + +import json +import os +import time +from dataclasses import asdict, dataclass +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, List, Optional + +from vllm.v1.core.sched.interface import SchedulerInterface + + +@dataclass +class RecordedIteration: + """A single recorded iteration of the scheduler""" + + iteration: int + schedule_output: Dict[str, Any] + model_runner_output: Dict[str, Any] + engine_core_outputs: Dict[str, Any] + timestamp: float + + +class RecordingScheduler(SchedulerInterface): + """ + Scheduler that records all operations for later replay. + + This scheduler forwards all operations to the underlying vLLM scheduler + while recording the inputs and outputs for analysis and replay. + """ + + def __init__( + self, + *args, + enable_recording: bool = True, + recording_path: Optional[Path] = None, + **kwargs, + ): + """ + Initialize the recording scheduler. + + Args: + enable_recording: Whether to enable recording + recording_path: Path to save the recording (defaults to .sandbox/recordings/) + *args, **kwargs: Passed to the underlying scheduler + """ + # Determine which scheduler to use based on environment variable + scheduler_type = os.getenv("DYN_VLLM_RECORD_SCHEDULER_CLS", "dynamo").lower() + + if scheduler_type == "vllm": + # Use vLLM's default scheduler directly + from vllm.v1.core.sched.scheduler import Scheduler + + self._wrapped_scheduler = Scheduler(*args, **kwargs) + print("RecordingScheduler: Using vLLM Scheduler") + elif scheduler_type == "dynamo": + # Use DynamoScheduler (which wraps vLLM scheduler) + from .scheduler import DynamoScheduler + + self._wrapped_scheduler = DynamoScheduler(*args, **kwargs) + print("RecordingScheduler: Using DynamoScheduler") + else: + raise ValueError( + f"Invalid scheduler type: {scheduler_type}. " + f"DYN_VLLM_RECORD_SCHEDULER_CLS must be 'dynamo' or 'vllm'" + ) + + self.enable_recording = enable_recording + self.iteration = 0 + self.recordings: List[RecordedIteration] = [] + self.current_schedule_output = None + + if recording_path: + self.recording_path = Path(recording_path) + else: + # Default to .sandbox/recordings/ in current working directory + self.recording_path = Path.cwd() / ".sandbox" / "recordings" + + # Create recordings directory if it doesn't exist + if self.enable_recording: + self.recording_path.mkdir(parents=True, exist_ok=True) + print(f"Recording enabled. Will save to: {self.recording_path}") + + def schedule(self): + """Schedule requests and record the output.""" + output = self._wrapped_scheduler.schedule() + + if self.enable_recording: + # Convert SchedulerOutput to dict + self.current_schedule_output = self._scheduler_output_to_dict(output) + + return output + + def update_from_output(self, scheduler_output, model_runner_output): + """Update from model output and record.""" + result = self._wrapped_scheduler.update_from_output( + scheduler_output, model_runner_output + ) + + if self.enable_recording and self.current_schedule_output: + # Record the complete iteration + iteration = RecordedIteration( + iteration=self.iteration, + schedule_output=self.current_schedule_output, + model_runner_output=self._model_runner_output_to_dict( + model_runner_output + ), + engine_core_outputs=self._engine_core_outputs_to_dict(result), + timestamp=time.time(), + ) + self.recordings.append(iteration) + + # Increment iteration counter + self.iteration += 1 + self.current_schedule_output = None + + # Optionally save incrementally (every 10 iterations) + if self.iteration % 10 == 0: + self.save_recording(incremental=True) + + return result + + def _scheduler_output_to_dict(self, output) -> Dict[str, Any]: + """Convert SchedulerOutput to a dictionary.""" + try: + return { + "scheduled_new_reqs": [ + { + "req_id": req.req_id, + "prompt_token_ids": req.prompt_token_ids, + "block_ids": [list(blocks) for blocks in req.block_ids] + if req.block_ids + else [], + "num_computed_tokens": req.num_computed_tokens, + "mm_hashes": req.mm_hashes if hasattr(req, "mm_hashes") else [], + } + for req in output.scheduled_new_reqs + ], + "scheduled_cached_reqs": { + "req_ids": output.scheduled_cached_reqs.req_ids, + "resumed_from_preemption": output.scheduled_cached_reqs.resumed_from_preemption, + "new_token_ids": output.scheduled_cached_reqs.new_token_ids, + "new_block_ids": [ + [list(blocks) for blocks in block_ids] if block_ids else None + for block_ids in output.scheduled_cached_reqs.new_block_ids + ], + "num_computed_tokens": output.scheduled_cached_reqs.num_computed_tokens, + }, + "num_scheduled_tokens": dict(output.num_scheduled_tokens), + "total_num_scheduled_tokens": output.total_num_scheduled_tokens, + "scheduled_spec_decode_tokens": dict( + output.scheduled_spec_decode_tokens + ), + "scheduled_encoder_inputs": dict(output.scheduled_encoder_inputs), + "num_common_prefix_blocks": list(output.num_common_prefix_blocks), + "finished_req_ids": list(output.finished_req_ids), + "free_encoder_mm_hashes": list(output.free_encoder_mm_hashes), + } + except Exception as e: + print(f"Error converting SchedulerOutput: {e}") + return {} + + def _model_runner_output_to_dict(self, output) -> Dict[str, Any]: + """Convert ModelRunnerOutput to a dictionary.""" + try: + result = { + "req_ids": output.req_ids, + "req_id_to_index": dict(output.req_id_to_index), + "sampled_token_ids": output.sampled_token_ids, + } + + if output.logprobs: + result["logprobs"] = { + "logprob_token_ids": output.logprobs.logprob_token_ids, + "logprobs": output.logprobs.logprobs, + "sampled_token_ranks": output.logprobs.sampled_token_ranks, + } + + if hasattr(output, "num_nans_in_logits") and output.num_nans_in_logits: + result["num_nans_in_logits"] = dict(output.num_nans_in_logits) + + return result + except Exception as e: + print(f"Error converting ModelRunnerOutput: {e}") + return {} + + def _engine_core_outputs_to_dict(self, outputs) -> Dict[str, Any]: + """Convert EngineCoreOutputs (dict) to a serializable format.""" + try: + result = {} + for engine_idx, engine_outputs in outputs.items(): + result[str(engine_idx)] = [ + { + "request_id": output.request_id, + "new_token_ids": output.new_token_ids, + "finish_reason": output.finish_reason.value + if output.finish_reason + else None, + "num_cached_tokens": getattr(output, "num_cached_tokens", 0), + } + for output in engine_outputs + ] + return result + except Exception as e: + print(f"Error converting EngineCoreOutputs: {e}") + return {} + + def save_recording(self, incremental: bool = False): + """Save the recording to a JSON file.""" + if not self.recordings: + print("No recordings to save") + return + + # Create filename with timestamp + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + suffix = "_incremental" if incremental else "" + filename = f"scheduler_trace_{timestamp}{suffix}.json" + filepath = self.recording_path / filename + + # Create the trace structure + trace = { + "metadata": { + "vllm_version": "0.10.2", # Could get this dynamically + "model": "gpt2", # Should be passed in + "timestamp": datetime.now().isoformat(), + "total_iterations": len(self.recordings), + }, + "iterations": [asdict(rec) for rec in self.recordings], + } + + # Save to file + with open(filepath, "w") as f: + json.dump(trace, f, indent=2) + + print(f"Saved {len(self.recordings)} iterations to {filepath}") + + def shutdown(self): + """Save recording and shutdown.""" + if self.enable_recording and self.recordings: + self.save_recording() + self._wrapped_scheduler.shutdown() + + def add_request(self, request) -> None: + """Add a new request to the scheduler.""" + self._wrapped_scheduler.add_request(request) + + def finish_requests(self, request_ids, finished_status) -> None: + """Mark requests as finished.""" + self._wrapped_scheduler.finish_requests(request_ids, finished_status) + + def get_num_unfinished_requests(self) -> int: + """Get the number of unfinished requests.""" + return self._wrapped_scheduler.get_num_unfinished_requests() + + def has_finished_requests(self) -> bool: + """Check if there are any finished requests.""" + return self._wrapped_scheduler.has_finished_requests() + + def reset_prefix_cache(self) -> bool: + """Reset the prefix cache.""" + return self._wrapped_scheduler.reset_prefix_cache() + + def get_request_counts(self): + """Get counts of requests in different states.""" + return self._wrapped_scheduler.get_request_counts() + + def make_stats(self): + """Generate statistics about the scheduler's current state.""" + return self._wrapped_scheduler.make_stats() + + def update_draft_token_ids(self, draft_token_ids) -> None: + """Update draft token IDs for scheduled requests.""" + return self._wrapped_scheduler.update_draft_token_ids(draft_token_ids) diff --git a/lib/bindings/python/src/dynamo/llm/vllm_integration/scheduler.py b/lib/bindings/python/src/dynamo/llm/vllm_integration/scheduler.py new file mode 100644 index 00000000000..c34e0ca35db --- /dev/null +++ b/lib/bindings/python/src/dynamo/llm/vllm_integration/scheduler.py @@ -0,0 +1,269 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Dynamo Scheduler implementation that forwards to vLLM's default scheduler. + +This module provides a custom scheduler that acts as a springboard to vLLM's +default scheduler implementation, allowing for future customization while +maintaining compatibility with vLLM's scheduling interface. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Dict, Iterable, Optional, Tuple, Union + +from vllm.v1.core.sched.interface import SchedulerInterface +from vllm.v1.core.sched.scheduler import Scheduler + +try: + from dynamo._core import RustSchedulerState +except ImportError: + RustSchedulerState = None + print("Warning: Could not import RustSchedulerState from dynamo._core") + +if TYPE_CHECKING: + from vllm.config import VllmConfig + from vllm.multimodal import MultiModalRegistry + from vllm.transformers_utils.structured_outputs import StructuredOutputManager + from vllm.v1.core.kv_cache_manager import KVCacheConfig + from vllm.v1.core.sched.output import SchedulerOutput + from vllm.v1.core.scheduler import DraftTokenIds, ModelRunnerOutput, SchedulerStats + from vllm.v1.outputs import EngineCoreOutputs + from vllm.v1.request import Request, RequestStatus + + +class DynamoScheduler(SchedulerInterface): + """ + Custom scheduler that forwards all operations to vLLM's default Scheduler. + + This scheduler acts as a transparent proxy, allowing for future customization + of scheduling behavior while maintaining full compatibility with vLLM. + """ + + def __init__( + self, + vllm_config: "VllmConfig", + kv_cache_config: "KVCacheConfig", + structured_output_manager: "StructuredOutputManager", + mm_registry: Optional["MultiModalRegistry"] = None, + include_finished_set: bool = False, + log_stats: bool = False, + ) -> None: + """ + Initialize the DynamoScheduler with a wrapped vLLM Scheduler. + + Args: + vllm_config: vLLM configuration object + kv_cache_config: KV cache configuration + structured_output_manager: Manager for structured outputs + mm_registry: Multi-modal registry (optional, will use default if None) + include_finished_set: Whether to include finished requests + log_stats: Whether to log statistics + """ + # Import here to handle optional mm_registry parameter + from vllm.multimodal import MULTIMODAL_REGISTRY + + # Use provided registry or default + if mm_registry is None: + mm_registry = MULTIMODAL_REGISTRY + + # Create the underlying vLLM scheduler + self._scheduler = Scheduler( + vllm_config=vllm_config, + kv_cache_config=kv_cache_config, + structured_output_manager=structured_output_manager, + mm_registry=mm_registry, + include_finished_set=include_finished_set, + log_stats=log_stats, + ) + + # Initialize Rust scheduler state if available + if RustSchedulerState is not None: + self._rust_scheduler = RustSchedulerState() + print("DynamoScheduler: Rust scheduler state initialized") + else: + self._rust_scheduler = None + + def schedule(self) -> "SchedulerOutput": + """ + Schedule requests for the next model forward pass. + + Returns: + SchedulerOutput containing scheduling decisions + """ + return self._scheduler.schedule() + + def update_from_output( + self, + scheduler_output: "SchedulerOutput", + model_runner_output: "ModelRunnerOutput", + ) -> Dict[int, "EngineCoreOutputs"]: + """ + Update scheduler state after model processing. + + Args: + scheduler_output: Output from the schedule() method + model_runner_output: Output from the model runner + + Returns: + Dictionary mapping request IDs to engine core outputs + """ + result = self._scheduler.update_from_output( + scheduler_output, model_runner_output + ) + + # Remove finished requests from Rust scheduler + if self._rust_scheduler is not None and hasattr( + scheduler_output, "finished_req_ids" + ): + try: + finished_ids = list(scheduler_output.finished_req_ids) + if finished_ids: + self._rust_scheduler.remove_finished_requests(finished_ids) + print( + f"DynamoScheduler: Removed {len(finished_ids)} finished requests from Rust scheduler" + ) + except Exception as e: + print( + f"DynamoScheduler: Error removing finished requests from Rust scheduler: {e}" + ) + + return result + + def update_draft_token_ids( + self, + draft_token_ids: "DraftTokenIds", + ) -> None: + """ + Update draft token IDs for scheduled requests. + + Args: + draft_token_ids: Draft token IDs to update + """ + self._scheduler.update_draft_token_ids(draft_token_ids) + + def add_request(self, request: "Request") -> None: + """ + Add a new request to the scheduler. + + Args: + request: Request object to add to the scheduler + """ + # Pass request to Rust scheduler if available + if self._rust_scheduler is not None: + try: + # Extract data available at add_request time + request_id = request.request_id + prompt_token_ids = request.prompt_token_ids + + # Pass cache_salt as string - Rust will handle the hashing + cache_salt = getattr(request, "cache_salt", None) + + # Extract LoRA ID if present + lora_int_id = None + if hasattr(request, "lora_request") and request.lora_request: + lora_int_id = request.lora_request.lora_int_id + + # Get priority and arrival time + priority = getattr(request, "priority", 0) + arrival_time = getattr(request, "arrival_time", 0.0) + + # Add to Rust scheduler (cache_salt is now passed as string) + self._rust_scheduler.add_request( + request_id=request_id, + prompt_token_ids=list(prompt_token_ids), # Convert to list + cache_salt=cache_salt, # Pass as string, Rust converts to u64 + lora_int_id=lora_int_id, + priority=priority, + arrival_time=arrival_time, + ) + except Exception as e: + print(f"DynamoScheduler: Error adding request to Rust scheduler: {e}") + + # Always add to vLLM scheduler + self._scheduler.add_request(request) + + def finish_requests( + self, + request_ids: Union[str, Iterable[str]], + finished_status: "RequestStatus", + ) -> None: + """ + Mark requests as finished. + + Args: + request_ids: Request ID(s) to mark as finished + finished_status: The finish status for the requests + """ + # Mark as finished in Rust scheduler (doesn't remove them yet) + if self._rust_scheduler is not None: + try: + # Ensure request_ids is a list + if isinstance(request_ids, str): + ids_list = [request_ids] + else: + ids_list = list(request_ids) + + self._rust_scheduler.mark_as_finished(ids_list) + print( + f"DynamoScheduler: Marked {len(ids_list)} requests as finished in Rust scheduler" + ) + except Exception as e: + print( + f"DynamoScheduler: Error marking requests as finished in Rust scheduler: {e}" + ) + + # Always call vLLM scheduler to handle the actual state transitions + self._scheduler.finish_requests(request_ids, finished_status) + + def get_num_unfinished_requests(self) -> int: + """ + Get the number of unfinished requests. + + Returns: + Number of unfinished requests in the scheduler + """ + return self._scheduler.get_num_unfinished_requests() + + def has_finished_requests(self) -> bool: + """ + Check if there are any finished requests. + + Returns: + True if there are finished requests, False otherwise + """ + return self._scheduler.has_finished_requests() + + def reset_prefix_cache(self) -> bool: + """ + Reset the prefix cache. + + Returns: + True if cache was reset successfully + """ + return self._scheduler.reset_prefix_cache() + + def get_request_counts(self) -> Tuple[int, int]: + """ + Get counts of requests in different states. + + Returns: + Tuple of (waiting_count, running_count) + """ + return self._scheduler.get_request_counts() + + def make_stats(self) -> Optional["SchedulerStats"]: + """ + Generate statistics about the scheduler's current state. + + Returns: + SchedulerStats object or None + """ + return self._scheduler.make_stats() + + def shutdown(self) -> None: + """ + Shutdown the scheduler and clean up resources. + """ + self._scheduler.shutdown() diff --git a/lib/bindings/python/src/dynamo/llm/vllm_integration/scheduler_conn_leader.py b/lib/bindings/python/src/dynamo/llm/vllm_integration/scheduler_conn_leader.py new file mode 100644 index 00000000000..282c23bb553 --- /dev/null +++ b/lib/bindings/python/src/dynamo/llm/vllm_integration/scheduler_conn_leader.py @@ -0,0 +1,85 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Minimal scheduler connector leader implementation for testing. + +This is a barebones implementation that returns minimal/no-op responses, +used specifically for scheduler integration testing without actual KV transfer. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Optional + +if TYPE_CHECKING: + from vllm.config import VllmConfig + from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.core.sched.output import SchedulerOutput + from vllm.v1.request import Request + + +class SchedulerConnectorLeader: + """ + Minimal scheduler connector leader that returns no-op responses. + + This connector is used for scheduler integration where no actual + KV transfer is needed. All methods return minimal valid responses. + """ + + def __init__(self, vllm_config: "VllmConfig", engine_id: str, **kwargs): + """Initialize the scheduler connector leader.""" + self.vllm_config = vllm_config + self.engine_id = engine_id + print(f"SchedulerConnectorLeader initialized with engine_id: {engine_id}") + + def get_num_new_matched_tokens( + self, + request: "Request", + num_computed_tokens: int, + ) -> tuple[int, bool]: + """ + Always returns (0, False) indicating no external tokens available. + + Returns: + (0, False): No external tokens, no async loading + """ + return (0, False) + + def update_state_after_alloc( + self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int + ) -> None: + """ + No-op since we never have external tokens. + + This should never be called with num_external_tokens > 0. + """ + if num_external_tokens > 0: + print( + f"Warning: update_state_after_alloc called with {num_external_tokens} " + f"external tokens for request {request.request_id}, but scheduler " + "connector always returns 0 external tokens" + ) + + def build_connector_meta(self, scheduler_output: SchedulerOutput) -> bytes: + """ + Build minimal connector metadata. + + Returns: + Empty bytes object + """ + # Return empty bytes - minimal valid metadata + return bytes() + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, Optional[dict[str, Any]]]: + """ + Never delays block freeing. + + Returns: + (False, None): Don't delay block freeing, no KV transfer params + """ + return (False, None) diff --git a/lib/bindings/python/src/dynamo/llm/vllm_integration/scheduler_conn_worker.py b/lib/bindings/python/src/dynamo/llm/vllm_integration/scheduler_conn_worker.py new file mode 100644 index 00000000000..b22874a74b9 --- /dev/null +++ b/lib/bindings/python/src/dynamo/llm/vllm_integration/scheduler_conn_worker.py @@ -0,0 +1,175 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Minimal scheduler connector worker implementation for testing. + +This is a barebones implementation that provides no-op responses, +used specifically for scheduler integration testing without actual KV transfer. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional + +import torch +from vllm.model_executor.models.utils import extract_layer_index +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE + +# Import our local block builder +from dynamo._core import scheduler_connector + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionMetadata + from vllm.config import VllmConfig + from vllm.forward_context import ForwardContext + + +class SchedulerConnectorWorker: + """ + Minimal scheduler connector worker that provides no-op implementations. + + This connector is used for scheduler integration where no actual + KV transfer is needed. All methods are no-ops or return minimal responses. + """ + + def __init__(self, vllm_config: "VllmConfig", engine_id: str, **kwargs): + """Initialize the scheduler connector worker.""" + self.vllm_config = vllm_config + self.engine_id = engine_id + self.local_blocks = None + print(f"SchedulerConnectorWorker initialized with engine_id: {engine_id}") + + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]) -> None: + """ + Register KV caches - builds local blocks without leader sync. + + This creates device blocks locally from the provided tensors + without requiring any network setup or synchronization. + """ + if not kv_caches: + print("Warning: register_kv_caches called with empty kv_caches") + return + + print( + f"SchedulerConnectorWorker.register_kv_caches called with {len(kv_caches)} layers" + ) + + # Extract configuration from vLLM config + cache_config = self.vllm_config.cache_config + + # Sort tensors by layer index to ensure correct ordering + ordered_kv_caches = sorted( + kv_caches.items(), key=lambda item: extract_layer_index(item[0]) + ) + + # Extract tensors in order + tensors = [tensor for _, tensor in ordered_kv_caches] + + # Get first tensor to extract common properties + first_tensor = tensors[0] + shape = first_tensor.shape + + # Validate all tensors have same shape + if not all(t.shape == shape for t in tensors): + raise NotImplementedError( + "Hybrid models with different KV cache shapes are not supported yet." + ) + + # Extract parameters + # TODO: Assume the block dimension is within the first 2. This will break if you're doing something weird + num_device_blocks = max(shape[0], shape[1]) + page_size = cache_config.block_size + device_id = ( + first_tensor.device.index if first_tensor.device.type == "cuda" else 0 + ) + + # Determine cache dtype + if cache_config.cache_dtype == "auto": + kv_cache_dtype = self.vllm_config.model_config.dtype + else: + kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] + + dtype_width_bytes = kv_cache_dtype.itemsize + + # Build worker device blocks + try: + self.local_blocks = scheduler_connector.WorkerDeviceBlocks( + tensors=tensors, + num_device_blocks=num_device_blocks, + page_size=page_size, + device_id=device_id, + dtype_width_bytes=dtype_width_bytes, + is_fully_contiguous=False, # Default to layer-separate layout + ) + + print(f"Successfully built worker device blocks: {self.local_blocks}") + print(f" - Blocks created: {self.local_blocks.num_blocks()}") + print(f" - Layers: {self.local_blocks.num_layers}") + print(f" - Outer dim: {self.local_blocks.outer_dim}") + print(f" - Page size: {self.local_blocks.page_size}") + print(f" - Inner dim: {self.local_blocks.inner_dim}") + print(f" - Bytes per block: {self.local_blocks.bytes_per_block}") + + except Exception as e: + print(f"Failed to build worker device blocks: {e}") + raise + + def bind_connector_metadata(self, data: bytes) -> None: + """ + Bind connector metadata - no-op. + + Since our leader returns empty bytes, this is always a no-op. + """ + pass + + def clear_connector_metadata(self) -> None: + """ + Clear connector metadata - no-op. + """ + pass + + def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: + """ + Start loading KV cache - no-op. + + No KV loading needed for scheduler connector. + """ + pass + + def save_kv_layer( + self, + layer_name: str, + kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", + **kwargs, + ) -> None: + """ + Save KV layer - no-op. + + No KV saving needed for scheduler connector. + """ + pass + + def get_finished( + self, finished_req_ids: set[str] + ) -> tuple[Optional[set[str]], Optional[set[str]]]: + """ + Get finished request IDs. + + Since request_finished() always returns False (never delays block freeing), + we just acknowledge the finished requests but don't return any as finished + for KV transfer purposes. + + Returns: + (None, None): No finished sends/receives + """ + # Just acknowledge the finished requests + # Since our leader's request_finished() always returns False, + # these requests have already had their blocks freed + if len(finished_req_ids) > 0: + print( + f"SchedulerConnectorWorker.get_finished() acknowledging {len(finished_req_ids)} finished requests" + ) + + return (None, None) diff --git a/lib/llm/Cargo.toml b/lib/llm/Cargo.toml index 3174b9523d0..00c401496ed 100644 --- a/lib/llm/Cargo.toml +++ b/lib/llm/Cargo.toml @@ -25,7 +25,7 @@ readme.workspace = true description = "Dynamo LLM Library" [features] -default = [] +default = ["block-manager"] # todo(ops): get this working in CI as a default. # default = ["block-manager", "testing-full"] @@ -47,6 +47,11 @@ harness = false name = "transfer_context_v2" harness = false required-features = ["block-manager", "testing-cuda"] + +[[bench]] +name = "block_manager_v2" +harness = false +required-features = ["block-manager"] [dependencies] # repo dynamo-runtime = { workspace = true } @@ -102,6 +107,7 @@ offset-allocator = "0.2" regex = "1" rayon = "1" dashmap = { version = "5.5.3" } +lru = "0.16" # input/text dialoguer = { version = "0.11", default-features = false, features = ["editor", "history"] } diff --git a/lib/llm/benches/block_manager_v2.rs b/lib/llm/benches/block_manager_v2.rs new file mode 100644 index 00000000000..505028cede6 --- /dev/null +++ b/lib/llm/benches/block_manager_v2.rs @@ -0,0 +1,414 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Block Manager V2 Performance Benchmarks +//! +//! Benchmarks for performance-sensitive operations in the block manager v2: +//! - Block registration and deregistration +//! - Block lookup and matching +//! - Drop implementations and cleanup +//! - Block allocation and reuse + +use std::hint::black_box; + +use criterion::{BenchmarkId, Criterion, Throughput, criterion_group, criterion_main}; + +use dynamo_llm::block_manager::v2::manager::BlockManager; +use dynamo_llm::tokens::{SequenceHash, TokenBlockSequence}; + +/// Test metadata for benchmarks +#[derive(Debug, Clone, PartialEq)] +struct BenchData { + value: u64, +} + +/// Create a token block for benchmarking +fn create_bench_token_block(start: u32, size: usize) -> dynamo_llm::tokens::TokenBlock { + let tokens: Vec = (start..start + size as u32).collect(); + let token_sequence = TokenBlockSequence::from_slice(&tokens, size as u32, Some(42)); + token_sequence + .blocks() + .first() + .cloned() + .expect("Should have at least one block for the given token sequence") +} + +/// Setup a manager for benchmarking +fn create_bench_manager(block_count: usize, block_size: usize) -> BlockManager { + BlockManager::::builder() + .block_count(block_count) + .block_size(block_size) + .with_lru_backend() + .build() + .expect("Should build manager") +} + +/// Generate sequence hashes for lookup benchmarks +fn generate_sequence_hashes(count: usize, block_size: usize) -> Vec { + (0..count) + .map(|i| create_bench_token_block(i as u32 * 100, block_size).sequence_hash()) + .collect() +} + +// ============================================================================= +// REGISTRATION BENCHMARKS +// ============================================================================= + +fn registration_benchmarks(c: &mut Criterion) { + let mut group = c.benchmark_group("registration"); + + // Single block registration + group.bench_function("register_single_block", |b| { + b.iter_batched( + || { + let manager = create_bench_manager(1000, 4); + let token_block = create_bench_token_block(100, 4); + let mutable_blocks = manager.allocate_blocks(1).expect("Should allocate"); + let complete_block = mutable_blocks + .into_iter() + .next() + .unwrap() + .complete(token_block) + .expect("Should complete"); + (manager, complete_block) + }, + |(manager, complete_block)| { + black_box(manager.register_blocks(vec![complete_block])); + }, + criterion::BatchSize::SmallInput, + ); + }); + + // Bulk registration + for size in &[10, 100, 1000] { + group.throughput(Throughput::Elements(*size as u64)); + group.bench_with_input( + BenchmarkId::new("register_bulk_blocks", size), + size, + |b, &size| { + b.iter_batched( + || { + let manager = create_bench_manager(size + 100, 4); + let mut complete_blocks = Vec::new(); + + for i in 0..size { + let token_block = create_bench_token_block(i as u32 * 10, 4); + let mutable_blocks = + manager.allocate_blocks(1).expect("Should allocate"); + let complete_block = mutable_blocks + .into_iter() + .next() + .unwrap() + .complete(token_block) + .expect("Should complete"); + complete_blocks.push(complete_block); + } + + (manager, complete_blocks) + }, + |(manager, complete_blocks)| { + black_box(manager.register_blocks(complete_blocks)); + }, + criterion::BatchSize::SmallInput, + ); + }, + ); + } + + // Registration with duplicates (deduplication overhead) + group.bench_function("register_with_duplicates", |b| { + b.iter_batched( + || { + let manager = create_bench_manager(100, 4); + let token_block = create_bench_token_block(500, 4); // Same block + let mut complete_blocks = Vec::new(); + + // Create 10 identical blocks + for _ in 0..10 { + let mutable_blocks = manager.allocate_blocks(1).expect("Should allocate"); + let complete_block = mutable_blocks + .into_iter() + .next() + .unwrap() + .complete(token_block.clone()) + .expect("Should complete"); + complete_blocks.push(complete_block); + } + + (manager, complete_blocks) + }, + |(manager, complete_blocks)| { + black_box(manager.register_blocks(complete_blocks)); + }, + criterion::BatchSize::SmallInput, + ); + }); + + group.finish(); +} + +// ============================================================================= +// LOOKUP/MATCHING BENCHMARKS +// ============================================================================= + +fn lookup_benchmarks(c: &mut Criterion) { + let mut group = c.benchmark_group("lookup"); + + // Single block lookup + group.bench_function("find_single_block", |b| { + b.iter_batched( + || { + let manager = create_bench_manager(1000, 4); + let token_block = create_bench_token_block(100, 4); + let seq_hash = token_block.sequence_hash(); + + // Register the block + let mutable_blocks = manager.allocate_blocks(1).expect("Should allocate"); + let complete_block = mutable_blocks + .into_iter() + .next() + .unwrap() + .complete(token_block) + .expect("Should complete"); + manager.register_blocks(vec![complete_block]); + + (manager, vec![seq_hash]) + }, + |(manager, hashes)| { + black_box(manager.match_blocks(&hashes)); + }, + criterion::BatchSize::SmallInput, + ); + }); + + // Multiple block lookup + for size in &[10, 100, 1000] { + group.throughput(Throughput::Elements(*size as u64)); + group.bench_with_input( + BenchmarkId::new("find_multiple_blocks", size), + size, + |b, &size| { + b.iter_batched( + || { + let manager = create_bench_manager(size + 100, 4); + let mut hashes = Vec::new(); + let mut complete_blocks = Vec::new(); + + // Create and register blocks + for i in 0..size { + let token_block = create_bench_token_block(i as u32 * 10, 4); + hashes.push(token_block.sequence_hash()); + + let mutable_blocks = + manager.allocate_blocks(1).expect("Should allocate"); + let complete_block = mutable_blocks + .into_iter() + .next() + .unwrap() + .complete(token_block) + .expect("Should complete"); + complete_blocks.push(complete_block); + } + + manager.register_blocks(complete_blocks); + (manager, hashes) + }, + |(manager, hashes)| { + black_box(manager.match_blocks(&hashes)); + }, + criterion::BatchSize::SmallInput, + ); + }, + ); + } + + // Lookup miss (blocks not found) + group.bench_function("find_miss", |b| { + b.iter_batched( + || { + let manager = create_bench_manager(100, 4); + let nonexistent_hashes = generate_sequence_hashes(10, 4); + (manager, nonexistent_hashes) + }, + |(manager, hashes)| { + black_box(manager.match_blocks(&hashes)); + }, + criterion::BatchSize::SmallInput, + ); + }); + + // Partial match (stops on first miss) + group.bench_function("find_partial_match", |b| { + b.iter_batched( + || { + let manager = create_bench_manager(100, 4); + + // Register first 5 blocks + let mut complete_blocks = Vec::new(); + let mut hashes = Vec::new(); + + for i in 0..5 { + let token_block = create_bench_token_block(i as u32 * 10, 4); + hashes.push(token_block.sequence_hash()); + + let mutable_blocks = manager.allocate_blocks(1).expect("Should allocate"); + let complete_block = mutable_blocks + .into_iter() + .next() + .unwrap() + .complete(token_block) + .expect("Should complete"); + complete_blocks.push(complete_block); + } + + manager.register_blocks(complete_blocks); + + // Add nonexistent hash in the middle to trigger partial match + hashes.insert(2, 99999); // This hash won't exist + (manager, hashes) + }, + |(manager, hashes)| { + black_box(manager.match_blocks(&hashes)); + }, + criterion::BatchSize::SmallInput, + ); + }); + + group.finish(); +} + +// ============================================================================= +// DROP/CLEANUP BENCHMARKS +// ============================================================================= + +fn drop_benchmarks(c: &mut Criterion) { + let mut group = c.benchmark_group("drop"); + + // MutableBlock drop (return to ResetPool) + group.bench_function("mutable_block_drop", |b| { + b.iter_batched( + || { + let manager = create_bench_manager(100, 4); + manager.allocate_blocks(10).expect("Should allocate") + }, + |mutable_blocks| { + // Drop all blocks, measuring cleanup time + black_box(drop(mutable_blocks)); + }, + criterion::BatchSize::SmallInput, + ); + }); + + // RegisteredBlock drop (return to InactivePool) + group.bench_function("registered_block_drop", |b| { + b.iter_batched( + || { + let manager = create_bench_manager(100, 4); + let mut complete_blocks = Vec::new(); + + for i in 0..10 { + let token_block = create_bench_token_block(i as u32 * 10, 4); + let mutable_blocks = manager.allocate_blocks(1).expect("Should allocate"); + let complete_block = mutable_blocks + .into_iter() + .next() + .unwrap() + .complete(token_block) + .expect("Should complete"); + complete_blocks.push(complete_block); + } + + manager.register_blocks(complete_blocks) + }, + |registered_blocks| { + // Drop all registered blocks, measuring cleanup time + black_box(drop(registered_blocks)); + }, + criterion::BatchSize::SmallInput, + ); + }); + + group.finish(); +} + +// ============================================================================= +// ALLOCATION BENCHMARKS +// ============================================================================= + +fn allocation_benchmarks(c: &mut Criterion) { + let mut group = c.benchmark_group("allocation"); + + // Allocate from ResetPool + for size in &[1, 10, 100] { + group.throughput(Throughput::Elements(*size as u64)); + group.bench_with_input( + BenchmarkId::new("allocate_from_reset", size), + size, + |b, &size| { + b.iter_batched( + || create_bench_manager(size + 10, 4), + |manager| { + black_box(manager.allocate_blocks(size)); + }, + criterion::BatchSize::SmallInput, + ); + }, + ); + } + + // Allocate from InactivePool (reuse) + for size in &[1, 10, 100] { + group.throughput(Throughput::Elements(*size as u64)); + group.bench_with_input( + BenchmarkId::new("allocate_from_inactive", size), + size, + |b, &size| { + b.iter_batched( + || { + let manager = create_bench_manager(size + 10, 4); + + // Pre-populate inactive pool by registering and dropping blocks + let mut complete_blocks = Vec::new(); + for i in 0..size { + let token_block = create_bench_token_block(i as u32 * 10, 4); + let mutable_blocks = + manager.allocate_blocks(1).expect("Should allocate"); + let complete_block = mutable_blocks + .into_iter() + .next() + .unwrap() + .complete(token_block) + .expect("Should complete"); + complete_blocks.push(complete_block); + } + + let registered_blocks = manager.register_blocks(complete_blocks); + drop(registered_blocks); // Puts blocks in inactive pool + + manager + }, + |manager| { + // Try to allocate from inactive pool (should reuse blocks) + black_box(manager.allocate_blocks(size)); + }, + criterion::BatchSize::SmallInput, + ); + }, + ); + } + + group.finish(); +} + +// ============================================================================= +// BENCHMARK GROUPS +// ============================================================================= + +criterion_group!( + benches, + registration_benchmarks, + lookup_benchmarks, + drop_benchmarks, + allocation_benchmarks +); +criterion_main!(benches); diff --git a/lib/llm/src/block_manager.rs b/lib/llm/src/block_manager.rs index b4a797de6fb..93f98df1565 100644 --- a/lib/llm/src/block_manager.rs +++ b/lib/llm/src/block_manager.rs @@ -19,6 +19,8 @@ //! mechanisms. It handles storage allocation, block management, and safe access //! patterns for both system memory and remote (NIXL) storage. +pub mod v2; + pub mod config; mod state; diff --git a/lib/llm/src/block_manager/v2/guards/complete.rs b/lib/llm/src/block_manager/v2/guards/complete.rs new file mode 100644 index 00000000000..ab8fcda30ef --- /dev/null +++ b/lib/llm/src/block_manager/v2/guards/complete.rs @@ -0,0 +1,65 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! RAII guard for complete blocks + +use std::sync::Arc; + +use super::{ + super::pools::{ + BlockMetadata, + block::{Block, BlockId, Complete, Reset}, + }, + MutableBlock, +}; +use crate::tokens::{SequenceHash, TokenBlock}; + +/// RAII guard for [`Block`] that automatically returns to ResetPool on drop +pub struct CompleteBlock { + pub(crate) block: Option>, + pub(crate) return_fn: Arc) + Send + Sync>, +} + +impl CompleteBlock { + /// Create a new CompleteBlock + pub(crate) fn new( + block: Block, + return_fn: Arc) + Send + Sync>, + ) -> Self { + Self { + block: Some(block), + return_fn, + } + } + + /// Get the block ID + pub fn block_id(&self) -> BlockId { + self.block.as_ref().unwrap().block_id() + } + + /// Access token block if in Complete state + pub fn token_block(&self) -> &TokenBlock { + self.block.as_ref().unwrap().token_block() + } + + /// Get sequence hash if in Complete state + pub fn sequence_hash(&self) -> SequenceHash { + self.block.as_ref().unwrap().sequence_hash() + } + + /// Reset the block back to mutable state + pub fn reset(mut self) -> MutableBlock { + let block = self.block.take().unwrap().reset(); + + MutableBlock::new(block, self.return_fn.clone()) + } +} + +impl Drop for CompleteBlock { + #[inline] + fn drop(&mut self) { + if let Some(block) = self.block.take() { + (self.return_fn)(block.reset()); + } + } +} diff --git a/lib/llm/src/block_manager/v2/guards/immutable.rs b/lib/llm/src/block_manager/v2/guards/immutable.rs new file mode 100644 index 00000000000..0f3a7bed1c9 --- /dev/null +++ b/lib/llm/src/block_manager/v2/guards/immutable.rs @@ -0,0 +1,74 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! RAII guards for immutable and weak block references + +use std::{ + ops::Deref, + sync::{Arc, Weak}, +}; + +use super::{super::pools::BlockMetadata, RegisteredBlock}; +use crate::tokens::SequenceHash; + +/// RAII guard for registered blocks with upgrade capability +pub struct ImmutableBlock { + block: Arc>, + upgrade_fn: Arc Option>> + Send + Sync>, +} + +/// Weak reference to a registered block with upgrade capability +pub struct WeakBlock { + sequence_hash: SequenceHash, + block: Weak>, + upgrade_fn: Arc Option>> + Send + Sync>, +} + +impl ImmutableBlock { + /// Create a new ImmutableBlock with an upgrade function + pub fn new( + block: Arc>, + upgrade_fn: Arc Option>> + Send + Sync>, + ) -> Self { + Self { block, upgrade_fn } + } + + /// Downgrade to a WeakBlock + pub fn downgrade(&self) -> WeakBlock { + WeakBlock { + sequence_hash: self.sequence_hash(), + block: Arc::downgrade(&self.block), + upgrade_fn: self.upgrade_fn.clone(), + } + } +} + +impl WeakBlock { + /// Try to upgrade this WeakBlock back to an ImmutableBlock + pub fn upgrade(&self) -> Option> { + // First try to upgrade the weak reference directly + if let Some(block) = self.block.upgrade() { + return Some(ImmutableBlock::new(block, self.upgrade_fn.clone())); + } + + // If that fails, use the upgrade function to search for the block + if let Some(block) = (self.upgrade_fn)(self.sequence_hash) { + return Some(ImmutableBlock::new(block, self.upgrade_fn.clone())); + } + + None + } + + /// Get the sequence hash + pub fn sequence_hash(&self) -> SequenceHash { + self.sequence_hash + } +} + +impl Deref for ImmutableBlock { + type Target = dyn RegisteredBlock; + + fn deref(&self) -> &Self::Target { + self.block.as_ref() + } +} diff --git a/lib/llm/src/block_manager/v2/guards/mod.rs b/lib/llm/src/block_manager/v2/guards/mod.rs new file mode 100644 index 00000000000..9cd1faf3db1 --- /dev/null +++ b/lib/llm/src/block_manager/v2/guards/mod.rs @@ -0,0 +1,36 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! RAII guard types for type-safe block management +//! +//! This module provides type-safe RAII guards that ensure automatic resource cleanup: +//! - `MutableBlock`: Guards blocks in Reset state +//! - `CompleteBlock`: Guards blocks in Complete state +//! - `ImmutableBlock`: Guards registered blocks with upgrade capability +//! - `WeakBlock`: Weak references to registered blocks +//! - `PrimaryBlock`, `DuplicateBlock`: Internal registered block types + +use super::pools::{block::BlockId, registry::BlockRegistrationHandle}; +use crate::tokens::SequenceHash; + +pub mod complete; +pub mod immutable; +pub mod mutable; +pub mod registered; + +pub use complete::CompleteBlock; +pub use immutable::{ImmutableBlock, WeakBlock}; +pub use mutable::MutableBlock; +pub(crate) use registered::{DuplicateBlock, PrimaryBlock}; + +/// Trait for types that can be registered and provide block information +pub trait RegisteredBlock: Send + Sync { + /// Get the block ID + fn block_id(&self) -> BlockId; + + /// Get the sequence hash + fn sequence_hash(&self) -> SequenceHash; + + /// Get the registration handle + fn registration_handle(&self) -> &BlockRegistrationHandle; +} diff --git a/lib/llm/src/block_manager/v2/guards/mutable.rs b/lib/llm/src/block_manager/v2/guards/mutable.rs new file mode 100644 index 00000000000..56fd1c8c990 --- /dev/null +++ b/lib/llm/src/block_manager/v2/guards/mutable.rs @@ -0,0 +1,85 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! RAII guard for mutable blocks in Reset state + +use std::sync::Arc; + +use super::{ + super::pools::{ + BlockMetadata, + block::{Block, BlockError, BlockId, Reset}, + }, + CompleteBlock, +}; +use crate::tokens::TokenBlock; + +/// RAII guard for [`Block`] that automatically returns to ResetPool on drop +pub struct MutableBlock { + block: Option>, + return_fn: Arc) + Send + Sync>, +} + +impl MutableBlock { + /// Create a new MutableBlock in Reset state + pub(crate) fn new( + block: Block, + return_fn: Arc) + Send + Sync>, + ) -> Self { + Self { + block: Some(block), + return_fn, + } + } + + /// Get the block ID + pub fn block_id(&self) -> BlockId { + self.block.as_ref().unwrap().block_id() + } + + /// Transition from Reset to Complete state + pub fn complete( + mut self, + token_block: TokenBlock, + ) -> Result, BlockError>> { + let block = self.block.take().unwrap(); + match block.complete(token_block) { + Ok(complete_block) => Ok(CompleteBlock::new(complete_block, self.return_fn.clone())), + Err(block_error) => { + // Extract the block from the error and put it back in self + match block_error { + BlockError::BlockSizeMismatch { + expected, + actual, + block, + } => { + self.block = Some(block); + Err(BlockError::BlockSizeMismatch { + expected, + actual, + block: self, + }) + } + } + } + } + } +} + +impl Drop for MutableBlock { + #[inline] + fn drop(&mut self) { + if let Some(block) = self.block.take() { + (self.return_fn)(block); + } + } +} + +impl std::fmt::Debug for MutableBlock { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("MutableBlock") + .field("block", &self.block.as_ref().map(|b| b.block_id())) + .field("return_fn", &"") + .finish() + } +} diff --git a/lib/llm/src/block_manager/v2/guards/registered.rs b/lib/llm/src/block_manager/v2/guards/registered.rs new file mode 100644 index 00000000000..2de06132682 --- /dev/null +++ b/lib/llm/src/block_manager/v2/guards/registered.rs @@ -0,0 +1,109 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! RAII guards for registered blocks (primary and duplicate) + +use std::sync::Arc; + +use super::{ + super::pools::{ + BlockMetadata, + block::{Block, BlockId, Registered, Reset}, + registry::BlockRegistrationHandle, + }, + RegisteredBlock, +}; +use crate::tokens::SequenceHash; + +/// RAII guard for [`Block`] that automatically returns to RegisteredPool on drop +pub(crate) struct PrimaryBlock { + pub(crate) block: Option>>, + pub(crate) return_fn: Arc>) + Send + Sync>, +} + +/// RAII guard for duplicate blocks that share the same sequence hash as a primary block +pub(crate) struct DuplicateBlock { + pub(crate) block: Option>, + pub(crate) return_fn: Arc) + Send + Sync>, + pub(crate) _primary: Arc>, +} + +impl PrimaryBlock { + /// Create a new PrimaryBlock + pub(crate) fn new( + block: Arc>, + return_fn: Arc>) + Send + Sync>, + ) -> Self { + Self { + block: Some(block), + return_fn, + } + } + + /// Register this block and get an Arc to the RegisteredBlock trait object + pub(crate) fn register(self) -> Arc> { + let block = self.block.clone().unwrap(); + block.registration_handle().attach_block(self) + } +} + +impl DuplicateBlock { + /// Create a new DuplicateBlock + pub(crate) fn new( + block: Block, + primary: Arc>, + return_fn: Arc) + Send + Sync>, + ) -> Self { + Self { + block: Some(block), + return_fn, + _primary: primary, + } + } +} + +impl RegisteredBlock for PrimaryBlock { + fn block_id(&self) -> BlockId { + self.block.as_ref().unwrap().block_id() + } + + fn sequence_hash(&self) -> SequenceHash { + self.block.as_ref().unwrap().sequence_hash() + } + + fn registration_handle(&self) -> &BlockRegistrationHandle { + self.block.as_ref().unwrap().registration_handle() + } +} + +impl RegisteredBlock for DuplicateBlock { + fn block_id(&self) -> BlockId { + self.block.as_ref().unwrap().block_id() + } + + fn sequence_hash(&self) -> SequenceHash { + self.block.as_ref().unwrap().sequence_hash() + } + + fn registration_handle(&self) -> &BlockRegistrationHandle { + self.block.as_ref().unwrap().registration_handle() + } +} + +impl Drop for PrimaryBlock { + #[inline] + fn drop(&mut self) { + if let Some(block) = self.block.take() { + (self.return_fn)(block); + } + } +} + +impl Drop for DuplicateBlock { + #[inline] + fn drop(&mut self) { + if let Some(block) = self.block.take() { + (self.return_fn)(block.reset()); + } + } +} diff --git a/lib/llm/src/block_manager/v2/manager/builder.rs b/lib/llm/src/block_manager/v2/manager/builder.rs new file mode 100644 index 00000000000..83a4f1694fc --- /dev/null +++ b/lib/llm/src/block_manager/v2/manager/builder.rs @@ -0,0 +1,275 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Builder pattern for BlockManager with ergonomic backend configuration. + +use std::num::NonZeroUsize; +use std::sync::Arc; + + +use crate::block_manager::v2::pools::{ + ActivePool, BlockDuplicationPolicy, BlockMetadata, + block::{Block, Reset}, + frequency_sketch::TinyLFUTracker, + inactive::{ + InactivePool, + backends::{ + hashmap_backend::HashMapBackend, + lru_backend::LruBackend, + multi_lru_backend::MultiLruBackend, + reuse::fifo::FifoReusePolicy, + InactivePoolBackend, + ReusePolicy, + }, + }, + registry::BlockRegistry, + reset::ResetPool, +}; + +use super::BlockManager; + +/// Configuration for different inactive pool backends +pub enum InactiveBackendConfig { + /// HashMap with configurable reuse policy + HashMap { + reuse_policy: Box, + }, + /// Simple LRU - capacity automatically set to block_count + Lru, + /// Multi-level LRU with 4 fixed levels - capacity automatically set to block_count + MultiLru { + /// Frequency thresholds: [cold->warm, warm->hot, hot->very_hot] + /// Default: [3, 8, 15] + frequency_thresholds: [u8; 3], + }, +} + +/// Builder for BlockManager configuration +pub struct BlockManagerConfigBuilder { + /// Number of blocks in the pool + block_count: Option, + + /// Frequency tracker size for TinyLFU (must be power of 2) + /// Default: 2^21, Min: 2^18, Max: 2^24 + frequency_tracker_size: Option, + + /// Inactive pool backend configuration + inactive_backend: Option, + + /// Policy for handling duplicate sequence hashes + duplication_policy: Option, + + /// Phantom data for type parameter + _phantom: std::marker::PhantomData, +} + +/// Error types for BlockManager builder +#[derive(Debug, thiserror::Error)] +pub enum BlockManagerBuilderError { + #[error("Block count must be greater than 0")] + InvalidBlockCount, + #[error("Invalid backend configuration: {0}")] + InvalidBackend(String), + #[error("Builder validation failed: {0}")] + ValidationError(String), +} + +impl Default for BlockManagerConfigBuilder { + fn default() -> Self { + Self { + block_count: None, + frequency_tracker_size: None, + inactive_backend: None, + duplication_policy: None, + _phantom: std::marker::PhantomData, + } + } +} + +impl BlockManagerConfigBuilder { + /// Create a new builder + pub fn new() -> Self { + Self::default() + } + + /// Set the number of blocks in the pool + pub fn block_count(mut self, count: usize) -> Self { + self.block_count = Some(count); + self + } + + /// Set the duplication policy + pub fn duplication_policy(mut self, policy: BlockDuplicationPolicy) -> Self { + self.duplication_policy = Some(policy); + self + } + /// Set frequency tracker size with validation + /// Must be a power of 2 between 2^18 and 2^24 + pub fn frequency_tracker_size(mut self, size: usize) -> Self { + assert!(size >= (1 << 18) && size <= (1 << 24), + "Frequency tracker size must be between 2^18 and 2^24, got: {}", size); + assert!(size.is_power_of_two(), + "Frequency tracker size must be a power of 2, got: {}", size); + self.frequency_tracker_size = Some(size); + self + } + + /// Use simple LRU backend (capacity automatically set to block_count) + pub fn with_lru_backend(mut self) -> Self { + self.inactive_backend = Some(InactiveBackendConfig::Lru); + self + } + + /// Use multi-level LRU backend with 4 fixed priority levels + /// Default thresholds: [3, 8, 15] for transitions between: + /// - Cold (0-2 hits) -> Warm (3-7 hits) -> Hot (8-14 hits) -> Very Hot (15 hits) + pub fn with_multi_lru_backend(mut self) -> Self { + self.inactive_backend = Some(InactiveBackendConfig::MultiLru { + frequency_thresholds: [3, 8, 15], + }); + self + } + + /// Use multi-level LRU with custom frequency thresholds + /// + /// # Requirements + /// - Thresholds must be in ascending order: cold_to_warm < warm_to_hot < hot_to_very_hot + /// - hot_to_very_hot must be <= 15 (4-bit counter maximum) + /// - cold_to_warm must be >= 1 (to distinguish from never-accessed blocks) + /// + /// # Arguments + /// * `cold_to_warm` - Minimum frequency to move from Cold to Warm level + /// * `warm_to_hot` - Minimum frequency to move from Warm to Hot level + /// * `hot_to_very_hot` - Minimum frequency to move from Hot to Very Hot level + /// + /// # Panics + /// Panics if thresholds don't meet the requirements above + pub fn with_multi_lru_backend_custom_thresholds( + mut self, + cold_to_warm: u8, + warm_to_hot: u8, + hot_to_very_hot: u8, + ) -> Self { + // Validate ascending order + assert!( + cold_to_warm < warm_to_hot && warm_to_hot < hot_to_very_hot, + "Thresholds must be in ascending order: {} < {} < {} failed", + cold_to_warm, warm_to_hot, hot_to_very_hot + ); + + // Validate maximum value (4-bit counter limit) + assert!( + hot_to_very_hot <= 15, + "hot_to_very_hot threshold ({}) must be <= 15 (4-bit counter maximum)", + hot_to_very_hot + ); + + // Additional validation: ensure reasonable gaps between levels + assert!( + cold_to_warm >= 1, + "cold_to_warm threshold must be >= 1 to distinguish from zero-access blocks" + ); + + self.inactive_backend = Some(InactiveBackendConfig::MultiLru { + frequency_thresholds: [cold_to_warm, warm_to_hot, hot_to_very_hot], + }); + self + } + + /// Use HashMap backend with custom reuse policy + pub fn with_hashmap_backend(mut self, reuse_policy: Box) -> Self { + self.inactive_backend = Some(InactiveBackendConfig::HashMap { reuse_policy }); + self + } + + /// Validate the configuration + fn validate(&self) -> Result<(), String> { + let block_count = self.block_count + .ok_or("block_count is required")?; + + if block_count == 0 { + return Err("block_count must be greater than 0".to_string()); + } + + // Additional validation for MultiLRU thresholds at build time + if let Some(InactiveBackendConfig::MultiLru { frequency_thresholds }) = &self.inactive_backend { + let [t1, t2, t3] = frequency_thresholds; + if !(*t1 < *t2 && *t2 < *t3) { + return Err(format!( + "Invalid thresholds [{}, {}, {}]: must be in ascending order", + t1, t2, t3 + )); + } + if *t3 > 15 { + return Err(format!( + "Invalid threshold {}: maximum frequency is 15 (4-bit counter)", + t3 + )); + } + } + + Ok(()) + } + + /// Build the BlockManager + pub fn build(mut self) -> Result, BlockManagerBuilderError> { + // First validate the configuration + self.validate() + .map_err(BlockManagerBuilderError::ValidationError)?; + + let block_count = self.block_count.unwrap(); + + // Create registry with frequency tracking + let freq_size = self.frequency_tracker_size.unwrap_or(2_097_152); + let frequency_tracker = Arc::new(TinyLFUTracker::new(freq_size)); + let registry = BlockRegistry::with_frequency_tracker(frequency_tracker.clone()); + + // Create reset pool + let blocks: Vec> = (0..block_count as u64) + .map(|id| Block::new(id)) + .collect(); + let reset_pool = ResetPool::new(blocks); + + // Create backend based on configuration + let backend: Box> = match self.inactive_backend.take() { + Some(InactiveBackendConfig::HashMap { reuse_policy }) => { + Box::new(HashMapBackend::new(reuse_policy)) + } + Some(InactiveBackendConfig::Lru) => { + // Capacity automatically set to block_count + let capacity = NonZeroUsize::new(block_count) + .expect("block_count must be > 0"); + Box::new(LruBackend::new(capacity)) + } + Some(InactiveBackendConfig::MultiLru { frequency_thresholds }) => { + // Total capacity = block_count, distributed across 4 levels + let capacity_per_level = (block_count + 3) / 4; // Round up division + let level_capacity = NonZeroUsize::new(capacity_per_level) + .expect("capacity per level must be > 0"); + + Box::new(MultiLruBackend::new_with_thresholds( + level_capacity, + &frequency_thresholds, + frequency_tracker, + )) + } + None => { + // Default to HashMap with FIFO + Box::new(HashMapBackend::new(Box::new(FifoReusePolicy::new()))) + } + }; + + // Create pools + let inactive_pool = InactivePool::new(backend, &reset_pool); + let active_pool = ActivePool::new(registry.clone(), inactive_pool.return_fn()); + + Ok(BlockManager { + reset_pool, + active_pool, + inactive_pool, + block_registry: registry, + duplication_policy: self.duplication_policy.unwrap_or(BlockDuplicationPolicy::Allow), + }) + } +} + diff --git a/lib/llm/src/block_manager/v2/manager/builder_tests.rs b/lib/llm/src/block_manager/v2/manager/builder_tests.rs new file mode 100644 index 00000000000..f649404999b --- /dev/null +++ b/lib/llm/src/block_manager/v2/manager/builder_tests.rs @@ -0,0 +1,115 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Tests for BlockManager builder pattern + +#[cfg(test)] +mod tests { + use super::super::BlockManager; + + #[derive(Debug, Clone, PartialEq)] + struct TestBlockData { + value: u32, + } + + #[test] + fn test_builder_default() { + let manager = BlockManager::::builder() + .block_count(100) + .build() + .expect("Should build with defaults"); + + // Verify we can allocate blocks + let blocks = manager.allocate_blocks(5); + assert!(blocks.is_some()); + assert_eq!(blocks.unwrap().len(), 5); + } + + #[test] + fn test_builder_with_lru_backend() { + let manager = BlockManager::::builder() + .block_count(100) + .with_lru_backend() + .build() + .expect("Should build with LRU backend"); + + // Verify we can allocate blocks + let blocks = manager.allocate_blocks(10); + assert!(blocks.is_some()); + assert_eq!(blocks.unwrap().len(), 10); + } + + #[test] + fn test_builder_with_multi_lru_backend() { + let manager = BlockManager::::builder() + .block_count(100) + .frequency_tracker_size(1 << 20) // 2^20 + .with_multi_lru_backend() + .build() + .expect("Should build with MultiLRU backend"); + + // Verify we can allocate blocks + let blocks = manager.allocate_blocks(8); + assert!(blocks.is_some()); + assert_eq!(blocks.unwrap().len(), 8); + } + + #[test] + fn test_builder_with_custom_multi_lru_thresholds() { + let manager = BlockManager::::builder() + .block_count(100) + .frequency_tracker_size(1 << 21) // 2^21 (default) + .with_multi_lru_backend_custom_thresholds(2, 6, 12) + .build() + .expect("Should build with custom thresholds"); + + // Verify we can allocate blocks + let blocks = manager.allocate_blocks(4); + assert!(blocks.is_some()); + assert_eq!(blocks.unwrap().len(), 4); + } + + #[test] + fn test_builder_validation_zero_blocks() { + let result = BlockManager::::builder() + .block_count(0) + .build(); + + assert!(result.is_err()); + if let Err(err) = result { + assert!(err.to_string().contains("block_count must be greater than 0")); + } + } + + #[test] + #[should_panic(expected = "must be <= 15")] + fn test_builder_invalid_threshold_too_high() { + BlockManager::::builder() + .block_count(100) + .with_multi_lru_backend_custom_thresholds(2, 6, 20); // 20 > 15, should panic + } + + #[test] + #[should_panic(expected = "must be in ascending order")] + fn test_builder_invalid_threshold_order() { + BlockManager::::builder() + .block_count(100) + .with_multi_lru_backend_custom_thresholds(6, 2, 10); // Not ascending, should panic + } + + #[test] + #[should_panic(expected = "must be between 2^18 and 2^24")] + fn test_builder_invalid_frequency_tracker_size() { + BlockManager::::builder() + .block_count(100) + .frequency_tracker_size(1000); // Not a valid size, should panic + } + + #[test] + #[should_panic(expected = "must be a power of 2")] + fn test_builder_non_power_of_two_frequency_tracker() { + BlockManager::::builder() + .block_count(100) + .frequency_tracker_size((1 << 20) + 1); // Not power of 2, should panic + } +} \ No newline at end of file diff --git a/lib/llm/src/block_manager/v2/manager/mod.rs b/lib/llm/src/block_manager/v2/manager/mod.rs new file mode 100644 index 00000000000..8e820523dc3 --- /dev/null +++ b/lib/llm/src/block_manager/v2/manager/mod.rs @@ -0,0 +1,1881 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Block Manager v2 + +use std::num::NonZeroUsize; +use std::sync::Arc; + +use parking_lot::Mutex; + +use super::{ + policies::{BlockDuplicationPolicy, ReusePolicy, reuse::fifo::FifoReusePolicy}, + pools::{ + ActivePool, BlockMetadata, RegisteredBlock, SequenceHash, + block::{Block, Reset}, + frequency_sketch::TinyLFUTracker, + inactive::InactivePool, + inactive::backends::{ + InactivePoolBackend, hashmap_backend::HashMapBackend, lru_backend::LruBackend, + multi_lru_backend::MultiLruBackend, + }, + registry::BlockRegistry, + reset::ResetPool, + *, + }, +}; + +/// Configuration for different inactive pool backends +pub enum InactiveBackendConfig { + /// HashMap with configurable reuse policy + HashMap { reuse_policy: Box }, + /// Simple LRU - capacity automatically set to block_count + Lru, + /// Multi-level LRU with 4 fixed levels - capacity automatically set to block_count + MultiLru { + /// Frequency thresholds: [cold->warm, warm->hot, hot->very_hot] + /// Default: [3, 8, 15] + frequency_thresholds: [u8; 3], + }, +} + +/// Builder for BlockManager configuration +pub struct BlockManagerConfigBuilder { + /// Number of blocks in the pool + block_count: Option, + + /// Size of each block in tokens (must be power of 2, 1-1024) + /// Default: 16 + block_size: Option, + + /// Frequency tracker size for TinyLFU (must be power of 2) + /// Default: 2^21, Min: 2^18, Max: 2^24 + frequency_tracker_size: Option, + + /// Inactive pool backend configuration + inactive_backend: Option, + + /// Policy for handling duplicate sequence hashes + duplication_policy: Option, + + /// Phantom data for type parameter + _phantom: std::marker::PhantomData, +} + +/// Error types for BlockManager builder +#[derive(Debug, thiserror::Error)] +pub enum BlockManagerBuilderError { + #[error("Block count must be greater than 0")] + InvalidBlockCount, + #[error("Block size mismatch: expected {expected} tokens, got {actual}")] + BlockSizeMismatch { expected: usize, actual: usize }, + #[error("Invalid backend configuration: {0}")] + InvalidBackend(String), + #[error("Builder validation failed: {0}")] + ValidationError(String), +} + +/// BlockManager v2 with pluggable inactive pool backends +pub struct BlockManager { + reset_pool: ResetPool, + active_pool: ActivePool, + inactive_pool: InactivePool, + block_registry: BlockRegistry, + duplication_policy: BlockDuplicationPolicy, + upgrade_fn: Arc Option>> + Send + Sync>, + allocate_mutex: Mutex<()>, + total_blocks: usize, + block_size: usize, +} + +impl BlockManager { + /// Create a new builder for BlockManager + /// + /// # Example + /// ```ignore + /// let manager = BlockManager::builder() + /// .block_count(1000) + /// .with_multi_lru_backend() + /// .build()?; + /// ``` + pub fn builder() -> BlockManagerConfigBuilder { + BlockManagerConfigBuilder::default() + } + + pub fn allocate_blocks(&self, count: usize) -> Option>> { + let _guard = self.allocate_mutex.lock(); + let mut blocks = self.reset_pool.try_allocate_blocks(count); + match self.inactive_pool.allocate_blocks(count - blocks.len()) { + Some(remaining) => { + blocks.extend(remaining); + Some(blocks) + } + None => return None, + } + } + + pub fn register_blocks(&self, blocks: Vec>) -> Vec> { + let pool_return_fn = self.inactive_pool.return_fn(); + blocks + .into_iter() + .map(|block| { + let handle = self + .block_registry + .register_sequence_hash(block.sequence_hash()); + let registered_block = + handle.register_block(block, self.duplication_policy, pool_return_fn.clone()); + ImmutableBlock::new(registered_block, self.upgrade_fn.clone()) + }) + .collect() + } + + pub fn match_blocks(&self, seq_hash: &[SequenceHash]) -> Vec> { + // First try to match against active blocks + let mut matched: Vec> = Vec::with_capacity(seq_hash.len()); + matched.extend( + self.active_pool + .find_matches(seq_hash, true) + .into_iter() + .map(|block| ImmutableBlock::new(block, self.upgrade_fn.clone())), + ); + + // If we didn't match all hashes, try inactive blocks for the remaining ones + let remaining_hashes = &seq_hash[matched.len()..]; + if !remaining_hashes.is_empty() { + matched.extend( + self.inactive_pool + .find_blocks(remaining_hashes, true) + .into_iter() + .map(|block| ImmutableBlock::new(block, self.upgrade_fn.clone())), + ); + } + + matched + } + + pub fn total_blocks(&self) -> usize { + self.total_blocks + } + + pub fn available_blocks(&self) -> usize { + self.reset_pool.len() + self.inactive_pool.len() + } + + pub fn block_size(&self) -> usize { + self.block_size + } +} + +impl Default for BlockManagerConfigBuilder { + fn default() -> Self { + Self { + block_count: None, + block_size: Some(16), // Default to 16 tokens per block + frequency_tracker_size: None, + inactive_backend: None, + duplication_policy: None, + _phantom: std::marker::PhantomData, + } + } +} + +impl BlockManagerConfigBuilder { + /// Create a new builder + pub fn new() -> Self { + Self::default() + } + + /// Set the number of blocks in the pool + pub fn block_count(mut self, count: usize) -> Self { + self.block_count = Some(count); + self + } + + /// Set the block size (number of tokens per block) + /// + /// # Requirements + /// - Must be >= 1 and <= 1024 + /// - Must be a power of 2 + /// + /// # Panics + /// Panics if the block size doesn't meet requirements + pub fn block_size(mut self, size: usize) -> Self { + assert!( + size >= 1 && size <= 1024, + "block_size must be between 1 and 1024, got {}", + size + ); + assert!( + size.is_power_of_two(), + "block_size must be a power of 2, got {}", + size + ); + self.block_size = Some(size); + self + } + + /// Set the duplication policy + pub fn duplication_policy(mut self, policy: BlockDuplicationPolicy) -> Self { + self.duplication_policy = Some(policy); + self + } + /// Set frequency tracker size with validation + /// Must be a power of 2 between 2^18 and 2^24 + pub fn frequency_tracker_size(mut self, size: usize) -> Self { + assert!( + size >= (1 << 18) && size <= (1 << 24), + "Frequency tracker size must be between 2^18 and 2^24, got: {}", + size + ); + assert!( + size.is_power_of_two(), + "Frequency tracker size must be a power of 2, got: {}", + size + ); + self.frequency_tracker_size = Some(size); + self + } + + /// Use simple LRU backend (capacity automatically set to block_count) + pub fn with_lru_backend(mut self) -> Self { + self.inactive_backend = Some(InactiveBackendConfig::Lru); + self + } + + /// Use multi-level LRU backend with 4 fixed priority levels + /// Default thresholds: [3, 8, 15] for transitions between: + /// - Cold (0-2 hits) -> Warm (3-7 hits) -> Hot (8-14 hits) -> Very Hot (15 hits) + pub fn with_multi_lru_backend(mut self) -> Self { + self.inactive_backend = Some(InactiveBackendConfig::MultiLru { + frequency_thresholds: [3, 8, 15], + }); + self + } + + /// Use multi-level LRU with custom frequency thresholds + /// + /// # Requirements + /// - Thresholds must be in ascending order: cold_to_warm < warm_to_hot < hot_to_very_hot + /// - hot_to_very_hot must be <= 15 (4-bit counter maximum) + /// - cold_to_warm must be >= 1 (to distinguish from never-accessed blocks) + /// + /// # Arguments + /// * `cold_to_warm` - Minimum frequency to move from Cold to Warm level + /// * `warm_to_hot` - Minimum frequency to move from Warm to Hot level + /// * `hot_to_very_hot` - Minimum frequency to move from Hot to Very Hot level + /// + /// # Panics + /// Panics if thresholds don't meet the requirements above + pub fn with_multi_lru_backend_custom_thresholds( + mut self, + cold_to_warm: u8, + warm_to_hot: u8, + hot_to_very_hot: u8, + ) -> Self { + // Validate ascending order + assert!( + cold_to_warm < warm_to_hot && warm_to_hot < hot_to_very_hot, + "Thresholds must be in ascending order: {} < {} < {} failed", + cold_to_warm, + warm_to_hot, + hot_to_very_hot + ); + + // Validate maximum value (4-bit counter limit) + assert!( + hot_to_very_hot <= 15, + "hot_to_very_hot threshold ({}) must be <= 15 (4-bit counter maximum)", + hot_to_very_hot + ); + + // Additional validation: ensure reasonable gaps between levels + assert!( + cold_to_warm >= 1, + "cold_to_warm threshold must be >= 1 to distinguish from zero-access blocks" + ); + + self.inactive_backend = Some(InactiveBackendConfig::MultiLru { + frequency_thresholds: [cold_to_warm, warm_to_hot, hot_to_very_hot], + }); + self + } + + /// Use HashMap backend with custom reuse policy + pub fn with_hashmap_backend(mut self, reuse_policy: Box) -> Self { + self.inactive_backend = Some(InactiveBackendConfig::HashMap { reuse_policy }); + self + } + + /// Validate the configuration + fn validate(&self) -> Result<(), String> { + let block_count = self.block_count.ok_or("block_count is required")?; + + if block_count == 0 { + return Err("block_count must be greater than 0".to_string()); + } + + // Validate block_size + let block_size = self.block_size.unwrap_or(16); + if !block_size.is_power_of_two() || block_size < 1 || block_size > 1024 { + return Err(format!( + "Invalid block_size {}: must be a power of 2 between 1 and 1024", + block_size + )); + } + + // Additional validation for MultiLRU thresholds at build time + if let Some(InactiveBackendConfig::MultiLru { + frequency_thresholds, + }) = &self.inactive_backend + { + let [t1, t2, t3] = frequency_thresholds; + if !(*t1 < *t2 && *t2 < *t3) { + return Err(format!( + "Invalid thresholds [{}, {}, {}]: must be in ascending order", + t1, t2, t3 + )); + } + if *t3 > 15 { + return Err(format!( + "Invalid threshold {}: maximum frequency is 15 (4-bit counter)", + t3 + )); + } + } + + Ok(()) + } + + /// Build the BlockManager + pub fn build(mut self) -> Result, BlockManagerBuilderError> { + // First validate the configuration + self.validate() + .map_err(BlockManagerBuilderError::ValidationError)?; + + let block_count = self.block_count.unwrap(); + let block_size = self.block_size.unwrap_or(16); + + // Create registry with frequency tracking + let freq_size = self.frequency_tracker_size.unwrap_or(2_097_152); + let frequency_tracker = Arc::new(TinyLFUTracker::new(freq_size)); + let registry = BlockRegistry::with_frequency_tracker(frequency_tracker.clone()); + + // Create reset pool + let blocks: Vec> = (0..block_count as u64) + .map(|id| Block::new(id, block_size)) + .collect(); + let reset_pool = ResetPool::new(blocks, block_size); + + // Create backend based on configuration + let backend: Box> = match self.inactive_backend.take() { + Some(InactiveBackendConfig::HashMap { reuse_policy }) => { + Box::new(HashMapBackend::new(reuse_policy)) + } + Some(InactiveBackendConfig::Lru) => { + // Capacity automatically set to block_count + let capacity = NonZeroUsize::new(block_count).expect("block_count must be > 0"); + Box::new(LruBackend::new(capacity)) + } + Some(InactiveBackendConfig::MultiLru { + frequency_thresholds, + }) => { + // Total capacity = block_count, distributed across 4 levels + let capacity_per_level = (block_count + 3) / 4; // Round up division + let level_capacity = + NonZeroUsize::new(capacity_per_level).expect("capacity per level must be > 0"); + + Box::new(MultiLruBackend::new_with_thresholds( + level_capacity, + &frequency_thresholds, + frequency_tracker, + )) + } + None => { + // Default to HashMap with FIFO + Box::new(HashMapBackend::new(Box::new(FifoReusePolicy::new()))) + } + }; + + // Create pools + let inactive_pool = InactivePool::new(backend, &reset_pool); + let active_pool = ActivePool::new(registry.clone(), inactive_pool.return_fn()); + + // Create upgrade function that captures the necessary components + let registry_clone = registry.clone(); + let inactive_pool_clone = inactive_pool.clone(); + let return_fn_clone = inactive_pool.return_fn(); + let upgrade_fn = Arc::new( + move |seq_hash: SequenceHash| -> Option>> { + // Try active pool first with touch=false (using registry directly) + if let Some(handle) = registry_clone.match_sequence_hash(seq_hash, false) { + if let Some(block) = handle.try_get_block::(return_fn_clone.clone()) { + return Some(block); + } + } + // Then try inactive pool with touch=false + if let Some(block) = inactive_pool_clone + .find_blocks(&[seq_hash], false) + .into_iter() + .next() + { + return Some(block); + } + None + }, + ); + + Ok(BlockManager { + reset_pool, + active_pool, + inactive_pool, + block_registry: registry, + duplication_policy: self + .duplication_policy + .unwrap_or(BlockDuplicationPolicy::Allow), + upgrade_fn, + allocate_mutex: Mutex::new(()), + total_blocks: block_count, + block_size, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::tokens::TokenBlockSequence; + + #[derive(Debug, Clone, PartialEq)] + struct TestBlockData { + value: u32, + } + + /// Helper function to create a token block with specific data + fn create_token_block(tokens: &[u32]) -> crate::tokens::TokenBlock { + let token_sequence = TokenBlockSequence::from_slice(tokens, tokens.len() as u32, Some(42)); + if let Some(block) = token_sequence.blocks().first() { + block.clone() + } else { + let mut partial = token_sequence.into_parts().1; + partial.commit().expect("Should be able to commit") + } + } + + /// Helper function to create a token block using fill_iota pattern + fn create_test_token_block_from_iota(start: u32) -> crate::tokens::TokenBlock { + // Use fill_iota to generate [start, start+1, start+2, start+3] + let tokens: Vec = (start..start + 4).collect(); + create_token_block(&tokens) + } + + fn create_test_token_block_8_from_iota(start: u32) -> crate::tokens::TokenBlock { + // Generate 8 sequential tokens starting from start + let tokens: Vec = (start..start + 8).collect(); + create_token_block(&tokens) + } + + /// Helper function to create a token block with exactly 16 tokens for testing + fn create_token_block_16() -> crate::tokens::TokenBlock { + let tokens: Vec = (100..116).collect(); // 16 tokens: 100, 101, ..., 115 + create_token_block(&tokens) + } + + /// Helper function to create a basic manager for testing + fn create_test_manager(block_count: usize) -> BlockManager { + BlockManager::::builder() + .block_count(block_count) + .block_size(4) // Most tests use 4-token blocks + .with_lru_backend() + .build() + .expect("Should build manager") + } + + // ============================================================================ + // BUILDER PATTERN TESTS + // ============================================================================ + + mod builder_tests { + use super::*; + + #[test] + fn test_builder_default() { + let manager = BlockManager::::builder() + .block_count(100) + .build() + .expect("Should build with defaults"); + + // Verify we can allocate blocks + let blocks = manager.allocate_blocks(5); + assert!(blocks.is_some()); + assert_eq!(blocks.unwrap().len(), 5); + } + + #[test] + fn test_builder_with_lru_backend() { + let manager = BlockManager::::builder() + .block_count(100) + .with_lru_backend() + .build() + .expect("Should build with LRU backend"); + + // Verify we can allocate blocks + let blocks = manager.allocate_blocks(10); + assert!(blocks.is_some()); + assert_eq!(blocks.unwrap().len(), 10); + } + + #[test] + fn test_builder_with_multi_lru_backend() { + let manager = BlockManager::::builder() + .block_count(100) + .frequency_tracker_size(1 << 20) // 2^20 + .with_multi_lru_backend() + .build() + .expect("Should build with MultiLRU backend"); + + // Verify we can allocate blocks + let blocks = manager.allocate_blocks(8); + assert!(blocks.is_some()); + assert_eq!(blocks.unwrap().len(), 8); + } + + #[test] + fn test_builder_with_custom_multi_lru_thresholds() { + let manager = BlockManager::::builder() + .block_count(100) + .frequency_tracker_size(1 << 21) // 2^21 (default) + .with_multi_lru_backend_custom_thresholds(2, 6, 12) + .build() + .expect("Should build with custom thresholds"); + + // Verify we can allocate blocks + let blocks = manager.allocate_blocks(4); + assert!(blocks.is_some()); + assert_eq!(blocks.unwrap().len(), 4); + } + + #[test] + fn test_builder_with_duplication_policy() { + let manager = BlockManager::::builder() + .block_count(50) + .duplication_policy(BlockDuplicationPolicy::Reject) + .with_lru_backend() + .build() + .expect("Should build with duplication policy"); + + let blocks = manager.allocate_blocks(2); + assert!(blocks.is_some()); + assert_eq!(blocks.unwrap().len(), 2); + } + + #[test] + fn test_builder_validation_zero_blocks() { + let result = BlockManager::::builder() + .block_count(0) + .build(); + + assert!(result.is_err()); + if let Err(err) = result { + assert!( + err.to_string() + .contains("block_count must be greater than 0") + ); + } + } + + #[test] + fn test_builder_validation_missing_block_count() { + let result = BlockManager::::builder() + .with_lru_backend() + .build(); + + assert!(result.is_err()); + if let Err(err) = result { + assert!(err.to_string().contains("block_count is required")); + } + } + + #[test] + #[should_panic(expected = "must be <= 15")] + fn test_builder_invalid_threshold_too_high() { + BlockManager::::builder() + .block_count(100) + .with_multi_lru_backend_custom_thresholds(2, 6, 20); // 20 > 15, should panic + } + + #[test] + #[should_panic(expected = "must be in ascending order")] + fn test_builder_invalid_threshold_order() { + BlockManager::::builder() + .block_count(100) + .with_multi_lru_backend_custom_thresholds(6, 2, 10); // Not ascending, should panic + } + + #[test] + #[should_panic(expected = "must be between 2^18 and 2^24")] + fn test_builder_invalid_frequency_tracker_size() { + BlockManager::::builder() + .block_count(100) + .frequency_tracker_size(1000); // Not a valid size, should panic + } + + #[test] + #[should_panic(expected = "must be a power of 2")] + fn test_builder_non_power_of_two_frequency_tracker() { + BlockManager::::builder() + .block_count(100) + .frequency_tracker_size((1 << 20) + 1); // Not power of 2, should panic + } + } + + // ============================================================================ + // BLOCK ALLOCATION TESTS + // ============================================================================ + + mod allocation_tests { + use super::*; + + #[test] + fn test_allocate_single_block() { + let manager = create_test_manager(10); + + let initial_available = manager.available_blocks(); + let initial_total = manager.total_blocks(); + assert_eq!(initial_available, 10); + + let blocks = manager.allocate_blocks(1).expect("Should allocate 1 block"); + assert_eq!(blocks.len(), 1); + + // Verify available blocks decreased + assert_eq!(manager.available_blocks(), initial_available - 1); + assert_eq!(manager.total_blocks(), initial_total); + + let block = blocks.into_iter().next().unwrap(); + // Verify block has a valid ID + let _block_id = block.block_id(); + + // Drop the block and verify it returns to pool + drop(block); + assert_eq!(manager.available_blocks(), initial_available); + assert_eq!(manager.total_blocks(), initial_total); + } + + #[test] + fn test_allocate_multiple_blocks() { + let manager = create_test_manager(20); + + let initial_available = manager.available_blocks(); + let initial_total = manager.total_blocks(); + assert_eq!(initial_available, 20); + + let blocks = manager + .allocate_blocks(5) + .expect("Should allocate 5 blocks"); + assert_eq!(blocks.len(), 5); + + // Verify available blocks decreased correctly + assert_eq!(manager.available_blocks(), initial_available - 5); + assert_eq!(manager.total_blocks(), initial_total); + + // Verify all blocks have unique IDs + let mut block_ids = Vec::new(); + for block in blocks { + let id = block.block_id(); + assert!(!block_ids.contains(&id), "Block IDs should be unique"); + block_ids.push(id); + } + + // All blocks should return to pool automatically on drop + assert_eq!(manager.available_blocks(), initial_available); + assert_eq!(manager.total_blocks(), initial_total); + } + + #[test] + fn test_allocate_all_blocks() { + let manager = create_test_manager(10); + + let blocks = manager + .allocate_blocks(10) + .expect("Should allocate all blocks"); + assert_eq!(blocks.len(), 10); + } + + #[test] + fn test_allocate_more_than_available() { + let manager = create_test_manager(5); + + let result = manager.allocate_blocks(10); + assert!( + result.is_none(), + "Should not allocate more blocks than available" + ); + } + + #[test] + fn test_allocate_zero_blocks() { + let manager = create_test_manager(10); + + let blocks = manager + .allocate_blocks(0) + .expect("Should allocate 0 blocks"); + assert_eq!(blocks.len(), 0); + } + + #[test] + fn test_sequential_allocations() { + let manager = create_test_manager(10); + + let total_blocks = manager.total_blocks(); + assert_eq!(manager.available_blocks(), total_blocks); + + let blocks1 = manager.allocate_blocks(3).expect("First allocation"); + assert_eq!(blocks1.len(), 3); + assert_eq!(manager.available_blocks(), total_blocks - 3); + + let blocks2 = manager.allocate_blocks(4).expect("Second allocation"); + assert_eq!(blocks2.len(), 4); + assert_eq!(manager.available_blocks(), total_blocks - 7); + + let blocks3 = manager.allocate_blocks(3).expect("Third allocation"); + assert_eq!(blocks3.len(), 3); + assert_eq!(manager.available_blocks(), 0); + + // Should have no blocks left + let blocks4 = manager.allocate_blocks(1); + assert!(blocks4.is_none(), "Should not have any blocks left"); + + // Drop blocks in reverse order and verify counts + drop(blocks3); + assert_eq!(manager.available_blocks(), 3); + + drop(blocks2); + assert_eq!(manager.available_blocks(), 7); + + drop(blocks1); + assert_eq!(manager.available_blocks(), total_blocks); + assert_eq!(manager.total_blocks(), total_blocks); + } + } + + // ============================================================================ + // BLOCK LIFECYCLE AND POOL RETURN TESTS + // ============================================================================ + + mod lifecycle_tests { + use super::*; + + #[test] + fn test_mutable_block_returns_to_reset_pool() { + let manager = create_test_manager(10); + + let initial_available = manager.available_blocks(); + let initial_total = manager.total_blocks(); + assert_eq!(initial_available, 10); + assert_eq!(initial_total, 10); + + { + let blocks = manager + .allocate_blocks(3) + .expect("Should allocate 3 blocks"); + assert_eq!(blocks.len(), 3); + + // Available blocks should decrease + assert_eq!(manager.available_blocks(), initial_available - 3); + assert_eq!(manager.total_blocks(), initial_total); // Total never changes + } // MutableBlocks dropped here - should return to reset pool + + // Available blocks should return to original count + assert_eq!(manager.available_blocks(), initial_available); + assert_eq!(manager.total_blocks(), initial_total); + } + + #[test] + fn test_complete_block_returns_to_reset_pool() { + let manager = create_test_manager(10); + + let initial_available = manager.available_blocks(); + let initial_total = manager.total_blocks(); + + { + let mutable_blocks = manager.allocate_blocks(2).expect("Should allocate blocks"); + assert_eq!(manager.available_blocks(), initial_available - 2); + + let _complete_blocks: Vec<_> = mutable_blocks + .into_iter() + .enumerate() + .map(|(i, block)| { + let tokens = vec![400 + i as u32, 401 + i as u32, 402 + i as u32]; + let token_block = create_token_block(&tokens); + block.complete(token_block) + }) + .collect(); + + // Blocks are still unavailable while in Complete state + assert_eq!(manager.available_blocks(), initial_available - 2); + } // CompleteBlocks dropped here - should return to reset pool + + // Available blocks should return to original count since blocks weren't registered + assert_eq!(manager.available_blocks(), initial_available); + assert_eq!(manager.total_blocks(), initial_total); + } + + #[test] + fn test_registered_block_lifecycle() { + let manager = create_test_manager(10); + + let initial_available = manager.available_blocks(); + let initial_total = manager.total_blocks(); + + // Step 1: Allocate and complete blocks + let token_block = create_test_token_block_from_iota(500); + let seq_hash = token_block.sequence_hash(); + + let mutable_blocks = manager.allocate_blocks(1).expect("Should allocate blocks"); + assert_eq!(manager.available_blocks(), initial_available - 1); + + let complete_block = mutable_blocks + .into_iter() + .next() + .unwrap() + .complete(token_block) + .expect("Should complete block"); + + // Still unavailable while in Complete state + assert_eq!(manager.available_blocks(), initial_available - 1); + + // Step 2: Register the block + let immutable_blocks = manager.register_blocks(vec![complete_block]); + assert_eq!(immutable_blocks.len(), 1); + let immutable_block = immutable_blocks.into_iter().next().unwrap(); + + // Block is still not available (it's now in active/inactive pools, not reset) + assert_eq!(manager.available_blocks(), initial_available - 1); + + { + // Step 3: Use the block and verify it can be matched + let matched_blocks = manager.match_blocks(&[seq_hash]); + assert_eq!(matched_blocks.len(), 1); + assert_eq!(matched_blocks[0].sequence_hash(), seq_hash); + + // Still not available while being used + assert_eq!(manager.available_blocks(), initial_available - 1); + } // matched blocks dropped here + + // Step 4: Drop the original registered block + drop(immutable_block); + + // Block should now be available again (moved to inactive pool when ref count reached 0) + assert_eq!(manager.available_blocks(), initial_available); + assert_eq!(manager.total_blocks(), initial_total); + } + + #[test] + fn test_concurrent_allocation_and_return() { + use std::sync::Arc; + use std::thread; + + let manager = Arc::new(create_test_manager(20)); + let initial_total = manager.total_blocks(); + + let handles: Vec<_> = (0..5) + .map(|i| { + let manager_clone = Arc::clone(&manager); + thread::spawn(move || { + // Each thread allocates and drops some blocks + for j in 0..3 { + let blocks = manager_clone.allocate_blocks(2); + if let Some(blocks) = blocks { + // Complete one block + let token_block = + create_test_token_block_from_iota((600 + i * 10 + j) as u32); + let complete_block = blocks + .into_iter() + .next() + .unwrap() + .complete(token_block) + .expect("Should complete block"); + + // Register and drop + let _immutable_blocks = + manager_clone.register_blocks(vec![complete_block]); + // blocks automatically dropped at end of scope + } + } + }) + }) + .collect(); + + // Wait for all threads to complete + for handle in handles { + handle.join().unwrap(); + } + + // All blocks should eventually be available again + assert_eq!(manager.total_blocks(), initial_total); + // Available might be less than total if some blocks are in inactive pool, + // but total should be preserved + } + + #[test] + fn test_full_block_lifecycle() { + let manager = create_test_manager(10); + let total_blocks = manager.total_blocks(); + assert_eq!(manager.available_blocks(), total_blocks); + + // Step 1: Allocate 5 blocks + let mutable_blocks = manager + .allocate_blocks(5) + .expect("Should allocate 5 blocks"); + assert_eq!(manager.available_blocks(), total_blocks - 5); + assert_eq!(manager.total_blocks(), total_blocks); + + // Step 2: Complete 3 blocks, drop 2 mutable blocks + let mut mutable_blocks_iter = mutable_blocks.into_iter(); + let complete_blocks: Vec<_> = (0..3) + .map(|i| { + let block = mutable_blocks_iter.next().unwrap(); + let tokens = vec![ + 700 + i as u32, + 701 + i as u32, + 702 + i as u32, + 703 + i as u32, + ]; + let token_block = create_token_block(&tokens); + block.complete(token_block).expect("Should complete block") + }) + .collect(); + let mutable_part: Vec<_> = mutable_blocks_iter.collect(); + + drop(mutable_part); // Drop 2 mutable blocks + + // Should have 2 blocks returned to reset pool + assert_eq!(manager.available_blocks(), total_blocks - 3); + + // Step 3: Register the 3 completed blocks + let immutable_blocks = manager.register_blocks(complete_blocks); + assert_eq!(immutable_blocks.len(), 3); + + // Still 3 blocks unavailable (now in active pool) + assert_eq!(manager.available_blocks(), total_blocks - 3); + + // Step 4: Match and use one of the blocks + let seq_hash = create_test_token_block_from_iota(700).sequence_hash(); + let matched_blocks = manager.match_blocks(&[seq_hash]); + assert_eq!(matched_blocks.len(), 1); + + // Step 5: Drop one registered block, keep others + drop(immutable_blocks.into_iter().nth(0)); + + // Still have registered blocks in use, so available count depends on ref counting + let available_after_drop = manager.available_blocks(); + assert!(available_after_drop >= total_blocks - 3); + assert!(available_after_drop <= total_blocks); + + // Step 6: Drop everything + drop(matched_blocks); + + // Eventually all blocks should be available again + // (Some might be in inactive pool, but available_blocks counts both reset and inactive) + assert_eq!(manager.total_blocks(), total_blocks); + let final_available = manager.available_blocks(); + assert_eq!(final_available, total_blocks); // Allow for some blocks in inactive pool + } + } + + // ============================================================================ + // BLOCK SIZE VALIDATION TESTS + // ============================================================================ + + mod block_size_tests { + use super::*; + + #[test] + fn test_default_block_size() { + let manager = create_test_manager(10); + assert_eq!(manager.block_size(), 4); // create_test_manager uses block_size(4) + } + + #[test] + fn test_custom_block_size() { + let manager = BlockManager::::builder() + .block_count(10) + .block_size(32) + .build() + .expect("Should build with custom block size"); + assert_eq!(manager.block_size(), 32); + } + + #[test] + fn test_block_size_validation_correct_size() { + let manager = create_test_manager(10); + let token_block = create_test_token_block_from_iota(100); // 4 tokens + + let mutable_blocks = manager.allocate_blocks(1).expect("Should allocate blocks"); + let mutable_block = mutable_blocks.into_iter().next().unwrap(); + + // Should succeed since token_block has exactly 4 tokens + let result = mutable_block.complete(token_block); + assert!(result.is_ok()); + } + + #[test] + fn test_block_size_validation_wrong_size() { + // Create a manager expecting 8-token blocks + let manager = BlockManager::::builder() + .block_count(10) + .block_size(8) + .with_lru_backend() + .build() + .expect("Should build manager"); + let token_block = create_test_token_block_from_iota(1); // 4 tokens, expected 8 + + let mutable_blocks = manager.allocate_blocks(1).expect("Should allocate blocks"); + let mutable_block = mutable_blocks.into_iter().next().unwrap(); + + // Should fail since token_block has 4 tokens but manager expects 8 + let result = mutable_block.complete(token_block); + assert!(result.is_err()); + + if let Err(BlockError::BlockSizeMismatch { + expected, + actual, + block: _, + }) = result + { + assert_eq!(expected, 8); + assert_eq!(actual, 4); + } else { + panic!("Expected BlockSizeMismatch error"); + } + } + + #[test] + fn test_builder_block_size_power_of_two() { + // Valid power of 2 values should work + for &size in &[1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024] { + let result = BlockManager::::builder() + .block_count(10) + .block_size(size) + .build(); + assert!(result.is_ok(), "Block size {} should be valid", size); + } + } + + #[test] + #[should_panic(expected = "block_size must be a power of 2")] + fn test_builder_block_size_not_power_of_two() { + BlockManager::::builder() + .block_count(10) + .block_size(15); // Not a power of 2 + } + + #[test] + #[should_panic(expected = "block_size must be between 1 and 1024")] + fn test_builder_block_size_too_large() { + BlockManager::::builder() + .block_count(10) + .block_size(2048); // Too large + } + + #[test] + #[should_panic(expected = "block_size must be between 1 and 1024")] + fn test_builder_block_size_zero() { + BlockManager::::builder() + .block_count(10) + .block_size(0); // Zero is invalid + } + + #[test] + #[should_panic(expected = "block_size must be a power of 2")] + fn test_builder_validation_invalid_block_size() { + BlockManager::::builder() + .block_count(10) + .block_size(7); // Not a power of 2, panics immediately + } + + #[test] + fn test_different_block_sizes() { + // Test with block size 4 + let manager_4 = BlockManager::::builder() + .block_count(10) + .block_size(4) + .build() + .expect("Should build with block size 4"); + + let token_block_4 = create_test_token_block_from_iota(10); // 4 tokens + let mutable_blocks = manager_4 + .allocate_blocks(1) + .expect("Should allocate blocks"); + let result = mutable_blocks + .into_iter() + .next() + .unwrap() + .complete(token_block_4); + assert!(result.is_ok()); + + // Test with block size 8 + let manager_8 = BlockManager::::builder() + .block_count(10) + .block_size(8) + .build() + .expect("Should build with block size 8"); + + let token_block_8 = create_test_token_block_8_from_iota(20); // 8 tokens + let mutable_blocks = manager_8 + .allocate_blocks(1) + .expect("Should allocate blocks"); + let result = mutable_blocks + .into_iter() + .next() + .unwrap() + .complete(token_block_8); + assert!(result.is_ok()); + } + } + + // ============================================================================ + // BLOCK REGISTRATION AND DEDUPLICATION TESTS + // ============================================================================ + + mod registration_tests { + use super::*; + + #[test] + fn test_register_single_block() { + let manager = create_test_manager(10); + + let token_block = create_test_token_block_from_iota(150); + let expected_hash = token_block.sequence_hash(); + let mutable_blocks = manager.allocate_blocks(1).expect("Should allocate blocks"); + let complete_block = mutable_blocks + .into_iter() + .next() + .unwrap() + .complete(token_block) + .expect("Should complete block"); + + let immutable_blocks = manager.register_blocks(vec![complete_block]); + assert_eq!(immutable_blocks.len(), 1); + + let immutable_block = immutable_blocks.into_iter().next().unwrap(); + assert_eq!(immutable_block.sequence_hash(), expected_hash); + } + + #[test] + fn test_register_multiple_blocks() { + let manager = create_test_manager(10); + + let mut complete_blocks = Vec::new(); + let mut expected_hashes = Vec::new(); + + for i in 0..3 { + let tokens = vec![100 + i, 101 + i, 102 + i, 103 + i]; + let token_block = create_token_block(&tokens); + expected_hashes.push(token_block.sequence_hash()); + + let mutable_blocks = manager.allocate_blocks(1).expect("Should allocate blocks"); + let complete_block = mutable_blocks + .into_iter() + .next() + .unwrap() + .complete(token_block) + .expect("Should complete block"); + complete_blocks.push(complete_block); + } + + let immutable_blocks = manager.register_blocks(complete_blocks); + assert_eq!(immutable_blocks.len(), 3); + + for (i, immutable_block) in immutable_blocks.iter().enumerate() { + assert_eq!(immutable_block.sequence_hash(), expected_hashes[i]); + } + } + + #[test] + fn test_deduplication_allow_policy() { + let manager = BlockManager::::builder() + .block_count(10) + .block_size(4) + .duplication_policy(BlockDuplicationPolicy::Allow) + .with_lru_backend() + .build() + .expect("Should build manager"); + + let token_block = create_test_token_block_from_iota(200); + let seq_hash = token_block.sequence_hash(); + + // Register the same sequence hash twice + let complete_block1 = { + let mutable_blocks = manager.allocate_blocks(1).expect("Should allocate blocks"); + mutable_blocks + .into_iter() + .next() + .unwrap() + .complete(token_block.clone()) + .expect("Should complete block") + }; + + let complete_block2 = { + let mutable_blocks = manager.allocate_blocks(1).expect("Should allocate blocks"); + mutable_blocks + .into_iter() + .next() + .unwrap() + .complete(token_block) + .expect("Should complete block") + }; + + let immutable_blocks1 = manager.register_blocks(vec![complete_block1]); + let immutable_blocks2 = manager.register_blocks(vec![complete_block2]); + + assert_eq!(immutable_blocks1.len(), 1); + assert_eq!(immutable_blocks2.len(), 1); + + // Both should have the same sequence hash but potentially different block IDs + assert_eq!(immutable_blocks1[0].sequence_hash(), seq_hash); + assert_eq!(immutable_blocks2[0].sequence_hash(), seq_hash); + + // Both should have the same sequence hash but potentially different block IDs + // Duplicates are allowed. + assert_ne!( + immutable_blocks1[0].block_id(), + immutable_blocks2[0].block_id() + ); + } + + #[test] + fn test_deduplication_reject_policy() { + let manager = BlockManager::::builder() + .block_count(10) + .block_size(4) + .duplication_policy(BlockDuplicationPolicy::Reject) + .with_lru_backend() + .build() + .expect("Should build manager"); + + let token_block = create_test_token_block_from_iota(300); + let seq_hash = token_block.sequence_hash(); + + // Register the same sequence hash twice + let complete_block1 = { + let mutable_blocks = manager.allocate_blocks(1).expect("Should allocate blocks"); + mutable_blocks + .into_iter() + .next() + .unwrap() + .complete(token_block.clone()) + .expect("Should complete block") + }; + + let immutable_blocks1 = manager.register_blocks(vec![complete_block1]); + assert_eq!(immutable_blocks1.len(), 1); + assert_eq!(immutable_blocks1[0].sequence_hash(), seq_hash); + + // Register a duplicate - should still work but may reuse the existing registration + let complete_block2 = { + let mutable_blocks = manager.allocate_blocks(1).expect("Should allocate blocks"); + mutable_blocks + .into_iter() + .next() + .unwrap() + .complete(token_block) + .expect("Should complete block") + }; + + let immutable_blocks2 = manager.register_blocks(vec![complete_block2]); + assert_eq!(immutable_blocks2.len(), 1); + assert_eq!(immutable_blocks2[0].sequence_hash(), seq_hash); + + // Duplicates are rejected. + assert_eq!( + immutable_blocks1[0].block_id(), + immutable_blocks2[0].block_id() + ); + } + } + + // ============================================================================ + // BLOCK MATCHING TESTS + // ============================================================================ + + mod matching_tests { + use super::*; + + #[test] + fn test_match_no_blocks() { + let manager = create_test_manager(10); + + let seq_hashes = vec![create_test_token_block_from_iota(400).sequence_hash()]; + let matched_blocks = manager.match_blocks(&seq_hashes); + assert_eq!(matched_blocks.len(), 0); + } + + #[test] + fn test_match_single_block() { + let manager = create_test_manager(10); + + let token_block = create_test_token_block_from_iota(500); + let seq_hash = token_block.sequence_hash(); + + // Register a block + let mutable_blocks = manager.allocate_blocks(1).expect("Should allocate blocks"); + let complete_block = mutable_blocks + .into_iter() + .next() + .unwrap() + .complete(token_block) + .expect("Should complete block"); + let _immutable_blocks = manager.register_blocks(vec![complete_block]); + + // Try to match it + let matched_blocks = manager.match_blocks(&[seq_hash]); + assert_eq!(matched_blocks.len(), 1); + assert_eq!(matched_blocks[0].sequence_hash(), seq_hash); + } + + #[test] + fn test_match_multiple_blocks() { + let manager = create_test_manager(10); + + let mut seq_hashes = Vec::new(); + + // Register multiple blocks + for i in 0..4 { + let tokens = vec![600 + i, 601 + i, 602 + i, 603 + i]; + let token_block = create_token_block(&tokens); + seq_hashes.push(token_block.sequence_hash()); + + let mutable_blocks = manager.allocate_blocks(1).expect("Should allocate blocks"); + let complete_block = mutable_blocks + .into_iter() + .next() + .unwrap() + .complete(token_block) + .expect("Should complete block"); + let _immutable_blocks = manager.register_blocks(vec![complete_block]); + } + + // Match all blocks + let matched_blocks = manager.match_blocks(&seq_hashes); + assert_eq!(matched_blocks.len(), 4); + + for (i, matched_block) in matched_blocks.iter().enumerate() { + assert_eq!(matched_block.sequence_hash(), seq_hashes[i]); + } + } + + #[test] + fn test_match_partial_blocks() { + let manager = create_test_manager(10); + + let mut seq_hashes = Vec::new(); + + // Register only some blocks + for i in 0..3 { + let tokens = vec![700 + i, 701 + i, 702 + i, 703 + i]; + let token_block = create_token_block(&tokens); + seq_hashes.push(token_block.sequence_hash()); + + if i < 2 { + // Only register first 2 blocks + let mutable_blocks = + manager.allocate_blocks(1).expect("Should allocate blocks"); + let complete_block = mutable_blocks + .into_iter() + .next() + .unwrap() + .complete(token_block) + .expect("Should complete block"); + let _immutable_blocks = manager.register_blocks(vec![complete_block]); + } + } + + // Try to match all 3 - should only get 2 + let matched_blocks = manager.match_blocks(&seq_hashes); + assert_eq!(matched_blocks.len(), 2); + + for matched_block in matched_blocks { + assert!(seq_hashes[0..2].contains(&matched_block.sequence_hash())); + } + } + + #[test] + fn test_match_blocks_returns_immutable_blocks() { + let manager = create_test_manager(10); + + let token_block = create_test_token_block_from_iota(800); + let seq_hash = token_block.sequence_hash(); + + // Register a block + let mutable_blocks = manager.allocate_blocks(1).expect("Should allocate blocks"); + let complete_block = mutable_blocks + .into_iter() + .next() + .unwrap() + .complete(token_block) + .expect("Should complete block"); + let _immutable_blocks = manager.register_blocks(vec![complete_block]); + + // Match and verify it's an ImmutableBlock + let matched_blocks = manager.match_blocks(&[seq_hash]); + assert_eq!(matched_blocks.len(), 1); + + let immutable_block = &matched_blocks[0]; + assert_eq!(immutable_block.sequence_hash(), seq_hash); + + // Test that we can downgrade it + let weak_block = immutable_block.downgrade(); + assert_eq!(weak_block.sequence_hash(), seq_hash); + } + } + + // ============================================================================ + // IMMUTABLE BLOCK AND WEAK BLOCK TESTS + // ============================================================================ + + mod immutable_block_tests { + use super::*; + + #[test] + fn test_immutable_block_downgrade_upgrade() { + let manager = create_test_manager(10); + + let token_block = create_test_token_block_from_iota(100); + let mutable_blocks = manager.allocate_blocks(1).expect("Should allocate blocks"); + let complete_block = mutable_blocks + .into_iter() + .next() + .unwrap() + .complete(token_block) + .expect("Should complete block"); + + let immutable_blocks = manager.register_blocks(vec![complete_block]); + let immutable_block = immutable_blocks.into_iter().next().unwrap(); + + // Test downgrade to WeakBlock + let weak_block = immutable_block.downgrade(); + assert_eq!(weak_block.sequence_hash(), immutable_block.sequence_hash()); + + // Test upgrade from WeakBlock + let upgraded_block = weak_block.upgrade().expect("Should be able to upgrade"); + assert_eq!( + upgraded_block.sequence_hash(), + immutable_block.sequence_hash() + ); + assert_eq!(upgraded_block.block_id(), immutable_block.block_id()); + } + + #[test] + fn test_weak_block_upgrade_after_drop() { + let manager = create_test_manager(10); + + let token_block = create_test_token_block_from_iota(200); + let seq_hash = token_block.sequence_hash(); + + // Create a weak block + let weak_block = { + let mutable_blocks = manager.allocate_blocks(1).expect("Should allocate blocks"); + let complete_block = mutable_blocks + .into_iter() + .next() + .unwrap() + .complete(token_block) + .expect("Should complete block"); + let immutable_blocks = manager.register_blocks(vec![complete_block]); + let immutable_block = immutable_blocks.into_iter().next().unwrap(); + + // Downgrade to weak + immutable_block.downgrade() + }; // immutable_block is dropped here + + // The upgrade function should still find the block through the pools + let upgraded_block = weak_block.upgrade(); + + // The result depends on whether the block is still in the pools + if let Some(block) = upgraded_block { + assert_eq!(block.sequence_hash(), seq_hash); + } + } + + #[test] + fn test_weak_block_upgrade_nonexistent() { + let manager = create_test_manager(10); + + let token_block = create_token_block(&[999, 998, 997, 996]); // Keep non-sequential for this test + + // Create an ImmutableBlock and immediately downgrade it + let weak_block = { + let mutable_blocks = manager.allocate_blocks(1).expect("Should allocate blocks"); + let complete_block = mutable_blocks + .into_iter() + .next() + .unwrap() + .complete(token_block) + .expect("Should complete block"); + let immutable_blocks = manager.register_blocks(vec![complete_block]); + let immutable_block = immutable_blocks.into_iter().next().unwrap(); + immutable_block.downgrade() + }; + + // Force eviction by filling up the pool with other blocks + for i in 0..10 { + let tokens = vec![1000 + i, 1001 + i, 1002 + i, 1003 + i]; + let token_block = create_token_block(&tokens); + let mutable_blocks = manager.allocate_blocks(1).expect("Should allocate blocks"); + let complete_block = mutable_blocks + .into_iter() + .next() + .unwrap() + .complete(token_block) + .expect("Should complete block"); + let _immutable_blocks = manager.register_blocks(vec![complete_block]); + } + + // Try to upgrade - might fail if the original block was evicted + let upgraded_block = weak_block.upgrade(); + assert!(upgraded_block.is_none()); + // // This test just verifies that upgrade doesn't panic, result can be None + // if let Some(block) = upgraded_block { + // assert_eq!( + // block.sequence_hash(), + // create_token_block(&[999, 998, 997, 996]).sequence_hash() + // ); + // } + } + + #[test] + fn test_multiple_weak_blocks_same_sequence() { + let manager = create_test_manager(10); + + let token_block = create_test_token_block_from_iota(150); + let seq_hash = token_block.sequence_hash(); + + // Create multiple weak blocks from the same immutable block + let (weak1, weak2, weak3) = { + let mutable_blocks = manager.allocate_blocks(1).expect("Should allocate blocks"); + let complete_block = mutable_blocks + .into_iter() + .next() + .unwrap() + .complete(token_block) + .expect("Should complete block"); + let immutable_blocks = manager.register_blocks(vec![complete_block]); + let immutable_block = immutable_blocks.into_iter().next().unwrap(); + + let w1 = immutable_block.downgrade(); + let w2 = immutable_block.downgrade(); + let w3 = immutable_block.downgrade(); + (w1, w2, w3) + }; + + // All weak blocks should have the same sequence hash + assert_eq!(weak1.sequence_hash(), seq_hash); + assert_eq!(weak2.sequence_hash(), seq_hash); + assert_eq!(weak3.sequence_hash(), seq_hash); + + // All should be able to upgrade + let upgraded1 = weak1.upgrade().expect("Should upgrade"); + let upgraded2 = weak2.upgrade().expect("Should upgrade"); + let upgraded3 = weak3.upgrade().expect("Should upgrade"); + + assert_eq!(upgraded1.sequence_hash(), seq_hash); + assert_eq!(upgraded2.sequence_hash(), seq_hash); + assert_eq!(upgraded3.sequence_hash(), seq_hash); + } + } + + // ============================================================================ + // UPGRADE FUNCTION TESTS + // ============================================================================ + + mod upgrade_function_tests { + use super::*; + + #[test] + fn test_upgrade_function_finds_active_blocks() { + let manager = create_test_manager(10); + + let token_block = create_test_token_block_from_iota(250); + let seq_hash = token_block.sequence_hash(); + + // Register a block (this puts it in active pool initially) + let mutable_blocks = manager.allocate_blocks(1).expect("Should allocate blocks"); + let complete_block = mutable_blocks + .into_iter() + .next() + .unwrap() + .complete(token_block) + .expect("Should complete block"); + let immutable_blocks = manager.register_blocks(vec![complete_block]); + let immutable_block = immutable_blocks.into_iter().next().unwrap(); + + // Create a weak block and test upgrade + let weak_block = immutable_block.downgrade(); + let upgraded = weak_block + .upgrade() + .expect("Should find block in active pool"); + assert_eq!(upgraded.sequence_hash(), seq_hash); + } + + #[test] + fn test_upgrade_function_finds_inactive_blocks() { + let manager = create_test_manager(20); + + let token_block = create_test_token_block_from_iota(350); + let seq_hash = token_block.sequence_hash(); + + // Register a block + let weak_block = { + let mutable_blocks = manager.allocate_blocks(1).expect("Should allocate blocks"); + let complete_block = mutable_blocks + .into_iter() + .next() + .unwrap() + .complete(token_block) + .expect("Should complete block"); + let immutable_blocks = manager.register_blocks(vec![complete_block]); + let immutable_block = immutable_blocks.into_iter().next().unwrap(); + immutable_block.downgrade() + }; + + // Force the block to potentially move to inactive pool by creating many other blocks + for i in 0..10 { + let tokens = vec![400 + i, 401 + i, 402 + i, 403 + i]; + let token_block = create_token_block(&tokens); + let mutable_blocks = manager.allocate_blocks(1).expect("Should allocate blocks"); + let complete_block = mutable_blocks + .into_iter() + .next() + .unwrap() + .complete(token_block) + .expect("Should complete block"); + let _immutable_blocks = manager.register_blocks(vec![complete_block]); + } + + // Try to upgrade - should still find the original block + let upgraded = weak_block.upgrade(); + if let Some(block) = upgraded { + assert_eq!(block.sequence_hash(), seq_hash); + } + } + } + + // ============================================================================ + // ERROR HANDLING AND EDGE CASE TESTS + // ============================================================================ + + mod error_handling_tests { + use super::*; + + #[test] + fn test_allocation_exhaustion() { + let manager = create_test_manager(3); + + // Allocate all blocks + let blocks1 = manager + .allocate_blocks(2) + .expect("Should allocate 2 blocks"); + let blocks2 = manager.allocate_blocks(1).expect("Should allocate 1 block"); + + // Try to allocate more - should fail + let blocks3 = manager.allocate_blocks(1); + assert!( + blocks3.is_none(), + "Should not be able to allocate when pool is empty" + ); + + // Drop some blocks and try again + drop(blocks1); + drop(blocks2); + + // Blocks should be returned to pool automatically + let blocks4 = manager.allocate_blocks(1); + assert!( + blocks4.is_some(), + "Should be able to allocate after blocks are returned" + ); + } + + #[test] + fn test_empty_sequence_matching() { + let manager = create_test_manager(10); + + let matched_blocks = manager.match_blocks(&[]); + assert_eq!(matched_blocks.len(), 0); + } + + #[test] + fn test_register_empty_block_list() { + let manager = create_test_manager(10); + + let immutable_blocks = manager.register_blocks(vec![]); + assert_eq!(immutable_blocks.len(), 0); + } + } + + // ============================================================================ + // INTEGRATION TESTS + // ============================================================================ + + mod integration_tests { + use super::*; + + #[test] + fn test_full_lifecycle_single_block() { + let manager = create_test_manager(10); + + // 1. Allocate a mutable block + let mutable_blocks = manager.allocate_blocks(1).expect("Should allocate"); + let mutable_block = mutable_blocks.into_iter().next().unwrap(); + let block_id = mutable_block.block_id(); + + // 2. Complete the block + let token_block = create_test_token_block_from_iota(1); + let seq_hash = token_block.sequence_hash(); + let complete_block = mutable_block + .complete(token_block) + .expect("Should complete block"); + + assert_eq!(complete_block.block_id(), block_id); + assert_eq!(complete_block.sequence_hash(), seq_hash); + + // 3. Register the block + let immutable_blocks = manager.register_blocks(vec![complete_block]); + let immutable_block = immutable_blocks.into_iter().next().unwrap(); + + assert_eq!(immutable_block.block_id(), block_id); + assert_eq!(immutable_block.sequence_hash(), seq_hash); + + // 4. Match the block + let matched_blocks = manager.match_blocks(&[seq_hash]); + assert_eq!(matched_blocks.len(), 1); + assert_eq!(matched_blocks[0].sequence_hash(), seq_hash); + + // 5. Create weak reference and upgrade + let weak_block = immutable_block.downgrade(); + let upgraded_block = weak_block.upgrade().expect("Should upgrade"); + assert_eq!(upgraded_block.sequence_hash(), seq_hash); + } + + #[test] + fn test_multiple_blocks_different_backends() { + // Test with LRU backend + let manager_lru = BlockManager::::builder() + .block_count(20) + .block_size(4) + .with_lru_backend() + .build() + .expect("Should build"); + + // Test with MultiLRU backend + let manager_multi_lru = BlockManager::::builder() + .block_count(20) + .block_size(4) + .with_multi_lru_backend() + .build() + .expect("Should build"); + + // Test with HashMap backend (skipping HashMap for now due to backend issue) + let managers = vec![manager_lru, manager_multi_lru]; + + for (i, manager) in managers.iter().enumerate() { + // Allocate, complete, and register blocks using BlockSequenceBuilder + let base = 1000 + (i * 20); // Space out sequences for different managers + let tokens: Vec = (base as u32..base as u32 + 20).collect(); // 5 blocks * 4 tokens each = 20 tokens + + let mut seq_hashes = Vec::new(); + let mut complete_blocks = Vec::new(); + + // Create token blocks from sequence + let token_blocks = { + let token_seq = + crate::tokens::TokenBlockSequence::from_slice(&tokens, 4, Some(42)); + token_seq.blocks().to_vec() + }; + + for (j, token_block) in token_blocks.iter().enumerate() { + let seq_hash = token_block.sequence_hash(); + seq_hashes.push(seq_hash); + + // Allocate mutable block and complete it + let mutable_blocks = manager.allocate_blocks(1).expect("Should allocate"); + let complete_block = mutable_blocks + .into_iter() + .next() + .unwrap() + .complete(token_block.clone()) + .expect("Should complete block"); + complete_blocks.push(complete_block); + } + + // Register all blocks + let _immutable_blocks = manager.register_blocks(complete_blocks); + + // Verify all blocks can be matched + let matched_blocks = manager.match_blocks(&seq_hashes); + assert_eq!( + matched_blocks.len(), + 5, + "Manager {} should match all blocks", + i + ); + } + } + + #[test] + fn test_concurrent_allocation_simulation() { + let manager = create_test_manager(50); + + // Simulate concurrent allocations by interleaving operations + let mut all_blocks = Vec::new(); + let mut all_hashes = Vec::new(); + + // Phase 1: Allocate and complete some blocks + for i in 0..10 { + let tokens = vec![2000 + i, 2001 + i, 2002 + i, 2003 + i]; + let token_block = create_token_block(&tokens); + all_hashes.push(token_block.sequence_hash()); + + let mutable_blocks = manager.allocate_blocks(1).expect("Should allocate"); + let complete_block = mutable_blocks + .into_iter() + .next() + .unwrap() + .complete(token_block) + .expect("Should complete block"); + all_blocks.push(complete_block); + } + + // Phase 2: Register half the blocks + let mut remaining_blocks = all_blocks.split_off(5); + let _immutable_blocks1 = manager.register_blocks(all_blocks); + + // Phase 3: Allocate more blocks while some are registered + for i in 10..15 { + let tokens = vec![2000 + i, 2001 + i, 2002 + i, 2003 + i]; + let token_block = create_token_block(&tokens); + all_hashes.push(token_block.sequence_hash()); + + let mutable_blocks = manager.allocate_blocks(1).expect("Should allocate"); + let complete_block = mutable_blocks + .into_iter() + .next() + .unwrap() + .complete(token_block) + .expect("Should complete block"); + remaining_blocks.push(complete_block); + } + + // Phase 4: Register remaining blocks + let _immutable_blocks2 = manager.register_blocks(remaining_blocks); + + // Phase 5: Verify we can match all registered blocks + let matched_blocks = manager.match_blocks(&all_hashes); + assert_eq!( + matched_blocks.len(), + 15, + "Should match all registered blocks" + ); + } + } +} diff --git a/lib/llm/src/block_manager/v2/mod.rs b/lib/llm/src/block_manager/v2/mod.rs new file mode 100644 index 00000000000..50f9e0bcadd --- /dev/null +++ b/lib/llm/src/block_manager/v2/mod.rs @@ -0,0 +1,44 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Block Manager V2 - EXPERIMENTAL +//! +//! A completely redesigned block management system with: +//! - Type-safe state transitions (Reset → Complete → Registered) +//! - Async batched return processing with controllable stepping +//! - Compile-time prevention of accessing registered mutable blocks +//! - Comprehensive testing support for race conditions +//! +//! NOTE: This module is currently experimental and under development. +//! It implements a simplified Block API that differs from the +//! main codebase's Block API. + +// Core modules +// pub mod block; +// pub mod blocks; +// pub mod free_list; +// pub mod registry; + +// V2 implementation modules - now standalone +// #pub mod async_pool; +// pub mod builder; +// pub mod manager; +pub mod pools; + +pub mod manager; +// pub mod progress_engine; +// pub mod state; +// pub mod wrappers; + +// // Test module +// #[cfg(test)] +// pub mod tests; + +// // Public exports +// pub use crate::tokens::SequenceHash; +// pub use block::BlockId; + +// // Re-export key types from the new implementation +// pub use blocks::block::{Block, BlockId}; +// pub use blocks::registry::{BlockRegistrationHandle, BlockRegistry}; +// pub use pools::{ImmutableBlock, MutableBlock, MutableBlockState, RegisteredPool}; diff --git a/lib/llm/src/block_manager/v2/policies/reuse/fifo.rs b/lib/llm/src/block_manager/v2/policies/reuse/fifo.rs new file mode 100644 index 00000000000..581cf5ef39e --- /dev/null +++ b/lib/llm/src/block_manager/v2/policies/reuse/fifo.rs @@ -0,0 +1,358 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! FIFO reuse policy for inactive registered blocks. +//! +//! Allocates blocks in first-in-first-out order based on insertion time. +//! Uses BTreeMap for O(log n) insertion/removal with priority key ordering. + +use super::*; + +use std::collections::{BTreeMap, HashMap}; + +/// Microseconds since epoch +pub type PriorityKey = u64; + +/// FIFO reuse policy +#[derive(Debug)] +pub struct FifoReusePolicy { + keys: HashMap, + blocks: BTreeMap, + start_time: std::time::Instant, +} + +impl Default for FifoReusePolicy { + fn default() -> Self { + Self::new() + } +} + +impl FifoReusePolicy { + pub fn new() -> Self { + Self { + keys: HashMap::new(), + blocks: BTreeMap::new(), + start_time: std::time::Instant::now(), + } + } +} + +impl ReusePolicy for FifoReusePolicy { + fn insert(&mut self, inactive_block: InactiveBlock) -> Result<(), ReusePolicyError> { + assert!( + !self.keys.contains_key(&inactive_block.block_id), + "block already exists" + ); + let priority_key = self.start_time.elapsed().as_millis() as u64; + self.keys.insert(inactive_block.block_id, priority_key); + self.blocks.insert(priority_key, inactive_block); + Ok(()) + } + + fn remove(&mut self, block_id: BlockId) -> Result<(), ReusePolicyError> { + let priority_key = self + .keys + .remove(&block_id) + .ok_or(ReusePolicyError::BlockNotFound(block_id))?; + + assert!( + self.blocks.remove(&priority_key).is_some(), + "block not found" + ); + Ok(()) + } + + fn next_free(&mut self) -> Option { + let next_block = self.blocks.pop_first(); + if let Some((_, block)) = next_block { + assert!( + self.keys.remove(&block.block_id).is_some(), + "block not found" + ); + Some(block) + } else { + None + } + } + + fn is_empty(&self) -> bool { + self.blocks.is_empty() + } + + fn len(&self) -> usize { + self.blocks.len() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::thread; + use std::time::Duration; + + /// Helper function to create InactiveBlock instances for testing + fn create_inactive_block(block_id: u64, seq_hash: u64) -> InactiveBlock { + InactiveBlock { + block_id: block_id, + seq_hash: seq_hash, + } + } + + #[test] + fn test_fifo_ordering_basic() { + let mut policy = FifoReusePolicy::new(); + + // Insert blocks with small delays to ensure different timestamps + let block1 = create_inactive_block(1, 100); + let block2 = create_inactive_block(2, 200); + let block3 = create_inactive_block(3, 300); + + policy.insert(block1).unwrap(); + thread::sleep(Duration::from_millis(1)); + policy.insert(block2).unwrap(); + thread::sleep(Duration::from_millis(1)); + policy.insert(block3).unwrap(); + + // Verify FIFO order - first inserted should come out first + assert_eq!(policy.len(), 3); + assert!(!policy.is_empty()); + + let retrieved1 = policy.next_free().unwrap(); + assert_eq!(retrieved1.block_id, 1); + assert_eq!(retrieved1.seq_hash, 100); + + let retrieved2 = policy.next_free().unwrap(); + assert_eq!(retrieved2.block_id, 2); + assert_eq!(retrieved2.seq_hash, 200); + + let retrieved3 = policy.next_free().unwrap(); + assert_eq!(retrieved3.block_id, 3); + assert_eq!(retrieved3.seq_hash, 300); + + assert!(policy.is_empty()); + assert_eq!(policy.len(), 0); + } + + #[test] + fn test_fifo_ordering_with_delays() { + let mut policy = FifoReusePolicy::new(); + + // Insert blocks with measurable delays to ensure distinct priority keys + let blocks = vec![ + create_inactive_block(10, 1000), + create_inactive_block(20, 2000), + create_inactive_block(30, 3000), + create_inactive_block(40, 4000), + ]; + + for block in blocks { + policy.insert(block).unwrap(); + thread::sleep(Duration::from_millis(5)); // Ensure distinct timestamps + } + + // Retrieve all blocks and verify FIFO order + let expected_order = vec![10, 20, 30, 40]; + let mut retrieved_order = Vec::new(); + + while let Some(block) = policy.next_free() { + retrieved_order.push(block.block_id); + } + + assert_eq!(retrieved_order, expected_order); + } + + #[test] + fn test_insert_and_remove() { + let mut policy = FifoReusePolicy::new(); + + // Insert several blocks + let blocks = vec![ + create_inactive_block(1, 100), + create_inactive_block(2, 200), + create_inactive_block(3, 300), + create_inactive_block(4, 400), + ]; + + for block in blocks { + policy.insert(block).unwrap(); + thread::sleep(Duration::from_millis(1)); + } + + assert_eq!(policy.len(), 4); + + // Remove block 2 (second inserted) + policy.remove(2).unwrap(); + assert_eq!(policy.len(), 3); + + // Retrieve remaining blocks - should be 1, 3, 4 in that order + let retrieved1 = policy.next_free().unwrap(); + assert_eq!(retrieved1.block_id, 1); + + let retrieved2 = policy.next_free().unwrap(); + assert_eq!(retrieved2.block_id, 3); + + let retrieved3 = policy.next_free().unwrap(); + assert_eq!(retrieved3.block_id, 4); + + assert!(policy.is_empty()); + } + + #[test] + fn test_empty_operations() { + let mut policy = FifoReusePolicy::new(); + + // Test empty state + assert!(policy.is_empty()); + assert_eq!(policy.len(), 0); + assert!(policy.next_free().is_none()); + + // Insert and remove a block + let block = create_inactive_block(1, 100); + policy.insert(block).unwrap(); + assert!(!policy.is_empty()); + assert_eq!(policy.len(), 1); + + let retrieved = policy.next_free().unwrap(); + assert_eq!(retrieved.block_id, 1); + + // Should be empty again + assert!(policy.is_empty()); + assert_eq!(policy.len(), 0); + assert!(policy.next_free().is_none()); + } + + #[test] + #[should_panic(expected = "block already exists")] + fn test_duplicate_block_panic() { + let mut policy = FifoReusePolicy::new(); + + let block = create_inactive_block(1, 100); + policy.insert(block).unwrap(); + + // Inserting the same block ID again should panic + let duplicate_block = create_inactive_block(1, 200); // Same ID, different hash + policy.insert(duplicate_block).unwrap(); + } + + #[test] + fn test_remove_nonexistent_block() { + let mut policy = FifoReusePolicy::new(); + + // Try to remove from empty policy + let result = policy.remove(999); + assert!(matches!(result, Err(ReusePolicyError::BlockNotFound(_)))); + + // Insert a block and try to remove a different one + let block = create_inactive_block(1, 100); + policy.insert(block).unwrap(); + + let result = policy.remove(999); + assert!(matches!(result, Err(ReusePolicyError::BlockNotFound(_)))); + + // Verify the original block is still there + assert_eq!(policy.len(), 1); + let retrieved = policy.next_free().unwrap(); + assert_eq!(retrieved.block_id, 1); + } + + #[test] + fn test_interleaved_operations() { + let mut policy = FifoReusePolicy::new(); + + // Insert some blocks + policy.insert(create_inactive_block(1, 100)).unwrap(); + thread::sleep(Duration::from_millis(1)); + policy.insert(create_inactive_block(2, 200)).unwrap(); + thread::sleep(Duration::from_millis(1)); + policy.insert(create_inactive_block(3, 300)).unwrap(); + + // Remove the first one + let first = policy.next_free().unwrap(); + assert_eq!(first.block_id, 1); + + // Insert another block + thread::sleep(Duration::from_millis(1)); + policy.insert(create_inactive_block(4, 400)).unwrap(); + + // Remove a specific block by ID + policy.remove(3).unwrap(); + + // Insert another block + thread::sleep(Duration::from_millis(1)); + policy.insert(create_inactive_block(5, 500)).unwrap(); + + // The remaining blocks should come out in order: 2, 4, 5 + let second = policy.next_free().unwrap(); + assert_eq!(second.block_id, 2); + + let third = policy.next_free().unwrap(); + assert_eq!(third.block_id, 4); + + let fourth = policy.next_free().unwrap(); + assert_eq!(fourth.block_id, 5); + + assert!(policy.is_empty()); + } + + #[test] + fn test_priority_key_ordering() { + let mut policy = FifoReusePolicy::new(); + + // Record the start time for manual verification + let start = std::time::Instant::now(); + + // Insert blocks with controlled timing + let mut insertion_times = Vec::new(); + + for i in 1..=5 { + let elapsed_before = start.elapsed().as_millis() as u64; + policy.insert(create_inactive_block(i, i * 100)).unwrap(); + let elapsed_after = start.elapsed().as_millis() as u64; + insertion_times.push((i, elapsed_before, elapsed_after)); + + // Small delay to ensure different timestamps + thread::sleep(Duration::from_millis(2)); + } + + // Retrieve all blocks and verify they come out in insertion order + let mut retrieval_order: Vec = Vec::new(); + while let Some(block) = policy.next_free() { + retrieval_order.push(block.block_id); + } + + let expected_order: Vec = (1..=5).collect(); + assert_eq!(retrieval_order, expected_order); + + // Print timing information for manual verification + println!("Insertion timing verification:"); + for (block_id, before, after) in insertion_times { + println!( + "Block {}: inserted between {}ms and {}ms", + block_id, before, after + ); + } + } + + #[test] + fn test_btreemap_ordering_assumption() { + use std::collections::BTreeMap; + + // Verify our assumption about BTreeMap ordering with u64 keys + let mut map = BTreeMap::new(); + + // Insert keys in non-sorted order + map.insert(100u64, "hundred"); + map.insert(10u64, "ten"); + map.insert(50u64, "fifty"); + map.insert(1u64, "one"); + map.insert(200u64, "two_hundred"); + + // pop_first should return the smallest key first + assert_eq!(map.pop_first(), Some((1, "one"))); + assert_eq!(map.pop_first(), Some((10, "ten"))); + assert_eq!(map.pop_first(), Some((50, "fifty"))); + assert_eq!(map.pop_first(), Some((100, "hundred"))); + assert_eq!(map.pop_first(), Some((200, "two_hundred"))); + assert_eq!(map.pop_first(), None); + } +} diff --git a/lib/llm/src/block_manager/v2/pools/block.rs b/lib/llm/src/block_manager/v2/pools/block.rs new file mode 100644 index 00000000000..bc90fc73435 --- /dev/null +++ b/lib/llm/src/block_manager/v2/pools/block.rs @@ -0,0 +1,119 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Type-state pattern for block lifecycle with compile-time state enforcement. +//! +//! Blocks transition through states: Reset → Complete → Registered → Reset. +//! The type system prevents invalid state transitions at compile time. + +use super::registry::BlockRegistrationHandle; +use crate::tokens::{SequenceHash, TokenBlock}; +use std::marker::PhantomData; + +/// Block identifier type +pub type BlockId = u64; + +// Generic Block with marker and state markers +#[derive(Debug)] +pub struct Block { + block_id: BlockId, + state: State, + marker: PhantomData, +} + +// State marker types +#[derive(Debug)] +pub struct Reset; + +// State-specific data holders +#[derive(Debug)] +pub struct Complete { + token_block: TokenBlock, +} + +#[derive(Debug)] +pub struct Registered { + sequence_hash: SequenceHash, + registration_handle: BlockRegistrationHandle, +} + +// Implementation for Reset state +impl Block { + pub fn new(block_id: BlockId) -> Self { + Self { + block_id, + state: Reset, + marker: PhantomData, + } + } + + pub fn complete(self, token_block: TokenBlock) -> Block { + Block { + block_id: self.block_id, + state: Complete { token_block }, + marker: PhantomData, + } + } + + pub fn reset(self) -> Block { + self // Already in reset state + } +} + +// Implementation for Complete state +impl Block { + pub fn register(self, registration_handle: BlockRegistrationHandle) -> Block { + Block { + block_id: self.block_id, + state: Registered { + sequence_hash: self.state.token_block.sequence_hash(), + registration_handle, + }, + marker: PhantomData, + } + } + + pub fn token_block(&self) -> &TokenBlock { + &self.state.token_block + } + + pub fn sequence_hash(&self) -> SequenceHash { + self.state.token_block.sequence_hash() + } + + pub fn reset(self) -> Block { + Block { + block_id: self.block_id, + state: Reset, + marker: PhantomData, + } + } +} + +// Implementation for Registered state +impl Block { + pub fn sequence_hash(&self) -> SequenceHash { + self.state.sequence_hash + } + + pub(crate) fn registration_handle(&self) -> &BlockRegistrationHandle { + &self.state.registration_handle + } + + pub fn reset(self) -> Block { + // Drop the registration handle + Block { + block_id: self.block_id, + state: Reset, + marker: PhantomData, + } + } +} + +// Common methods for all states +impl Block { + #[inline] + pub fn block_id(&self) -> BlockId { + self.block_id + } +} diff --git a/lib/llm/src/block_manager/v2/pools/frequency_sketch.rs b/lib/llm/src/block_manager/v2/pools/frequency_sketch.rs new file mode 100644 index 00000000000..210ed5e56b8 --- /dev/null +++ b/lib/llm/src/block_manager/v2/pools/frequency_sketch.rs @@ -0,0 +1,220 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Frequency tracking for block reuse policies using Count-Min Sketch. + +use parking_lot::Mutex; +use xxhash_rust::const_xxh3::const_custom_default_secret; +use xxhash_rust::xxh3::xxh3_64_with_secret; + +const SECRET_0: &[u8; 192] = &const_custom_default_secret(0); +const SECRET_1: &[u8; 192] = &const_custom_default_secret(1); +const SECRET_2: &[u8; 192] = &const_custom_default_secret(2); +const SECRET_3: &[u8; 192] = &const_custom_default_secret(3); + +pub struct TinyLFUSketch { + table: Vec, + size: u32, + sample_size: u32, +} + +impl TinyLFUSketch { + const RESET_MASK: u64 = 0x7777_7777_7777_7777; + const ONE_MASK: u64 = 0x1111_1111_1111_1111; + + pub fn new(capacity: usize) -> Self { + let table_size = std::cmp::max(1, capacity / 4); + let sample_size = capacity.saturating_mul(10).min(u32::MAX as usize) as u32; + + Self { + table: vec![0; table_size], + size: 0, + sample_size, + } + } + + fn hash(key: u64, seed: u32) -> u64 { + let key_bytes = key.to_le_bytes(); + let secret = match seed { + 0 => SECRET_0, + 1 => SECRET_1, + 2 => SECRET_2, + 3 => SECRET_3, + _ => SECRET_0, + }; + xxh3_64_with_secret(&key_bytes, secret) + } + + pub fn increment(&mut self, key: u64) { + if self.table.is_empty() { + return; + } + + let mut added = false; + + for i in 0..4 { + let hash = Self::hash(key, i); + let table_index = (hash as usize) % self.table.len(); + let counter_index = (hash & 15) as u8; + + if self.increment_at(table_index, counter_index) { + added = true; + } + } + + if added { + self.size += 1; + if self.size >= self.sample_size { + self.reset(); + } + } + } + + fn increment_at(&mut self, table_index: usize, counter_index: u8) -> bool { + let offset = (counter_index as usize) * 4; + let mask = 0xF_u64 << offset; + + if self.table[table_index] & mask != mask { + self.table[table_index] += 1u64 << offset; + true + } else { + false + } + } + + pub fn estimate(&self, key: u64) -> u32 { + if self.table.is_empty() { + return 0; + } + + let mut min_count = u32::MAX; + + for i in 0..4 { + let hash = Self::hash(key, i); + let table_index = (hash as usize) % self.table.len(); + let counter_index = (hash & 15) as u8; + let count = self.count_at(table_index, counter_index); + min_count = min_count.min(count as u32); + } + + min_count + } + + fn count_at(&self, table_index: usize, counter_index: u8) -> u8 { + let offset = (counter_index as usize) * 4; + let mask = 0xF_u64 << offset; + ((self.table[table_index] & mask) >> offset) as u8 + } + + fn reset(&mut self) { + let mut count = 0u32; + + for entry in self.table.iter_mut() { + count += (*entry & Self::ONE_MASK).count_ones(); + *entry = (*entry >> 1) & Self::RESET_MASK; + } + + self.size = (self.size >> 1) - (count >> 2); + } +} + +pub trait FrequencyTracker: Send + Sync { + fn touch(&self, key: u64); + fn count(&self, key: u64) -> u32; +} + +pub struct TinyLFUTracker { + sketch: Mutex, +} + +impl TinyLFUTracker { + pub fn new(capacity: usize) -> Self { + Self { + sketch: Mutex::new(TinyLFUSketch::new(capacity)), + } + } +} + +impl FrequencyTracker for TinyLFUTracker { + fn touch(&self, key: u64) { + self.sketch.lock().increment(key); + } + + fn count(&self, key: u64) -> u32 { + self.sketch.lock().estimate(key) + } +} + +pub struct NoOpTracker; + +impl FrequencyTracker for NoOpTracker { + fn touch(&self, _key: u64) {} + fn count(&self, _key: u64) -> u32 { + 0 + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_tinylfu_increment_and_estimate() { + let mut sketch = TinyLFUSketch::new(100); + + sketch.increment(42); + assert_eq!(sketch.estimate(42), 1); + + sketch.increment(42); + sketch.increment(42); + assert_eq!(sketch.estimate(42), 3); + + assert_eq!(sketch.estimate(99), 0); + } + + #[test] + fn test_tinylfu_saturation() { + let mut sketch = TinyLFUSketch::new(100); + + for _ in 0..20 { + sketch.increment(42); + } + + assert!(sketch.estimate(42) <= 15); + } + + #[test] + fn test_tinylfu_reset() { + let mut sketch = TinyLFUSketch::new(10); + + for i in 0..100 { + sketch.increment(i); + } + + let estimate_before = sketch.estimate(5); + assert!(estimate_before > 0); + } + + #[test] + fn test_frequency_tracker_trait() { + let tracker = TinyLFUTracker::new(100); + + tracker.touch(42); + assert_eq!(tracker.count(42), 1); + + tracker.touch(42); + tracker.touch(42); + assert_eq!(tracker.count(42), 3); + } + + #[test] + fn test_noop_tracker() { + let tracker = NoOpTracker; + + tracker.touch(42); + assert_eq!(tracker.count(42), 0); + + tracker.touch(42); + assert_eq!(tracker.count(42), 0); + } +} \ No newline at end of file diff --git a/lib/llm/src/block_manager/v2/pools/inactive/backends/lru_backend.rs b/lib/llm/src/block_manager/v2/pools/inactive/backends/lru_backend.rs new file mode 100644 index 00000000000..43e5d803511 --- /dev/null +++ b/lib/llm/src/block_manager/v2/pools/inactive/backends/lru_backend.rs @@ -0,0 +1,190 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use std::num::NonZeroUsize; + +use lru::LruCache; + +use crate::block_manager::v2::pools::{ + BlockMetadata, SequenceHash, + block::{Block, Registered}, +}; + +use super::super::InactivePoolBackend; + +pub struct LruBackend { + cache: LruCache>, +} + +impl LruBackend { + pub fn new(capacity: NonZeroUsize) -> Self { + Self { + cache: LruCache::new(capacity), + } + } +} + +impl InactivePoolBackend for LruBackend { + fn find_matches(&mut self, hashes: &[SequenceHash]) -> Vec> { + let mut matches = Vec::with_capacity(hashes.len()); + + for hash in hashes { + if let Some(block) = self.cache.pop(hash) { + matches.push(block); + } else { + break; + } + } + + matches + } + + fn allocate(&mut self, count: usize) -> Vec> { + let mut allocated = Vec::with_capacity(count); + + for _ in 0..count { + if let Some((_seq_hash, block)) = self.cache.pop_lru() { + allocated.push(block); + } else { + break; + } + } + + allocated + } + + fn insert(&mut self, block: Block) { + let seq_hash = block.sequence_hash(); + + // Assert we're not causing an eviction + debug_assert!( + self.cache.len() < self.cache.cap().get(), + "LRU backend insert would cause eviction! len={}, cap={}. \ + This indicates insufficient capacity for all blocks.", + self.cache.len(), + self.cache.cap().get() + ); + + self.cache.put(seq_hash, block); + } + + fn len(&self) -> usize { + self.cache.len() + } + + fn has_block(&self, seq_hash: SequenceHash) -> bool { + self.cache.peek(&seq_hash).is_some() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::block_manager::v2::pools::{block::Block, registry::BlockRegistry}; + use crate::tokens::TokenBlockSequence; + + #[derive(Debug, Clone, PartialEq)] + struct TestData { + value: u64, + } + + fn create_registered_block(id: u64, token_value: u32) -> (Block, u64) { + let tokens = vec![ + token_value, + token_value + 1, + token_value + 2, + token_value + 3, + ]; + let token_block_seq = TokenBlockSequence::from_slice(&tokens, 4, Some(42)); + let token_block = if let Some(block) = token_block_seq.blocks().first() { + block.clone() + } else { + let mut partial = token_block_seq.into_parts().1; + partial.commit().expect("Should be able to commit") + }; + + let actual_seq_hash = token_block.sequence_hash(); + let registry = BlockRegistry::new(); + let handle = registry.register_sequence_hash(actual_seq_hash); + + let final_block = Block::new(id).complete(token_block).register(handle); + (final_block, actual_seq_hash) + } + + #[test] + fn test_lru_eviction_order() { + let mut backend = LruBackend::new(NonZeroUsize::new(3).unwrap()); + + let (block1, hash1) = create_registered_block(1, 100); + let (block2, hash2) = create_registered_block(2, 200); + let (block3, hash3) = create_registered_block(3, 300); + + backend.insert(block1); + backend.insert(block2); + backend.insert(block3); + + assert_eq!(backend.len(), 3); + + let allocated = backend.allocate(1); + assert_eq!(allocated.len(), 1); + assert_eq!(allocated[0].block_id(), 1); + + assert!(!backend.has_block(hash1)); + assert!(backend.has_block(hash2)); + assert!(backend.has_block(hash3)); + } + + #[test] + fn test_lru_capacity_limit() { + let mut backend = LruBackend::new(NonZeroUsize::new(2).unwrap()); + + let (block1, hash1) = create_registered_block(1, 100); + let (block2, hash2) = create_registered_block(2, 200); + let (block3, hash3) = create_registered_block(3, 300); + + backend.insert(block1); + backend.insert(block2); + assert_eq!(backend.len(), 2); + + backend.insert(block3); + assert_eq!(backend.len(), 2); + + assert!(!backend.has_block(hash1)); + assert!(backend.has_block(hash2)); + assert!(backend.has_block(hash3)); + } + + #[test] + fn test_lru_peek_doesnt_affect_order() { + let mut backend = LruBackend::new(NonZeroUsize::new(2).unwrap()); + + let (block1, hash1) = create_registered_block(1, 100); + let (block2, hash2) = create_registered_block(2, 200); + + backend.insert(block1); + backend.insert(block2); + + assert!(backend.has_block(hash1)); + + let (block3, hash3) = create_registered_block(3, 300); + backend.insert(block3); + + assert!(!backend.has_block(hash1)); + assert!(backend.has_block(hash2)); + assert!(backend.has_block(hash3)); + } + + #[test] + fn test_lru_allocate_more_than_available() { + let mut backend = LruBackend::new(NonZeroUsize::new(10).unwrap()); + + let (block1, _) = create_registered_block(1, 100); + let (block2, _) = create_registered_block(2, 200); + backend.insert(block1); + backend.insert(block2); + + let allocated = backend.allocate(5); + assert_eq!(allocated.len(), 2); + assert_eq!(backend.len(), 0); + } +} diff --git a/lib/llm/src/block_manager/v2/pools/inactive/backends/multi_lru_backend.rs b/lib/llm/src/block_manager/v2/pools/inactive/backends/multi_lru_backend.rs new file mode 100644 index 00000000000..fa5e7254cf1 --- /dev/null +++ b/lib/llm/src/block_manager/v2/pools/inactive/backends/multi_lru_backend.rs @@ -0,0 +1,345 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use std::num::NonZeroUsize; +use std::sync::Arc; + +use lru::LruCache; + +use crate::block_manager::v2::pools::{ + BlockMetadata, SequenceHash, + block::{Block, Registered}, + frequency_sketch::FrequencyTracker, +}; + +use super::super::InactivePoolBackend; + +pub struct MultiLruBackend { + priority_pools: [LruCache>; 4], + frequency_tracker: Arc, + frequency_thresholds: [u8; 3], +} + +impl MultiLruBackend { + pub fn new(capacity: NonZeroUsize, frequency_tracker: Arc) -> Self { + let level_capacity = NonZeroUsize::new( + std::cmp::max(1, capacity.get() / 4) + ).unwrap(); + + Self { + priority_pools: [ + LruCache::new(level_capacity), + LruCache::new(level_capacity), + LruCache::new(level_capacity), + LruCache::new(level_capacity), + ], + frequency_tracker, + frequency_thresholds: [2, 6, 15], // Old default for backward compatibility + } + } + + /// Create with custom frequency thresholds + /// The 4 levels are fixed, but thresholds can be customized + /// + /// # Arguments + /// * `capacity_per_level` - Capacity for each of the 4 LRU pools + /// * `thresholds` - Array of 3 thresholds: [cold->warm, warm->hot, hot->very_hot] + /// * `frequency_tracker` - Shared frequency tracker + pub fn new_with_thresholds( + capacity_per_level: NonZeroUsize, + thresholds: &[u8; 3], + frequency_tracker: Arc, + ) -> Self { + // Validate thresholds + debug_assert!( + thresholds[0] < thresholds[1] && thresholds[1] < thresholds[2], + "Thresholds must be in ascending order: {:?}", + thresholds + ); + debug_assert!( + thresholds[2] <= 15, + "Maximum threshold cannot exceed 15 (4-bit counter limit), got: {}", + thresholds[2] + ); + + Self { + priority_pools: [ + LruCache::new(capacity_per_level), + LruCache::new(capacity_per_level), + LruCache::new(capacity_per_level), + LruCache::new(capacity_per_level), + ], + frequency_tracker, + frequency_thresholds: *thresholds, + } + } + + fn calculate_priority_level(&self, seq_hash: SequenceHash) -> usize { + let frequency = self.frequency_tracker.count(seq_hash); + let [t1, t2, t3] = self.frequency_thresholds; + + if frequency < t1 as u32 { + 0 // Cold: 0 to (t1 - 1) + } else if frequency < t2 as u32 { + 1 // Warm: t1 to (t2 - 1) + } else if frequency < t3 as u32 { + 2 // Hot: t2 to (t3 - 1) + } else { + 3 // Very Hot: t3 to 15 + } + } +} + +impl InactivePoolBackend for MultiLruBackend { + fn find_matches(&mut self, hashes: &[SequenceHash]) -> Vec> { + let mut matches = Vec::with_capacity(hashes.len()); + + for hash in hashes { + let mut found = false; + + for pool in &mut self.priority_pools { + if let Some(block) = pool.pop(hash) { + matches.push(block); + found = true; + break; + } + } + + if !found { + break; + } + } + + matches + } + + fn allocate(&mut self, count: usize) -> Vec> { + let mut allocated = Vec::with_capacity(count); + + for _ in 0..count { + let mut found = false; + + for pool in &mut self.priority_pools { + if let Some((_seq_hash, block)) = pool.pop_lru() { + allocated.push(block); + found = true; + break; + } + } + + if !found { + break; + } + } + + allocated + } + + fn insert(&mut self, block: Block) { + let seq_hash = block.sequence_hash(); + let level = self.calculate_priority_level(seq_hash); + + // Assert the target pool isn't full (would cause eviction) + debug_assert!( + self.priority_pools[level].len() < self.priority_pools[level].cap().get(), + "MultiLRU level {} insert would cause eviction! len={}, cap={}. \ + This indicates insufficient capacity for all blocks.", + level, + self.priority_pools[level].len(), + self.priority_pools[level].cap().get() + ); + + self.priority_pools[level].put(seq_hash, block); + } + + fn len(&self) -> usize { + self.priority_pools.iter().map(|pool| pool.len()).sum() + } + + fn has_block(&self, seq_hash: SequenceHash) -> bool { + self.priority_pools.iter().any(|pool| pool.peek(&seq_hash).is_some()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::block_manager::v2::pools::{ + block::Block, + frequency_sketch::TinyLFUTracker, + registry::BlockRegistry, + }; + use crate::tokens::TokenBlockSequence; + + #[derive(Debug, Clone, PartialEq)] + struct TestData { + value: u64, + } + + fn create_registered_block(id: u64, token_value: u32) -> (Block, u64) { + let tokens = vec![token_value, token_value + 1, token_value + 2, token_value + 3]; + let token_block_seq = TokenBlockSequence::from_slice(&tokens, 4, Some(42)); + let token_block = if let Some(block) = token_block_seq.blocks().first() { + block.clone() + } else { + let mut partial = token_block_seq.into_parts().1; + partial.commit().expect("Should be able to commit") + }; + + let actual_seq_hash = token_block.sequence_hash(); + let registry = BlockRegistry::new(); + let handle = registry.register_sequence_hash(actual_seq_hash); + + let final_block = Block::new(id).complete(token_block).register(handle); + (final_block, actual_seq_hash) + } + + #[test] + fn test_multi_lru_priority_levels() { + let frequency_tracker = Arc::new(TinyLFUTracker::new(100)); + let mut backend = MultiLruBackend::new(NonZeroUsize::new(12).unwrap(), frequency_tracker.clone()); + + let (block1, hash1) = create_registered_block(1, 100); + let (block2, hash2) = create_registered_block(2, 200); + let (block3, hash3) = create_registered_block(3, 300); + let (block4, hash4) = create_registered_block(4, 400); + + frequency_tracker.touch(hash2); + frequency_tracker.touch(hash2); + + for _ in 0..6 { + frequency_tracker.touch(hash3); + } + + for _ in 0..16 { + frequency_tracker.touch(hash4); + } + + let freq1 = frequency_tracker.count(hash1); + let freq2 = frequency_tracker.count(hash2); + let freq3 = frequency_tracker.count(hash3); + let freq4 = frequency_tracker.count(hash4); + + assert_eq!(backend.calculate_priority_level(hash1), 0); // Cold + assert_eq!(backend.calculate_priority_level(hash2), 1); // Warm + assert_eq!(backend.calculate_priority_level(hash3), 2); // Hot + assert_eq!(backend.calculate_priority_level(hash4), 3); // Very hot (15) + + backend.insert(block1); + backend.insert(block2); + backend.insert(block3); + backend.insert(block4); + + assert_eq!(backend.len(), 4); + assert!(backend.has_block(hash1)); + assert!(backend.has_block(hash2)); + assert!(backend.has_block(hash3)); + assert!(backend.has_block(hash4)); + } + + #[test] + fn test_multi_lru_eviction_order() { + let frequency_tracker = Arc::new(TinyLFUTracker::new(100)); + let mut backend = MultiLruBackend::new(NonZeroUsize::new(8).unwrap(), frequency_tracker.clone()); + + let (block1, hash1) = create_registered_block(1, 100); + let (block2, hash2) = create_registered_block(2, 200); + let (block3, hash3) = create_registered_block(3, 300); + + for _ in 0..6 { + frequency_tracker.touch(hash3); + } + + backend.insert(block1); + backend.insert(block2); + backend.insert(block3); + + let allocated = backend.allocate(2); + assert_eq!(allocated.len(), 2); + assert_eq!(allocated[0].block_id(), 1); + assert_eq!(allocated[1].block_id(), 2); + + assert!(!backend.has_block(hash1)); + assert!(!backend.has_block(hash2)); + assert!(backend.has_block(hash3)); + } + + #[test] + fn test_multi_lru_find_matches() { + let frequency_tracker = Arc::new(TinyLFUTracker::new(100)); + let mut backend = MultiLruBackend::new(NonZeroUsize::new(8).unwrap(), frequency_tracker.clone()); + + let (block1, hash1) = create_registered_block(1, 100); + let (block2, hash2) = create_registered_block(2, 200); + let (block3, hash3) = create_registered_block(3, 300); + + for _ in 0..3 { + frequency_tracker.touch(hash2); + } + + for _ in 0..10 { + frequency_tracker.touch(hash3); + } + + backend.insert(block1); + backend.insert(block2); + backend.insert(block3); + + let matches = backend.find_matches(&[hash1, hash2, hash3]); + assert_eq!(matches.len(), 3); + assert_eq!(backend.len(), 0); + } + + #[test] + fn test_multi_lru_capacity_distribution() { + let frequency_tracker = Arc::new(TinyLFUTracker::new(100)); + let mut backend = MultiLruBackend::new(NonZeroUsize::new(8).unwrap(), frequency_tracker.clone()); + + let (block1, hash1) = create_registered_block(1, 100); + let (block2, hash2) = create_registered_block(2, 200); + let (block3, hash3) = create_registered_block(3, 300); + let (block4, hash4) = create_registered_block(4, 400); + + for _ in 0..3 { + frequency_tracker.touch(hash2); + } + + for _ in 0..7 { + frequency_tracker.touch(hash3); + } + + for _ in 0..15 { + frequency_tracker.touch(hash4); + } + + backend.insert(block1); + backend.insert(block2); + backend.insert(block3); + backend.insert(block4); + + assert_eq!(backend.len(), 4); + assert!(backend.has_block(hash1)); + assert!(backend.has_block(hash2)); + assert!(backend.has_block(hash3)); + assert!(backend.has_block(hash4)); + + let (block5, hash5) = create_registered_block(5, 500); + let (block6, hash6) = create_registered_block(6, 600); + let (block7, hash7) = create_registered_block(7, 700); + let (block8, hash8) = create_registered_block(8, 800); + + backend.insert(block5); + backend.insert(block6); + backend.insert(block7); + backend.insert(block8); + + let current_len = backend.len(); + assert!(current_len >= 4 && current_len <= 8); + + let (block9, _hash9) = create_registered_block(9, 900); + backend.insert(block9); + + let new_len = backend.len(); + assert!(new_len >= 4 && new_len <= 8); + } +} \ No newline at end of file diff --git a/lib/llm/src/block_manager/v2/pools/inactive/backends/reuse/fifo.rs b/lib/llm/src/block_manager/v2/pools/inactive/backends/reuse/fifo.rs new file mode 100644 index 00000000000..2256ac12da7 --- /dev/null +++ b/lib/llm/src/block_manager/v2/pools/inactive/backends/reuse/fifo.rs @@ -0,0 +1,346 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! FIFO reuse policy for inactive registered blocks. +//! +//! Allocates blocks in first-in-first-out order based on insertion time. +//! Uses BTreeMap for O(log n) insertion/removal with priority key ordering. + +use super::*; + +use std::collections::{BTreeMap, HashMap}; + +/// Microseconds since epoch +pub type PriorityKey = u64; + +/// FIFO reuse policy +#[derive(Debug)] +pub struct FifoReusePolicy { + keys: HashMap, + blocks: BTreeMap, + start_time: std::time::Instant, +} + +impl FifoReusePolicy { + pub fn new() -> Self { + Self { + keys: HashMap::new(), + blocks: BTreeMap::new(), + start_time: std::time::Instant::now(), + } + } +} + +impl ReusePolicy for FifoReusePolicy { + fn insert(&mut self, inactive_block: InactiveBlock) -> Result<(), ReusePolicyError> { + assert!( + !self.keys.contains_key(&inactive_block.block_id), + "block already exists" + ); + let priority_key = self.start_time.elapsed().as_millis() as u64; + self.keys.insert(inactive_block.block_id, priority_key); + self.blocks.insert(priority_key, inactive_block); + Ok(()) + } + + fn remove(&mut self, block_id: BlockId) -> Result<(), ReusePolicyError> { + let priority_key = self + .keys + .remove(&block_id) + .ok_or(ReusePolicyError::BlockNotFound(block_id))?; + + assert!( + self.blocks.remove(&priority_key).is_some(), + "block not found" + ); + Ok(()) + } + + fn next_free(&mut self) -> Option { + let next_block = self.blocks.pop_first(); + if let Some((_, block)) = next_block { + assert!(self.keys.remove(&block.block_id).is_some(), "block not found"); + Some(block) + } else { + None + } + } + + fn is_empty(&self) -> bool { + self.blocks.is_empty() + } + + fn len(&self) -> usize { + self.blocks.len() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::thread; + use std::time::Duration; + + /// Helper function to create InactiveBlock instances for testing + fn create_inactive_block(block_id: u64, seq_hash: u64) -> InactiveBlock { + InactiveBlock { + block_id: block_id, + seq_hash: seq_hash, + } + } + + #[test] + fn test_fifo_ordering_basic() { + let mut policy = FifoReusePolicy::new(); + + // Insert blocks with small delays to ensure different timestamps + let block1 = create_inactive_block(1, 100); + let block2 = create_inactive_block(2, 200); + let block3 = create_inactive_block(3, 300); + + policy.insert(block1).unwrap(); + thread::sleep(Duration::from_millis(1)); + policy.insert(block2).unwrap(); + thread::sleep(Duration::from_millis(1)); + policy.insert(block3).unwrap(); + + // Verify FIFO order - first inserted should come out first + assert_eq!(policy.len(), 3); + assert!(!policy.is_empty()); + + let retrieved1 = policy.next_free().unwrap(); + assert_eq!(retrieved1.block_id, 1); + assert_eq!(retrieved1.seq_hash, 100); + + let retrieved2 = policy.next_free().unwrap(); + assert_eq!(retrieved2.block_id, 2); + assert_eq!(retrieved2.seq_hash, 200); + + let retrieved3 = policy.next_free().unwrap(); + assert_eq!(retrieved3.block_id, 3); + assert_eq!(retrieved3.seq_hash, 300); + + assert!(policy.is_empty()); + assert_eq!(policy.len(), 0); + } + + #[test] + fn test_fifo_ordering_with_delays() { + let mut policy = FifoReusePolicy::new(); + + // Insert blocks with measurable delays to ensure distinct priority keys + let blocks = vec![ + create_inactive_block(10, 1000), + create_inactive_block(20, 2000), + create_inactive_block(30, 3000), + create_inactive_block(40, 4000), + ]; + + for block in blocks { + policy.insert(block).unwrap(); + thread::sleep(Duration::from_millis(5)); // Ensure distinct timestamps + } + + // Retrieve all blocks and verify FIFO order + let expected_order = vec![10, 20, 30, 40]; + let mut retrieved_order = Vec::new(); + + while let Some(block) = policy.next_free() { + retrieved_order.push(block.block_id); + } + + assert_eq!(retrieved_order, expected_order); + } + + #[test] + fn test_insert_and_remove() { + let mut policy = FifoReusePolicy::new(); + + // Insert several blocks + let blocks = vec![ + create_inactive_block(1, 100), + create_inactive_block(2, 200), + create_inactive_block(3, 300), + create_inactive_block(4, 400), + ]; + + for block in blocks { + policy.insert(block).unwrap(); + thread::sleep(Duration::from_millis(1)); + } + + assert_eq!(policy.len(), 4); + + // Remove block 2 (second inserted) + policy.remove(2).unwrap(); + assert_eq!(policy.len(), 3); + + // Retrieve remaining blocks - should be 1, 3, 4 in that order + let retrieved1 = policy.next_free().unwrap(); + assert_eq!(retrieved1.block_id, 1); + + let retrieved2 = policy.next_free().unwrap(); + assert_eq!(retrieved2.block_id, 3); + + let retrieved3 = policy.next_free().unwrap(); + assert_eq!(retrieved3.block_id, 4); + + assert!(policy.is_empty()); + } + + #[test] + fn test_empty_operations() { + let mut policy = FifoReusePolicy::new(); + + // Test empty state + assert!(policy.is_empty()); + assert_eq!(policy.len(), 0); + assert!(policy.next_free().is_none()); + + // Insert and remove a block + let block = create_inactive_block(1, 100); + policy.insert(block).unwrap(); + assert!(!policy.is_empty()); + assert_eq!(policy.len(), 1); + + let retrieved = policy.next_free().unwrap(); + assert_eq!(retrieved.block_id, 1); + + // Should be empty again + assert!(policy.is_empty()); + assert_eq!(policy.len(), 0); + assert!(policy.next_free().is_none()); + } + + #[test] + #[should_panic(expected = "block already exists")] + fn test_duplicate_block_panic() { + let mut policy = FifoReusePolicy::new(); + + let block = create_inactive_block(1, 100); + policy.insert(block).unwrap(); + + // Inserting the same block ID again should panic + let duplicate_block = create_inactive_block(1, 200); // Same ID, different hash + policy.insert(duplicate_block).unwrap(); + } + + #[test] + fn test_remove_nonexistent_block() { + let mut policy = FifoReusePolicy::new(); + + // Try to remove from empty policy + let result = policy.remove(999); + assert!(matches!(result, Err(ReusePolicyError::BlockNotFound(_)))); + + // Insert a block and try to remove a different one + let block = create_inactive_block(1, 100); + policy.insert(block).unwrap(); + + let result = policy.remove(999); + assert!(matches!(result, Err(ReusePolicyError::BlockNotFound(_)))); + + // Verify the original block is still there + assert_eq!(policy.len(), 1); + let retrieved = policy.next_free().unwrap(); + assert_eq!(retrieved.block_id, 1); + } + + #[test] + fn test_interleaved_operations() { + let mut policy = FifoReusePolicy::new(); + + // Insert some blocks + policy.insert(create_inactive_block(1, 100)).unwrap(); + thread::sleep(Duration::from_millis(1)); + policy.insert(create_inactive_block(2, 200)).unwrap(); + thread::sleep(Duration::from_millis(1)); + policy.insert(create_inactive_block(3, 300)).unwrap(); + + // Remove the first one + let first = policy.next_free().unwrap(); + assert_eq!(first.block_id, 1); + + // Insert another block + thread::sleep(Duration::from_millis(1)); + policy.insert(create_inactive_block(4, 400)).unwrap(); + + // Remove a specific block by ID + policy.remove(3).unwrap(); + + // Insert another block + thread::sleep(Duration::from_millis(1)); + policy.insert(create_inactive_block(5, 500)).unwrap(); + + // The remaining blocks should come out in order: 2, 4, 5 + let second = policy.next_free().unwrap(); + assert_eq!(second.block_id, 2); + + let third = policy.next_free().unwrap(); + assert_eq!(third.block_id, 4); + + let fourth = policy.next_free().unwrap(); + assert_eq!(fourth.block_id, 5); + + assert!(policy.is_empty()); + } + + #[test] + fn test_priority_key_ordering() { + let mut policy = FifoReusePolicy::new(); + + // Record the start time for manual verification + let start = std::time::Instant::now(); + + // Insert blocks with controlled timing + let mut insertion_times = Vec::new(); + + for i in 1..=5 { + let elapsed_before = start.elapsed().as_millis() as u64; + policy.insert(create_inactive_block(i, i * 100)).unwrap(); + let elapsed_after = start.elapsed().as_millis() as u64; + insertion_times.push((i, elapsed_before, elapsed_after)); + + // Small delay to ensure different timestamps + thread::sleep(Duration::from_millis(2)); + } + + // Retrieve all blocks and verify they come out in insertion order + let mut retrieval_order: Vec = Vec::new(); + while let Some(block) = policy.next_free() { + retrieval_order.push(block.block_id); + } + + let expected_order: Vec = (1..=5).collect(); + assert_eq!(retrieval_order, expected_order); + + // Print timing information for manual verification + println!("Insertion timing verification:"); + for (block_id, before, after) in insertion_times { + println!("Block {}: inserted between {}ms and {}ms", block_id, before, after); + } + } + + #[test] + fn test_btreemap_ordering_assumption() { + use std::collections::BTreeMap; + + // Verify our assumption about BTreeMap ordering with u64 keys + let mut map = BTreeMap::new(); + + // Insert keys in non-sorted order + map.insert(100u64, "hundred"); + map.insert(10u64, "ten"); + map.insert(50u64, "fifty"); + map.insert(1u64, "one"); + map.insert(200u64, "two_hundred"); + + // pop_first should return the smallest key first + assert_eq!(map.pop_first(), Some((1, "one"))); + assert_eq!(map.pop_first(), Some((10, "ten"))); + assert_eq!(map.pop_first(), Some((50, "fifty"))); + assert_eq!(map.pop_first(), Some((100, "hundred"))); + assert_eq!(map.pop_first(), Some((200, "two_hundred"))); + assert_eq!(map.pop_first(), None); + } +} diff --git a/lib/llm/src/block_manager/v2/pools/inactive/backends/reuse/mod.rs b/lib/llm/src/block_manager/v2/pools/inactive/backends/reuse/mod.rs new file mode 100644 index 00000000000..89c33b3f018 --- /dev/null +++ b/lib/llm/src/block_manager/v2/pools/inactive/backends/reuse/mod.rs @@ -0,0 +1,48 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Reuse policies for determining block allocation priority. +//! +//! Different policies (FIFO, LRU, etc.) control which inactive registered +//! block should be allocated next when the reset pool is exhausted. + +pub mod fifo; + +use super::{BlockId, InactiveBlock}; + +#[derive(Debug, thiserror::Error)] +pub enum ReusePolicyError { + #[error("Block {0} already exists in free list")] + BlockAlreadyExists(BlockId), + + #[error("Block {0} not found in free list")] + BlockNotFound(BlockId), +} + +/// Trait for managing a free list of blocks +/// +/// Different implementations can provide different priority strategies +/// for selecting which block to allocate next. +pub trait ReusePolicy: Send + Sync + std::fmt::Debug { + /// Insert a block into the free list + /// + /// The implementation will compute the priority key and manage the free list + /// based on its specific strategy. + fn insert(&mut self, inactive_block: InactiveBlock) -> Result<(), ReusePolicyError>; + + /// Remove a specific block from the free list + fn remove(&mut self, block_id: BlockId) -> Result<(), ReusePolicyError>; + + /// Get the next free block based on the implementation's priority strategy + /// + /// Returns None if the free list is empty. + /// The returned FreeBlock contains both the block_id and seq_hash needed + /// to look up the block in the InactivePool's HashMap. + fn next_free(&mut self) -> Option; + + /// Check if the free list is empty + fn is_empty(&self) -> bool; + + /// Get the number of free blocks + fn len(&self) -> usize; +} diff --git a/lib/llm/src/block_manager/v2/pools/inactive/mod.rs b/lib/llm/src/block_manager/v2/pools/inactive/mod.rs new file mode 100644 index 00000000000..a4a06adf0da --- /dev/null +++ b/lib/llm/src/block_manager/v2/pools/inactive/mod.rs @@ -0,0 +1,300 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Thread-safe pool for registered immutable blocks with automatic RAII return. +//! +//! Manages blocks in the Registered state, providing: +//! - Finding blocks by sequence hash with O(1) lookup +//! - Conversion of registered blocks back to mutable blocks for reuse +//! - Thread-safe access via interior mutability +//! - Automatic block return via RAII ImmutableBlock guards + +pub mod backends; + +use parking_lot::RwLock; +use std::sync::Arc; + +use crate::tokens::SequenceHash; + +use super::{ + Block, BlockMetadata, MutableBlock, PrimaryBlock, RegisteredBlock, Reset, block::Registered, + reset::ResetPool, +}; + +/// Backend trait for InactivePool storage strategies. +pub trait InactivePoolBackend: Send + Sync { + fn find_matches(&mut self, hashes: &[SequenceHash], touch: bool) -> Vec>; + + fn allocate(&mut self, count: usize) -> Vec>; + + fn insert(&mut self, block: Block); + + fn len(&self) -> usize; + + fn is_empty(&self) -> bool { + self.len() == 0 + } + + fn has_block(&self, seq_hash: SequenceHash) -> bool; +} +/// Pool for managing registered (immutable) blocks +/// +/// This pool handles blocks in the Registered state and provides them as +/// RegisteredBlock RAII guards that automatically return to the pool on drop. +#[derive(Clone)] +pub struct InactivePool { + // Inner state protected by RwLock for thread-safe access from guards + inner: Arc>>, + // Return function for MutableBlocks to return to ResetPool + reset_return_fn: Arc) + Send + Sync>, + + return_fn: Arc>) + Send + Sync>, + block_size: usize, +} + +struct InactivePoolInner { + backend: Box>, +} + +impl InactivePool { + /// Create a new InactivePool with the given backend and reset pool + pub fn new(backend: Box>, reset_pool: &ResetPool) -> Self { + let inner = Arc::new(RwLock::new(InactivePoolInner { backend })); + + let inner_clone = inner.clone(); + let return_fn = Arc::new(move |block: Arc>| { + let mut inner = inner_clone.write(); + + if let Ok(block) = Arc::try_unwrap(block) { + inner.backend.insert(block); + } + }) as Arc>) + Send + Sync>; + + Self { + inner, + reset_return_fn: reset_pool.return_fn(), + return_fn, + block_size: reset_pool.block_size(), + } + } + + /// Find blocks by sequence hashes and return them as RegisteredBlock guards + pub fn find_blocks( + &self, + hashes: &[SequenceHash], + touch: bool, + ) -> Vec>> { + let mut inner = self.inner.write(); + let matched_blocks = inner.backend.find_matches(hashes, touch); + + matched_blocks + .into_iter() + .map(|block| PrimaryBlock::new(Arc::new(block), self.return_fn.clone()).register()) + .collect() + } + + /// Allocate blocks from registered pool, converting them to MutableBlocks for ResetPool + pub fn allocate_blocks(&self, count: usize) -> Option>> { + if count == 0 { + return Some(Vec::new()); + } + + let mut inner = self.inner.write(); + + if inner.backend.len() < count { + return None; + } + + let allocated_blocks = inner.backend.allocate(count); + + if allocated_blocks.len() == count { + let mut mutable_blocks = Vec::with_capacity(count); + mutable_blocks.extend(allocated_blocks.into_iter().map(|registered_block| { + let reset_block = registered_block.reset(); + MutableBlock::new(reset_block, self.reset_return_fn.clone()) + })); + Some(mutable_blocks) + } else { + for block in allocated_blocks { + inner.backend.insert(block); + } + None + } + } + + /// Check if a block exists in the pool + pub fn has_block(&self, hash: SequenceHash) -> bool { + let inner = self.inner.read(); + inner.backend.has_block(hash) + } + + /// Get the number of blocks in the pool + pub fn len(&self) -> usize { + let inner = self.inner.read(); + inner.backend.len() + } + + /// Check if the pool is empty + pub fn is_empty(&self) -> bool { + let inner = self.inner.read(); + inner.backend.is_empty() + } + + pub(crate) fn return_fn(&self) -> Arc>) + Send + Sync> { + self.return_fn.clone() + } +} + +#[cfg(test)] +mod tests { + use super::super::super::policies::reuse::fifo::FifoReusePolicy; + use super::*; + use crate::block_manager::v2::pools::{ + block::Block, + test_utils::{TestData, fixtures::*}, + }; + + impl InactivePool { + fn insert(&self, block: Block) { + let mut inner = self.inner.write(); + inner.backend.insert(block); + } + } + + fn create_test_pool() -> (InactivePool, ResetPool) { + use super::backends::hashmap_backend::HashMapBackend; + + let reuse_policy = Box::new(FifoReusePolicy::new()); + let backend = Box::new(HashMapBackend::new(reuse_policy)); + + let reset_blocks = (0..10).map(|i| Block::new(i, 4)).collect(); + let reset_pool = ResetPool::new(reset_blocks, 4); + + let inactive_pool = InactivePool::new(backend, &reset_pool); + (inactive_pool, reset_pool) + } + + #[test] + fn test_new_pool_starts_empty() { + let (pool, _reset_pool) = create_test_pool(); + assert_eq!(pool.len(), 0); + assert!(pool.is_empty()); + assert!(!pool.has_block(100)); + } + + #[test] + fn test_return_and_find_single_block() { + let (pool, _reset_pool) = create_test_pool(); + let (block, seq_hash) = create_registered_block(1, &tokens_for_id(1)); + + // Return block directly (simulating manual return) + pool.insert(block); + + assert_eq!(pool.len(), 1); + assert!(pool.has_block(seq_hash)); + + // Find the block + let found_blocks = pool.find_blocks(&[seq_hash], true); + assert_eq!(found_blocks.len(), 1); + assert_eq!(found_blocks[0].block_id(), 1); + assert_eq!(found_blocks[0].sequence_hash(), seq_hash); + + // Block should be removed from pool after finding + assert_eq!(pool.len(), 0); + assert!(!pool.has_block(seq_hash)); + + // Blocks will auto-return when dropped at end of scope + } + + #[test] + fn test_find_blocks_stops_on_first_miss() { + let (pool, _reset_pool) = create_test_pool(); + + // Add blocks with different sequence hashes + let (block1, seq_hash1) = create_registered_block(1, &tokens_for_id(1)); + let (block3, seq_hash3) = create_registered_block(3, &tokens_for_id(3)); + pool.insert(block1); + pool.insert(block3); + + assert_eq!(pool.len(), 2); + + // Try to find blocks - use a sequence hash that doesn't exist to test first miss behavior + let nonexistent_hash = 99999; + let found_blocks = pool.find_blocks(&[seq_hash1, nonexistent_hash, seq_hash3], true); + assert_eq!(found_blocks.len(), 1); // Only found first block + assert_eq!(found_blocks[0].sequence_hash(), seq_hash1); + + // Block 3 should still be in pool since search stopped at first miss + assert_eq!(pool.len(), 1); + assert!(pool.has_block(seq_hash3)); + } + + #[test] + fn test_raii_auto_return() { + let (pool, _reset_pool) = create_test_pool(); + let (block, seq_hash) = create_registered_block(1, &tokens_for_id(1)); + pool.insert(block); + + assert_eq!(pool.len(), 1); + + { + let _found_blocks = pool.find_blocks(&[seq_hash], true); + assert_eq!(pool.len(), 0); + } + + assert_eq!(pool.len(), 1); + assert!(pool.has_block(seq_hash)); + } + + #[test] + fn test_allocate_blocks() { + let (pool, reset_pool) = create_test_pool(); + + // Add some registered blocks to the pool + let (block1, _seq_hash1) = create_registered_block(1, &tokens_for_id(1)); + let (block2, _seq_hash2) = create_registered_block(2, &tokens_for_id(2)); + let (block3, _seq_hash3) = create_registered_block(3, &tokens_for_id(3)); + pool.insert(block1); + pool.insert(block2); + pool.insert(block3); + + assert_eq!(pool.len(), 3); + + // Allocate 1 block - should convert to MutableBlocks + // Note: Due to test setup limitations with reuse policy, we can only allocate 1 block + let mutable_blocks = pool.allocate_blocks(1).expect("Should allocate 1 block"); + assert_eq!(mutable_blocks.len(), 1); + + // Pool should have one less block + assert_eq!(pool.len(), 2); + + // The MutableBlocks should have the correct IDs + let block_ids: Vec = mutable_blocks.iter().map(|b| b.block_id()).collect(); + assert!(block_ids.contains(&1) || block_ids.contains(&2) || block_ids.contains(&3)); + + drop(mutable_blocks); + + assert_eq!(pool.len(), 2); + assert_eq!(reset_pool.available_blocks(), 11); + } + + #[test] + fn test_allocate_more_than_available_fails() { + let (pool, _reset_pool) = create_test_pool(); + + // Add only 2 blocks + let (block1, _seq_hash1) = create_registered_block(1, &tokens_for_id(1)); + let (block2, _seq_hash2) = create_registered_block(2, &tokens_for_id(2)); + pool.insert(block1); + pool.insert(block2); + + assert_eq!(pool.len(), 2); + + // Try to allocate 3 blocks - should fail + let result = pool.allocate_blocks(3); + assert!(result.is_none()); + + // Pool should be unchanged + assert_eq!(pool.len(), 2); + } +} diff --git a/lib/llm/src/block_manager/v2/pools/mod.rs b/lib/llm/src/block_manager/v2/pools/mod.rs new file mode 100644 index 00000000000..c7e1d1d9cc9 --- /dev/null +++ b/lib/llm/src/block_manager/v2/pools/mod.rs @@ -0,0 +1,264 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Block pool RAII guards and allocation traits for thread-safe block management. +//! +//! This module provides: +//! - Type-safe RAII guards (MutableBlock, CompleteBlock, ImmutableBlock) for automatic resource cleanup +//! - ResetPool: Pool for mutable blocks in reset state +//! - InactivePool: Pool for inactive immutable registered blocks +//! - BlockRegistry: Global registry for block deduplication via weak references +//! - Pluggable allocation and reuse policies + +pub mod block; +pub mod frequency_sketch; +pub mod inactive; +pub mod registry; +pub mod reset; +pub mod reuse_policy; + +#[cfg(test)] +mod test_utils; + +use std::{ops::Deref, sync::Arc}; + +pub use crate::tokens::{SequenceHash, TokenBlock}; + +use block::{Block, BlockId, Complete, Registered, Reset}; +use registry::BlockRegistrationHandle; + +pub use inactive::InactivePool; + +pub trait BlockMetadata: Clone + Send + Sync + 'static {} +impl BlockMetadata for T {} + +pub trait BlockAllocator { + // fn new(blocks: Vec>) -> Arc + // where + // Self: Sized; + + /// Insert a block into the pool + fn insert(&mut self, block: Block); + + /// Acquire the first block to be reused + fn pop(&mut self) -> Option>; + + /// Get the number of available blocks + fn len(&self) -> usize; +} + +pub trait BlockMatcher { + fn find_match(&self, seq_hash: SequenceHash) -> Option>; +} + +/// Policy for handling duplicate blocks with the same sequence hash +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum BlockDuplicationPolicy { + /// Allow duplicate blocks - each gets its own DuplicateBlock wrapper + Allow, + /// Reject duplicates - return the existing primary block instead + Reject, +} + +// Re-export the new RAII guard types - no need to re-export here since they're defined in this module + +/// A block that is free and available for allocation +/// This block must be in a Registered state and have a valid sequence hash +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct InactiveBlock { + pub block_id: BlockId, + pub seq_hash: SequenceHash, +} + +/// RAII guard for [`Block`] that automatically returns to ResetPool on drop +pub struct MutableBlock { + block: Option>, + return_fn: Arc) + Send + Sync>, +} + +/// RAII guard for [`Block`] that automatically returns to ResetPool on drop +pub struct CompleteBlock { + block: Option>, + return_fn: Arc) + Send + Sync>, +} + +pub trait RegisteredBlock: Send + Sync { + /// Get the block ID + fn block_id(&self) -> BlockId; + + /// Get the sequence hash + fn sequence_hash(&self) -> SequenceHash; + + /// Get the registration handle + fn registration_handle(&self) -> &BlockRegistrationHandle; +} + +/// RAII guard for [`Block`] that automatically returns to RegisteredPool on drop +pub(crate) struct PrimaryBlock { + block: Option>>, + return_fn: Arc>) + Send + Sync>, +} + +struct DuplicateBlock { + block: Option>, + return_fn: Arc) + Send + Sync>, + _primary: Arc>, +} + +pub struct ImmutableBlock { + block: Arc>, +} + +// RegisteredPool implementation moved to registered.rs + +impl MutableBlock { + /// Create a new MutableBlock in Reset state + fn new(block: Block, return_fn: Arc) + Send + Sync>) -> Self { + Self { + block: Some(block), + return_fn, + } + } + + /// Get the block ID + pub fn block_id(&self) -> BlockId { + self.block.as_ref().unwrap().block_id() + } + + /// Transition from Reset to Complete state + pub fn complete(mut self, token_block: TokenBlock) -> CompleteBlock { + let block = self.block.take().unwrap().complete(token_block); + + CompleteBlock { + block: Some(block), + return_fn: self.return_fn.clone(), + } + } +} + +impl CompleteBlock { + /// Get the block ID + pub fn block_id(&self) -> BlockId { + self.block.as_ref().unwrap().block_id() + } + + /// Access token block if in Complete state + pub fn token_block(&self) -> &TokenBlock { + self.block.as_ref().unwrap().token_block() + } + + /// Get sequence hash if in Complete state + pub fn sequence_hash(&self) -> SequenceHash { + self.block.as_ref().unwrap().sequence_hash() + } + + pub fn reset(mut self) -> MutableBlock { + let block = self.block.take().unwrap().reset(); + + MutableBlock { + block: Some(block), + return_fn: self.return_fn.clone(), + } + } +} + +impl PrimaryBlock { + /// Create a new RegisteredBlock + fn new( + block: Arc>, + return_fn: Arc>) + Send + Sync>, + ) -> Self { + Self { + block: Some(block), + return_fn, + } + } + + fn register(self) -> ImmutableBlock { + let block = self.block.clone().unwrap(); + block.registration_handle().attach_block(self) + } +} + +impl RegisteredBlock for PrimaryBlock { + fn block_id(&self) -> BlockId { + self.block.as_ref().unwrap().block_id() + } + + fn sequence_hash(&self) -> SequenceHash { + self.block.as_ref().unwrap().sequence_hash() + } + + fn registration_handle(&self) -> &BlockRegistrationHandle { + self.block.as_ref().unwrap().registration_handle() + } +} + +impl DuplicateBlock { + /// Create a new DuplicateBlock + fn new( + block: Block, + primary: Arc>, + return_fn: Arc) + Send + Sync>, + ) -> Self { + Self { + block: Some(block), + return_fn, + _primary: primary, + } + } +} + +impl RegisteredBlock for DuplicateBlock { + fn block_id(&self) -> BlockId { + self.block.as_ref().unwrap().block_id() + } + + fn sequence_hash(&self) -> SequenceHash { + self.block.as_ref().unwrap().sequence_hash() + } + + fn registration_handle(&self) -> &BlockRegistrationHandle { + self.block.as_ref().unwrap().registration_handle() + } +} + +impl Deref for ImmutableBlock { + type Target = dyn RegisteredBlock; + + fn deref(&self) -> &Self::Target { + self.block.as_ref() + } +} + +impl Drop for MutableBlock { + fn drop(&mut self) { + if let Some(block) = self.block.take() { + (self.return_fn)(block); + } + } +} + +impl Drop for CompleteBlock { + fn drop(&mut self) { + if let Some(block) = self.block.take() { + (self.return_fn)(block.reset()); + } + } +} + +impl Drop for PrimaryBlock { + fn drop(&mut self) { + if let Some(block) = self.block.take() { + (self.return_fn)(block); + } + } +} + +impl Drop for DuplicateBlock { + fn drop(&mut self) { + if let Some(block) = self.block.take() { + (self.return_fn)(block.reset()); + } + } +} diff --git a/lib/llm/src/block_manager/v2/pools/registry.rs b/lib/llm/src/block_manager/v2/pools/registry.rs new file mode 100644 index 00000000000..5ce6ba52712 --- /dev/null +++ b/lib/llm/src/block_manager/v2/pools/registry.rs @@ -0,0 +1,1152 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Global registry for block deduplication via weak references and sequence hash matching. +//! +//! The registry provides: +//! - Sequence hash → block mapping using weak references +//! - Automatic cleanup when all strong references are dropped +//! - Attachment system for storing arbitrary typed data on registration handles + +use super::block::{Block, Registered}; +use super::frequency_sketch::FrequencyTracker; +use super::{BlockMetadata, CompleteBlock, RegisteredBlock, SequenceHash}; + +use std::any::{Any, TypeId}; +use std::collections::HashMap; +use std::marker::PhantomData; +use std::sync::{Arc, Weak}; + +use parking_lot::{Mutex, RwLock}; + +/// Error types for attachment operations +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum AttachmentError { + /// Attempted to attach a type as unique when it's already registered as multiple + TypeAlreadyRegisteredAsMultiple(TypeId), + /// Attempted to attach a type as multiple when it's already registered as unique + TypeAlreadyRegisteredAsUnique(TypeId), +} + +impl std::fmt::Display for AttachmentError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + AttachmentError::TypeAlreadyRegisteredAsMultiple(type_id) => { + write!( + f, + "Type {:?} is already registered as multiple attachment", + type_id + ) + } + AttachmentError::TypeAlreadyRegisteredAsUnique(type_id) => { + write!( + f, + "Type {:?} is already registered as unique attachment", + type_id + ) + } + } + } +} + +impl std::error::Error for AttachmentError {} + +/// Tracks how a type is registered in the attachment system +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum AttachmentMode { + Unique, + Multiple, +} + +/// Storage for attachments on a BlockRegistrationHandle +#[derive(Debug)] +struct AttachmentStore { + /// Unique attachments - only one value per TypeId + unique_attachments: HashMap>, + /// Multiple attachments - multiple values per TypeId + multiple_attachments: HashMap>>, + /// Track which types are registered and how + type_registry: HashMap, + /// Storage for weak block references - separate from generic attachments, keyed by TypeId + weak_blocks: HashMap>, +} + +impl AttachmentStore { + fn new() -> Self { + Self { + unique_attachments: HashMap::new(), + multiple_attachments: HashMap::new(), + type_registry: HashMap::new(), + weak_blocks: HashMap::new(), + } + } +} + +/// Typed accessor for attachments of a specific type +pub struct TypedAttachments<'a, T> { + handle: &'a BlockRegistrationHandle, + _phantom: PhantomData, +} + +/// Handle that represents a block registration in the global registry. +/// This handle is cloneable and can be shared across pools. +#[derive(Clone, Debug)] +pub struct BlockRegistrationHandle { + inner: Arc, +} + +struct WeakBlock { + raw_block: Weak>, + reg_block: Weak>, +} + +#[derive(Debug)] +struct BlockRegistrationHandleInner { + /// Sequence hash of the block + seq_hash: SequenceHash, + /// Attachments for the block + attachments: Mutex, + /// Weak reference to the registry - allows us to remove the block from the registry on drop + registry: Weak>, +} + +impl Drop for BlockRegistrationHandleInner { + #[inline] + fn drop(&mut self) { + if let Some(registry) = self.registry.upgrade() { + let mut state = registry.write(); + state.canonical_blocks.remove(&self.seq_hash); + } + } +} + +impl BlockRegistrationHandle { + pub fn seq_hash(&self) -> SequenceHash { + self.inner.seq_hash + } + + /// Get a typed accessor for attachments of type T + pub fn get(&self) -> TypedAttachments<'_, T> { + TypedAttachments { + handle: self, + _phantom: PhantomData, + } + } + + /// Attach a unique value of type T to this handle. + /// Only one value per type is allowed - subsequent calls will replace the previous value. + /// Returns an error if type T is already registered as multiple attachment. + pub fn attach_unique(&self, value: T) -> Result<(), AttachmentError> { + let type_id = TypeId::of::(); + let mut attachments = self.inner.attachments.lock(); + + // Check if this type is already registered as multiple + if let Some(AttachmentMode::Multiple) = attachments.type_registry.get(&type_id) { + return Err(AttachmentError::TypeAlreadyRegisteredAsMultiple(type_id)); + } + + // Register/update as unique + attachments + .unique_attachments + .insert(type_id, Box::new(value)); + attachments + .type_registry + .insert(type_id, AttachmentMode::Unique); + + Ok(()) + } + + /// Attach a value of type T to this handle. + /// Multiple values per type are allowed - this will append to existing values. + /// Returns an error if type T is already registered as unique attachment. + pub fn attach(&self, value: T) -> Result<(), AttachmentError> { + let type_id = TypeId::of::(); + let mut attachments = self.inner.attachments.lock(); + + // Check if this type is already registered as unique + if let Some(AttachmentMode::Unique) = attachments.type_registry.get(&type_id) { + return Err(AttachmentError::TypeAlreadyRegisteredAsUnique(type_id)); + } + + // Register/update as multiple + attachments + .multiple_attachments + .entry(type_id) + .or_insert_with(Vec::new) + .push(Box::new(value)); + attachments + .type_registry + .insert(type_id, AttachmentMode::Multiple); + + Ok(()) + } + + pub(crate) fn attach_block( + &self, + block: super::PrimaryBlock, + ) -> Arc> { + let type_id = TypeId::of::>>(); + let mut attachments = self.inner.attachments.lock(); + + #[cfg(debug_assertions)] + { + if let Some(weak_any) = attachments.weak_blocks.get(&type_id) { + if let Some(weak) = weak_any.downcast_ref::>() { + debug_assert!( + weak.raw_block.upgrade().is_none(), + "Attempted to reattach block when raw block is still alive" + ); + debug_assert!( + weak.reg_block.upgrade().is_none(), + "Attempted to reattach block when registered block is still alive" + ); + } + } + } + + let raw_weak = Arc::downgrade(block.block.as_ref().unwrap()); + let reg_arc = Arc::new(block); + let reg_weak = Arc::downgrade(®_arc); + + attachments.weak_blocks.insert( + type_id, + Box::new(WeakBlock { + raw_block: raw_weak, + reg_block: reg_weak, + }), + ); + + reg_arc as Arc> + } + + pub(crate) fn register_block( + &self, + mut block: CompleteBlock, + duplication_policy: super::BlockDuplicationPolicy, + pool_return_fn: Arc>) + Send + Sync>, + ) -> Arc> { + let type_id = TypeId::of::>>(); + let block_id = block.block_id(); + + // Take ownership of the inner block + let inner_block = block.block.take().unwrap(); + let reset_return_fn = block.return_fn.clone(); + + // Register the block to get it in Registered state + let registered_block = inner_block.register(self.clone()); + + let mut attachments = self.inner.attachments.lock(); + + // Check for existing blocks with same sequence hash + if let Some(weak_any) = attachments.weak_blocks.get(&type_id) { + if let Some(weak_block) = weak_any.downcast_ref::>() { + // Try to get the existing primary block + if let Some(existing_primary) = weak_block.reg_block.upgrade() { + // Check if same block_id (shouldn't happen) + if existing_primary.block_id() == block_id { + panic!("Attempted to register block with same block_id as existing"); + } + + // Handle duplicate based on policy + match duplication_policy { + super::BlockDuplicationPolicy::Allow => { + // Create DuplicateBlock referencing the primary + let duplicate = super::DuplicateBlock::new( + registered_block, + existing_primary.clone(), + reset_return_fn, + ); + return Arc::new(duplicate); + } + super::BlockDuplicationPolicy::Reject => { + // Return existing primary, discard new block + // The registered_block will be dropped and eventually returned to reset pool + return existing_primary as Arc>; + } + } + } + + // Primary couldn't be upgraded but raw block might exist + // This is an edge case - for now, treat as creating a new primary + } + } + + // No existing block or couldn't upgrade - create new primary + let primary = super::PrimaryBlock::new(Arc::new(registered_block), pool_return_fn); + + // Store weak references for future lookups + let primary_arc = Arc::new(primary); + let raw_weak = Arc::downgrade(primary_arc.block.as_ref().unwrap()); + let reg_weak = Arc::downgrade(&primary_arc); + + attachments.weak_blocks.insert( + type_id, + Box::new(WeakBlock { + raw_block: raw_weak, + reg_block: reg_weak, + }), + ); + + drop(attachments); // Release lock + + primary_arc as Arc> + } + + #[inline] + pub(crate) fn try_get_block( + &self, + pool_return_fn: Arc>) + Send + Sync>, + ) -> Option>> { + let type_id = TypeId::of::>>(); + let attachments = self.inner.attachments.lock(); + + let weak_block = attachments + .weak_blocks + .get(&type_id) + .and_then(|weak_any| weak_any.downcast_ref::>())?; + + if let Some(primary_arc) = weak_block.reg_block.upgrade() { + drop(attachments); + return Some(primary_arc as Arc>); + } + + if let Some(raw_arc) = weak_block.raw_block.upgrade() { + let primary = super::PrimaryBlock::new(raw_arc, pool_return_fn); + let primary_arc = Arc::new(primary); + + let new_weak = Arc::downgrade(&primary_arc); + let weak_block_mut = WeakBlock { + raw_block: weak_block.raw_block.clone(), + reg_block: new_weak, + }; + + drop(attachments); + + let mut attachments = self.inner.attachments.lock(); + attachments + .weak_blocks + .insert(type_id, Box::new(weak_block_mut)); + drop(attachments); + + return Some(primary_arc as Arc>); + } + + None + } +} + +impl<'a, T: Any + Send + Sync> TypedAttachments<'a, T> { + /// Execute a closure with immutable access to the unique attachment of type T. + pub fn with_unique(&self, f: impl FnOnce(&T) -> R) -> Option { + let type_id = TypeId::of::(); + let attachments = self.handle.inner.attachments.lock(); + attachments + .unique_attachments + .get(&type_id)? + .downcast_ref::() + .map(f) + } + + /// Execute a closure with mutable access to the unique attachment of type T. + pub fn with_unique_mut(&self, f: impl FnOnce(&mut T) -> R) -> Option { + let type_id = TypeId::of::(); + let mut attachments = self.handle.inner.attachments.lock(); + attachments + .unique_attachments + .get_mut(&type_id)? + .downcast_mut::() + .map(f) + } + + /// Execute a closure with immutable access to multiple attachments of type T. + pub fn with_multiple(&self, f: impl FnOnce(&[&T]) -> R) -> R { + let type_id = TypeId::of::(); + let attachments = self.handle.inner.attachments.lock(); + + let multiple_refs: Vec<&T> = attachments + .multiple_attachments + .get(&type_id) + .map(|vec| vec.iter().filter_map(|v| v.downcast_ref::()).collect()) + .unwrap_or_default(); + + f(&multiple_refs) + } + + /// Execute a closure with mutable access to multiple attachments of type T. + pub fn with_multiple_mut(&self, f: impl FnOnce(&mut [&mut T]) -> R) -> R { + let type_id = TypeId::of::(); + let mut attachments = self.handle.inner.attachments.lock(); + + let mut multiple_refs: Vec<&mut T> = attachments + .multiple_attachments + .get_mut(&type_id) + .map(|vec| { + vec.iter_mut() + .filter_map(|v| v.downcast_mut::()) + .collect() + }) + .unwrap_or_default(); + + f(&mut multiple_refs) + } + + /// Execute a closure with immutable access to both unique and multiple attachments of type T. + pub fn with_all(&self, f: impl FnOnce(Option<&T>, &[&T]) -> R) -> R { + let type_id = TypeId::of::(); + let attachments = self.handle.inner.attachments.lock(); + + let unique = attachments + .unique_attachments + .get(&type_id) + .and_then(|v| v.downcast_ref::()); + + let multiple_refs: Vec<&T> = attachments + .multiple_attachments + .get(&type_id) + .map(|vec| vec.iter().filter_map(|v| v.downcast_ref::()).collect()) + .unwrap_or_default(); + + f(unique, &multiple_refs) + } + + /// Execute a closure with mutable access to both unique and multiple attachments of type T. + pub fn with_all_mut(&self, f: impl FnOnce(Option<&mut T>, &mut [&mut T]) -> R) -> R { + let type_id = TypeId::of::(); + let mut attachments = self.handle.inner.attachments.lock(); + + // Check where this type is registered to avoid double mutable borrow + match attachments.type_registry.get(&type_id) { + Some(AttachmentMode::Unique) => { + // Type is registered as unique - get mutable reference to unique only + let unique = attachments + .unique_attachments + .get_mut(&type_id) + .and_then(|v| v.downcast_mut::()); + let mut empty_vec: Vec<&mut T> = Vec::new(); + f(unique, &mut empty_vec) + } + Some(AttachmentMode::Multiple) => { + // Type is registered as multiple - get mutable references to multiple only + let mut multiple_refs: Vec<&mut T> = attachments + .multiple_attachments + .get_mut(&type_id) + .map(|vec| { + vec.iter_mut() + .filter_map(|v| v.downcast_mut::()) + .collect() + }) + .unwrap_or_default(); + f(None, &mut multiple_refs) + } + None => { + // Type not registered at all + let mut empty_vec: Vec<&mut T> = Vec::new(); + f(None, &mut empty_vec) + } + } + } +} + +/// Global registry for managing block registrations. +/// Tracks canonical blocks and provides registration handles. +#[derive(Clone)] +pub struct BlockRegistry { + state: Arc>, + frequency_tracker: Option>, +} + +#[derive(Debug)] +struct RegistryState { + canonical_blocks: HashMap>, +} + +impl BlockRegistry { + pub fn new() -> Self { + Self { + state: Arc::new(RwLock::new(RegistryState { + canonical_blocks: HashMap::new(), + })), + frequency_tracker: None, + } + } + + pub fn with_frequency_tracker(frequency_tracker: Arc) -> Self { + Self { + state: Arc::new(RwLock::new(RegistryState { + canonical_blocks: HashMap::new(), + })), + frequency_tracker: Some(frequency_tracker), + } + } + + pub fn has_frequency_tracking(&self) -> bool { + self.frequency_tracker.is_some() + } + + pub fn touch(&self, seq_hash: SequenceHash) { + if let Some(tracker) = &self.frequency_tracker { + tracker.touch(seq_hash); + } + } + + pub fn count(&self, seq_hash: SequenceHash) -> u32 { + if let Some(tracker) = &self.frequency_tracker { + tracker.count(seq_hash) + } else { + 0 + } + } + + /// Register a sequence hash and get a registration handle. + /// If the sequence is already registered, returns the existing handle. + /// Otherwise, creates a new canonical registration. + /// This method triggers frequency tracking. + #[inline] + pub fn register_sequence_hash(&self, seq_hash: SequenceHash) -> BlockRegistrationHandle { + self.touch(seq_hash); + + // First try to get existing registration with read lock + { + let state = self.state.read(); + if let Some(weak_handle) = state.canonical_blocks.get(&seq_hash) { + if let Some(existing_handle) = weak_handle.upgrade() { + // Return a clone of the existing canonical handle + return BlockRegistrationHandle { + inner: existing_handle, + }; + } + } + } + + // Need to create new registration, acquire write lock + let mut state = self.state.write(); + + // Double-check after acquiring write lock (another thread might have inserted) + if let Some(weak_handle) = state.canonical_blocks.get(&seq_hash) { + if let Some(existing_handle) = weak_handle.upgrade() { + // Return a clone of the existing canonical handle + return BlockRegistrationHandle { + inner: existing_handle, + }; + } + } + + // Create a new canonical registration + let inner = Arc::new(BlockRegistrationHandleInner { + seq_hash, + registry: Arc::downgrade(&self.state), + attachments: Mutex::new(AttachmentStore::new()), + }); + + state + .canonical_blocks + .insert(seq_hash, Arc::downgrade(&inner)); + + BlockRegistrationHandle { inner } + } + + /// Internal method for transferring block registration without triggering frequency tracking. + /// Used when copying blocks between pools where we don't want to count the transfer as a new access. + pub(crate) fn transfer_registration(&self, seq_hash: SequenceHash) -> BlockRegistrationHandle { + // First try to get existing registration with read lock + { + let state = self.state.read(); + if let Some(weak_handle) = state.canonical_blocks.get(&seq_hash) { + if let Some(existing_handle) = weak_handle.upgrade() { + return BlockRegistrationHandle { + inner: existing_handle, + }; + } + } + } + + // Need to create new registration, acquire write lock + let mut state = self.state.write(); + + // Double-check after acquiring write lock + if let Some(weak_handle) = state.canonical_blocks.get(&seq_hash) { + if let Some(existing_handle) = weak_handle.upgrade() { + return BlockRegistrationHandle { + inner: existing_handle, + }; + } + } + + // Create a new canonical registration without tracking + let inner = Arc::new(BlockRegistrationHandleInner { + seq_hash, + registry: Arc::downgrade(&self.state), + attachments: Mutex::new(AttachmentStore::new()), + }); + + state + .canonical_blocks + .insert(seq_hash, Arc::downgrade(&inner)); + + BlockRegistrationHandle { inner } + } + + /// Match a sequence hash and return a registration handle. + /// This method triggers frequency tracking. + #[inline] + pub fn match_sequence_hash( + &self, + seq_hash: SequenceHash, + touch: bool, + ) -> Option { + let state = self.state.read(); + let result = state + .canonical_blocks + .get(&seq_hash) + .and_then(|weak| weak.upgrade()) + .map(|inner| BlockRegistrationHandle { inner }); + + if result.is_some() && touch { + drop(state); + self.touch(seq_hash); + } + + result + } + + /// Check if a sequence is currently registered (has a canonical handle). + #[inline] + pub fn is_registered(&self, seq_hash: SequenceHash) -> bool { + let state = self.state.read(); + state + .canonical_blocks + .get(&seq_hash) + .map(|weak| weak.strong_count() > 0) + .unwrap_or(false) + } + + /// Get the current number of registered blocks. + pub fn registered_count(&self) -> usize { + let state = self.state.read(); + state.canonical_blocks.len() + } + + /// Get the frequency tracker if frequency tracking is enabled. + pub fn frequency_tracker(&self) -> Option> { + self.frequency_tracker.clone() + } +} + +impl Default for BlockRegistry { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::thread; + + #[test] + fn test_register_new_sequence() { + let registry = BlockRegistry::new(); + let seq_hash = 42; + let handle = registry.register_sequence_hash(seq_hash); + + assert_eq!(handle.seq_hash(), seq_hash); + assert!(registry.is_registered(seq_hash)); + assert_eq!(registry.registered_count(), 1); + } + + #[test] + fn test_register_existing_sequence_returns_same_handle() { + let registry = BlockRegistry::new(); + let seq_hash = 42; + let handle1 = registry.register_sequence_hash(seq_hash); + let handle2 = registry.register_sequence_hash(seq_hash); + + assert_eq!(handle1.seq_hash(), handle2.seq_hash()); + assert_eq!(registry.registered_count(), 1); + } + + #[test] + fn test_handle_drop_removes_registration() { + let registry = BlockRegistry::new(); + let seq_hash = 42; + { + let _handle = registry.register_sequence_hash(seq_hash); + assert!(registry.is_registered(seq_hash)); + assert_eq!(registry.registered_count(), 1); + } + + // Handle should be dropped and registration removed + assert!(!registry.is_registered(seq_hash)); + assert_eq!(registry.registered_count(), 0); + } + + #[test] + fn test_multiple_handles_same_sequence() { + let registry = BlockRegistry::new(); + let seq_hash = 42; + let handle1 = registry.register_sequence_hash(seq_hash); + let handle2 = handle1.clone(); + + drop(handle1); + + // Sequence should still be registered because handle2 exists + assert!(registry.is_registered(seq_hash)); + assert_eq!(registry.registered_count(), 1); + + drop(handle2); + + // Now sequence should be unregistered + assert!(!registry.is_registered(seq_hash)); + assert_eq!(registry.registered_count(), 0); + } + + #[test] + fn test_attach_unique() { + let registry = BlockRegistry::new(); + let handle = registry.register_sequence_hash(42); + + // Attach a unique value + handle.attach_unique("test_publisher".to_string()).unwrap(); + + // Retrieve the value using the new API + let value = handle.get::().with_unique(|s| s.clone()); + assert_eq!(value, Some("test_publisher".to_string())); + + // Replace with a new value (should succeed) + handle.attach_unique("new_publisher".to_string()).unwrap(); + let value = handle.get::().with_unique(|s| s.clone()); + assert_eq!(value, Some("new_publisher".to_string())); + } + + #[test] + fn test_attach_multiple() { + let registry = BlockRegistry::new(); + let handle = registry.register_sequence_hash(42); + + // Attach multiple values + handle.attach("listener1".to_string()).unwrap(); + handle.attach("listener2".to_string()).unwrap(); + handle.attach("listener3".to_string()).unwrap(); + + // Retrieve all values using the new API + let listeners = handle + .get::() + .with_multiple(|listeners| listeners.iter().map(|s| (*s).clone()).collect::>()); + assert_eq!(listeners.len(), 3); + assert!(listeners.contains(&"listener1".to_string())); + assert!(listeners.contains(&"listener2".to_string())); + assert!(listeners.contains(&"listener3".to_string())); + + // Also test with_all + handle.get::().with_all(|unique, multiple| { + assert_eq!(unique, None); + assert_eq!(multiple.len(), 3); + }); + } + + #[test] + fn test_type_tracking_enforcement() { + let registry = BlockRegistry::new(); + let handle = registry.register_sequence_hash(42); + + // Test: attach unique first, then try multiple (should fail) + handle + .attach_unique("unique_publisher".to_string()) + .unwrap(); + + let result = handle.attach("listener1".to_string()); + assert_eq!( + result, + Err(AttachmentError::TypeAlreadyRegisteredAsUnique( + TypeId::of::() + )) + ); + + // Test with different types: attach multiple first, then try unique (should fail) + handle.attach(42i32).unwrap(); + handle.attach(43i32).unwrap(); + + let result = handle.attach_unique(44i32); + assert_eq!( + result, + Err(AttachmentError::TypeAlreadyRegisteredAsMultiple( + TypeId::of::() + )) + ); + } + + #[test] + fn test_with_unique_closure() { + let registry = BlockRegistry::new(); + let handle = registry.register_sequence_hash(42); + + handle.attach_unique(42i32).unwrap(); + + // Test with_unique closure using new API + let result = handle.get::().with_unique(|value| *value * 2); + assert_eq!(result, Some(84)); + + // Test with non-existent type + let result = handle.get::().with_unique(|value| *value * 2); + assert_eq!(result, None); + } + + #[test] + fn test_with_all_closure() { + let registry = BlockRegistry::new(); + let handle = registry.register_sequence_hash(42); + + // Use different types since we can't mix unique and multiple for same type + handle.attach_unique(100i32).unwrap(); // unique i32 + handle.attach(1i64).unwrap(); // multiple i64 + handle.attach(2i64).unwrap(); + handle.attach(3i64).unwrap(); + + // Test with_all closure for i32 (should have unique only) using new API + let result = handle.get::().with_all(|unique, multiple| { + let unique_sum = unique.unwrap_or(&0); + let multiple_sum: i32 = multiple.iter().map(|&&x| x).sum(); + unique_sum + multiple_sum + }); + assert_eq!(result, 100); // Only unique value + + // Test with_all closure for i64 (should have multiple only) using new API + let result = handle.get::().with_all(|unique, multiple| { + let unique_sum = unique.unwrap_or(&0); + let multiple_sum: i64 = multiple.iter().map(|&&x| x).sum(); + unique_sum + multiple_sum + }); + assert_eq!(result, 6); // 1 + 2 + 3 + + // Test with non-existent type using new API + let result = handle.get::().with_all(|unique, multiple| { + let unique_sum = unique.unwrap_or(&0); + let multiple_sum: u64 = multiple.iter().map(|&&x| x).sum(); + unique_sum + multiple_sum + }); + assert_eq!(result, 0); + } + + #[test] + fn test_different_types_usage() { + let registry = BlockRegistry::new(); + let handle = registry.register_sequence_hash(42); + + // Define some test types for better demonstration + #[derive(Debug, Clone, PartialEq)] + struct EventPublisher(String); + + #[derive(Debug, Clone, PartialEq)] + struct EventListener(String); + + // Attach unique EventPublisher + handle + .attach_unique(EventPublisher("main_publisher".to_string())) + .unwrap(); + + // Attach multiple EventListeners + handle + .attach(EventListener("listener1".to_string())) + .unwrap(); + handle + .attach(EventListener("listener2".to_string())) + .unwrap(); + + // Retrieve by type using new API + let publisher = handle.get::().with_unique(|p| p.clone()); + assert_eq!( + publisher, + Some(EventPublisher("main_publisher".to_string())) + ); + + let listeners = handle + .get::() + .with_multiple(|listeners| listeners.iter().map(|l| (*l).clone()).collect::>()); + assert_eq!(listeners.len(), 2); + assert!(listeners.contains(&EventListener("listener1".to_string()))); + assert!(listeners.contains(&EventListener("listener2".to_string()))); + + // Test with_all for EventListener (should have no unique, only multiple) + handle.get::().with_all(|unique, multiple| { + assert_eq!(unique, None); + assert_eq!(multiple.len(), 2); + }); + + // Attempting to register EventPublisher as multiple should fail + let result = handle.attach(EventPublisher("another_publisher".to_string())); + assert_eq!( + result, + Err(AttachmentError::TypeAlreadyRegisteredAsUnique( + TypeId::of::() + )) + ); + + // Attempting to register EventListener as unique should fail + let result = handle.attach_unique(EventListener("unique_listener".to_string())); + assert_eq!( + result, + Err(AttachmentError::TypeAlreadyRegisteredAsMultiple( + TypeId::of::() + )) + ); + } + + #[test] + fn test_mutable_access() { + let registry = BlockRegistry::new(); + let handle = registry.register_sequence_hash(42); + + #[derive(Debug, Clone, PartialEq)] + struct UniqueCounter(i32); + + #[derive(Debug, Clone, PartialEq)] + struct MultipleCounter(i32); + + impl UniqueCounter { + fn increment(&mut self) { + self.0 += 1; + } + } + + impl MultipleCounter { + fn increment(&mut self) { + self.0 += 1; + } + } + + // Test unique mutable access + handle.attach_unique(UniqueCounter(0)).unwrap(); + + handle.get::().with_unique_mut(|counter| { + counter.increment(); + counter.increment(); + }); + + // Verify the change + let value = handle + .get::() + .with_unique(|counter| counter.0); + assert_eq!(value, Some(2)); + + // Test mutable access to multiple (different type) + handle.attach(MultipleCounter(10)).unwrap(); + handle.attach(MultipleCounter(20)).unwrap(); + + handle + .get::() + .with_multiple_mut(|counters| { + for counter in counters { + counter.increment(); + } + }); + + // Verify multiple were modified + let total = handle + .get::() + .with_multiple(|counters| counters.iter().map(|c| c.0).sum::()); + assert_eq!(total, 32); // 11 + 21 + } + + #[test] + fn test_with_all_mut_unique() { + let registry = BlockRegistry::new(); + let handle = registry.register_sequence_hash(42); + + #[derive(Debug, Clone, PartialEq)] + struct UniqueValue(i32); + + impl UniqueValue { + fn increment(&mut self) { + self.0 += 1; + } + } + + // Attach unique value + handle.attach_unique(UniqueValue(10)).unwrap(); + + // Test with_all_mut for unique type + handle + .get::() + .with_all_mut(|unique, multiple| { + assert!(unique.is_some()); + assert_eq!(multiple.len(), 0); + if let Some(val) = unique { + val.increment(); + } + }); + + // Verify the change + let value = handle.get::().with_unique(|v| v.0); + assert_eq!(value, Some(11)); + } + + #[test] + fn test_with_all_mut_multiple() { + let registry = BlockRegistry::new(); + let handle = registry.register_sequence_hash(42); + + #[derive(Debug, Clone, PartialEq)] + struct MultipleValue(i32); + + impl MultipleValue { + fn increment(&mut self) { + self.0 += 1; + } + } + + // Attach multiple values + handle.attach(MultipleValue(1)).unwrap(); + handle.attach(MultipleValue(2)).unwrap(); + + // Test with_all_mut for multiple type + handle + .get::() + .with_all_mut(|unique, multiple| { + assert!(unique.is_none()); + assert_eq!(multiple.len(), 2); + for val in multiple { + val.increment(); + } + }); + + // Verify the changes + let total = handle + .get::() + .with_multiple(|values| values.iter().map(|v| v.0).sum::()); + assert_eq!(total, 5); // 2 + 3 + } + + #[test] + fn test_frequency_tracking() { + use super::super::frequency_sketch::TinyLFUTracker; + + let tracker = Arc::new(TinyLFUTracker::new(100)); + let registry = BlockRegistry::with_frequency_tracker(tracker.clone()); + + assert!(registry.has_frequency_tracking()); + assert_eq!(registry.count(42), 0); + + registry.touch(42); + assert_eq!(registry.count(42), 1); + + registry.touch(42); + registry.touch(42); + assert_eq!(registry.count(42), 3); + + let _handle1 = registry.register_sequence_hash(100); + assert_eq!(registry.count(100), 1); + + let _handle2 = registry.match_sequence_hash(100, true); + assert_eq!(registry.count(100), 2); + + let _handle3 = registry.match_sequence_hash(100, false); + assert_eq!(registry.count(100), 2); + } + + #[test] + fn test_no_frequency_tracking() { + let registry = BlockRegistry::new(); + + assert!(!registry.has_frequency_tracking()); + assert_eq!(registry.count(42), 0); + + registry.touch(42); + assert_eq!(registry.count(42), 0); + + let _handle = registry.register_sequence_hash(100); + assert_eq!(registry.count(100), 0); + } + + #[test] + fn test_transfer_registration_no_tracking() { + use super::super::frequency_sketch::TinyLFUTracker; + + let tracker = Arc::new(TinyLFUTracker::new(100)); + let registry = BlockRegistry::with_frequency_tracker(tracker.clone()); + + let _handle1 = registry.transfer_registration(42); + assert_eq!(registry.count(42), 0); + + let _handle2 = registry.register_sequence_hash(100); + assert_eq!(registry.count(100), 1); + } + + #[test] + fn test_concurrent_try_get_block_and_drop() { + use super::super::super::policies::reuse::fifo::FifoReusePolicy; + use super::super::{ + BlockDuplicationPolicy, CompleteBlock, + block::Block, + inactive::{InactivePool, backends::hashmap_backend::HashMapBackend}, + reset::ResetPool, + }; + use crate::tokens::TokenBlockSequence; + + #[derive(Debug, Clone, PartialEq)] + struct TestData { + value: u64, + } + + let registry = BlockRegistry::new(); + + let tokens = vec![1u32, 2, 3, 4]; + let sequence = TokenBlockSequence::from_slice(&tokens, 4, Some(42)); + let token_block = if let Some(block) = sequence.blocks().first() { + block.clone() + } else { + let mut partial = sequence.into_parts().1; + partial.commit().expect("Should be able to commit") + }; + + let seq_hash = token_block.sequence_hash(); + let handle = registry.register_sequence_hash(seq_hash); + + let reset_blocks: Vec<_> = (0..10).map(|i| Block::new(i, 4)).collect(); + let reset_pool = ResetPool::new(reset_blocks, 4); + let reuse_policy = Box::new(FifoReusePolicy::new()); + let backend = Box::new(HashMapBackend::new(reuse_policy)); + let registered_pool = InactivePool::new(backend, &reset_pool); + let pool_return_fn = registered_pool.return_fn(); + + let complete_block = Block::new(0, 4) + .complete(token_block) + .expect("Block size should match"); + + let immutable_block = handle.register_block( + CompleteBlock { + block: Some(complete_block), + return_fn: reset_pool.return_fn(), + }, + BlockDuplicationPolicy::Allow, + pool_return_fn.clone(), + ); + + let handle_clone = handle.clone(); + let pool_return_fn_clone = pool_return_fn.clone(); + let success_count = Arc::new(AtomicUsize::new(0)); + let success_clone = success_count.clone(); + + let handle1 = thread::spawn(move || { + for _ in 0..100 { + if let Some(_block) = + handle_clone.try_get_block::(pool_return_fn_clone.clone()) + { + success_clone.fetch_add(1, Ordering::Relaxed); + } + } + }); + + let handle2 = thread::spawn(move || { + drop(immutable_block); + }); + + handle1.join().unwrap(); + handle2.join().unwrap(); + + assert!( + success_count.load(Ordering::Relaxed) >= 1, + "Should successfully upgrade at least once" + ); + } +} diff --git a/lib/llm/src/block_manager/v2/pools/reset.rs b/lib/llm/src/block_manager/v2/pools/reset.rs new file mode 100644 index 00000000000..f27e55279db --- /dev/null +++ b/lib/llm/src/block_manager/v2/pools/reset.rs @@ -0,0 +1,204 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Thread-safe pool for mutable blocks in reset state with pluggable allocation strategies. +//! +//! The ResetPool manages blocks available for allocation, using: +//! - Pluggable BlockAllocator for flexible allocation strategies +//! - RAII MutableBlock guards for automatic return +//! - Thread-safe access via parking_lot::Mutex + +use super::{Block, BlockAllocator, BlockMetadata, MutableBlock, Reset}; +use parking_lot::Mutex; +use std::{collections::VecDeque, sync::Arc}; + +pub struct ResetPool { + block_allocator: Arc + Send + Sync>>, + return_fn: Arc) + Send + Sync>, + block_size: usize, +} + +impl ResetPool { + pub fn new(blocks: Vec>, block_size: usize) -> Self { + let allocator = DequeBlockAllocator::new(); + Self::from_block_allocator(allocator, blocks, block_size) + } + + pub fn from_block_allocator( + mut allocator: impl BlockAllocator + Send + Sync + 'static, + blocks: Vec>, + block_size: usize, + ) -> Self { + for (i, block) in blocks.iter().enumerate() { + if block.block_id() != i as u64 { + panic!("Block ids must be monotonically increasing starting at 0"); + } + } + + for block in blocks { + allocator.insert(block); + } + + let block_allocator = Arc::new(Mutex::new(allocator)); + + let allocator_clone = block_allocator.clone(); + let return_fn = Arc::new(move |block: Block| { + allocator_clone.lock().insert(block); + }); + + Self { + block_allocator, + return_fn, + block_size, + } + } + + pub fn allocate_blocks(&self, count: usize) -> Option>> { + let mut allocator = self.block_allocator.lock(); + if allocator.len() < count { + return None; + } + + let mut blocks = Vec::with_capacity(count); + for _ in 0..count { + blocks.push(MutableBlock::new( + allocator.pop().unwrap(), + self.return_fn.clone(), + )); + } + + Some(blocks) + } + + pub fn try_allocate_blocks(&self, count: usize) -> Vec> { + let mut blocks = Vec::with_capacity(count); + let mut allocator = self.block_allocator.lock(); + let available_count = std::cmp::min(count, allocator.len()); + + for _ in 0..available_count { + blocks.push(MutableBlock::new( + allocator.pop().unwrap(), + self.return_fn.clone(), + )); + } + + blocks + } + + /// Get the number of available blocks + pub fn available_blocks(&self) -> usize { + self.block_allocator.lock().len() + } + + pub fn len(&self) -> usize { + self.block_allocator.lock().len() + } + + /// Create a return function for blocks to return to this pool + /// This allows other pools to create MutableBlocks that return here + pub(crate) fn return_fn(&self) -> Arc) + Send + Sync> { + self.return_fn.clone() + } + + /// Get the expected block size for this pool + pub(crate) fn block_size(&self) -> usize { + self.block_size + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::block_manager::v2::pools::test_utils::TestData; + + fn create_test_blocks(count: usize) -> Vec> { + (0..count as u64).map(|id| Block::new(id, 4)).collect() + } + + #[test] + fn test_mutable_block_raii_return() { + let blocks = create_test_blocks(3); + let pool = ResetPool::new(blocks, 4); + + assert_eq!(pool.len(), 3); + + { + let allocated = pool.allocate_blocks(2).unwrap(); + assert_eq!(allocated.len(), 2); + assert_eq!(pool.len(), 1); + } + + assert_eq!(pool.len(), 3); + } + + #[test] + fn test_pool_allocation_and_return_cycle() { + let blocks = create_test_blocks(5); + let pool = ResetPool::new(blocks, 4); + + for _ in 0..3 { + assert_eq!(pool.len(), 5); + + { + let allocated = pool.allocate_blocks(2).unwrap(); + assert_eq!(allocated.len(), 2); + assert_eq!(pool.len(), 3); + } + + assert_eq!(pool.len(), 5); + } + } + + #[test] + fn test_try_allocate_blocks_partial() { + let blocks = create_test_blocks(3); + let pool = ResetPool::new(blocks, 4); + + let allocated = pool.try_allocate_blocks(5); + assert_eq!(allocated.len(), 3); + assert_eq!(pool.len(), 0); + } + + #[test] + fn test_allocate_blocks_insufficient() { + let blocks = create_test_blocks(2); + let pool = ResetPool::new(blocks, 4); + + let result = pool.allocate_blocks(3); + assert!(result.is_none()); + assert_eq!(pool.len(), 2); + } +} + +#[derive(Debug)] +pub struct DequeBlockAllocator { + blocks: VecDeque>, +} + +impl Default for DequeBlockAllocator { + fn default() -> Self { + Self::new() + } +} + +impl DequeBlockAllocator { + pub fn new() -> Self { + Self { + blocks: VecDeque::new(), + } + } +} + +impl BlockAllocator for DequeBlockAllocator { + fn insert(&mut self, block: Block) { + self.blocks.push_back(block); + } + + fn pop(&mut self) -> Option> { + self.blocks.pop_front() + } + + fn len(&self) -> usize { + self.blocks.len() + } +} diff --git a/lib/llm/src/integrations/mod.rs b/lib/llm/src/integrations/mod.rs new file mode 100644 index 00000000000..0556aa8a8eb --- /dev/null +++ b/lib/llm/src/integrations/mod.rs @@ -0,0 +1,6 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! External Integrations + +pub mod vllm; diff --git a/lib/llm/src/integrations/vllm/mod.rs b/lib/llm/src/integrations/vllm/mod.rs new file mode 100644 index 00000000000..a42c9c34411 --- /dev/null +++ b/lib/llm/src/integrations/vllm/mod.rs @@ -0,0 +1,10 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! External Integrations + +// #[cfg(feature = "scheduler")] +pub mod scheduler; + +pub mod recorder; +pub mod types; diff --git a/lib/llm/src/integrations/vllm/recorder.rs b/lib/llm/src/integrations/vllm/recorder.rs new file mode 100644 index 00000000000..8609ccc6ddb --- /dev/null +++ b/lib/llm/src/integrations/vllm/recorder.rs @@ -0,0 +1,251 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Scheduler recorder for capturing vLLM scheduler behavior +//! +//! Records scheduler outputs and model runner outputs for replay and testing. + +use super::types::*; +use chrono::Utc; +use serde_json; +use std::collections::HashMap; +use std::fs::File; +use std::io::Write; +use std::path::Path; + +/// Records scheduler interactions for later replay +pub struct SchedulerRecorder { + /// Current iteration counter + iteration: u64, + + /// All recorded iterations + recordings: Vec, + + /// Partial record being built for current iteration + current_record: Option, + + /// Metadata for the recording + metadata: TraceMetadata, +} + +/// Partial record while building an iteration +struct PartialIterationRecord { + iteration: u64, + schedule_output: Option, + model_runner_output: Option, + engine_core_outputs: Option, + timestamp: f64, +} + +impl SchedulerRecorder { + /// Create a new recorder with metadata + pub fn new(model: String, vllm_version: String) -> Self { + Self { + iteration: 0, + recordings: Vec::new(), + current_record: None, + metadata: TraceMetadata { + vllm_version, + model, + timestamp: Utc::now().to_rfc3339(), + total_iterations: 0, + }, + } + } + + /// Record a scheduler output + pub fn record_schedule_output(&mut self, output: SchedulerOutput) { + let timestamp = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs_f64(); + + match &mut self.current_record { + Some(record) => { + record.schedule_output = Some(output); + } + None => { + self.current_record = Some(PartialIterationRecord { + iteration: self.iteration, + schedule_output: Some(output), + model_runner_output: None, + engine_core_outputs: None, + timestamp, + }); + } + } + } + + /// Record a model runner output + pub fn record_model_runner_output(&mut self, output: ModelRunnerOutput) { + match &mut self.current_record { + Some(record) => { + record.model_runner_output = Some(output); + } + None => { + let timestamp = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs_f64(); + + self.current_record = Some(PartialIterationRecord { + iteration: self.iteration, + schedule_output: None, + model_runner_output: Some(output), + engine_core_outputs: None, + timestamp, + }); + } + } + } + + /// Record engine core outputs + pub fn record_engine_core_outputs(&mut self, outputs: EngineCoreOutputs) { + match &mut self.current_record { + Some(record) => { + record.engine_core_outputs = Some(outputs); + } + None => { + let timestamp = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs_f64(); + + self.current_record = Some(PartialIterationRecord { + iteration: self.iteration, + schedule_output: None, + model_runner_output: None, + engine_core_outputs: Some(outputs), + timestamp, + }); + } + } + } + + /// Move to the next iteration + pub fn next_iteration(&mut self) { + // Finalize current record if complete + if let Some(record) = self.current_record.take() { + if record.schedule_output.is_some() + && record.model_runner_output.is_some() + && record.engine_core_outputs.is_some() + { + let complete_record = IterationRecord { + iteration: record.iteration, + schedule_output: record.schedule_output.unwrap(), + model_runner_output: record.model_runner_output.unwrap(), + engine_core_outputs: record.engine_core_outputs.unwrap(), + timestamp: record.timestamp, + }; + self.recordings.push(complete_record); + } else { + eprintln!( + "Warning: Incomplete iteration {} - schedule: {}, model: {}, engine: {}", + record.iteration, + record.schedule_output.is_some(), + record.model_runner_output.is_some(), + record.engine_core_outputs.is_some() + ); + // Still save partial record if needed + if record.schedule_output.is_some() || record.model_runner_output.is_some() { + // Create a minimal complete record with defaults + let complete_record = IterationRecord { + iteration: record.iteration, + schedule_output: record.schedule_output.unwrap_or_else(|| { + SchedulerOutput { + scheduled_new_reqs: Vec::new(), + scheduled_cached_reqs: CachedRequestData { + req_ids: Vec::new(), + resumed_from_preemption: Vec::new(), + new_token_ids: Vec::new(), + new_block_ids: Vec::new(), + num_computed_tokens: Vec::new(), + }, + num_scheduled_tokens: HashMap::new(), + total_num_scheduled_tokens: 0, + scheduled_spec_decode_tokens: HashMap::new(), + scheduled_encoder_inputs: HashMap::new(), + num_common_prefix_blocks: Vec::new(), + finished_req_ids: Vec::new(), + free_encoder_mm_hashes: Vec::new(), + } + }), + model_runner_output: record.model_runner_output.unwrap_or_else(|| { + ModelRunnerOutput { + req_ids: Vec::new(), + req_id_to_index: HashMap::new(), + sampled_token_ids: Vec::new(), + logprobs: None, + prompt_logprobs_dict: HashMap::new(), + num_nans_in_logits: None, + } + }), + engine_core_outputs: record.engine_core_outputs.unwrap_or_else(|| { + EngineCoreOutputs { + engine_index: 0, + outputs: Vec::new(), + timestamp: record.timestamp, + } + }), + timestamp: record.timestamp, + }; + self.recordings.push(complete_record); + } + } + } + + // Increment iteration counter + self.iteration += 1; + self.current_record = None; + } + + /// Get current iteration number + pub fn current_iteration(&self) -> u64 { + self.iteration + } + + /// Save recordings to a JSON file + pub fn save_to_file(&mut self, path: &Path) -> std::io::Result<()> { + // Finalize any pending record + if self.current_record.is_some() { + self.next_iteration(); + } + + // Update metadata + self.metadata.total_iterations = self.recordings.len(); + + let trace = SchedulerTrace { + metadata: self.metadata.clone(), + iterations: self.recordings.clone(), + }; + + let json = serde_json::to_string_pretty(&trace)?; + let mut file = File::create(path)?; + file.write_all(json.as_bytes())?; + + println!("Saved {} iterations to {:?}", self.recordings.len(), path); + Ok(()) + } + + /// Load recordings from a JSON file + pub fn load_from_file(path: &Path) -> std::io::Result { + let file = File::open(path)?; + let trace: SchedulerTrace = serde_json::from_reader(file)?; + Ok(trace) + } + + /// Get the recorded trace + pub fn get_trace(&self) -> SchedulerTrace { + SchedulerTrace { + metadata: self.metadata.clone(), + iterations: self.recordings.clone(), + } + } + + /// Clear all recordings + pub fn clear(&mut self) { + self.recordings.clear(); + self.current_record = None; + self.iteration = 0; + } +} diff --git a/lib/llm/src/integrations/vllm/scheduler.md b/lib/llm/src/integrations/vllm/scheduler.md new file mode 100644 index 00000000000..46478160b9c --- /dev/null +++ b/lib/llm/src/integrations/vllm/scheduler.md @@ -0,0 +1,115 @@ +SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES +SPDX-License-Identifier: Apache-2.0 + +## Rust Scheduler (v1) – Working Design + +This document describes a Rust-native scheduler that mirrors vLLM v1 concepts but adjusts the control flow to better suit Dynamo’s architecture and concurrency model. + +### Goals + +* Extract constants from vLLM configs and keep them in `SchedulerConstants`. +* Make Rust-native data models for requests, outputs, and intermediate scheduler products. +* Separate Python boundary conversions from hot-path Rust logic. +* Parallelize `update_from_output` handling using Rust-owned `OwnedModelRunnerOutput`. +* Introduce a staging thread that prepares “ready-to-execute” requests ahead of scheduling steps. + +### Key Differences from vLLM Python Scheduler + +* We only process requests for the current forward pass and a `ready_to_execute` queue. +* Expensive block matches (both device prefix-cache and external connector) happen in a separate staging thread. +* The Python boundary hands off `ModelRunnerOutput` as a Rust-owned structure; we then fan out per request in parallel (future: rayon or task pools) to amortize Python <-> Rust overhead. + +### Data Models (Rust) + +* `RequestState`: tracks per-request state used by the scheduler. Key fields: + * token sequences: `prompt_token_ids`, `output_token_ids`, `all_token_ids` + * counters: `num_computed_tokens`, `num_cached_tokens`, `num_output_placeholders` + * speculative: `spec_token_ids` + * mm inputs: `mm_positions`, `has_encoder_inputs` + * LoRA, structured output, `client_index`, `priority` +* `SchedulerOutput`: mirrors Python’s `SchedulerOutput` but in Rust types, split into `new_requests` and `cached_requests`, plus per-request token schedules and batch signals (finished ids, prefix blocks, etc.). +* `OwnedModelRunnerOutput`: Rust-owned copy of the Python `ModelRunnerOutput`, holding Vecs/Maps that are GIL-less. + +#### vLLM docstrings summary (relevant parts) + +- `Scheduler.schedule()` (Python): produces per-step scheduling for a single forward pass; respects `max_num_batched_tokens`, `max_num_seqs`, `max_model_len`; handles preemption, prefix caching, speculative decoding, structured output, encoder inputs; may include finished request IDs and encoder free lists. +- `Scheduler.update_from_output()`: consumes generated token IDs, adjusts for spec decoding rejections, handles stop conditions, frees requests (including encoder/KV cache), and returns grouped `EngineCoreOutputs` per client with optional stats. +- `Request`: maintains `num_tokens_with_spec = prompt + output + speculative`, `num_computed_tokens` advanced post-schedule, structured-output FSM hooks, `cache_salt` for hashing, and status transitions. +- `SchedulerOutput`: two lists (`scheduled_new_reqs`, `scheduled_cached_reqs`) plus maps of scheduled tokens, spec tokens, encoder inputs, grammar bitmask, finished IDs, and connector metadata. + +### Interfaces + +* `Scheduler` trait: `schedule`, `update_from_output`, `add_request`, `finish_requests`, `get_request_counts`, `has_finished_requests`, `reset_prefix_cache`, `shutdown`, `get_kv_connector`. +* `KvCacheManager` trait: abstract device KV operations (allocate slots, free, cache blocks, prefix computations). +* `KvConnector` trait: optional p/d disaggregation hooks, including matched-token queries and post-alloc updates. +* `StructuredOutputManager` trait: grammar bitmask and FSM advancement hook. + +### Control Flow + +1. Staging Thread (separate): + - On new requests or when recovering from preemption, perform: + - local prefix cache matches (`KvCacheManager.get_computed_blocks`) + - external matches (`KvConnector.get_num_new_matched_tokens`) + - Update prepared state (e.g., `num_computed_tokens`) and enqueue `ready_to_execute` item. + +2. `schedule()`: + - Consume token budget for currently `running` requests first. + - Then pop from `ready_to_execute` to schedule “new” prepared requests. + - Allocate blocks via `KvCacheManager.allocate_slots` with spec lookahead when enabled. + - Build `SchedulerOutput`, set `num_common_prefix_blocks`, optionally add connector metadata. + - Advance `num_computed_tokens` after building the output. + +3. `update_from_output()`: + - For each scheduled request, merge generated token ids, adjust for speculative rejections, check stop, and emit `EngineCoreOutputs` grouped by `client_index`. + - Update connector finished-sending/receiving state and free blocks as needed. + - Attach `finished_requests` sets if configured. + +### Parallelizing ModelRunnerOutput + +* Python side converts its `ModelRunnerOutput` into a Rust-owned `OwnedModelRunnerOutput` (no torch tensors crossing the boundary for the hot path). +* In Rust, process each request’s outputs in parallel (future: `rayon` or work-stealing executor). The current sketch processes sequentially but isolates the loop so parallelization is straightforward. +* Alternative: Python-side channel enqueue per-request items; Rust side drains and processes concurrently. + +### Open TODOs + +* Implement stop conditions (EOS, max length, guided decoding). +* Integrate real logprobs and pooling tensors as zero-copy or pinned views. +* Add encoder-cache budget management and free-on-advance logic. +* Implement preemption policy under allocation pressure. +* Wire structured output manager and grammar bitmask generation. +* Provide concrete KvCacheManager and KvConnector adapters to existing components. +* Add feature-gated rayon parallelism for `update_from_output`. + +### Testing Plan +### Worst-Case Projections and Offload Strategy + +We introduce a projection subsystem to anticipate KV block pressure and plan proactive offload to host memory. + +- Inputs: + - `block_size`, `gpu_total_blocks` from cache config. + - Per-request `current_tokens`, `max_tokens`, `num_computed_tokens`. +- Definitions: + - `wc_until_first_complete`: in worst-case (1 token/request/step), minimum steps until any request completes. + - `wc_until_block_starvation`: the first future step where total required blocks would exceed `gpu_total_blocks`, assuming no early termination. + - `predicted_blocks_per_pass`: total blocks trajectory across the next K steps. +- Heuristics: + - Model decoding with `tokens_per_pass_per_request=1`; increase to approximate chunked prefills. + - Use ceil division `ceil(seq_len / block_size)` to estimate blocks. + - If a starvation is predicted at pass `k`, we target freeing enough blocks before `k` by offloading one or more requests to host. + - Selection policy: greedy on fewest blocks first to minimize movement and ensure complete offload(s). Prefer pausing on block boundaries. +- Data model (Rust): + - `ProjectionParams`, `RequestProjection`, `WorstCaseProjection`, `OffloadCandidate`, `OffloadPlan`. +- Lifecycle: + - Compute projections at the start of `schedule()` or periodically. + - If `wc_until_block_starvation` is Some and less than a small threshold, build an `OffloadPlan`, pause and offload selected requests, and re-run scheduling. + +This allows maintaining steady throughput by avoiding mid-iteration allocation failures and by freeing GPU KV blocks predictably. + + +* Unit tests for: + * token budgeting and `num_computed_tokens` progression + * staging updates and ready queue semantics + * connector callbacks on alloc/finish +* Integration tests against a thin Python harness using vLLM fixtures. + + diff --git a/lib/llm/src/integrations/vllm/scheduler/mod.rs b/lib/llm/src/integrations/vllm/scheduler/mod.rs new file mode 100644 index 00000000000..ce2f6a9fa1f --- /dev/null +++ b/lib/llm/src/integrations/vllm/scheduler/mod.rs @@ -0,0 +1,1689 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! v1 Scheduler (Rust) – high-level sketch and scaffolding +//! +//! This module provides a Rust-native scheduler interface and data models that +//! mirror the vLLM v1 scheduler concepts, with deliberate deviations in the +//! control flow to enable: +//! - a staging thread that performs expensive pre-checks and block matches +//! - a ready-to-execute queue for the current forward pass +//! - parallel processing of per-request outputs from the model runner +//! +//! The implementation below focuses on clear, strongly typed Rust APIs and +//! the separation of Python boundary conversions from the hot-path logic. +//! Many functions are left as stubs with detailed TODO notes. +//! +//! PERFORMANCE NOTES: +//! - Uses Arc> to allow concurrent access to different requests +//! - Avoids cloning Arc where possible (just borrow references) +//! - Only clones strings when building output structs +//! - Mutexes are per-request, so different requests can be accessed concurrently +//! - Rust's borrow checker prevents multiple mutable access to HashMap values simultaneously +//! (even if just mutating the values, not the map structure), hence the mutex pattern +//! - Consider rayon::scope for parallel processing in update_from_output when performance is critical + +pub mod worker; + +use derive_getters::Dissolve; +use tokenizers::Token; +use tokio::sync::{Notify, OwnedSemaphorePermit, Semaphore}; + +use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet, VecDeque}; +use std::hash::{Hash, Hasher}; +use std::sync::atomic::AtomicU64; +use std::sync::{Arc, Mutex}; + +// NOTE: Keep this module self-contained and compilation-safe. +// We avoid depending on other internal modules until integration time. + +/// Unique identifier for a request. +pub type RequestId = String; + +/// Zero-based index of a client (used to route outputs back to the origin). +pub type ClientId = i32; + +/// Logical block identifier in the KV cache. +pub type BlockId = u32; + +/// Unix monotonic timestamp in seconds. +pub type MonotonicTs = f64; + +/// Small wrapper to indicate the configured LoRA limits. +/// +/// vLLM enforces a maximum number of concurrent LoRA adapters for a step. +/// When the limit is reached, additional requests using different LoRA IDs are +/// deferred for the current step. We mirror that accounting here. +#[derive(Clone, Copy, Debug, Default)] +pub struct LoraLimits { + pub max_loras: Option, +} + +/// Constants extracted from vLLM configs (scheduler + cache + parallel + spec). +/// +/// Initialize from vLLM: +/// - `scheduler_config.max_num_seqs`, `max_num_batched_tokens`, `max_model_len` +/// - `cache_config.block_size`, `cache_config.num_gpu_blocks (> 0)` +/// - `parallel_config.pipeline_parallel_size (> 1 => PP enabled)` +/// - `speculative_config.num_speculative_tokens` (if enabled) +/// - `include_finished_set`, `log_stats` +#[derive(Clone, Debug)] +pub struct SchedulerConstants { + /// Size of a KV block in tokens. vLLM: `cache_config.block_size`. + /// Used for ceil-div computations of memory and for offload planning. + pub block_size: usize, + + /// Maximum sequence length (prompt + output) supported by model. + /// vLLM: `scheduler_config.max_model_len`. Used to clamp scheduled tokens. + pub max_model_len: usize, + /// Maximum number of concurrently running sequences. vLLM: `max_num_seqs`. + pub max_num_seqs: usize, + /// Global per-step token budget across requests. vLLM: `max_num_batched_tokens`. + pub max_num_batched_tokens: usize, + + /// Cap for prefill chunk size used to bound latency. vLLM may split long + /// prefills using this threshold. If set, we min(num_new_tokens, threshold). + pub long_prefill_token_threshold: Option, + + /// Whether prefill can be chunked even if it would exceed the remaining + /// token budget. If false, such requests are deferred. Mirrors vLLM logic. + pub chunked_prefill_enabled: bool, + /// When true, multimodal encoder inputs are not chunked; each item must + /// be scheduled whole or postponed. Mirrors vLLM `disable_chunked_mm_input`. + pub disable_chunked_mm_input: bool, + + /// Whether pipeline parallelism is enabled (PP > 1). If true, scheduler + /// may need to return `new_token_ids` back to the first-stage worker. + pub pipeline_parallel: bool, + + /// Include finished request IDs in outputs for efficient lifetime tracking + /// in multi-engine setups. Mirrors vLLM option. + pub include_finished_set: bool, + /// Emit stats records each step (prefix cache, spec decoding, counters). + pub log_stats: bool, + + /// Whether speculative decoding is enabled and the number of speculative + /// tokens (lookahead) per request per step. + pub use_spec_decode: bool, + pub num_spec_tokens: usize, + + /// LoRA concurrency limits for each step. + pub lora_limits: LoraLimits, + /// Total device KV blocks available to allocate (from vLLM cache config). + pub total_gpu_blocks: usize, +} + +impl SchedulerConstants { + pub fn token_budget(&self) -> usize { + self.max_num_batched_tokens + } +} + +/// Status of a request – mirrors `vllm.v1.request.RequestStatus`. +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] +pub enum RequestStatus { + /// The request is queued and waiting to be scheduled. + Waiting, + /// Waiting for FSM compilation for structured output (guided decoding). + WaitingForFsm, + /// Waiting for remote KV transfers to complete before resuming. + WaitingForKvLoad, + /// Currently scheduled/running in this or a previous step. + Running, + /// Temporarily descheduled due to resource pressure (candidate for resume). + Preempted, + /// Terminal states (anything after PREEMPTED is considered finished): + FinishedStopped, + FinishedLengthCapped, + FinishedAborted, + FinishedIgnored, +} + +impl RequestStatus { + pub fn is_finished(self) -> bool { + self > RequestStatus::Preempted + } +} + +/// Minimal LoRA descriptor required by the scheduler for accounting. +#[derive(Clone, Debug, Default)] +pub struct LoraRequestLight { + pub lora_int_id: i64, +} + +/// Placeholder for structured-output support. +#[derive(Clone, Debug, Default)] +pub struct StructuredOutputLight { + pub enabled: bool, +} + +/// Placeholder for multimodal positions (encoder inputs). +#[derive(Clone, Debug, Default)] +pub struct MmPosition { + pub offset: usize, + pub length: usize, +} + +/// State tracked per request inside the scheduler. +#[derive(Clone, Debug)] +pub struct RequestState { + /// Unique request id (string). + pub request_id: RequestId, + + /// Originating client index; used to route outputs to the right frontend. + pub client_id: ClientId, + + /// Larger means higher priority (used for preemption policies). + pub priority: i32, + + /// Current status; transitions are driven by schedule and runner outputs. + pub status: RequestStatus, + + /// Request-specific EOS token (can differ per LoRA). Used for stop checks. + pub eos_token_id: Option, + + /// LoRA context for accounting (see `LoraLimits`). + pub lora: Option, + + /// Salt Hash + pub salt_hash: Option, + + /// Whether structured decoding is enabled; integrates with grammar bitmasks. + pub structured_output: StructuredOutputLight, + + /// Monotonic arrival time; used in tie-breaking and stats. + pub arrival_time: MonotonicTs, + + /// Prompt tokens received at request creation. + pub prompt_token_ids: Vec, + /// Generated tokens so far (grows over steps). + pub output_token_ids: Vec, + /// Concatenation of prompt + output (read-only view in vLLM). + pub all_token_ids: Vec, + + /// Used by async scheduling to pre-reserve positions. + pub num_output_placeholders: usize, + /// Draft tokens for speculative decoding (to be validated next step). + pub spec_token_ids: Vec, + /// Tokens that have been computed (prefix-cached + executed this step). + /// vLLM advances this AFTER schedule() and may adjust in update_from_output + /// to account for rejected speculative tokens. + pub num_computed_tokens: usize, + /// Number of tokens served by prefix cache (>= 0 once known). + pub num_cached_tokens: i32, + /// Indicator of corrupted outputs (NaNs in logits); > 0 means corrupted. + pub num_nans_in_logits: i32, + + /// Positions and lengths of multimodal encoder inputs in the token stream. + pub mm_positions: Vec, + /// Whether the request has encoder inputs (e.g., images for LMMs). + pub has_encoder_inputs: bool, + + /// Request-specific salt used in block hashing. In vLLM this is + /// `Request.cache_salt`; we use it as the KV block hashing salt (aka + /// salt_hash) when creating connector slots or computing block hashes. + pub cache_salt: Option, + /// Maximum allowed output tokens for this request. In vLLM derived from + /// `sampling_params.max_tokens` or set to 1 for pooling models. + /// Used for worst-case projections. + pub max_tokens: usize, +} + +impl RequestState { + pub fn num_tokens(&self) -> usize { + self.all_token_ids.len() + } + + pub fn num_tokens_with_spec(&self) -> usize { + self.all_token_ids.len() + self.spec_token_ids.len() + } + + pub fn use_structured_output(&self) -> bool { + self.structured_output.enabled + } +} + +/// Data for requests scheduled for the first time in a step. +#[derive(Clone, Debug, Dissolve)] +pub struct NewRequestData { + /// Request id for the new (first-time scheduled) request. + pub request_id: RequestId, + /// Prompt tokens to be cached by workers so we don't resend every step. + pub prompt_token_ids: Vec, + /// New block ids allocated this step per KV cache group. + pub block_ids: Vec>, // per cache group + /// Value of `num_computed_tokens` after schedule() for the request. + pub num_computed_tokens: usize, + /// Optional LoRA metadata to cache on workers. + pub lora: Option, + /// Hashing Salt + pub salt_hash: Option, +} + +/// Data for requests that were seen before; we send only incremental info. +#[derive(Clone, Debug, Default)] +pub struct CachedRequestData { + /// Request id for a request that has been seen in previous steps. + pub request_id: RequestId, + /// If true, indicates this request was preempted and has just resumed, + /// thus the provided block ids should replace, not append. + pub resumed_from_preemption: bool, + /// When pipeline parallelism is enabled, sampled token ids to return to + /// the first-stage worker; otherwise left empty. + pub new_token_ids: Vec, + /// New block ids allocated this step per KV cache group. + pub new_block_ids: Vec>, // per cache group + /// `num_computed_tokens` before applying the tokens scheduled this step. + pub num_computed_tokens: usize, +} + +/// Batch-scoped output from the scheduler. +#[derive(Clone, Debug, Default)] +pub struct SchedulerOutput { + /// New requests scheduled for the first time this step. + pub new_requests: Vec, + /// Previously-seen requests with incremental diffs for this step. + pub cached_requests: Vec, + + /// Per-request tokens scheduled this step (before spec rejections). + pub num_scheduled_tokens: HashMap, + /// Sum of all scheduled tokens. + pub total_num_scheduled_tokens: usize, + + /// If present, the speculative draft tokens scheduled for validation. + pub scheduled_spec_decode_tokens: HashMap>, + /// Encoder input indices to process in this step per request. + pub scheduled_encoder_inputs: HashMap>, // indices into mm inputs + /// Number of common prefix blocks across requests per KV group (for cascade attention). + pub num_common_prefix_blocks: Vec, + + /// Requests that finished between previous and current steps; used by + /// workers to free per-request cached state. + pub finished_req_ids: BTreeSet, + /// Pairs of (req_id, encoder_input_index) to free from encoder caches. + pub free_encoder_input_ids: Vec<(RequestId, usize)>, + + /// Mapping from req_id to its index in the batch for grammar bitmask slicing. + pub structured_output_request_ids: HashMap, +} + +/// Per-request output returned to the engine frontend. +#[derive(Clone, Debug, Default)] +pub struct EngineCoreOutput { + /// Id of the request. + pub request_id: RequestId, + /// Newly generated token ids for this step. + pub new_token_ids: Vec, + pub new_logprobs: Option<()>, // TODO: add logprobs tensors/lists when needed + pub new_prompt_logprobs_tensors: Option<()>, + pub pooling_output: Option<()>, // TODO: add tensor handle when needed + pub finish_reason: Option, // Map to FinishReason codes + pub stop_reason: Option, + pub events: Option>, // TODO + /// Connector-specific metadata to instruct workers on KV transfer. + pub kv_transfer_params: Option, + /// The number of tokens served from prefix cache (for stats/clients). + pub num_cached_tokens: i32, +} + +/// Grouped outputs for a client index. +#[derive(Clone, Debug, Default)] +pub struct EngineCoreOutputs { + /// Outputs for all requests that originated from this client. + pub outputs: Vec, + /// Optional set of requests that finished since last step (multi-engine). + pub finished_requests: Option>, +} + +/// Owned, Rust-native representation of `vllm.v1.outputs.ModelRunnerOutput`. +#[derive(Clone, Debug, Default)] +pub struct OwnedModelRunnerOutput { + /// Request ids in the same order as outputs (batch order). + pub req_ids: Vec, + /// Inverse map from req_id to its index in arrays. + pub req_id_to_index: HashMap, + /// Generated token ids per request for this step (variable length per req). + pub sampled_token_ids: Vec>, // per-request variable length + /// Speculative draft tokens per request (None if spec decoding disabled). + pub spec_token_ids: Option>>, + /// Optional sampled logprobs (lists or tensors) — converted to Rust-friendly form. + pub logprobs: Option<()>, // TODO + /// Prompt logprobs per request id (None for non-prefill steps). + pub prompt_logprobs_dict: HashMap>, // TODO + /// Optional pooling outputs per request (embeddings for pooler models). + pub pooler_output: Vec>, // TODO + /// From connector output: req_ids that finished receiving remote KVs. + pub kv_connector_finished_recving: Option>, + /// From connector output: req_ids that finished sending and can be freed. + pub kv_connector_finished_sending: Option>, + /// Number of NaNs observed in logits per request (detect corruption). + pub num_nans_in_logits: Option>, +} + +impl OwnedModelRunnerOutput { + pub fn is_empty(&self) -> bool { + self.req_ids.is_empty() + } +} + +/// Scheduling policy – limited to priority and FCFS for now. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum SchedulingPolicy { + Priority, + Fcfs, +} + +/// Interface for (pluggable) KV cache manager used by the scheduler. +/// +/// Note: This trait is designed to be called from the scheduler's main thread. +/// If you need concurrent access from multiple threads, consider wrapping the +/// implementation in Arc> or using interior mutability patterns. +pub trait KvCacheManager: Send + Sync { + /// Try to allocate device slots for `num_new_tokens` (plus lookahead if any). + /// + /// Called by the scheduler during `schedule()` for both running requests + /// and newly staged requests. If allocation fails (returns None), the + /// scheduler may attempt preemption (or skip) and try again later. + fn allocate_slots( + &self, + request: &RequestState, + num_new_tokens: usize, + num_lookahead_tokens: usize, + ) -> Option>>; + + /// Free all resources for a finished request. + /// + /// Called when a request transitions to any finished status or when the + /// connector indicates GPU blocks can be reclaimed. + fn free(&self, request: &RequestState); + + /// Called to update prefix cache stats or hashes when a request is freed. + fn free_block_hashes(&self, _request: &RequestState) {} + + /// Compute common-prefix blocks across the batch; sized by kv cache groups. + /// + /// Called at the end of `schedule()` to inform potential cascade attention + /// optimizations downstream. + fn get_num_common_prefix_blocks( + &self, + _sample: &RequestState, + _batch_size: usize, + ) -> Vec { + vec![] + } + + /// For new requests: get locally computed blocks/tokens (prefix cache hits). + /// + /// Typically used by a staging thread to determine how many tokens can be + /// served from local cache and to initialize `num_computed_tokens`. + fn get_computed_blocks(&self, _request: &RequestState) -> (Vec>, usize) { + (vec![], 0) + } + + /// Return block ids for a given request (all groups concatenated if needed). + /// + /// Used when finalizing a request to free resources or by connectors to + /// identify which blocks are tied to the request. + fn get_block_ids(&self, _request_id: &RequestId) -> (Vec,); + + /// Cache blocks up to `num_computed_tokens` for a request (prefix cache). + /// + /// Used after remote KV reception to mark blocks cacheable for future reuse. + fn cache_blocks(&self, _request: &RequestState, _num_computed_tokens: usize) {} +} + +// KV blocks can be loaded async or sync + +// KV blocks loads can be managed by the leader or workers +// if managed by the workers, then + +/// The method of loading KV blocks +pub enum KvLoadType { + LeaderAsync, + LeaderSync, + WorkerAsync, + WorkerSync, +} + +pub struct ConnectorMatchResult { + /// The number of tokens from the start of the sequences that are already + /// present and provided by the caller + pub num_computed_tokens: usize, + + /// The number of tokens that are matched by the connectors data sources. + /// These tokens represent the range in the input sequence + /// `(num_computed_tokens..num_connector_tokens)` + pub num_connector_tokens: usize, + + /// The method used to load KV blocks + pub load_type: KvLoadType, +} + +/// Optional KV connector interface used for P/D disaggregation and offloading. +/// +/// Note: This trait is designed to be called from the scheduler's main thread. +/// If you need concurrent access from multiple threads, consider wrapping the +/// implementation in Arc> or using interior mutability patterns. +pub trait KvConnector: Send + Sync { + /// Determine how many tokens match external caches and whether to load + /// remote KVs asynchronously. In vLLM, called for waiting requests during + /// scheduling; in this design, expected to run in a staging thread. + fn get_matched_connector_tokens( + &self, + request: &RequestState, + num_computed_tokens: usize, + ) -> (usize, /*load_kv_async*/ bool); + + /// Called after device allocation to update connector-internal state and + /// possibly trigger onboarding of remote tokens into the allocated blocks. + fn update_state_after_alloc( + &self, + request: &RequestState, + all_block_ids_for_step: &[Vec], + num_external_computed_tokens: usize, + ); + + /// Build opaque connector metadata to be attached to the scheduler output + /// and consumed by workers to execute KV transfers. + fn build_connector_metadata(&self, _output: &SchedulerOutput) -> Option> { + None + } + + /// Notify the connector that a request has finished. Returns whether block + /// freeing must be delayed (true) and optional kv_transfer_params to attach + /// to the client's output. + fn request_finished( + &self, + _request: &RequestState, + _block_ids: &[BlockId], + ) -> (/*delay_free*/ bool, Option) { + (false, None) + } + + /// Update connector state with finished send/recv notifications coming from + /// the worker processes for the previous step. + fn update_connector_output( + &self, + _finished_recving: &BTreeSet, + _finished_sending: &BTreeSet, + ) { + } +} + +/// Structured output manager – used to build grammar bitmasks and advance FSMs. +pub trait StructuredOutputManager: Send + Sync { + /// Build grammar bitmask for structured decoding for the entire batch, + /// slicing per request based on `structured_output_request_ids`. + fn grammar_bitmask( + &self, + _requests: &HashMap>>, + structured_output_request_ids: &HashMap, + scheduled_spec_decode_tokens: &HashMap>, + ) -> Option>; // placeholder for NDArray + + /// Whether the structured output FSM should advance based on new tokens. + fn should_advance(&self, _request: &RequestState) -> bool { + false + } +} + +/// The main scheduler trait – mirrors vLLM's `SchedulerInterface` with Rust-native models. +pub trait Scheduler: Send + Sync { + /// Plan a single scheduling step (one forward pass worth of work), subject + /// to token and sequence budgets. Returns a `SchedulerOutput` with new and + /// cached request diffs for workers to prepare inputs and execute. + fn schedule(&mut self) -> SchedulerOutput; + + /// Consume model runner outputs to update internal state, produce client + /// outputs, and free finished requests. Must handle speculative token + /// rejections and connector updates. + fn update_from_output( + &mut self, + scheduler_output: &SchedulerOutput, + model_runner_output: &OwnedModelRunnerOutput, + ) -> HashMap; + + /// Enqueue a new request into the scheduler. A staging thread may process + /// it before `schedule()` moves it to the running set. + fn add_request(&mut self, request: RequestState); + + /// External finish (abort/stop) for one or more requests. Frees resources + /// and generates finished events on next step. + fn finish_requests(&mut self, request_ids: &[RequestId], finished_status: RequestStatus); + + /// Returns number of running and waiting requests. + fn get_request_counts(&self) -> (usize, usize); + + /// Whether there are finished requests to be returned in next outputs. + fn has_finished_requests(&self) -> bool; + + fn reset_prefix_cache(&mut self) -> bool { + false + } + + fn shutdown(&mut self) {} + + fn get_kv_connector(&self) -> Option> { + None + } +} + +/// Internal ready-to-execute queue item. +#[derive(Clone, Debug)] +struct ReadyItem { + request_id: RequestId, + // Additional bookkeeping for staged data could be added here. +} + +/// A basic FCFS/priority queue for waiting requests. +#[derive(Default)] +struct RequestQueue { + fcfs: VecDeque, +} + +impl RequestQueue { + fn add_request(&mut self, req_id: RequestId) { + self.fcfs.push_back(req_id); + } + fn peek_request(&self) -> Option<&RequestId> { + self.fcfs.front() + } + fn pop_request(&mut self) -> Option { + self.fcfs.pop_front() + } + fn remove_requests(&mut self, to_remove: &HashSet) { + self.fcfs.retain(|r| !to_remove.contains(r)); + } + fn is_empty(&self) -> bool { + self.fcfs.is_empty() + } +} + +/// Our reference scheduler implementation. +/// +/// Performance notes: +/// - Uses Arc> to allow concurrent access to different requests +/// - Avoids cloning Arc where possible (just borrow references) +/// - Only clones strings when building output structs +/// - Mutexes are per-request, so different requests can be accessed concurrently +/// - This pattern is necessary because Rust cannot statically guarantee that +/// multiple mutable references to HashMap values are safe (values might have +/// internal references to each other or the map) +pub struct RustScheduler { + consts: SchedulerConstants, + + kv_cache: Arc, + structured: Option>, + connector: Option>, + + policy: SchedulingPolicy, + + requests: HashMap>>, + waiting: RequestQueue, + running: Vec, + ready_to_execute: VecDeque, + + finished_req_ids: BTreeSet, + finished_req_ids_dict: Option>>, + /// Secondary budget for preparing requests (separate from main scheduling budget). + preparation_budget: PreparationBudget, + /// Preparation state for each request. + preparation_states: HashMap, +} + +// impl RustScheduler { +// pub fn new( +// consts: SchedulerConstants, +// kv_cache: Arc, +// structured: Option>, +// connector: Option>, +// policy: SchedulingPolicy, +// ) -> Self { +// let finished_req_ids_dict = if consts.include_finished_set { +// Some(HashMap::new()) +// } else { +// None +// }; +// let total_gpu_blocks = consts.total_gpu_blocks; +// Self { +// consts, +// kv_cache, +// structured, +// connector, +// policy, +// requests: HashMap::new(), +// waiting: RequestQueue::default(), +// running: Vec::new(), +// ready_to_execute: VecDeque::new(), +// finished_req_ids: BTreeSet::new(), +// finished_req_ids_dict, +// preparation_budget: PreparationBudget::new(total_gpu_blocks / 4), // Use 25% of GPU blocks for preparation +// preparation_states: HashMap::new(), +// } +// } + +// fn advance_num_computed_tokens(&mut self, output: &SchedulerOutput) { +// for (req_id, n) in &output.num_scheduled_tokens { +// if let Some(req) = self.requests.get(req_id) { +// let mut req = req.lock().unwrap(); +// req.num_computed_tokens = req.num_computed_tokens.saturating_add(*n); +// // TODO: encoder-freeing logic can be placed here (after num_computed_tokens updates) +// } +// } +// self.finished_req_ids.clear(); +// } + +// /// Free request resources and emit finished ids. Connector may delay GPU +// /// block freeing if remote copy-out is pending; in that case, the worker +// /// will later notify via `update_from_output` connector outputs. +// fn free_request(&mut self, request_id: &RequestId) -> Option { +// // Free KV blocks and hashes; record finished ids. +// if let Some(req) = self.requests.remove(request_id) { +// let req_guard = req.lock().unwrap(); +// let block_ids_tuple = self.kv_cache.get_block_ids(&req_guard.request_id); +// let (delay_free, kv_meta) = if let Some(conn) = &self.connector { +// conn.request_finished(&req_guard, &block_ids_tuple.0) +// } else { +// (false, None) +// }; +// if !delay_free { +// self.kv_cache.free(&req_guard); +// self.kv_cache.free_block_hashes(&req_guard); +// } +// self.finished_req_ids.insert(request_id.clone()); +// if let Some(map) = &mut self.finished_req_ids_dict { +// map.entry(req_guard.client_id) +// .or_default() +// .insert(request_id.clone()); +// } +// return kv_meta; +// } +// None +// } + +// // ------------------------------- +// // Worst-Case Projection Utilities +// // ------------------------------- + +// /// Projection parameters controlling modeling assumptions. +// /// - `tokens_per_pass_per_request`: use 1 to model decoding; higher for chunked prefill. +// /// - `consider_only_running`: if true, ignore waiting requests in projections. +// pub fn projection_params_default(&self) -> ProjectionParams { +// ProjectionParams { +// tokens_per_pass_per_request: 1, +// consider_only_running: true, +// } +// } + +// /// Compute worst-case projections: +// /// - `wc_until_first_complete`: min passes to finish among active requests +// /// - `wc_until_block_starvation`: first pass when total blocks > capacity +// /// - `predicted_blocks_per_pass`: total blocks trajectory across horizon +// pub fn compute_worst_case_projection(&self, params: &ProjectionParams) -> WorstCaseProjection { +// let mut per_request: Vec = Vec::new(); +// let block_size = self.consts.block_size; + +// let iter_ids: Vec<_> = if params.consider_only_running { +// self.running.iter().collect() +// } else { +// self.requests.keys().collect() +// }; + +// for req_id in iter_ids { +// if let Some(req_arc) = self.requests.get(req_id) { +// let req = req_arc.lock().unwrap(); +// let current_tokens = req.num_tokens(); +// let current_blocks = div_ceil(current_tokens, block_size); +// let remaining_tokens = req.max_tokens.saturating_sub(req.output_token_ids.len()); +// let steps_until_completion = remaining_tokens; // 1 token per pass +// per_request.push(RequestProjection { +// request_id: req.request_id.clone(), +// current_tokens, +// current_blocks, +// remaining_tokens, +// steps_until_completion, +// }); +// } +// } + +// let wc_until_first_complete = per_request +// .iter() +// .map(|r| r.steps_until_completion) +// .min() +// .unwrap_or(0); + +// let mut predicted_blocks_per_pass: Vec = Vec::new(); +// let mut wc_until_block_starvation: Option = None; +// let total_capacity = self.consts.total_gpu_blocks; + +// let horizon = wc_until_first_complete.max(1); +// for k in 0..=horizon { +// let mut total_blocks = 0usize; +// for r in &per_request { +// let incr = (k as usize) +// .saturating_mul(params.tokens_per_pass_per_request) +// .min(r.remaining_tokens); +// let future_tokens = r.current_tokens + incr; +// total_blocks = total_blocks.saturating_add(div_ceil(future_tokens, block_size)); +// } +// predicted_blocks_per_pass.push(total_blocks); +// if wc_until_block_starvation.is_none() && total_blocks > total_capacity { +// wc_until_block_starvation = Some(k); +// } +// } + +// WorstCaseProjection { +// per_request, +// wc_until_first_complete, +// wc_until_block_starvation, +// predicted_blocks_per_pass, +// } +// } + +// /// Plan which requests to offload to host to free a target number of blocks. +// /// Greedy heuristic: pick the requests with the fewest blocks first. +// pub fn plan_offload_for_deficit(&self, deficit_blocks: usize) -> OffloadPlan { +// let block_size = self.consts.block_size; +// let mut candidates: Vec<(RequestId, usize)> = Vec::new(); + +// for (req_id, req_arc) in &self.requests { +// let req = req_arc.lock().unwrap(); +// let current_blocks = div_ceil(req.num_tokens(), block_size); +// if current_blocks > 0 { +// candidates.push((req_id.clone(), current_blocks)); +// } +// } + +// candidates.sort_by_key(|(_, blocks)| *blocks); + +// let mut selected: Vec = Vec::new(); +// let mut freed = 0usize; +// for (req_id, blocks) in candidates { +// selected.push(OffloadCandidate { +// request_id: req_id, +// blocks, +// }); +// freed = freed.saturating_add(blocks); +// if freed >= deficit_blocks { +// break; +// } +// } + +// OffloadPlan { +// selected, +// total_blocks_freed: freed, +// satisfied: freed >= deficit_blocks, +// } +// } + +// // ------------------------------- +// // Request Preparation & Staging +// // ------------------------------- + +// /// Start preparing a request by checking external blocks and allocating GPU blocks. +// /// This is called from the staging thread when a request is ready to be prepared. +// pub fn start_request_preparation(&mut self, request_id: &RequestId) -> bool { +// let Some(req_arc) = self.requests.get(request_id) else { +// return false; +// }; + +// // Get current request state +// let req_guard = req_arc.lock().unwrap(); +// let prompt_tokens = req_guard.prompt_token_ids.len(); +// let max_output_tokens = req_guard.max_tokens; +// let total_tokens = prompt_tokens + max_output_tokens; +// let total_blocks_needed = div_ceil(total_tokens, self.consts.block_size); + +// // Check GPU cache for prefix matches +// let (block_ids, num_matched_gpu_blocks) = self.kv_cache.get_computed_blocks(&req_guard); +// let gpu_matched_tokens = num_matched_gpu_blocks * self.consts.block_size; + +// // Check external sources for additional matches +// let external_matched_tokens = if let Some(conn) = &self.connector { +// conn.get_matched_connector_tokens(&req_guard, gpu_matched_tokens) +// } else { +// (0, false) +// }; +// let external_matched_blocks = div_ceil(external_matched_tokens, self.consts.block_size); + +// // Calculate remaining blocks needed +// let remaining_blocks = +// total_blocks_needed.saturating_sub(gpu_matched_blocks + external_matched_blocks); + +// // Check if we can allocate the remaining blocks within preparation budget +// if !self.preparation_budget.can_allocate(remaining_blocks) { +// // Not enough budget - pause preparation +// self.preparation_states.insert( +// request_id.clone(), +// RequestPreparationState { +// status: RequestPreparationStatus::Paused, +// external_blocks: ExternalBlockInfo { +// available_blocks: external_matched_blocks, +// stage_to_host_first: false, // TODO: get from connector +// blocks_on_host: 0, +// }, +// allocated_gpu_blocks: vec![], +// total_blocks_needed, +// gpu_matched_blocks, +// external_matched_blocks, +// remaining_blocks_needed: remaining_blocks, +// }, +// ); +// return false; +// } + +// // Allocate GPU blocks for the remaining tokens +// let allocated_blocks = self.kv_cache.allocate_slots( +// &req_guard, +// remaining_blocks, +// 0, // No lookahead tokens during preparation +// ); + +// let Some(allocated_blocks) = allocated_blocks else { +// // Failed to allocate - pause preparation +// self.preparation_states.insert( +// request_id.clone(), +// RequestPreparationState { +// status: RequestPreparationStatus::Paused, +// external_blocks: ExternalBlockInfo { +// available_blocks: external_matched_blocks, +// stage_to_host_first: false, +// blocks_on_host: 0, +// }, +// allocated_gpu_blocks: vec![], +// total_blocks_needed, +// gpu_matched_blocks, +// external_matched_blocks, +// remaining_blocks_needed: remaining_blocks, +// }, +// ); +// return false; +// }; + +// // Update preparation budget - flatten the nested block structure +// let total_blocks: usize = allocated_blocks.iter().map(|group| group.len()).sum(); +// self.preparation_budget.allocate(total_blocks); + +// // Create preparation state +// let prep_state = RequestPreparationState { +// status: RequestPreparationStatus::Preparing, +// external_blocks: ExternalBlockInfo { +// available_blocks: external_matched_blocks, +// stage_to_host_first: false, // TODO: get from connector +// blocks_on_host: 0, +// }, +// allocated_gpu_blocks: allocated_blocks.into_iter().flatten().collect(), +// total_blocks_needed, +// gpu_matched_blocks, +// external_matched_blocks, +// remaining_blocks_needed: remaining_blocks, +// }; + +// self.preparation_states +// .insert(request_id.clone(), prep_state); +// true +// } + +// /// Complete request preparation by transferring external blocks to device. +// /// This is called when external block transfer is complete. +// pub fn complete_request_preparation(&mut self, request_id: &RequestId) -> bool { +// let Some(prep_state) = self.preparation_states.get_mut(request_id) else { +// return false; +// }; + +// if prep_state.status != RequestPreparationStatus::Preparing { +// return false; +// } + +// // Update request state with computed tokens +// if let Some(req_arc) = self.requests.get(request_id) { +// let mut req_guard = req_arc.lock().unwrap(); +// req_guard.num_computed_tokens = req_guard.num_computed_tokens.saturating_add( +// (prep_state.gpu_matched_blocks + prep_state.external_matched_blocks) +// * self.consts.block_size, +// ); +// } + +// // Move request to ready queue +// prep_state.status = RequestPreparationStatus::Ready; +// self.ready_to_execute.push_back(ReadyItem { +// request_id: request_id.clone(), +// }); + +// // Free preparation budget +// self.preparation_budget +// .free(prep_state.allocated_gpu_blocks.len()); + +// true +// } + +// /// Handle connector staging to host first (when connector wants to stage everything to host). +// pub fn handle_host_staging(&mut self, request_id: &RequestId, blocks_on_host: usize) -> bool { +// let Some(prep_state) = self.preparation_states.get_mut(request_id) else { +// return false; +// }; + +// // Update external block info to reflect host staging +// prep_state.external_blocks.stage_to_host_first = true; +// prep_state.external_blocks.blocks_on_host = blocks_on_host; + +// // If we have enough blocks on host, we can complete preparation +// if prep_state.external_blocks.blocks_on_host >= prep_state.remaining_blocks_needed { +// self.complete_request_preparation(request_id) +// } else { +// // Still waiting for more blocks to be staged to host +// true +// } +// } + +// /// Resume preparation of a paused request when budget becomes available. +// pub fn resume_preparation(&mut self, request_id: &RequestId) -> bool { +// let Some(prep_state) = self.preparation_states.get(request_id) else { +// return false; +// }; + +// // If we have enough budget, try to start preparation again +// if self +// .preparation_budget +// .can_allocate(prep_state.remaining_blocks_needed) +// { +// self.start_request_preparation(request_id) +// } else { +// false +// } +// } + +// /// Get current preparation status for a request. +// pub fn get_preparation_status( +// &self, +// request_id: &RequestId, +// ) -> Option<&RequestPreparationStatus> { +// self.preparation_states.get(request_id).map(|s| &s.status) +// } + +// /// Get preparation budget status. +// pub fn get_preparation_budget(&self) -> &PreparationBudget { +// &self.preparation_budget +// } +// } + +// impl Scheduler for RustScheduler { +// /// See `Scheduler::schedule` for semantics. This implementation deviates +// /// from vLLM by focusing only on currently running and ready-to-execute +// /// requests; expensive matching is expected in a staging thread. +// fn schedule(&mut self) -> SchedulerOutput { +// let mut output = SchedulerOutput::default(); +// let mut token_budget = self.consts.token_budget(); + +// // 1) Schedule running requests up to budget. +// let mut req_index = 0usize; +// while req_index < self.running.len() && token_budget > 0 { +// let req_id = &self.running[req_index]; +// let req_arc = match self.requests.get(req_id) { +// Some(a) => a, // Just borrow, don't clone the Arc +// None => { +// req_index += 1; +// continue; +// } +// }; +// let req_guard = req_arc.lock().unwrap(); + +// // Compute how many tokens to schedule this step. +// let mut num_new_tokens = req_guard +// .num_tokens_with_spec() +// .saturating_add(req_guard.num_output_placeholders) +// .saturating_sub(req_guard.num_computed_tokens); + +// if let Some(thresh) = self.consts.long_prefill_token_threshold { +// if thresh > 0 && num_new_tokens > thresh { +// num_new_tokens = thresh; +// } +// } +// num_new_tokens = num_new_tokens.min(token_budget); +// // Keep within model length constraints. +// num_new_tokens = num_new_tokens.min( +// self.consts +// .max_model_len +// .saturating_sub(1 + req_guard.num_computed_tokens), +// ); + +// if num_new_tokens == 0 { +// req_index += 1; +// continue; +// } + +// // Try to allocate blocks (with potential lookahead for spec decoding). +// let new_blocks = self.kv_cache.allocate_slots( +// &req_guard, +// num_new_tokens, +// if self.consts.use_spec_decode { +// self.consts.num_spec_tokens +// } else { +// 0 +// }, +// ); +// if new_blocks.is_none() { +// // NOTE: Preemption strategy can go here. For now, skip. +// req_index += 1; +// continue; +// } +// let new_blocks = new_blocks.unwrap(); + +// // Record scheduling. +// output.cached_requests.push(CachedRequestData { +// request_id: req_guard.request_id.clone(), +// resumed_from_preemption: false, +// new_token_ids: if self.consts.pipeline_parallel { +// // when PP>1 we may need to ship token ids +// let start = req_guard.num_computed_tokens; +// let end = start + num_new_tokens; +// req_guard.all_token_ids[start..end].to_vec() +// } else { +// vec![] +// }, +// new_block_ids: new_blocks.clone(), +// num_computed_tokens: req_guard.num_computed_tokens, +// }); +// output +// .num_scheduled_tokens +// .insert(req_guard.request_id.clone(), num_new_tokens); +// token_budget = token_budget.saturating_sub(num_new_tokens); + +// // If using a connector, inform it about allocation results. +// if let Some(conn) = &self.connector { +// conn.update_state_after_alloc(&req_guard, &new_blocks, /*num_external*/ 0); +// } + +// req_index += 1; +// } + +// // 2) Pick from ready_to_execute (staged) and schedule new ones. +// while token_budget > 0 { +// let Some(ready) = self.ready_to_execute.pop_front() else { +// break; +// }; +// let Some(req_arc) = self.requests.get(&ready.request_id) else { +// continue; +// }; +// let req_guard = req_arc.lock().unwrap(); + +// // For staged requests, `num_computed_tokens` may already reflect local/external matches. +// let mut num_new_tokens = req_guard +// .num_tokens() +// .saturating_sub(req_guard.num_computed_tokens); +// if let Some(thresh) = self.consts.long_prefill_token_threshold { +// if thresh > 0 && num_new_tokens > thresh { +// num_new_tokens = thresh; +// } +// } + +// if !self.consts.chunked_prefill_enabled && num_new_tokens > token_budget { +// continue; +// } + +// num_new_tokens = num_new_tokens.min(token_budget); +// if num_new_tokens == 0 { +// continue; +// } + +// let new_blocks = self.kv_cache.allocate_slots( +// &req_guard, +// num_new_tokens, +// if self.consts.use_spec_decode { +// self.consts.num_spec_tokens +// } else { +// 0 +// }, +// ); +// let Some(new_blocks) = new_blocks else { +// continue; +// }; + +// output.new_requests.push(NewRequestData { +// request_id: req_guard.request_id.clone(), +// prompt_token_ids: req_guard.prompt_token_ids.clone(), +// block_ids: new_blocks.clone(), +// num_computed_tokens: req_guard.num_computed_tokens, +// lora: req_guard.lora.clone(), +// }); +// output +// .num_scheduled_tokens +// .insert(req_guard.request_id.clone(), num_new_tokens); +// token_budget = token_budget.saturating_sub(num_new_tokens); + +// // Move to running. +// self.running.push(req_guard.request_id.clone()); +// } + +// // 3) Compute common prefix stats. +// if let Some(first) = self.running.first().and_then(|id| self.requests.get(id)) { +// let req = first.lock().unwrap(); +// output.num_common_prefix_blocks = self +// .kv_cache +// .get_num_common_prefix_blocks(&req, self.running.len()); +// } + +// // 4) Connector metadata, if any. +// if let Some(conn) = &self.connector { +// let _md = conn.build_connector_metadata(&output); +// // TODO: attach to output when needed by bindings +// } + +// // 5) Advance in-request counters and clear finished list for next tick. +// output.total_num_scheduled_tokens = output.num_scheduled_tokens.values().copied().sum(); +// self.advance_num_computed_tokens(&output); + +// output +// } + +// /// See `Scheduler::update_from_output` for semantics. Adjusts for spec +// /// rejections, emits outputs grouped by client, and advances structured +// /// decoding if applicable. +// fn update_from_output( +// &mut self, +// scheduler_output: &SchedulerOutput, +// model_runner_output: &OwnedModelRunnerOutput, +// ) -> HashMap { +// let mut outputs_by_client: HashMap = HashMap::new(); + +// // Fast return if nothing was scheduled. +// if scheduler_output.num_scheduled_tokens.is_empty() || model_runner_output.is_empty() { +// return outputs_by_client; +// } + +// // PERF: Parallelize per-request processing of outputs. We avoid holding the +// // GIL here because `OwnedModelRunnerOutput` is a Rust-owned view. +// // For simplicity, use a scoped iterator without rayon to keep the sketch self-contained. +// // +// // TODO: When performance is critical, consider using rayon::scope for parallel processing: +// // rayon::scope(|s| { +// // for (req_id, num_tokens_scheduled) in &scheduler_output.num_scheduled_tokens { +// // let req_arc = self.requests.get(req_id).cloned(); +// // s.spawn(move |_| { /* process request */ }); +// // } +// // }); +// for (req_id, num_tokens_scheduled) in &scheduler_output.num_scheduled_tokens { +// if *num_tokens_scheduled == 0 { +// continue; +// } + +// let Some(req_arc) = self.requests.get(req_id) else { +// continue; +// }; +// let mut req = req_arc.lock().unwrap(); +// let Some(&req_index) = model_runner_output.req_id_to_index.get(req_id) else { +// continue; +// }; + +// let mut generated_token_ids = model_runner_output +// .sampled_token_ids +// .get(req_index) +// .cloned() +// .unwrap_or_default(); + +// // Speculative decoding adjustment – reduce computed tokens by rejected drafts. +// if let Some(spec_lists) = &model_runner_output.spec_token_ids { +// if let Some(scheduled_spec) = +// scheduler_output.scheduled_spec_decode_tokens.get(req_id) +// { +// let accepted = generated_token_ids.len().saturating_sub(1); +// let rejected = scheduled_spec.len() + 1 - generated_token_ids.len(); +// if rejected > 0 { +// req.num_computed_tokens = req.num_computed_tokens.saturating_sub(rejected); +// } +// // Replace drafts with next-step drafts if needed. +// req.spec_token_ids = spec_lists.get(req_index).cloned().unwrap_or_default(); +// let _ = accepted; // reserved for stats +// } +// } + +// // Append newly generated tokens and check stop conditions. +// let mut stopped = false; +// if !generated_token_ids.is_empty() { +// for (idx, tok) in generated_token_ids.iter().copied().enumerate() { +// req.output_token_ids.push(tok); +// req.all_token_ids.push(tok); +// // TODO: implement stop check using max_model_len / eos / sampling params +// let _num_new = idx + 1; +// } +// } + +// // Pooler output / prompt logprobs / events / kv_transfer_params can be attached here. +// let kv_transfer_params = if stopped { +// self.free_request(req_id) +// } else { +// None +// }; + +// let eco = EngineCoreOutput { +// request_id: req_id.clone(), +// new_token_ids: generated_token_ids, +// finish_reason: if stopped { Some(0) } else { None }, +// stop_reason: req.stop_reason(), +// kv_transfer_params, +// num_cached_tokens: req.num_cached_tokens, +// ..Default::default() +// }; + +// outputs_by_client +// .entry(req.client_id) +// .or_default() +// .outputs +// .push(eco); +// } + +// // Update connector status for finished KV transfers. +// if let Some(conn) = &self.connector { +// let finished_recving = model_runner_output +// .kv_connector_finished_recving +// .clone() +// .unwrap_or_default(); +// let finished_sending = model_runner_output +// .kv_connector_finished_sending +// .clone() +// .unwrap_or_default(); +// conn.update_connector_output(&finished_recving, &finished_sending); +// for req_id in finished_sending { +// // Safe to free blocks for requests done sending. +// let _ = self.free_request(&req_id); +// } +// } + +// // Attach finished request sets to one of the client groups. +// if let Some(map) = &mut self.finished_req_ids_dict { +// if let Some((_client, group)) = map.iter_mut().next() { +// if let Some((_k, v)) = outputs_by_client.iter_mut().next() { +// v.finished_requests = Some(group.clone()); +// } +// group.clear(); +// } +// } + +// outputs_by_client +// } + +// fn add_request(&mut self, request: RequestState) { +// let req_id = request.request_id.clone(); +// self.waiting.add_request(req_id.clone()); +// self.requests.insert(req_id, Arc::new(Mutex::new(request))); +// } + +// fn finish_requests(&mut self, request_ids: &[RequestId], finished_status: RequestStatus) { +// let ids: HashSet<_> = request_ids.iter().cloned().collect(); +// self.waiting.remove_requests(&ids); +// self.running.retain(|r| !ids.contains(r)); + +// for req_id in request_ids { +// if let Some(req) = self.requests.get(req_id).cloned() { +// let mut req = req.lock().unwrap(); +// req.status = finished_status; +// } +// let _ = self.free_request(req_id); +// } +// } + +// fn get_request_counts(&self) -> (usize, usize) { +// (self.running.len(), self.waiting.fcfs.len()) +// } + +// fn has_finished_requests(&self) -> bool { +// !self.finished_req_ids.is_empty() +// } +// } + +// impl RequestState { +// fn stop_reason(&self) -> Option { +// // TODO: encode stop reason similar to vLLM (stop string, length, abort) +// None +// } +// } + +// ========================== +// Staging Thread Sketch (TODO) +// ========================== +// +// Design: +// - A dedicated staging thread receives new requests and performs: +// * local prefix-cache matches via KvCacheManager.get_computed_blocks +// * external matches via KvConnector.get_matched_connector_tokens +// * updates request.num_computed_tokens accordingly +// * enqueues ReadyItem into ready_to_execute when prepared +// - Communication: std::sync::mpsc or tokio::mpsc; configurable batch size +// - Safety: staging must never mutate fields used by schedule() without holding +// the same request mutex; use a small immutable message to the scheduler +// instead (request_id + prepared state deltas). +// - This sketch leaves concrete code to integration. + +// --------------------------------- +// Projection & Offload Data Models +// --------------------------------- + +/// Integer ceil division helper. +fn div_ceil(n: usize, d: usize) -> usize { + (n + d - 1) / d +} + +/// Parameters controlling projection behavior. +#[derive(Clone, Debug)] +pub struct ProjectionParams { + /// Tokens each request advances per pass (1 for decode). Larger to approximate chunked prefill. + pub tokens_per_pass_per_request: usize, + /// Whether to restrict projections to currently running requests only. + pub consider_only_running: bool, +} + +/// Per-request projection snapshot used to build worst-case aggregates. +#[derive(Clone, Debug)] +pub struct RequestProjection { + pub request_id: RequestId, + pub current_tokens: usize, + pub current_blocks: usize, + pub remaining_tokens: usize, + pub steps_until_completion: usize, +} + +/// Aggregate projection results. +#[derive(Clone, Debug)] +pub struct WorstCaseProjection { + /// Per-request current counts and remaining work. + pub per_request: Vec, + /// In worst-case, upper bound on number of passes until the first request completes. + pub wc_until_first_complete: usize, + /// First pass index at which blocks would exceed capacity, if ever. + pub wc_until_block_starvation: Option, + /// Total blocks predicted per pass over the examined horizon. + pub predicted_blocks_per_pass: Vec, +} + +/// Offload plan proposal to avoid or resolve a predicted starvation. +#[derive(Clone, Debug)] +pub struct OffloadPlan { + pub selected: Vec, + pub total_blocks_freed: usize, + pub satisfied: bool, +} + +/// Request chosen for offloading and the approximate blocks freed if fully moved to host. +#[derive(Clone, Debug)] +pub struct OffloadCandidate { + pub request_id: RequestId, + pub blocks: usize, +} + +// --------------------------------- +// Request Preparation & Staging +// --------------------------------- + +/// Status of a request during the preparation/staging phase. +#[derive(Clone, Debug, PartialEq)] +pub enum RequestPreparationStatus { + /// Request is waiting to be prepared (not yet staged). + Waiting, + /// Request is being prepared - external blocks are being transferred to device. + Preparing, + /// Request is prepared and ready to execute (moved to ready_to_execute queue). + Ready, + /// Request preparation was paused due to insufficient GPU blocks. + Paused, +} + +/// Information about external blocks that need to be transferred to device. +#[derive(Clone, Debug)] +pub struct ExternalBlockInfo { + /// Number of blocks available from external sources (host, remote). + pub available_blocks: usize, + /// Whether the connector wants to stage everything to host first. + pub stage_to_host_first: bool, + /// If staging to host first, how many blocks are currently on host. + pub blocks_on_host: usize, +} + +/// Request preparation state tracking. +#[derive(Clone, Debug)] +pub struct RequestPreparationState { + /// Current preparation status. + pub status: RequestPreparationStatus, + /// External block information from connector. + pub external_blocks: ExternalBlockInfo, + /// GPU blocks allocated for this request during preparation. + pub allocated_gpu_blocks: Vec, + /// Total blocks needed for the full request (prompt + max output). + pub total_blocks_needed: usize, + /// Blocks already matched from GPU cache. + pub gpu_matched_blocks: usize, + /// Blocks already matched from external sources. + pub external_matched_blocks: usize, + /// Remaining blocks that need to be allocated. + pub remaining_blocks_needed: usize, +} + +/// Secondary budget for preparing requests (separate from main scheduling budget). +#[derive(Clone, Debug)] +pub struct PreparationBudget { + /// Maximum GPU blocks that can be allocated for request preparation. + pub max_gpu_blocks: usize, + /// Currently allocated blocks for preparation. + pub allocated_blocks: usize, + /// Whether we're currently in a preparation phase. + pub is_preparing: bool, +} + +impl PreparationBudget { + pub fn new(max_gpu_blocks: usize) -> Self { + Self { + max_gpu_blocks, + allocated_blocks: 0, + is_preparing: false, + } + } + + pub fn can_allocate(&self, blocks: usize) -> bool { + self.allocated_blocks + blocks <= self.max_gpu_blocks + } + + pub fn allocate(&mut self, blocks: usize) -> bool { + if self.can_allocate(blocks) { + self.allocated_blocks += blocks; + true + } else { + false + } + } + + pub fn free(&mut self, blocks: usize) { + self.allocated_blocks = self.allocated_blocks.saturating_sub(blocks); + } +} + +pub struct RequestPreloadState {} + +impl RequestPreloadState { + async fn prepare_request(&mut self, request: RequestState) { + // match against gpu blocks + + // match against host/disk blocks + + // determine if prefill should be computed locally or offloaded + } + + async fn onboard_locally_stored_blocks(&mut self, request: RequestState) { + unimplemented!() + } + + async fn onboard_remotely_stored_blocks(&mut self, request: RequestState) { + unimplemented!() + } + + async fn onboard_remotely_computed_blocks(&mut self, request: RequestState) { + // if remote prefill, then ensure ensure gpu blocks are available + // in host memory, then release the gpu blocks + + // acquire cpu blocks for the remote instance to write + + // prepare src and dst descriptors + // the remote prefill worker will pull kv blocks from "src" descriptors + // the remote prefill worker will push kv blocks to the "dst" descriptors + + // issue remote prefill request and await its completion + } +} + +pub struct SchedulerState {} + +// ======================================================================== +// Rust Scheduler State for tracking requests in parallel with vLLM +// ======================================================================== + +/// Compute a deterministic u64 hash from a string cache_salt. +/// This ensures consistent hashing between Python and Rust. +pub fn compute_salt_hash(cache_salt: &str) -> u64 { + use std::collections::hash_map::DefaultHasher; + let mut hasher = DefaultHasher::new(); + cache_salt.hash(&mut hasher); + hasher.finish() +} + +/// Initial request data available at add_request time. +/// This is a simplified version that only contains data we have +/// before any scheduling decisions are made. +#[derive(Clone, Debug)] +pub struct InitialRequestData { + pub request_id: String, + pub prompt_token_ids: Vec, + pub salt_hash: Option, + pub lora_int_id: Option, + pub priority: i32, + pub arrival_time: f64, +} + +/// The Rust scheduler state that tracks requests in parallel with vLLM. +/// This allows us to build up Rust-side state while vLLM continues to +/// drive the actual scheduling decisions. +#[derive(Debug)] +pub struct RustSchedulerState { + /// Map of request_id to initial request data + requests: Arc>>, +} + +impl RustSchedulerState { + /// Create a new empty scheduler state. + pub fn new() -> Self { + Self { + requests: Arc::new(Mutex::new(HashMap::new())), + } + } + + /// Add a new request to the scheduler state. + /// Called when DynamoScheduler.add_request() is invoked. + pub fn add_request( + &self, + request_id: String, + prompt_token_ids: Vec, + cache_salt: Option, + lora_int_id: Option, + priority: i32, + arrival_time: f64, + ) -> Result<(), String> { + // Convert cache_salt string to u64 hash + let salt_hash = cache_salt.as_ref().map(|s| compute_salt_hash(s)); + + let request_data = InitialRequestData { + request_id: request_id.clone(), + prompt_token_ids, + salt_hash, + lora_int_id, + priority, + arrival_time, + }; + + let mut requests = self.requests.lock().unwrap(); + requests.insert(request_id.clone(), request_data); + + // Log for debugging + println!( + "Rust scheduler: Added request {} with {} prompt tokens", + request_id, + requests.get(&request_id).unwrap().prompt_token_ids.len() + ); + + Ok(()) + } + + /// Mark requests as finished without removing them. + /// The actual removal happens when the scheduler reports them as finished. + /// This is called when finish_requests is invoked externally (e.g., client disconnect). + pub fn mark_as_finished(&self, request_ids: Vec) -> Result<(), String> { + // TODO: When we track request states, update the state to finished here. + // For now, this is a no-op as we don't track states yet. + // The request will be removed when update_from_output reports it as finished. + for req_id in &request_ids { + println!("Rust scheduler: Marked request {} as finished (no-op for now)", req_id); + } + Ok(()) + } + + /// Remove finished requests from the scheduler state. + /// This should only be called when the scheduler reports requests as finished + /// via scheduler_output.finished_req_ids in update_from_output. + pub fn remove_finished_requests(&self, request_ids: Vec) -> Result<(), String> { + let mut requests = self.requests.lock().unwrap(); + + for req_id in &request_ids { + if requests.remove(req_id).is_some() { + println!("Rust scheduler: Removed finished request {}", req_id); + } + } + + Ok(()) + } + + /// Get the current number of tracked requests. + pub fn num_requests(&self) -> usize { + self.requests.lock().unwrap().len() + } + + /// Check if a request is being tracked. + pub fn has_request(&self, request_id: &str) -> bool { + self.requests.lock().unwrap().contains_key(request_id) + } + + /// Get all currently tracked request IDs. + pub fn get_request_ids(&self) -> Vec { + self.requests.lock().unwrap().keys().cloned().collect() + } +} + +impl Default for RustSchedulerState { + fn default() -> Self { + Self::new() + } +} + +pub struct LogicalBlock { + block_id: usize, + permit: OwnedSemaphorePermit, +} + +enum BlockState { + Mutable, + Registering, + Immutable, +} + +pub struct Block { + block_id: u64, +} + +impl Block { + fn block_id(&self) -> u64 { + self.block_id + } + + fn block_set_id() -> u8 { + todo!("split the u128 into two u64s, return the first u8 of the second u64") + } + + fn instance_id() -> u64 { + todo!("return the first u64 bits of the u128") + } +} diff --git a/lib/llm/src/integrations/vllm/scheduler/worker/blocks.rs b/lib/llm/src/integrations/vllm/scheduler/worker/blocks.rs new file mode 100644 index 00000000000..f15d05fdc79 --- /dev/null +++ b/lib/llm/src/integrations/vllm/scheduler/worker/blocks.rs @@ -0,0 +1,189 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Worker device blocks construction for scheduler. +//! +//! This module provides a simplified way to build device blocks from tensors +//! on the worker side without requiring leader/worker synchronization or network setup. + +use std::sync::Arc; + +use anyhow::{anyhow, Result}; + +use crate::block_manager::{ + BasicMetadata, LayoutConfigBuilder, NixlLayout, + block::{Block, layout_to_blocks, locality}, + layout::LayoutType, + storage::{DeviceAllocator, DeviceStorage, torch::TorchTensor}, +}; + +/// Container for worker device blocks constructed locally. +pub struct WorkerDeviceBlocks { + /// Device blocks constructed from KV cache tensors + pub device_blocks: Vec>, + + /// Metadata about the layout + pub num_device_blocks: usize, + pub num_layers: usize, + pub outer_dim: usize, + pub page_size: usize, + pub inner_dim: usize, + pub dtype_width_bytes: usize, + pub bytes_per_block: usize, +} + +impl WorkerDeviceBlocks { + /// Build device blocks locally from KV cache tensors. + /// + /// This is a simplified version of KvbmWorker's initialization that: + /// - Validates tensor consistency + /// - Infers layout configuration + /// - Creates device blocks + /// - Does NOT perform leader sync or network setup + pub fn from_tensors( + tensors: Vec>, + num_device_blocks: usize, + page_size: usize, + device_id: usize, + dtype_width_bytes: usize, + is_fully_contiguous_layout: bool, + ) -> Result { + if num_device_blocks == 0 { + return Err(anyhow!("num_device_blocks must be greater than 0")); + } + + if tensors.is_empty() { + return Err(anyhow!("tensors cannot be empty")); + } + + // Validate tensors and get device storage + let (device_tensors, shape) = Self::load_and_validate_tensors(&tensors, device_id)?; + + if shape.len() < 3 { + return Err(anyhow!( + "Unsupported kv cache layout. Got shape: {:?}", + shape + )); + } + + // Infer layout configuration + let (layout_type, num_layers, outer_dim, inner_dim) = if !is_fully_contiguous_layout { + let (outer_contiguous, outer_dim) = if shape[0] >= num_device_blocks { + (false, shape[1]) + } else if shape[1] >= num_device_blocks { + (true, shape[0]) + } else { + return Err(anyhow!( + "Unsupported kv cache layout. Got shape: {:?}", + shape + )); + }; + let num_layers = device_tensors.len(); + let inner_dim = shape[2..].iter().product::() / page_size; + + ( + LayoutType::LayerSeparate { outer_contiguous }, + num_layers, + outer_dim, + inner_dim, + ) + } else { + let num_layers = shape[1]; + let outer_dim = shape[2]; + let inner_dim = shape[3..].iter().product::() / page_size; + + ( + LayoutType::FullyContiguous, + num_layers, + outer_dim, + inner_dim, + ) + }; + + let bytes_per_block = + num_layers * outer_dim * page_size * inner_dim * dtype_width_bytes; + + // Build layout + let mut layout_builder_instance = LayoutConfigBuilder::default(); + let layout_builder = layout_builder_instance + .num_layers(num_layers) + .outer_dim(outer_dim) + .page_size(page_size) + .inner_dim(inner_dim) + .dtype_width_bytes(dtype_width_bytes); + + let device_layout = layout_builder + .num_blocks(num_device_blocks) + .build()? + .create_layout(layout_type, device_tensors)?; + + // Convert layout to blocks + let device_blocks = Self::make_layout(device_layout)?; + + Ok(Self { + device_blocks, + num_device_blocks, + num_layers, + outer_dim, + page_size, + inner_dim, + dtype_width_bytes, + bytes_per_block, + }) + } + + /// Validate tensors and create device storage. + fn load_and_validate_tensors( + tensors: &[Arc], + device_id: usize, + ) -> Result<(Vec, Vec)> { + let mut shape = None; + let mut device_tensors = Vec::with_capacity(tensors.len()); + let allocator = DeviceAllocator::new(device_id)?; + + for tensor in tensors { + // Check the stride + let stride = tensor.stride(); + for i in 1..stride.len() { + if stride[i] > stride[i - 1] { + return Err(anyhow!( + "Tensor strides must be monotonically decreasing! Got {:?}", + stride + )); + } + } + + // Check that all tensors have the same shape + if let Some(shape) = shape.as_ref() { + if *shape != tensor.shape() { + return Err(anyhow!( + "All tensors must have the same shape! Got {:?} and {:?}", + *shape, + tensor.shape() + )); + } + } else { + shape = Some(tensor.shape()); + } + + // Build the storage object from the tensor + let device_tensor = DeviceStorage::new_from_torch(allocator.ctx(), tensor.clone())?; + device_tensors.push(device_tensor); + } + + Ok((device_tensors, shape.unwrap())) + } + + /// Convert layout to blocks without NIXL registration. + fn make_layout( + layout: Box>, + ) -> Result>> { + // Convert to Arc for layout_to_blocks + let layout: Arc> = Arc::from(layout); + + // Create blocks with block_set_idx=0, worker_id=0 (local only) + let blocks = layout_to_blocks::<_, BasicMetadata>(layout, 0, 0)?; + + Ok(blocks) + } +} \ No newline at end of file diff --git a/lib/llm/src/integrations/vllm/scheduler/worker/mod.rs b/lib/llm/src/integrations/vllm/scheduler/worker/mod.rs new file mode 100644 index 00000000000..539b7807c7c --- /dev/null +++ b/lib/llm/src/integrations/vllm/scheduler/worker/mod.rs @@ -0,0 +1,8 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Worker-side scheduler components. + +pub mod blocks; + +pub use blocks::WorkerDeviceBlocks; \ No newline at end of file diff --git a/lib/llm/src/integrations/vllm/types.rs b/lib/llm/src/integrations/vllm/types.rs new file mode 100644 index 00000000000..b7525a5f32b --- /dev/null +++ b/lib/llm/src/integrations/vllm/types.rs @@ -0,0 +1,177 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Rust data structures for vLLM scheduler types +//! +//! These structures mirror the essential fields from vLLM's Python objects +//! and are designed to be serializable for recording and replay. + +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +/// Represents a new request being scheduled for the first time +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct NewRequestData { + pub req_id: String, + pub prompt_token_ids: Vec, + pub block_ids: Vec>, // tuple[list[int], ...] in Python + pub num_computed_tokens: usize, + // Additional fields we might need + pub mm_hashes: Vec, + pub mm_positions: Vec, +} + +/// Placeholder range for multimodal inputs +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PlaceholderRange { + pub start: usize, + pub end: usize, +} + +/// Represents cached requests that have been scheduled before +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CachedRequestData { + pub req_ids: Vec, + pub resumed_from_preemption: Vec, + pub new_token_ids: Vec>, // For pipeline parallelism + pub new_block_ids: Vec>>>, + pub num_computed_tokens: Vec, +} + +/// Main scheduler output structure +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SchedulerOutput { + /// Requests scheduled for the first time + pub scheduled_new_reqs: Vec, + + /// Previously scheduled requests (cached) + pub scheduled_cached_reqs: CachedRequestData, + + /// Number of tokens scheduled for each request + pub num_scheduled_tokens: HashMap, + + /// Total number of scheduled tokens + pub total_num_scheduled_tokens: usize, + + /// Speculative decode tokens per request + pub scheduled_spec_decode_tokens: HashMap>, + + /// Encoder inputs that need processing + pub scheduled_encoder_inputs: HashMap>, + + /// Number of common prefix blocks for cascade attention + pub num_common_prefix_blocks: Vec, + + /// Finished request IDs + pub finished_req_ids: Vec, + + /// MM hashes to free from encoder cache + pub free_encoder_mm_hashes: Vec, +} + +/// Logprobs data structure +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct LogprobsLists { + pub logprob_token_ids: Vec>, + pub logprobs: Vec>, + pub sampled_token_ranks: Vec, +} + +/// Model runner output structure +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ModelRunnerOutput { + /// Request IDs in order + pub req_ids: Vec, + + /// Map from request ID to index + pub req_id_to_index: HashMap, + + /// Sampled token IDs for each request + pub sampled_token_ids: Vec>, + + /// Optional logprobs + pub logprobs: Option, + + /// Prompt logprobs per request + pub prompt_logprobs_dict: HashMap>, + + /// Number of NaNs in logits (for debugging) + pub num_nans_in_logits: Option>, +} + +/// Finish reason for a request +#[derive(Debug, Clone, Copy, Serialize, Deserialize)] +pub enum FinishReason { + Stop = 0, + Length = 1, + Abort = 2, +} + +/// Engine core event type +#[derive(Debug, Clone, Copy, Serialize, Deserialize)] +pub enum EngineCoreEventType { + Queued = 1, + Scheduled = 2, + Preempted = 3, +} + +/// Engine core event with timestamp +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct EngineCoreEvent { + pub event_type: EngineCoreEventType, + pub timestamp: f64, +} + +/// Output for a single request from the engine core +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct EngineCoreOutput { + pub request_id: String, + pub new_token_ids: Vec, + pub new_logprobs: Option, + pub finish_reason: Option, + pub stop_reason: Option, + pub events: Option>, + pub num_cached_tokens: usize, +} + +/// Stop reason (can be string or int) +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum StopReason { + String(String), + Int(i32), +} + +/// Collection of engine core outputs +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct EngineCoreOutputs { + pub engine_index: usize, + pub outputs: Vec, + pub timestamp: f64, +} + +/// Complete iteration record containing all scheduler data +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct IterationRecord { + pub iteration: u64, + pub schedule_output: SchedulerOutput, + pub model_runner_output: ModelRunnerOutput, + pub engine_core_outputs: EngineCoreOutputs, + pub timestamp: f64, +} + +/// Complete recording trace +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SchedulerTrace { + pub metadata: TraceMetadata, + pub iterations: Vec, +} + +/// Metadata about the recording +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TraceMetadata { + pub vllm_version: String, + pub model: String, + pub timestamp: String, + pub total_iterations: usize, +} \ No newline at end of file diff --git a/lib/llm/src/lib.rs b/lib/llm/src/lib.rs index e2dcf044ca9..79f16e2458b 100644 --- a/lib/llm/src/lib.rs +++ b/lib/llm/src/lib.rs @@ -44,6 +44,8 @@ pub mod block_manager; #[cfg(feature = "cuda")] pub mod cuda; +pub mod integrations; + /// Reads a JSON file, extracts a specific field, and deserializes it into type T. /// /// # Arguments diff --git a/lib/llm/src/mocker/sequence.rs b/lib/llm/src/mocker/sequence.rs index f4ada402833..d467ebaa29e 100644 --- a/lib/llm/src/mocker/sequence.rs +++ b/lib/llm/src/mocker/sequence.rs @@ -28,7 +28,7 @@ fn create_unique_blocks_from_sequence( .collect(); // Only push the partial block if tokens count isn't a multiple of block_size - if tokens.total_tokens() % block_size != 0 { + if !tokens.total_tokens().is_multiple_of(block_size) { unique_blocks.push(match uuid { Some(uuid) => UniqueBlock::PartialBlock(uuid), None => UniqueBlock::default(), @@ -258,7 +258,7 @@ impl ActiveSequence { self.generated_tokens = self.generated_tokens.saturating_sub(1); // Reverts to the last full block - if self.tokens.total_tokens() % self.block_size == 0 { + if self.tokens.total_tokens().is_multiple_of(self.block_size) { self.unique_blocks.pop(); } }