Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 16 additions & 53 deletions native/core/src/execution/jni_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ use futures::poll;
use jni::{
errors::Result as JNIResult,
objects::{
JByteArray, JClass, JIntArray, JLongArray, JMap, JObject, JObjectArray, JPrimitiveArray,
JString, ReleaseMode,
JByteArray, JClass, JIntArray, JLongArray, JObject, JObjectArray, JPrimitiveArray, JString,
ReleaseMode,
},
sys::{jbyteArray, jint, jlong, jlongArray},
JNIEnv,
Expand Down Expand Up @@ -75,8 +75,6 @@ struct ExecutionContext {
pub input_sources: Vec<Arc<GlobalRef>>,
/// The record batch stream to pull results from
pub stream: Option<SendableRecordBatchStream>,
/// Configurations for DF execution
pub conf: HashMap<String, String>,
/// The Tokio runtime used for async.
pub runtime: Runtime,
/// Native metrics
Expand All @@ -99,11 +97,15 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
e: JNIEnv,
_class: JClass,
id: jlong,
config_object: JObject,
iterators: jobjectArray,
serialized_query: jbyteArray,
metrics_node: JObject,
comet_task_memory_manager_obj: JObject,
batch_size: jint,
debug_native: jboolean,
explain_native: jboolean,
worker_threads: jint,
blocking_threads: jint,
) -> jlong {
try_unwrap_or_throw(&e, |mut env| {
// Init JVM classes
Expand All @@ -115,36 +117,10 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
// Deserialize query plan
let spark_plan = serde::deserialize_op(bytes.as_slice())?;

// Sets up context
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we can put this in an initConf that is called only once per session? That allows us the flexibility to pass more session configuration parameters if needed down the road? Or maybe that can be done later if the need arises.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is that this code is invoked once per task in the executors. I don't know if there is a concept of a session at this point?

let mut configs = HashMap::new();

let config_map = JMap::from_env(&mut env, &config_object)?;
let mut map_iter = config_map.iter(&mut env)?;
while let Some((key, value)) = map_iter.next(&mut env)? {
let key: String = env.get_string(&JString::from(key)).unwrap().into();
let value: String = env.get_string(&JString::from(value)).unwrap().into();
configs.insert(key, value);
}

// Whether we've enabled additional debugging on the native side
let debug_native = parse_bool(&configs, "debug_native")?;
let explain_native = parse_bool(&configs, "explain_native")?;

let worker_threads = configs
.get("worker_threads")
.map(String::as_str)
.unwrap_or("4")
.parse::<usize>()?;
let blocking_threads = configs
.get("blocking_threads")
.map(String::as_str)
.unwrap_or("10")
.parse::<usize>()?;

// Use multi-threaded tokio runtime to prevent blocking spawned tasks if any
let runtime = tokio::runtime::Builder::new_multi_thread()
.worker_threads(worker_threads)
.max_blocking_threads(blocking_threads)
.worker_threads(worker_threads as usize)
.max_blocking_threads(blocking_threads as usize)
.enable_all()
.build()?;

Expand All @@ -165,7 +141,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
// We need to keep the session context alive. Some session state like temporary
// dictionaries are stored in session context. If it is dropped, the temporary
// dictionaries will be dropped as well.
let session = prepare_datafusion_session_context(&configs, task_memory_manager)?;
let session = prepare_datafusion_session_context(batch_size as usize, task_memory_manager)?;

let exec_context = Box::new(ExecutionContext {
id,
Expand All @@ -174,32 +150,23 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
scans: vec![],
input_sources,
stream: None,
conf: configs,
runtime,
metrics,
session_ctx: Arc::new(session),
debug_native,
explain_native,
debug_native: debug_native == 1,
explain_native: explain_native == 1,
metrics_jstrings: HashMap::new(),
});

Ok(Box::into_raw(exec_context) as i64)
})
}

/// Parse Comet configs and configure DataFusion session context.
/// Configure DataFusion session context.
fn prepare_datafusion_session_context(
conf: &HashMap<String, String>,
batch_size: usize,
comet_task_memory_manager: Arc<GlobalRef>,
) -> CometResult<SessionContext> {
// Get the batch size from Comet JVM side
let batch_size = conf
.get("batch_size")
.ok_or(CometError::Internal(
"Config 'batch_size' is not specified from Comet JVM side".to_string(),
))?
.parse::<usize>()?;

let mut rt_config = RuntimeConfig::new().with_disk_manager(DiskManagerConfig::NewOs);

// Set Comet memory pool for native
Expand All @@ -209,7 +176,7 @@ fn prepare_datafusion_session_context(
// Get Datafusion configuration from Spark Execution context
// can be configured in Comet Spark JVM using Spark --conf parameters
// e.g: spark-shell --conf spark.datafusion.sql_parser.parse_float_as_decimal=true
let mut session_config = SessionConfig::new()
let session_config = SessionConfig::new()
.with_batch_size(batch_size)
// DataFusion partial aggregates can emit duplicate rows so we disable the
// skip partial aggregation feature because this is not compatible with Spark's
Expand All @@ -222,11 +189,7 @@ fn prepare_datafusion_session_context(
&ScalarValue::Float64(Some(1.1)),
);

for (key, value) in conf.iter().filter(|(k, _)| k.starts_with("datafusion.")) {
session_config = session_config.set_str(key, value);
}

let runtime = RuntimeEnv::try_new(rt_config).unwrap();
let runtime = RuntimeEnv::try_new(rt_config)?;

let mut session_ctx = SessionContext::new_with_config_rt(session_config, Arc::new(runtime));

Expand Down
32 changes: 6 additions & 26 deletions spark/src/main/scala/org/apache/comet/CometExecIterator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -54,43 +54,23 @@ class CometExecIterator(
new CometBatchIterator(iterator, nativeUtil)
}.toArray
private val plan = {
val configs = createNativeConf
nativeLib.createPlan(
id,
configs,
cometBatchIterators,
protobufQueryPlan,
nativeMetrics,
new CometTaskMemoryManager(id))
new CometTaskMemoryManager(id),
batchSize = COMET_BATCH_SIZE.get(),
debug = COMET_DEBUG_ENABLED.get(),
explain = COMET_EXPLAIN_NATIVE_ENABLED.get(),
workerThreads = COMET_WORKER_THREADS.get(),
blockingThreads = COMET_BLOCKING_THREADS.get())
}

private var nextBatch: Option[ColumnarBatch] = None
private var currentBatch: ColumnarBatch = null
private var closed: Boolean = false

/**
* Creates a new configuration map to be passed to the native side.
*/
private def createNativeConf: java.util.HashMap[String, String] = {
val result = new java.util.HashMap[String, String]()
val conf = SparkEnv.get.conf

result.put("batch_size", String.valueOf(COMET_BATCH_SIZE.get()))
result.put("debug_native", String.valueOf(COMET_DEBUG_ENABLED.get()))
result.put("explain_native", String.valueOf(COMET_EXPLAIN_NATIVE_ENABLED.get()))
result.put("worker_threads", String.valueOf(COMET_WORKER_THREADS.get()))
result.put("blocking_threads", String.valueOf(COMET_BLOCKING_THREADS.get()))

// Strip mandatory prefix spark. which is not required for DataFusion session params
conf.getAll.foreach {
case (k, v) if k.startsWith("spark.datafusion") =>
result.put(k.replaceFirst("spark\\.", ""), v)
case _ =>
}

result
}

def getNextBatch(): Option[ColumnarBatch] = {
nativeUtil.getNextBatch(
numOutputCols,
Expand Down
10 changes: 6 additions & 4 deletions spark/src/main/scala/org/apache/comet/Native.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@

package org.apache.comet

import java.util.Map

import org.apache.spark.CometTaskMemoryManager
import org.apache.spark.sql.comet.CometMetricNode

Expand All @@ -47,11 +45,15 @@ class Native extends NativeBase {
*/
@native def createPlan(
id: Long,
configMap: Map[String, String],
iterators: Array[CometBatchIterator],
plan: Array[Byte],
metrics: CometMetricNode,
taskMemoryManager: CometTaskMemoryManager): Long
taskMemoryManager: CometTaskMemoryManager,
batchSize: Int,
debug: Boolean,
explain: Boolean,
workerThreads: Int,
blockingThreads: Int): Long

/**
* Execute a native query plan based on given input Arrow arrays.
Expand Down