diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index 47d87fe1af..b2db35d2e1 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -31,8 +31,8 @@ use futures::poll; use jni::{ errors::Result as JNIResult, objects::{ - JByteArray, JClass, JIntArray, JLongArray, JMap, JObject, JObjectArray, JPrimitiveArray, - JString, ReleaseMode, + JByteArray, JClass, JIntArray, JLongArray, JObject, JObjectArray, JPrimitiveArray, JString, + ReleaseMode, }, sys::{jbyteArray, jint, jlong, jlongArray}, JNIEnv, @@ -75,8 +75,6 @@ struct ExecutionContext { pub input_sources: Vec>, /// The record batch stream to pull results from pub stream: Option, - /// Configurations for DF execution - pub conf: HashMap, /// The Tokio runtime used for async. pub runtime: Runtime, /// Native metrics @@ -99,11 +97,15 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( e: JNIEnv, _class: JClass, id: jlong, - config_object: JObject, iterators: jobjectArray, serialized_query: jbyteArray, metrics_node: JObject, comet_task_memory_manager_obj: JObject, + batch_size: jint, + debug_native: jboolean, + explain_native: jboolean, + worker_threads: jint, + blocking_threads: jint, ) -> jlong { try_unwrap_or_throw(&e, |mut env| { // Init JVM classes @@ -115,36 +117,10 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( // Deserialize query plan let spark_plan = serde::deserialize_op(bytes.as_slice())?; - // Sets up context - let mut configs = HashMap::new(); - - let config_map = JMap::from_env(&mut env, &config_object)?; - let mut map_iter = config_map.iter(&mut env)?; - while let Some((key, value)) = map_iter.next(&mut env)? { - let key: String = env.get_string(&JString::from(key)).unwrap().into(); - let value: String = env.get_string(&JString::from(value)).unwrap().into(); - configs.insert(key, value); - } - - // Whether we've enabled additional debugging on the native side - let debug_native = parse_bool(&configs, "debug_native")?; - let explain_native = parse_bool(&configs, "explain_native")?; - - let worker_threads = configs - .get("worker_threads") - .map(String::as_str) - .unwrap_or("4") - .parse::()?; - let blocking_threads = configs - .get("blocking_threads") - .map(String::as_str) - .unwrap_or("10") - .parse::()?; - // Use multi-threaded tokio runtime to prevent blocking spawned tasks if any let runtime = tokio::runtime::Builder::new_multi_thread() - .worker_threads(worker_threads) - .max_blocking_threads(blocking_threads) + .worker_threads(worker_threads as usize) + .max_blocking_threads(blocking_threads as usize) .enable_all() .build()?; @@ -165,7 +141,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( // We need to keep the session context alive. Some session state like temporary // dictionaries are stored in session context. If it is dropped, the temporary // dictionaries will be dropped as well. - let session = prepare_datafusion_session_context(&configs, task_memory_manager)?; + let session = prepare_datafusion_session_context(batch_size as usize, task_memory_manager)?; let exec_context = Box::new(ExecutionContext { id, @@ -174,12 +150,11 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( scans: vec![], input_sources, stream: None, - conf: configs, runtime, metrics, session_ctx: Arc::new(session), - debug_native, - explain_native, + debug_native: debug_native == 1, + explain_native: explain_native == 1, metrics_jstrings: HashMap::new(), }); @@ -187,19 +162,11 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( }) } -/// Parse Comet configs and configure DataFusion session context. +/// Configure DataFusion session context. fn prepare_datafusion_session_context( - conf: &HashMap, + batch_size: usize, comet_task_memory_manager: Arc, ) -> CometResult { - // Get the batch size from Comet JVM side - let batch_size = conf - .get("batch_size") - .ok_or(CometError::Internal( - "Config 'batch_size' is not specified from Comet JVM side".to_string(), - ))? - .parse::()?; - let mut rt_config = RuntimeConfig::new().with_disk_manager(DiskManagerConfig::NewOs); // Set Comet memory pool for native @@ -209,7 +176,7 @@ fn prepare_datafusion_session_context( // Get Datafusion configuration from Spark Execution context // can be configured in Comet Spark JVM using Spark --conf parameters // e.g: spark-shell --conf spark.datafusion.sql_parser.parse_float_as_decimal=true - let mut session_config = SessionConfig::new() + let session_config = SessionConfig::new() .with_batch_size(batch_size) // DataFusion partial aggregates can emit duplicate rows so we disable the // skip partial aggregation feature because this is not compatible with Spark's @@ -222,11 +189,7 @@ fn prepare_datafusion_session_context( &ScalarValue::Float64(Some(1.1)), ); - for (key, value) in conf.iter().filter(|(k, _)| k.starts_with("datafusion.")) { - session_config = session_config.set_str(key, value); - } - - let runtime = RuntimeEnv::try_new(rt_config).unwrap(); + let runtime = RuntimeEnv::try_new(rt_config)?; let mut session_ctx = SessionContext::new_with_config_rt(session_config, Arc::new(runtime)); diff --git a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala index b1f22726a6..da2729ce57 100644 --- a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala +++ b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala @@ -54,43 +54,23 @@ class CometExecIterator( new CometBatchIterator(iterator, nativeUtil) }.toArray private val plan = { - val configs = createNativeConf nativeLib.createPlan( id, - configs, cometBatchIterators, protobufQueryPlan, nativeMetrics, - new CometTaskMemoryManager(id)) + new CometTaskMemoryManager(id), + batchSize = COMET_BATCH_SIZE.get(), + debug = COMET_DEBUG_ENABLED.get(), + explain = COMET_EXPLAIN_NATIVE_ENABLED.get(), + workerThreads = COMET_WORKER_THREADS.get(), + blockingThreads = COMET_BLOCKING_THREADS.get()) } private var nextBatch: Option[ColumnarBatch] = None private var currentBatch: ColumnarBatch = null private var closed: Boolean = false - /** - * Creates a new configuration map to be passed to the native side. - */ - private def createNativeConf: java.util.HashMap[String, String] = { - val result = new java.util.HashMap[String, String]() - val conf = SparkEnv.get.conf - - result.put("batch_size", String.valueOf(COMET_BATCH_SIZE.get())) - result.put("debug_native", String.valueOf(COMET_DEBUG_ENABLED.get())) - result.put("explain_native", String.valueOf(COMET_EXPLAIN_NATIVE_ENABLED.get())) - result.put("worker_threads", String.valueOf(COMET_WORKER_THREADS.get())) - result.put("blocking_threads", String.valueOf(COMET_BLOCKING_THREADS.get())) - - // Strip mandatory prefix spark. which is not required for DataFusion session params - conf.getAll.foreach { - case (k, v) if k.startsWith("spark.datafusion") => - result.put(k.replaceFirst("spark\\.", ""), v) - case _ => - } - - result - } - def getNextBatch(): Option[ColumnarBatch] = { nativeUtil.getNextBatch( numOutputCols, diff --git a/spark/src/main/scala/org/apache/comet/Native.scala b/spark/src/main/scala/org/apache/comet/Native.scala index 03a9dea0c6..929a93ebea 100644 --- a/spark/src/main/scala/org/apache/comet/Native.scala +++ b/spark/src/main/scala/org/apache/comet/Native.scala @@ -19,8 +19,6 @@ package org.apache.comet -import java.util.Map - import org.apache.spark.CometTaskMemoryManager import org.apache.spark.sql.comet.CometMetricNode @@ -47,11 +45,15 @@ class Native extends NativeBase { */ @native def createPlan( id: Long, - configMap: Map[String, String], iterators: Array[CometBatchIterator], plan: Array[Byte], metrics: CometMetricNode, - taskMemoryManager: CometTaskMemoryManager): Long + taskMemoryManager: CometTaskMemoryManager, + batchSize: Int, + debug: Boolean, + explain: Boolean, + workerThreads: Int, + blockingThreads: Int): Long /** * Execute a native query plan based on given input Arrow arrays.