Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions datafusion/spark/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ edition = { workspace = true }
[package.metadata.docs.rs]
all-features = true

[features]
default = []
core = ["datafusion"]

# Note: add additional linter rules in lib.rs.
# Rust does not support workspace + new linter rules in subcrates yet
# https://github.com/rust-lang/cargo/issues/13157
Expand All @@ -43,6 +47,8 @@ arrow = { workspace = true }
bigdecimal = { workspace = true }
chrono = { workspace = true }
crc32fast = "1.4"
# Optional dependency for SessionStateBuilderSpark extension trait
datafusion = { workspace = true, optional = true, default-features = false }
datafusion-catalog = { workspace = true }
datafusion-common = { workspace = true }
datafusion-execution = { workspace = true }
Expand All @@ -59,6 +65,8 @@ url = { workspace = true }
[dev-dependencies]
arrow = { workspace = true, features = ["test_utils"] }
criterion = { workspace = true }
# for SessionStateBuilderSpark tests
datafusion = { workspace = true, default-features = false }

[[bench]]
harness = false
Expand Down
32 changes: 32 additions & 0 deletions datafusion/spark/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,42 @@
//! ```
//!
//![`Expr`]: datafusion_expr::Expr
//!
//! # Example: enabling Apache Spark features with SessionStateBuilder
//!
//! The recommended way to enable Apache Spark compatibility is to use the
//! `SessionStateBuilderSpark` extension trait. This registers all
//! Apache Spark functions (scalar, aggregate, window, and table) as well as the Apache Spark
//! expression planner.
//!
//! Enable the `core` feature in your `Cargo.toml`:
//! ```toml
//! datafusion-spark = { version = "X", features = ["core"] }
//! ```
//!
//! Then use the extension trait:
//! ```ignore

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would prefer to avoid ignore here if possible

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed it

//! use datafusion_core::execution::SessionStateBuilder;
//! use datafusion_spark::SessionStateBuilderSpark;
//!
//! // Create a SessionState with Apache Spark features enabled
//! // note: the order matters here, `with_spark_features` should be
//! // called after `with_default_features` to overwrite any existing functions
//! let state = SessionStateBuilder::new()
//! .with_default_features()
//! .with_spark_features()
//! .build();
//! ```

pub mod function;
pub mod planner;

#[cfg(feature = "core")]
mod session_state;

#[cfg(feature = "core")]
pub use session_state::SessionStateBuilderSpark;

use datafusion_catalog::TableFunction;
use datafusion_common::Result;
use datafusion_execution::FunctionRegistry;
Expand Down
96 changes: 96 additions & 0 deletions datafusion/spark/src/session_state.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use std::collections::HashMap;
use std::sync::Arc;

use datafusion::execution::session_state::SessionStateBuilder;

use crate::planner::SparkFunctionPlanner;
use crate::{
all_default_aggregate_functions, all_default_scalar_functions,
all_default_table_functions, all_default_window_functions,
};

/// Extension trait for adding Apache Spark features to [`SessionStateBuilder`].
///
/// This trait provides a convenient way to register all Apache Spark-compatible
/// functions and planners with a DataFusion session.
pub trait SessionStateBuilderSpark {
/// Adds all expr_planners, scalar, aggregate, window and table functions
/// compatible with Apache Spark.
///
/// Note: This overwrites any previously registered items with the same name.
fn with_spark_features(self) -> Self;
}

impl SessionStateBuilderSpark for SessionStateBuilder {
fn with_spark_features(mut self) -> Self {
self.expr_planners()
.get_or_insert_with(Vec::new)
// planners are evaluated in order of insertion. Push Apache Spark function planner to the front
// to take precedence over others
.insert(0, Arc::new(SparkFunctionPlanner));

self.scalar_functions()
.get_or_insert_with(Vec::new)
.extend(all_default_scalar_functions());

self.aggregate_functions()
.get_or_insert_with(Vec::new)
.extend(all_default_aggregate_functions());

self.window_functions()
.get_or_insert_with(Vec::new)
.extend(all_default_window_functions());

self.table_functions()
.get_or_insert_with(HashMap::new)
.extend(
all_default_table_functions()
.into_iter()
.map(|f| (f.name().to_string(), f)),
);

self
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_session_state_with_spark_features() {
let state = SessionStateBuilder::new().with_spark_features().build();

assert!(
state.scalar_functions().contains_key("sha2"),
"Apache Spark scalar function 'sha2' should be registered"
);

assert!(
state.aggregate_functions().contains_key("try_sum"),
"Apache Spark aggregate function 'try_sum' should be registered"
);

assert!(
!state.expr_planners().is_empty(),
"Apache Spark expr planners should be registered"
);
}
}
2 changes: 1 addition & 1 deletion datafusion/sqllogictest/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ bytes = { workspace = true, optional = true }
chrono = { workspace = true, optional = true }
clap = { version = "4.5.53", features = ["derive", "env"] }
datafusion = { workspace = true, default-features = true, features = ["avro"] }
datafusion-spark = { workspace = true, default-features = true }
datafusion-spark = { workspace = true, features = ["core"] }
datafusion-substrait = { workspace = true, default-features = true }
futures = { workspace = true }
half = { workspace = true, default-features = true }
Expand Down
16 changes: 5 additions & 11 deletions datafusion/sqllogictest/src/test_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ use datafusion::{
datasource::{MemTable, TableProvider, TableType},
prelude::{CsvReadOptions, SessionContext},
};
use datafusion_spark::SessionStateBuilderSpark;

use crate::is_spark_path;
use async_trait::async_trait;
Expand Down Expand Up @@ -84,21 +85,14 @@ impl TestContext {

let mut state_builder = SessionStateBuilder::new()
.with_config(config)
.with_runtime_env(runtime);
.with_runtime_env(runtime)
.with_default_features();

if is_spark_path(relative_path) {
state_builder = state_builder.with_expr_planners(vec![Arc::new(
datafusion_spark::planner::SparkFunctionPlanner,
)]);
state_builder = state_builder.with_spark_features();
}

let mut state = state_builder.with_default_features().build();

if is_spark_path(relative_path) {
info!("Registering Spark functions");
datafusion_spark::register_all(&mut state)
.expect("Can not register Spark functions");
}
let state = state_builder.build();

let mut test_ctx = TestContext::new(SessionContext::new_with_state(state));

Expand Down