diff --git a/apollo-router/src/allocator.rs b/apollo-router/src/allocator.rs index 81f398abf4..7733c5550e 100644 --- a/apollo-router/src/allocator.rs +++ b/apollo-router/src/allocator.rs @@ -1,11 +1,495 @@ +use std::alloc::{GlobalAlloc, Layout}; +use std::cell::Cell; use std::ffi::CStr; +use std::future::Future; +use std::pin::Pin; +use std::ptr::NonNull; +use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::task::{Context, Poll}; +use tokio::sync::oneshot::Sender; #[cfg(feature = "dhat-heap")] use parking_lot::Mutex; +/// Thread-local allocation statistics that can be shared across threads. +/// +/// Supports nested tracking where allocations in a child context are also tracked +/// in all parent contexts up to the root. Uses loop unwrapping rather than recursion +/// for performance. +/// +/// Uses AtomicUsize for all fields to allow lock-free concurrent access from multiple threads +/// that share the same Arc. This is critical for performance in the global +/// allocator hot path where even an uncontended Mutex would add significant overhead. +#[derive(Debug)] +pub(crate) struct AllocationStats { + /// Context name used for metric labeling + name: &'static str, + /// Parent context for nested tracking (None for root) + parent: Option>, + bytes_allocated: AtomicUsize, + bytes_deallocated: AtomicUsize, + bytes_zeroed: AtomicUsize, + bytes_reallocated: AtomicUsize, + /// If bytes_allocated exceeds this limit and an allocation fails, a message will be sent to the channel + pub(crate) limit: Option, +} + +#[derive(Debug)] +pub(crate) struct AllocationLimit { + pub(crate) max_bytes: usize, + pub(crate) sender: Option>, +} + +impl AllocationLimit { + pub(crate) fn new(max_bytes: usize, sender: Sender) -> Self { + Self { + max_bytes, + sender: Some(sender), + } + } +} + +impl AllocationStats { + /// Create a new root allocation stats context with the given name. + fn new(name: &'static str, limit: Option) -> Self { + Self { + name, + parent: None, + bytes_allocated: AtomicUsize::new(0), + bytes_deallocated: AtomicUsize::new(0), + bytes_zeroed: AtomicUsize::new(0), + bytes_reallocated: AtomicUsize::new(0), + limit, + } + } + + /// Create a new child allocation stats context that tracks to a parent. + fn with_parent( + name: &'static str, + parent: Arc, + limit: Option, + ) -> Self { + Self { + name, + parent: Some(parent), + bytes_allocated: AtomicUsize::new(0), + bytes_deallocated: AtomicUsize::new(0), + bytes_zeroed: AtomicUsize::new(0), + bytes_reallocated: AtomicUsize::new(0), + limit, + } + } + + /// Get the context name for this allocation stats. + #[inline] + pub(crate) fn name(&self) -> &'static str { + self.name + } + + /// Get the parent context, if any. + #[inline] + #[allow(dead_code)] + pub(crate) fn parent(&self) -> Option<&Arc> { + self.parent.as_ref() + } + + /// Get the root context by traversing up the parent chain. + /// Returns self if this is already a root context. + pub(crate) fn root(&self) -> &Self { + let mut current = self; + while let Some(parent) = ¤t.parent { + current = parent.as_ref(); + } + current + } + + /// Track allocation in this context and all parent contexts. + /// Uses loop unwrapping instead of recursion for performance. + #[inline] + fn track_alloc(&self, size: usize) { + let mut current = Some(self); + while let Some(stats) = current { + stats.bytes_allocated.fetch_add(size, Ordering::Relaxed); + current = stats.parent.as_ref().map(|p| p.as_ref()); + } + } + + /// Track deallocation in this context and all parent contexts. + /// Uses loop unwrapping instead of recursion for performance. + #[inline] + fn track_dealloc(&self, size: usize) { + let mut current = Some(self); + while let Some(stats) = current { + stats.bytes_deallocated.fetch_add(size, Ordering::Relaxed); + current = stats.parent.as_ref().map(|p| p.as_ref()); + } + } + + /// Track zeroed allocation in this context and all parent contexts. + /// Uses loop unwrapping instead of recursion for performance. + #[inline] + fn track_zeroed(&self, size: usize) { + let mut current = Some(self); + while let Some(stats) = current { + stats.bytes_zeroed.fetch_add(size, Ordering::Relaxed); + current = stats.parent.as_ref().map(|p| p.as_ref()); + } + } + + /// Track reallocation in this context and all parent contexts. + /// Uses loop unwrapping instead of recursion for performance. + #[inline] + fn track_realloc(&self, size: usize) { + let mut current = Some(self); + while let Some(stats) = current { + stats.bytes_reallocated.fetch_add(size, Ordering::Relaxed); + current = stats.parent.as_ref().map(|p| p.as_ref()); + } + } + + /// Get the current number of bytes allocated. + #[inline] + pub(crate) fn bytes_allocated(&self) -> usize { + self.bytes_allocated.load(Ordering::Relaxed) + } + + /// Get the current number of bytes deallocated. + #[inline] + pub(crate) fn bytes_deallocated(&self) -> usize { + self.bytes_deallocated.load(Ordering::Relaxed) + } + + /// Get the current number of bytes allocated with zeroing. + #[inline] + pub(crate) fn bytes_zeroed(&self) -> usize { + self.bytes_zeroed.load(Ordering::Relaxed) + } + + /// Get the current number of bytes reallocated. + #[inline] + pub(crate) fn bytes_reallocated(&self) -> usize { + self.bytes_reallocated.load(Ordering::Relaxed) + } + + /// Get the current net allocated bytes (allocated + zeroed - deallocated). + #[inline] + pub(crate) fn net_allocated(&self) -> usize { + let allocated = self.bytes_allocated(); + let zeroed = self.bytes_zeroed(); + let deallocated = self.bytes_deallocated(); + allocated.saturating_add(zeroed).saturating_sub(deallocated) + } +} + +// Thread-local to track the current task's allocation stats. +// +// ## Why Cell>> instead of Cell>> or Mutex>>? +// +// We use a NonNull pointer instead of Arc because: +// +// 1. **Cell requires Copy**: Cell::get() requires T: Copy, but Arc is not Copy +// because it has a Drop implementation for reference counting. +// +// 2. **TLS destructors conflict with global allocators**: If we stored Option> +// in the thread-local, its Drop implementation would run when the thread exits. +// This Drop could call the allocator (to deallocate the Arc), causing a fatal +// reentrancy error: "the global allocator may not use TLS with destructors". +// +// 3. **Cell is faster than Mutex**: Cell has zero overhead (just a memory read/write), +// while Mutex requires atomic operations and potential thread parking. Since we +// access this on every allocation, performance is critical. +// +// ## Safety invariants: +// +// - The NonNull pointer is only valid while a MemoryTrackedFuture holding the corresponding +// Arc is on the call stack (either in poll() or with_memory_tracking()). +// - We manually manage Arc reference counts when propagating across tasks. +// - The pointer always points to valid AllocationStats when Some. +thread_local! { + static CURRENT_TASK_STATS: Cell>> = const { Cell::new(None) }; +} + +/// Future wrapper that tracks memory allocations for a task. +pub(crate) struct MemoryTrackedFuture { + inner: F, + stats: Arc, +} + +impl MemoryTrackedFuture { + /// Create a new memory tracked future with explicit allocation stats. + fn new(inner: F, stats: Arc) -> Self { + Self { inner, stats } + } +} + +impl Future for MemoryTrackedFuture { + type Output = F::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + // SAFETY: We're using get_unchecked_mut which is safe because: + // 1. We never move the inner future - we only create a new Pin to it + // 2. The MemoryTrackedFuture struct itself doesn't have any Drop glue that + // would be affected by mutation + let this = unsafe { self.get_unchecked_mut() }; + + // SAFETY: The inner future hasn't moved - it's still in the same memory location. + // We're just creating a Pin pointing to it. + let inner = unsafe { Pin::new_unchecked(&mut this.inner) }; + + // Set the current task's stats for this thread (as NonNull pointer). + // SAFETY: Arc::as_ptr gives us a non-null pointer without affecting the reference count. + // The pointer remains valid because `this.stats` is alive for the duration of poll(). + let stats_ptr = unsafe { NonNull::new_unchecked(Arc::as_ptr(&this.stats) as *mut _) }; + let previous = CURRENT_TASK_STATS.with(|cell| cell.replace(Some(stats_ptr))); + + // Poll the inner future + let result = inner.poll(cx); + + // Restore the previous stats (for nested tracking contexts) + CURRENT_TASK_STATS.with(|cell| cell.set(previous)); + + result + } +} + +/// Get the current task's allocation stats, if available. +/// Returns None if called outside a memory-tracked context. +/// +/// This function clones the Arc (by reconstructing it from the raw pointer and incrementing +/// the reference count), so the caller owns a new reference to the stats. +#[must_use] +pub(crate) fn current() -> Option> { + CURRENT_TASK_STATS.with(|cell| { + cell.get().map(|ptr| { + // SAFETY: The pointer is valid because it was set by MemoryTrackedFuture::poll() + // or with_memory_tracking(), which are both on the call stack. We manually + // increment the reference count and reconstruct an Arc to return a new owned reference. + unsafe { + Arc::increment_strong_count(ptr.as_ptr()); + Arc::from_raw(ptr.as_ptr()) + } + }) + }) +} + +/// Run a synchronous closure with memory tracking. +/// If a parent context exists, creates a child context that tracks to the parent. +/// If no parent exists, creates a new root context with the given name. +/// This is useful for tracking allocations in synchronous code or threads. +pub(crate) fn with_memory_tracking( + name: &'static str, + limit: Option, + f: F, +) -> R +where + F: FnOnce() -> R, +{ + // Check if there's a parent context, and create either a child or root stats + let stats = CURRENT_TASK_STATS.with(|cell| { + if let Some(ptr) = cell.get() { + // Parent context exists - create a child that tracks to the parent + // SAFETY: The pointer is valid because it's managed by a parent MemoryTrackedFuture. + // We clone the Arc by manually incrementing the reference count. + let parent = unsafe { + Arc::increment_strong_count(ptr.as_ptr()); + Arc::from_raw(ptr.as_ptr()) + }; + + Arc::new(AllocationStats::with_parent(name, parent, limit)) + } else { + // No parent context - create a new root + Arc::new(AllocationStats::new(name, limit)) + } + }); + + with_explicit_memory_tracking(stats, f) +} + +/// Run a synchronous closure with memory tracking using an explicit parent. +/// Creates a child context with the given name that tracks to the provided parent. +pub(crate) fn with_parented_memory_tracking( + name: &'static str, + parent: Arc, + f: F, + limit: Option, +) -> R +where + F: FnOnce() -> R, +{ + let stats = Arc::new(AllocationStats::with_parent(name, parent, limit)); + with_explicit_memory_tracking(stats, f) +} + +/// Internal function to run a closure with explicit allocation stats. +/// Sets the thread-local stats, runs the closure, and restores the previous stats. +fn with_explicit_memory_tracking(stats: Arc, f: F) -> R +where + F: FnOnce() -> R, +{ + // Set the current task's stats for this thread (as NonNull pointer) + // SAFETY: Arc::as_ptr never returns null + let stats_ptr = unsafe { NonNull::new_unchecked(Arc::as_ptr(&stats) as *mut _) }; + let previous = CURRENT_TASK_STATS.with(|cell| cell.replace(Some(stats_ptr))); + + // Run the closure + let result = f(); + + // Restore the previous stats + CURRENT_TASK_STATS.with(|cell| cell.set(previous)); + + result +} + +/// Trait to add memory tracking to futures. +pub(crate) trait WithMemoryTracking: Future + Sized { + /// Wraps this future to track memory allocations with a named context. + /// If a parent context exists, creates a child context that tracks to the parent. + /// If no parent exists, creates a new root context with the given name. + fn with_memory_tracking( + self, + name: &'static str, + limit: Option, + ) -> MemoryTrackedFuture; +} + +impl WithMemoryTracking for F { + fn with_memory_tracking( + self, + name: &'static str, + limit: Option, + ) -> MemoryTrackedFuture { + // Check if there's a parent context, and create either a child or root stats + let stats = CURRENT_TASK_STATS.with(|cell| { + if let Some(ptr) = cell.get() { + // Parent context exists - create a child that tracks to the parent + // SAFETY: The pointer is valid because it's managed by a parent MemoryTrackedFuture. + // We clone the Arc by manually incrementing the reference count. + let parent = unsafe { + Arc::increment_strong_count(ptr.as_ptr()); + Arc::from_raw(ptr.as_ptr()) + }; + + Arc::new(AllocationStats::with_parent(name, parent, limit)) + } else { + // No parent context - create a new root + Arc::new(AllocationStats::new(name, limit)) + } + }); + + MemoryTrackedFuture { inner: self, stats } + } +} + +/// Custom allocator wrapper that delegates to tikv-jemalloc. +/// This allows for custom allocation tracking and instrumentation +/// while still using jemalloc as the underlying allocator. +/// +/// The allocator hooks into allocation/deallocation to track memory usage +/// per-task via thread-local storage. This adds minimal overhead (~1-2ns per +/// allocation) compared to using jemalloc directly. +struct CustomAllocator { + inner: tikv_jemallocator::Jemalloc, +} + +impl CustomAllocator { + const fn new() -> Self { + Self { + inner: tikv_jemallocator::Jemalloc, + } + } +} + +unsafe extern "C" { + unsafe fn printf(format: *const libc::c_char, ...) -> libc::c_int; +} + +// SAFETY: All methods below properly delegate to jemalloc and only add tracking +// on top. The tracking uses thread-locals with raw pointers to avoid TLS destructor +// issues (see CURRENT_TASK_STATS documentation above). +unsafe impl GlobalAlloc for CustomAllocator { + #[inline] + unsafe fn alloc(&self, layout: Layout) -> *mut u8 { + unsafe { + let ptr = self.inner.alloc(layout); + if !ptr.is_null() { + // Track to the current task's stats if available. + // SAFETY: The pointer was set by MemoryTrackedFuture::poll() or + // with_memory_tracking(), and is guaranteed to be valid during the + // execution of the tracked future/closure. + CURRENT_TASK_STATS.with(|cell| { + if let Some(stats_ptr) = cell.get() { + stats_ptr.as_ref().track_alloc(layout.size()); + + let stats = &mut (*stats_ptr.as_ptr()); + let bytes_allocated = stats.bytes_allocated(); + if let Some(limit) = &mut stats.limit { + if bytes_allocated >= limit.max_bytes { + if let Some(sender) = limit.sender.take() { + let _ = sender.send(bytes_allocated).unwrap(); + + let format: *const libc::c_char = + "total allocated: %u, limit: %u \n\0".as_ptr().cast(); + let _ = printf(format, bytes_allocated, limit.max_bytes); + } + } + } + } + }); + } + ptr + } + } + + #[inline] + unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) { + unsafe { + self.inner.dealloc(ptr, layout); + // Track to the current task's stats if available + CURRENT_TASK_STATS.with(|cell| { + if let Some(stats_ptr) = cell.get() { + stats_ptr.as_ref().track_dealloc(layout.size()); + } + }); + } + } + + #[inline] + unsafe fn alloc_zeroed(&self, layout: Layout) -> *mut u8 { + unsafe { + let ptr = self.inner.alloc_zeroed(layout); + if !ptr.is_null() { + // Track to the current task's stats if available + CURRENT_TASK_STATS.with(|cell| { + if let Some(stats_ptr) = cell.get() { + stats_ptr.as_ref().track_zeroed(layout.size()); + } + }); + } + ptr + } + } + + #[inline] + unsafe fn realloc(&self, ptr: *mut u8, layout: Layout, new_size: usize) -> *mut u8 { + unsafe { + let new_ptr = self.inner.realloc(ptr, layout, new_size); + if !new_ptr.is_null() { + // Track to the current task's stats if available + CURRENT_TASK_STATS.with(|cell| { + if let Some(stats_ptr) = cell.get() { + stats_ptr.as_ref().track_realloc(new_size); + } + }); + } + new_ptr + } + } +} + #[cfg(all(feature = "global-allocator", not(feature = "dhat-heap"), unix))] #[global_allocator] -static ALLOC: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc; +static ALLOC: CustomAllocator = CustomAllocator::new(); // Note: the dhat-heap and dhat-ad-hoc features should not be both enabled. We name our functions // and variables identically to prevent this from happening. @@ -65,9 +549,10 @@ static malloc_conf: Option<&'static libc::c_char> = Some(unsafe { #[cfg(test)] mod tests { - use std::ffi::CStr; - use super::*; + use std::ffi::CStr; + use std::thread; + use tokio::task; #[test] fn test_malloc_conf_is_valid_c_string() { @@ -93,4 +578,189 @@ mod tests { panic!("malloc_conf should not be None"); } } + + #[tokio::test] + async fn test_async_memory_tracking() { + // Test that allocations within a memory-tracked async context are tracked + let result = async { + let _v = Vec::::with_capacity(10000); + current().expect("stats should be set") + } + .with_memory_tracking("test", None) + .await; + + // Verify context name + assert_eq!(result.name(), "test"); + + // The allocator may allocate more than requested due to alignment, overhead, etc. + // We check that at least the requested amount was allocated. + assert!( + result.bytes_allocated() >= 10000, + "should track at least 10000 bytes, got {}", + result.bytes_allocated() + ); + + // Net allocated should be 0 or close to 0 since the Vec was dropped + assert!( + result.net_allocated() < 100, + "net allocated should be near 0 after Vec is dropped, got {}", + result.net_allocated() + ); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_spawned_task_memory_tracking() { + // Test that memory tracking creates child contexts that propagate to parent + async { + let parent_stats = current().expect("stats should be set in parent"); + assert_eq!(parent_stats.name(), "parent"); + + // Wrap the future BEFORE spawning to create a child context + let child_future = async { + let child_stats = current().expect("stats should be set in child"); + assert_eq!(child_stats.name(), "child"); + let _v = Vec::::with_capacity(5000); + } + .with_memory_tracking("child", None); + + task::spawn(child_future).await.unwrap(); + + let final_stats = current().expect("stats should still be set"); + + // The child's allocations should be tracked to the parent as well + assert!( + final_stats.bytes_allocated() >= 5000, + "child task allocations should be tracked in parent, got {}", + final_stats.bytes_allocated() + ); + + assert!( + Arc::ptr_eq(&parent_stats, &final_stats), + "should be the same Arc" + ); + } + .with_memory_tracking("parent", None) + .await; + } + + #[test] + fn test_sync_memory_tracking() { + // Test that synchronous code can use with_memory_tracking for thread propagation + let stats = with_memory_tracking("sync_test", None, || { + let stats = current().expect("stats should be set"); + assert_eq!(stats.name(), "sync_test"); + + { + let _v = Vec::::with_capacity(8000); + + // The allocator may allocate more than requested due to alignment, overhead, etc. + assert!( + stats.bytes_allocated() >= 8000, + "should track at least 8000 bytes, got {}", + stats.bytes_allocated() + ); + } + + // Net should be near 0 after first Vec is dropped + assert!( + stats.net_allocated() < 100, + "net allocated should be near 0 after Vec is dropped, got {}", + stats.net_allocated() + ); + + let first_allocated = stats.bytes_allocated(); + + // Test propagation to child thread with parented context + let parent_stats = stats.clone(); + let handle = thread::spawn(move || { + with_parented_memory_tracking( + "sync_test_child", + parent_stats, + || { + let child_stats = current().expect("child stats should be set"); + assert_eq!(child_stats.name(), "sync_test_child"); + let _v = Vec::::with_capacity(3000); + }, + None, + ) + }); + handle.join().unwrap(); + + // Should have tracked allocations from both contexts (parent propagation) + assert!( + stats.bytes_allocated() >= first_allocated + 3000, + "should track allocations from both contexts, got {} (expected at least {})", + stats.bytes_allocated(), + first_allocated + 3000 + ); + + stats + }); + + // Net should be near 0 after all Vecs are dropped + // Allow up to 200 bytes for internal allocations (Arc overhead, thread infrastructure, etc.) + assert!( + stats.net_allocated() < 200, + "net allocated should be near 0 after all Vecs are dropped, got {}", + stats.net_allocated() + ); + } + + #[tokio::test] + async fn test_nested_memory_tracking() { + // Test that nested contexts track allocations to all parent contexts + async { + let root_stats = current().expect("root stats should be set"); + assert_eq!(root_stats.name(), "root"); + let _root_vec = Vec::::with_capacity(1000); + + // Create a child context + async { + let child_stats = current().expect("child stats should be set"); + assert_eq!(child_stats.name(), "child"); + let _child_vec = Vec::::with_capacity(2000); + + // Child allocations should be in child stats + assert!( + child_stats.bytes_allocated() >= 2000, + "child should track its own allocations, got {}", + child_stats.bytes_allocated() + ); + + // Create a grandchild context + async { + let grandchild_stats = current().expect("grandchild stats should be set"); + assert_eq!(grandchild_stats.name(), "grandchild"); + let _grandchild_vec = Vec::::with_capacity(3000); + + // Grandchild allocations should be in grandchild stats + assert!( + grandchild_stats.bytes_allocated() >= 3000, + "grandchild should track its own allocations, got {}", + grandchild_stats.bytes_allocated() + ); + } + .with_memory_tracking("grandchild", None) + .await; + + // After grandchild completes, child should have tracked grandchild's allocations + assert!( + child_stats.bytes_allocated() >= 5000, + "child should track child + grandchild allocations, got {}", + child_stats.bytes_allocated() + ); + } + .with_memory_tracking("child", None) + .await; + + // After child completes, root should have tracked all allocations + assert!( + root_stats.bytes_allocated() >= 6000, + "root should track root + child + grandchild allocations, got {}", + root_stats.bytes_allocated() + ); + } + .with_memory_tracking("root", None) + .await; + } } diff --git a/apollo-router/src/compute_job/mod.rs b/apollo-router/src/compute_job/mod.rs index 8b9d4c3a86..105878aac5 100644 --- a/apollo-router/src/compute_job/mod.rs +++ b/apollo-router/src/compute_job/mod.rs @@ -1,10 +1,12 @@ mod metrics; +use std::any::Any; use std::future::Future; use std::ops::ControlFlow; use std::sync::OnceLock; use std::time::Instant; +use apollo_federation::error::FederationError; use opentelemetry::metrics::MeterProvider as _; use opentelemetry::metrics::ObservableGauge; use tokio::sync::oneshot; @@ -21,6 +23,8 @@ use self::metrics::observe_queue_wait_duration; use crate::ageing_priority_queue::AgeingPriorityQueue; use crate::ageing_priority_queue::Priority; use crate::ageing_priority_queue::SendError; +use crate::allocator::AllocationLimit; +use crate::allocator::current; use crate::metrics::meter_provider; use crate::plugins::telemetry::consts::COMPUTE_JOB_EXECUTION_SPAN_NAME; use crate::plugins::telemetry::consts::COMPUTE_JOB_SPAN_NAME; @@ -65,6 +69,7 @@ impl JobStatus<'_, T> { /// to avoid needless resource consumption. pub(crate) fn check_for_cooperative_cancellation(&self) -> ControlFlow<()> { if self.result_sender.is_closed() { + println!("result sender is closed"); ControlFlow::Break(()) } else { ControlFlow::Continue(()) @@ -110,6 +115,25 @@ impl crate::graphql::IntoGraphQLErrors for ComputeBackPressureError { } } +/// Job was cancelled due to cooperative cancellation +#[derive(thiserror::Error, Debug, displaydoc::Display, Clone)] +pub(crate) struct ComputeCooperativeCancellationError; + +impl ComputeCooperativeCancellationError { + pub(crate) fn to_graphql_error(&self) -> crate::graphql::Error { + crate::graphql::Error::builder() + .message("Your request has been cancelled due to cooperative cancellation") + .extension_code("REQUEST_COOPERATIVE_CANCELLATION") + .build() + } +} + +impl crate::graphql::IntoGraphQLErrors for ComputeCooperativeCancellationError { + fn into_graphql_errors(self) -> Result, Self> { + Ok(vec![self.to_graphql_error()]) + } +} + #[derive(Copy, Clone, Hash, Eq, PartialEq, Debug, strum_macros::IntoStaticStr)] #[strum(serialize_all = "snake_case")] pub(crate) enum ComputeJobType { @@ -145,6 +169,8 @@ pub(crate) struct Job { ty: ComputeJobType, queue_start: Instant, job_fn: Box, + allocation_stats: Option>, + cancel_tx: Option>, } pub(crate) fn queue() -> &'static AgeingPriorityQueue { @@ -182,7 +208,35 @@ pub(crate) fn queue() -> &'static AgeingPriorityQueue { job.type = job.ty ); let job_start = Instant::now(); - (job.job_fn)(); + + // Execute job with memory tracking if stats are available + if let Some(stats) = job.allocation_stats { + let max_bytes = + std::env::var("APOLLO_ROUTER_QUERY_PLANNER_MEMORY_LIMIT") + .ok() + .and_then(|s| s.parse::().ok()); + + // Create a child context with the job type as the name + let job_name: &'static str = job.ty.into(); + + crate::allocator::with_parented_memory_tracking( + job_name, + stats, + || { + (job.job_fn)(); + if let Some(allocation_stats) = current() { + record_metrics(&allocation_stats); + } + }, + Option::zip(max_bytes, job.cancel_tx).map( + |(max_bytes, sender)| { + AllocationLimit::new(max_bytes, sender) + }, + ), + ); + } else { + (job.job_fn)(); + } observe_compute_duration(job.ty, job_start.elapsed()); }) }) @@ -198,6 +252,54 @@ pub(crate) fn queue() -> &'static AgeingPriorityQueue { }) } +fn record_metrics(stats: &crate::allocator::AllocationStats) { + let bytes_allocated = stats.bytes_allocated() as u64; + let bytes_deallocated = stats.bytes_deallocated() as u64; + let bytes_zeroed = stats.bytes_zeroed() as u64; + let bytes_reallocated = stats.bytes_reallocated() as u64; + let context_name = stats.name(); + + // Record total bytes allocated + u64_histogram_with_unit!( + "apollo.router.query_planner.memory", + "Memory allocated during query planning", + "By", + bytes_allocated, + allocation.type = "allocated", + context = context_name + ); + + // Record bytes deallocated + u64_histogram_with_unit!( + "apollo.router.query_planner.memory", + "Memory allocated during query planning", + "By", + bytes_deallocated, + allocation.type = "deallocated", + context = context_name + ); + + // Record bytes zeroed + u64_histogram_with_unit!( + "apollo.router.query_planner.memory", + "Memory allocated during query planning", + "By", + bytes_zeroed, + allocation.type = "zeroed", + context = context_name + ); + + // Record bytes reallocated + u64_histogram_with_unit!( + "apollo.router.query_planner.memory", + "Memory allocated during query planning", + "By", + bytes_reallocated, + allocation.type = "reallocated", + context = context_name + ); +} + /// Returns a future that resolves to a `Result` that is `Ok` if `f` returned or `Err` if it panicked. pub(crate) fn execute( compute_job_type: ComputeJobType, @@ -215,7 +317,17 @@ where ); span.in_scope(|| { let mut job_watcher = JobWatcher::new(compute_job_type); - let (tx, rx) = oneshot::channel(); + let (tx, mut rx) = oneshot::channel(); + + let is_cancellable = crate::allocator::current().is_some(); + + let (cancel_tx, cancel_rx) = if is_cancellable { + let (sender, receiver) = tokio::sync::oneshot::channel(); + (Some(sender), Some(receiver)) + } else { + (None, None) + }; + let wrapped_job_fn = Box::new(move || { let status = JobStatus { result_sender: &tx }; // `AssertUnwindSafe` here is correct because this `catch_unwind` @@ -238,6 +350,8 @@ where ty: compute_job_type, job_fn: wrapped_job_fn, queue_start: Instant::now(), + allocation_stats: crate::allocator::current(), + cancel_tx, }; queue @@ -261,7 +375,23 @@ where })?; Ok(async move { - let result = rx.await; + let result: Result>, oneshot::error::RecvError> = + if let Some(mut cancel_rx) = cancel_rx { + println!("waiting for cancel or result"); + tokio::select! { + biased; + cancellation = &mut cancel_rx => { + if let Ok(bytes_requested) = cancellation { + // TODO(memory-tracking): How can we get the operation name here? + tracing::error!("job {compute_job_type_str} cancelled as it exceeded memory limit (requested {bytes_requested} bytes)"); + } + Ok(Err(Box::new(ComputeCooperativeCancellationError))) + } + result = &mut rx => result + } + } else { + rx.await + }; // This local variable MUST exist. Otherwise, only the field from the JobWatcher struct is moved and drop will occur before the outcome is set. // This is predicated on all the fields in the struct being Copy!!! diff --git a/apollo-router/src/plugins/telemetry/metrics/allocation/mod.rs b/apollo-router/src/plugins/telemetry/metrics/allocation/mod.rs new file mode 100644 index 0000000000..3779f505d3 --- /dev/null +++ b/apollo-router/src/plugins/telemetry/metrics/allocation/mod.rs @@ -0,0 +1,197 @@ +//! Memory allocation tracking metrics for router requests. +//! +//! This module provides a Tower layer that wraps router requests with memory tracking, +//! measuring bytes allocated, deallocated, zeroed, and reallocated during request processing. + +use std::task::Context; +use std::task::Poll; + +use opentelemetry_sdk::metrics::Aggregation; +use opentelemetry_sdk::metrics::Instrument; +use opentelemetry_sdk::metrics::Stream; +use tower::Layer; +use tower::Service; + +use crate::allocator::AllocationStats; +use crate::allocator::WithMemoryTracking; +use crate::metrics::aggregation::MeterProviderType; +use crate::plugins::telemetry::reload::metrics::MetricsBuilder; +use crate::services::router; + +/// Memory allocation histogram buckets: 1KB, 10KB, 100KB, 1MB, 10MB, 100MB +const MEMORY_BUCKETS: &[f64] = &[ + 1_000.0, // 1KB + 10_000.0, // 10KB + 100_000.0, // 100KB + 1_000_000.0, // 1MB + 10_000_000.0, // 10MB + 100_000_000.0, // 100MB +]; + +/// Register memory allocation metric views with custom bucket boundaries. +pub(crate) fn register_memory_allocation_views(builder: &mut MetricsBuilder) { + // Create aggregation with memory-specific buckets + let aggregation = Aggregation::ExplicitBucketHistogram { + boundaries: MEMORY_BUCKETS.to_vec(), + record_min_max: true, + }; + + // Register view for router request memory metric + let request_view = opentelemetry_sdk::metrics::new_view( + Instrument::new().name("apollo.router.request.memory"), + Stream::new().aggregation(aggregation.clone()), + ) + .unwrap(); + builder.with_view(MeterProviderType::Public, Box::new(request_view)); + + // Register view for query planner memory metric + let query_planner_view = opentelemetry_sdk::metrics::new_view( + Instrument::new().name("apollo.router.query_planner.memory"), + Stream::new().aggregation(aggregation), + ) + .unwrap(); + builder.with_view(MeterProviderType::Public, Box::new(query_planner_view)); +} + +/// Tower layer that adds memory allocation tracking to router requests. +#[derive(Clone)] +pub(crate) struct AllocationMetricsLayer; + +impl AllocationMetricsLayer { + /// Create a new allocation metrics layer. + pub(crate) fn new() -> Self { + Self + } +} + +impl Layer for AllocationMetricsLayer { + type Service = AllocationMetricsService; + + fn layer(&self, inner: S) -> Self::Service { + AllocationMetricsService { inner } + } +} + +/// Tower service that tracks memory allocations for each router request. +#[derive(Clone)] +pub(crate) struct AllocationMetricsService { + inner: S, +} + +impl Service for AllocationMetricsService +where + S: Service + Send + 'static, + S::Future: Send + 'static, +{ + type Response = S::Response; + type Error = S::Error; + type Future = std::pin::Pin< + Box> + Send>, + >; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: router::Request) -> Self::Future { + let fut = self.inner.call(req); + + Box::pin( + async move { + let result = fut.await; + + // Record allocation metrics if stats are available + if let Some(stats) = crate::allocator::current() { + record_metrics(&stats); + } + + result + } + .with_memory_tracking("router.request", None), + ) + } +} + +/// Record allocation metrics for a specific context. +fn record_metrics(stats: &AllocationStats) { + let bytes_allocated = stats.bytes_allocated() as u64; + let bytes_deallocated = stats.bytes_deallocated() as u64; + let bytes_zeroed = stats.bytes_zeroed() as u64; + let bytes_reallocated = stats.bytes_reallocated() as u64; + let context_name = stats.name(); + + // Record total bytes allocated + u64_histogram_with_unit!( + "apollo.router.request.memory", + "Memory allocated during request processing", + "By", + bytes_allocated, + allocation.type = "allocated", + context = context_name + ); + + // Record bytes deallocated + u64_histogram_with_unit!( + "apollo.router.request.memory", + "Memory allocated during request processing", + "By", + bytes_deallocated, + allocation.type = "deallocated", + context = context_name + ); + + // Record bytes zeroed + u64_histogram_with_unit!( + "apollo.router.request.memory", + "Memory allocated during request processing", + "By", + bytes_zeroed, + allocation.type = "zeroed", + context = context_name + ); + + // Record bytes reallocated + u64_histogram_with_unit!( + "apollo.router.request.memory", + "Memory allocated during request processing", + "By", + bytes_reallocated, + allocation.type = "reallocated", + context = context_name + ); +} + +#[cfg(test)] +mod tests { + use tower::ServiceExt; + + use super::*; + use crate::metrics::FutureMetricsExt; + use crate::services::router; + + #[tokio::test] + async fn test_allocation_metrics_layer() { + async { + // Create a simple service that allocates memory + let service = tower::service_fn(|_req: router::Request| async { + // Allocate some memory during request processing + let _v = Vec::::with_capacity(10000); + Ok::<_, tower::BoxError>(router::Response::fake_builder().build().unwrap()) + }); + + // Wrap with allocation metrics layer + let layer = AllocationMetricsLayer::new(); + let mut service = layer.layer(service); + + // Make a request + let request = router::Request::fake_builder().build().unwrap(); + let _response = service.ready().await.unwrap().call(request).await.unwrap(); + + // Verify metrics were recorded + // Note: We can't easily assert on histogram values, but the test verifies + // the layer compiles and runs without errors + } + .with_metrics() + .await; + } +} diff --git a/apollo-router/src/plugins/telemetry/metrics/mod.rs b/apollo-router/src/plugins/telemetry/metrics/mod.rs index b8d0fac56d..0c5d5fad16 100644 --- a/apollo-router/src/plugins/telemetry/metrics/mod.rs +++ b/apollo-router/src/plugins/telemetry/metrics/mod.rs @@ -1,6 +1,7 @@ use opentelemetry_sdk::metrics::Aggregation; use opentelemetry_sdk::metrics::InstrumentKind; use opentelemetry_sdk::metrics::reader::AggregationSelector; +pub(crate) mod allocation; pub(crate) mod apollo; pub(crate) mod local_type_stats; pub(crate) mod otlp; diff --git a/apollo-router/src/plugins/telemetry/mod.rs b/apollo-router/src/plugins/telemetry/mod.rs index 14c3924550..e8235a0dde 100644 --- a/apollo-router/src/plugins/telemetry/mod.rs +++ b/apollo-router/src/plugins/telemetry/mod.rs @@ -384,6 +384,7 @@ impl PluginPrivate for Telemetry { .clone(); ServiceBuilder::new() + .layer(metrics::allocation::AllocationMetricsLayer::new()) .map_response(move |response: router::Response| { // The current span *should* be the request span as we are outside the instrument block. let span = Span::current(); diff --git a/apollo-router/src/plugins/telemetry/reload/builder.rs b/apollo-router/src/plugins/telemetry/reload/builder.rs index 3c962fb977..5cf185720f 100644 --- a/apollo-router/src/plugins/telemetry/reload/builder.rs +++ b/apollo-router/src/plugins/telemetry/reload/builder.rs @@ -97,6 +97,8 @@ impl<'a> Builder<'a> { let mut builder = MetricsBuilder::new(self.config); builder.configure(&self.config.exporters.metrics.prometheus)?; builder.configure(&self.config.exporters.metrics.otlp)?; + // Register memory allocation views with custom buckets + crate::plugins::telemetry::metrics::allocation::register_memory_allocation_views(&mut builder); builder.configure_views(MeterProviderType::Public)?; let (prometheus_registry, meter_providers, _) = builder.build(); diff --git a/apollo-router/src/query_planner/query_planner_service.rs b/apollo-router/src/query_planner/query_planner_service.rs index 57729b7269..9116afa7ef 100644 --- a/apollo-router/src/query_planner/query_planner_service.rs +++ b/apollo-router/src/query_planner/query_planner_service.rs @@ -24,6 +24,7 @@ use tower::Service; use super::PlanNode; use super::QueryKey; use crate::Configuration; +use crate::allocator::WithMemoryTracking; use crate::apollo_studio_interop::generate_usage_reporting; use crate::compute_job; use crate::compute_job::ComputeJobType; @@ -470,7 +471,7 @@ impl Service for QueryPlannerService { }; // Return the response as an immediate future - Box::pin(fut) + Box::pin(fut.with_memory_tracking("query_planner.request", None)) } }