Skip to content

Commit

Permalink
Add SessionStateBuilder and extract out the registration of defaults (a…
Browse files Browse the repository at this point in the history
…pache#11403)

* Create a SessionStateBuilder and use it for creating anything but a basic SessionState.

* Updated new_from_existing to take a reference to the existing SessionState and clone it.

* Minor documentation update.

* SessionStateDefaults improvements.

* Reworked how SessionStateBuilder works from PR feedback.

* Bug fix for missing array_expressions cfg feature.

* Review feedback updates + doc fixes for SessionStateDefaults

* Cargo fmt update.
  • Loading branch information
Omega359 authored and xinlifoobar committed Jul 18, 2024
1 parent c5c8587 commit b8f66e4
Show file tree
Hide file tree
Showing 14 changed files with 884 additions and 232 deletions.
11 changes: 9 additions & 2 deletions datafusion-cli/src/catalog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ use datafusion::datasource::listing::{
use datafusion::datasource::TableProvider;
use datafusion::error::Result;
use datafusion::execution::context::SessionState;
use datafusion::execution::session_state::SessionStateBuilder;

use async_trait::async_trait;
use dirs::home_dir;
Expand Down Expand Up @@ -162,6 +163,7 @@ impl SchemaProvider for DynamicFileSchemaProvider {
.ok_or_else(|| plan_datafusion_err!("locking error"))?
.read()
.clone();
let mut builder = SessionStateBuilder::from(state.clone());
let optimized_name = substitute_tilde(name.to_owned());
let table_url = ListingTableUrl::parse(optimized_name.as_str())?;
let scheme = table_url.scheme();
Expand All @@ -178,13 +180,18 @@ impl SchemaProvider for DynamicFileSchemaProvider {
// to any command options so the only choice is to use an empty collection
match scheme {
"s3" | "oss" | "cos" => {
state = state.add_table_options_extension(AwsOptions::default());
if let Some(table_options) = builder.table_options() {
table_options.extensions.insert(AwsOptions::default())
}
}
"gs" | "gcs" => {
state = state.add_table_options_extension(GcpOptions::default())
if let Some(table_options) = builder.table_options() {
table_options.extensions.insert(GcpOptions::default())
}
}
_ => {}
};
state = builder.build();
let store = get_object_store(
&state,
table_url.scheme(),
Expand Down
9 changes: 4 additions & 5 deletions datafusion-examples/examples/custom_file_format.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use arrow::{
datatypes::UInt64Type,
};
use arrow_schema::{DataType, Field, Schema, SchemaRef};
use datafusion::execution::session_state::SessionStateBuilder;
use datafusion::{
datasource::{
file_format::{
Expand All @@ -32,9 +33,9 @@ use datafusion::{
MemTable,
},
error::Result,
execution::{context::SessionState, runtime_env::RuntimeEnv},
execution::context::SessionState,
physical_plan::ExecutionPlan,
prelude::{SessionConfig, SessionContext},
prelude::SessionContext,
};
use datafusion_common::{GetExt, Statistics};
use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement};
Expand Down Expand Up @@ -176,9 +177,7 @@ impl GetExt for TSVFileFactory {
#[tokio::main]
async fn main() -> Result<()> {
// Create a new context with the default configuration
let config = SessionConfig::new();
let runtime = RuntimeEnv::default();
let mut state = SessionState::new_with_config_rt(config, Arc::new(runtime));
let mut state = SessionStateBuilder::new().with_default_features().build();

// Register the custom file format
let file_format = Arc::new(TSVFileFactory::new());
Expand Down
7 changes: 6 additions & 1 deletion datafusion/core/src/datasource/file_format/csv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,7 @@ mod tests {
use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv};
use datafusion_expr::{col, lit};

use crate::execution::session_state::SessionStateBuilder;
use chrono::DateTime;
use object_store::local::LocalFileSystem;
use object_store::path::Path;
Expand Down Expand Up @@ -814,7 +815,11 @@ mod tests {
let runtime = Arc::new(RuntimeEnv::new(RuntimeConfig::new()).unwrap());
let mut cfg = SessionConfig::new();
cfg.options_mut().catalog.has_header = true;
let session_state = SessionState::new_with_config_rt(cfg, runtime);
let session_state = SessionStateBuilder::new()
.with_config(cfg)
.with_runtime_env(runtime)
.with_default_features()
.build();
let integration = LocalFileSystem::new_with_prefix(arrow_test_data()).unwrap();
let path = Path::from("csv/aggregate_test_100.csv");
let csv = CsvFormat::default().with_has_header(true);
Expand Down
25 changes: 19 additions & 6 deletions datafusion/core/src/execution/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ use object_store::ObjectStore;
use parking_lot::RwLock;
use url::Url;

use crate::execution::session_state::SessionStateBuilder;
pub use datafusion_execution::config::SessionConfig;
pub use datafusion_execution::TaskContext;
pub use datafusion_expr::execution_props::ExecutionProps;
Expand Down Expand Up @@ -294,7 +295,11 @@ impl SessionContext {
/// all `SessionContext`'s should be configured with the
/// same `RuntimeEnv`.
pub fn new_with_config_rt(config: SessionConfig, runtime: Arc<RuntimeEnv>) -> Self {
let state = SessionState::new_with_config_rt(config, runtime);
let state = SessionStateBuilder::new()
.with_config(config)
.with_runtime_env(runtime)
.with_default_features()
.build();
Self::new_with_state(state)
}

Expand All @@ -315,7 +320,7 @@ impl SessionContext {
}

/// Creates a new `SessionContext` using the provided [`SessionState`]
#[deprecated(since = "32.0.0", note = "Use SessionState::new_with_state")]
#[deprecated(since = "32.0.0", note = "Use SessionContext::new_with_state")]
pub fn with_state(state: SessionState) -> Self {
Self::new_with_state(state)
}
Expand Down Expand Up @@ -1574,6 +1579,7 @@ mod tests {
use datafusion_common_runtime::SpawnedTask;

use crate::catalog::schema::SchemaProvider;
use crate::execution::session_state::SessionStateBuilder;
use crate::physical_planner::PhysicalPlanner;
use async_trait::async_trait;
use tempfile::TempDir;
Expand Down Expand Up @@ -1707,7 +1713,11 @@ mod tests {
.set_str("datafusion.catalog.location", url.as_str())
.set_str("datafusion.catalog.format", "CSV")
.set_str("datafusion.catalog.has_header", "true");
let session_state = SessionState::new_with_config_rt(cfg, runtime);
let session_state = SessionStateBuilder::new()
.with_config(cfg)
.with_runtime_env(runtime)
.with_default_features()
.build();
let ctx = SessionContext::new_with_state(session_state);
ctx.refresh_catalogs().await?;

Expand All @@ -1733,9 +1743,12 @@ mod tests {
#[tokio::test]
async fn custom_query_planner() -> Result<()> {
let runtime = Arc::new(RuntimeEnv::default());
let session_state =
SessionState::new_with_config_rt(SessionConfig::new(), runtime)
.with_query_planner(Arc::new(MyQueryPlanner {}));
let session_state = SessionStateBuilder::new()
.with_config(SessionConfig::new())
.with_runtime_env(runtime)
.with_default_features()
.with_query_planner(Arc::new(MyQueryPlanner {}))
.build();
let ctx = SessionContext::new_with_state(session_state);

let df = ctx.sql("SELECT 1").await?;
Expand Down
Loading

0 comments on commit b8f66e4

Please sign in to comment.