diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala b/common/src/main/scala/org/apache/comet/CometConf.scala index 0824235718..8c415ec465 100644 --- a/common/src/main/scala/org/apache/comet/CometConf.scala +++ b/common/src/main/scala/org/apache/comet/CometConf.scala @@ -504,8 +504,8 @@ object CometConf extends ShimCometConf { .doc( "The type of memory pool to be used for Comet native execution. " + "Available memory pool types are 'greedy', 'fair_spill', 'greedy_task_shared', " + - "'fair_spill_task_shared', 'greedy_global' and 'fair_spill_global', By default, " + - "this config is 'greedy_task_shared'.") + "'fair_spill_task_shared', 'greedy_global' and 'fair_spill_global'. For off-heap " + + "types are 'unified' and `fair_unified`.") .stringConf .createWithDefault("greedy_task_shared") diff --git a/docs/source/user-guide/configs.md b/docs/source/user-guide/configs.md index 8245e7b76b..2ddff3b98e 100644 --- a/docs/source/user-guide/configs.md +++ b/docs/source/user-guide/configs.md @@ -48,7 +48,7 @@ Comet provides the following configuration settings. | spark.comet.exec.hashJoin.enabled | Whether to enable hashJoin by default. | true | | spark.comet.exec.initCap.enabled | Whether to enable initCap by default. | false | | spark.comet.exec.localLimit.enabled | Whether to enable localLimit by default. | true | -| spark.comet.exec.memoryPool | The type of memory pool to be used for Comet native execution. Available memory pool types are 'greedy', 'fair_spill', 'greedy_task_shared', 'fair_spill_task_shared', 'greedy_global' and 'fair_spill_global', By default, this config is 'greedy_task_shared'. | greedy_task_shared | +| spark.comet.exec.memoryPool | The type of memory pool to be used for Comet native execution. Available memory pool types are 'greedy', 'fair_spill', 'greedy_task_shared', 'fair_spill_task_shared', 'greedy_global' and 'fair_spill_global'. For off-heap types are 'unified' and `fair_unified`. | greedy_task_shared | | spark.comet.exec.project.enabled | Whether to enable project by default. | true | | spark.comet.exec.replaceSortMergeJoin | Experimental feature to force Spark to replace SortMergeJoin with ShuffledHashJoin for improved performance. This feature is not stable yet. For more information, refer to the Comet Tuning Guide (https://datafusion.apache.org/comet/user-guide/tuning.html). | false | | spark.comet.exec.shuffle.compression.codec | The codec of Comet native shuffle used to compress shuffle data. lz4, zstd, and snappy are supported. Compression can be disabled by setting spark.shuffle.compress=false. | lz4 | diff --git a/native/Cargo.lock b/native/Cargo.lock index ee734a0760..eae27c36d3 100644 --- a/native/Cargo.lock +++ b/native/Cargo.lock @@ -910,6 +910,7 @@ dependencies = [ "num", "object_store", "once_cell", + "parking_lot", "parquet", "paste", "pprof", diff --git a/native/core/Cargo.toml b/native/core/Cargo.toml index d3f17a7056..6ab5dcd4c5 100644 --- a/native/core/Cargo.toml +++ b/native/core/Cargo.toml @@ -77,6 +77,7 @@ datafusion-comet-proto = { workspace = true } object_store = { workspace = true } url = { workspace = true } chrono = { workspace = true } +parking_lot = "0.12.3" [dev-dependencies] pprof = { version = "0.14.0", features = ["flamegraph"] } diff --git a/native/core/src/execution/fair_memory_pool.rs b/native/core/src/execution/fair_memory_pool.rs new file mode 100644 index 0000000000..9b50f8bd93 --- /dev/null +++ b/native/core/src/execution/fair_memory_pool.rs @@ -0,0 +1,166 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{ + fmt::{Debug, Formatter, Result as FmtResult}, + sync::Arc, +}; + +use jni::objects::GlobalRef; + +use crate::{ + errors::CometResult, + jvm_bridge::{jni_call, JVMClasses}, +}; +use datafusion::{ + common::DataFusionError, + execution::memory_pool::{MemoryPool, MemoryReservation}, +}; +use datafusion_common::resources_err; +use datafusion_execution::memory_pool::MemoryConsumer; +use parking_lot::Mutex; + +/// A DataFusion fair `MemoryPool` implementation for Comet. Internally this is +/// implemented via delegating calls to [`crate::jvm_bridge::CometTaskMemoryManager`]. +pub struct CometFairMemoryPool { + task_memory_manager_handle: Arc, + pool_size: usize, + state: Mutex, +} + +struct CometFairPoolState { + used: usize, + num: usize, +} + +impl Debug for CometFairMemoryPool { + fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { + let state = self.state.lock(); + f.debug_struct("CometFairMemoryPool") + .field("pool_size", &self.pool_size) + .field("used", &state.used) + .field("num", &state.num) + .finish() + } +} + +impl CometFairMemoryPool { + pub fn new( + task_memory_manager_handle: Arc, + pool_size: usize, + ) -> CometFairMemoryPool { + Self { + task_memory_manager_handle, + pool_size, + state: Mutex::new(CometFairPoolState { used: 0, num: 0 }), + } + } + + fn acquire(&self, additional: usize) -> CometResult { + let mut env = JVMClasses::get_env()?; + let handle = self.task_memory_manager_handle.as_obj(); + unsafe { + jni_call!(&mut env, + comet_task_memory_manager(handle).acquire_memory(additional as i64) -> i64) + } + } + + fn release(&self, size: usize) -> CometResult<()> { + let mut env = JVMClasses::get_env()?; + let handle = self.task_memory_manager_handle.as_obj(); + unsafe { + jni_call!(&mut env, comet_task_memory_manager(handle).release_memory(size as i64) -> ()) + } + } +} + +unsafe impl Send for CometFairMemoryPool {} +unsafe impl Sync for CometFairMemoryPool {} + +impl MemoryPool for CometFairMemoryPool { + fn register(&self, _: &MemoryConsumer) { + let mut state = self.state.lock(); + state.num = state + .num + .checked_add(1) + .expect("unexpected amount of register happened"); + } + + fn unregister(&self, _: &MemoryConsumer) { + let mut state = self.state.lock(); + state.num = state + .num + .checked_sub(1) + .expect("unexpected amount of unregister happened"); + } + + fn grow(&self, reservation: &MemoryReservation, additional: usize) { + self.try_grow(reservation, additional).unwrap(); + } + + fn shrink(&self, reservation: &MemoryReservation, subtractive: usize) { + if subtractive > 0 { + let mut state = self.state.lock(); + let size = reservation.size(); + if size < subtractive { + panic!("Failed to release {subtractive} bytes where only {size} bytes reserved") + } + self.release(subtractive) + .unwrap_or_else(|_| panic!("Failed to release {} bytes", subtractive)); + state.used = state.used.checked_sub(subtractive).unwrap(); + } + } + + fn try_grow( + &self, + reservation: &MemoryReservation, + additional: usize, + ) -> Result<(), DataFusionError> { + if additional > 0 { + let mut state = self.state.lock(); + let num = state.num; + let limit = self.pool_size.checked_div(num).unwrap(); + let size = reservation.size(); + if limit < size + additional { + return resources_err!( + "Failed to acquire {additional} bytes where {size} bytes already reserved and the fair limit is {limit} bytes, {num} registered" + ); + } + + let acquired = self.acquire(additional)?; + // If the number of bytes we acquired is less than the requested, return an error, + // and hopefully will trigger spilling from the caller side. + if acquired < additional as i64 { + // Release the acquired bytes before throwing error + self.release(acquired as usize)?; + + return resources_err!( + "Failed to acquire {} bytes, only got {} bytes. Reserved: {} bytes", + additional, + acquired, + state.used + ); + } + state.used = state.used.checked_add(additional).unwrap(); + } + Ok(()) + } + + fn reserved(&self) -> usize { + self.state.lock().used + } +} diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index fe29d8da14..25dadb1d9d 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -63,6 +63,7 @@ use std::num::NonZeroUsize; use std::sync::Mutex; use tokio::runtime::Runtime; +use crate::execution::fair_memory_pool::CometFairMemoryPool; use crate::execution::operators::ScanExec; use crate::execution::shuffle::{read_ipc_compressed, CompressionCodec}; use crate::execution::spark_plan::SparkPlan; @@ -108,6 +109,7 @@ struct ExecutionContext { #[derive(PartialEq, Eq)] enum MemoryPoolType { Unified, + FairUnified, Greedy, FairSpill, GreedyTaskShared, @@ -285,11 +287,14 @@ fn parse_memory_pool_config( memory_limit: i64, memory_limit_per_task: i64, ) -> CometResult { + let pool_size = memory_limit as usize; let memory_pool_config = if use_unified_memory_manager { - MemoryPoolConfig::new(MemoryPoolType::Unified, 0) + match memory_pool_type.as_str() { + "fair_unified" => MemoryPoolConfig::new(MemoryPoolType::FairUnified, pool_size), + _ => MemoryPoolConfig::new(MemoryPoolType::Unified, 0), + } } else { // Use the memory pool from DF - let pool_size = memory_limit as usize; let pool_size_per_task = memory_limit_per_task as usize; match memory_pool_type.as_str() { "fair_spill_task_shared" => { @@ -327,6 +332,12 @@ fn create_memory_pool( let memory_pool = CometMemoryPool::new(comet_task_memory_manager); Arc::new(memory_pool) } + MemoryPoolType::FairUnified => { + // Set Comet fair memory pool for native + let memory_pool = + CometFairMemoryPool::new(comet_task_memory_manager, memory_pool_config.pool_size); + Arc::new(memory_pool) + } MemoryPoolType::Greedy => Arc::new(TrackConsumersPool::new( GreedyMemoryPool::new(memory_pool_config.pool_size), NonZeroUsize::new(NUM_TRACKED_CONSUMERS).unwrap(), diff --git a/native/core/src/execution/mod.rs b/native/core/src/execution/mod.rs index a74ec3017e..23b16f5f9a 100644 --- a/native/core/src/execution/mod.rs +++ b/native/core/src/execution/mod.rs @@ -29,6 +29,7 @@ pub(crate) mod util; pub use datafusion_comet_spark_expr::timezone; pub(crate) mod utils; +mod fair_memory_pool; mod memory_pool; pub use memory_pool::*; diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTPCHQuerySuite.scala b/spark/src/test/scala/org/apache/spark/sql/CometTPCHQuerySuite.scala index 0dadb22179..b9fc56b94b 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTPCHQuerySuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTPCHQuerySuite.scala @@ -94,6 +94,7 @@ class CometTPCHQuerySuite extends QueryTest with TPCBase with ShimCometTPCHQuery conf.set(CometConf.COMET_SHUFFLE_MODE.key, "jvm") conf.set(MEMORY_OFFHEAP_ENABLED.key, "true") conf.set(MEMORY_OFFHEAP_SIZE.key, "2g") + conf.set(CometConf.COMET_MEMORY_OVERHEAD.key, "2g") } protected override def createSparkSession: TestSparkSession = {