diff --git a/.github/workflows/pr_benchmarks.yml b/.github/workflows/pr_benchmarks.yml deleted file mode 100644 index 5827c42e85ae..000000000000 --- a/.github/workflows/pr_benchmarks.yml +++ /dev/null @@ -1,101 +0,0 @@ -# Runs the benchmark command on the PR and -# on the branch, posting the results as a comment back the PR -name: Benchmarks - -on: - issue_comment: - -jobs: - benchmark: - name: Run Benchmarks - runs-on: ubuntu-latest - if: github.event.issue.pull_request && contains(github.event.comment.body, '/benchmark') - steps: - - name: Dump GitHub context - env: - GITHUB_CONTEXT: ${{ toJSON(github) }} - run: echo "$GITHUB_CONTEXT" - - - name: Checkout PR changes - uses: actions/checkout@v4 - with: - ref: refs/pull/${{ github.event.issue.number }}/head - - - name: Setup test data - # Workaround for `the input device is not a TTY`, appropriated from https://github.com/actions/runner/issues/241 - shell: 'script -q -e -c "bash -e {0}"' - run: | - cd benchmarks - mkdir data - - # Setup the TPC-H data sets for scale factors 1 and 10 - ./bench.sh data tpch - ./bench.sh data tpch10 - - - name: Generate unique result names - run: | - echo "HEAD_LONG_SHA=$(git log -1 --format='%H')" >> "$GITHUB_ENV" - echo "HEAD_SHORT_SHA=$(git log -1 --format='%h' --abbrev=7)" >> "$GITHUB_ENV" - echo "BASE_SHORT_SHA=$(echo "${{ github.sha }}" | cut -c1-7)" >> "$GITHUB_ENV" - - - name: Benchmark PR changes - env: - RESULTS_NAME: ${{ env.HEAD_SHORT_SHA }} - run: | - cd benchmarks - - ./bench.sh run tpch - ./bench.sh run tpch_mem - ./bench.sh run tpch10 - - # For some reason this step doesn't seem to propagate the env var down into the script - if [ -d "results/HEAD" ]; then - echo "Moving results into ${{ env.HEAD_SHORT_SHA }}" - mv results/HEAD results/${{ env.HEAD_SHORT_SHA }} - fi - - - name: Checkout base commit - uses: actions/checkout@v4 - with: - ref: ${{ github.sha }} - clean: false - - - name: Benchmark baseline and generate comparison message - env: - RESULTS_NAME: ${{ env.BASE_SHORT_SHA }} - run: | - cd benchmarks - - ./bench.sh run tpch - ./bench.sh run tpch_mem - ./bench.sh run tpch10 - - echo ${{ github.event.issue.number }} > pr - - pip3 install rich - cat > message.md < - Benchmarks comparing ${{ github.sha }} (main) and ${{ env.HEAD_LONG_SHA }} (PR) - - \`\`\` - $(./bench.sh compare ${{ env.BASE_SHORT_SHA }} ${{ env.HEAD_SHORT_SHA }}) - \`\`\` - - - EOF - - cat message.md - - - name: Upload benchmark comparison message - uses: actions/upload-artifact@v4 - with: - name: message - path: benchmarks/message.md - - - name: Upload PR number - uses: actions/upload-artifact@v4 - with: - name: pr - path: benchmarks/pr diff --git a/ci/scripts/rust_example.sh b/ci/scripts/rust_example.sh index 675dc4e527d0..0415090665d2 100755 --- a/ci/scripts/rust_example.sh +++ b/ci/scripts/rust_example.sh @@ -29,5 +29,6 @@ do # Skip tests that rely on external storage and flight if [ ! -d $filename ]; then cargo run --example $example_name + cargo clean -p datafusion-examples fi done diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 28312fee79a7..5fc8dbcfdfb3 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -387,7 +387,7 @@ checksum = "c6fa2087f2753a7da8cc1c0dbfcf89579dd57458e36769de5ac750b4671737ca" dependencies = [ "proc-macro2", "quote", - "syn 2.0.67", + "syn 2.0.68", ] [[package]] @@ -757,9 +757,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.5.0" +version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf4b9d6a944f767f8e5e0db018570623c85f3d925ac718db4e06d0187adb21c1" +checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" [[package]] name = "blake2" @@ -875,9 +875,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.0.99" +version = "1.0.103" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96c51067fd44124faa7f870b4b1c969379ad32b2ba805aa959430ceaa384f695" +checksum = "2755ff20a1d93490d26ba33a6f092a38a508398a5320df5d4b3014fcccce9410" dependencies = [ "jobserver", "libc", @@ -981,7 +981,7 @@ version = "7.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b34115915337defe99b2aff5c2ce6771e5fbc4079f4b506301f5cf394c8452f7" dependencies = [ - "strum 0.26.2", + "strum 0.26.3", "strum_macros 0.26.4", "unicode-width", ] @@ -1099,7 +1099,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "edb49164822f3ee45b17acd4a208cfc1251410cf0cad9a833234c9890774dd9f" dependencies = [ "quote", - "syn 2.0.67", + "syn 2.0.68", ] [[package]] @@ -1262,7 +1262,7 @@ dependencies = [ "paste", "serde_json", "sqlparser", - "strum 0.26.2", + "strum 0.26.3", "strum_macros 0.26.4", ] @@ -1426,7 +1426,7 @@ dependencies = [ "log", "regex", "sqlparser", - "strum 0.26.2", + "strum 0.26.3", ] [[package]] @@ -1504,9 +1504,9 @@ checksum = "fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10" [[package]] name = "either" -version = "1.12.0" +version = "1.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3dca9240753cf90908d7e4aac30f630662b02aebaa1b58a3cadabdb23385b58b" +checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" [[package]] name = "endian-type" @@ -1685,7 +1685,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" dependencies = [ "proc-macro2", "quote", - "syn 2.0.67", + "syn 2.0.68", ] [[package]] @@ -2253,9 +2253,9 @@ checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" [[package]] name = "libmimalloc-sys" -version = "0.1.38" +version = "0.1.39" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e7bb23d733dfcc8af652a78b7bf232f0e967710d044732185e561e47c0336b6" +checksum = "23aa6811d3bd4deb8a84dde645f943476d13b248d818edcf8ce0b2f37f036b44" dependencies = [ "cc", "libc", @@ -2267,7 +2267,7 @@ version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c0ff37bd590ca25063e35af745c343cb7a0271906fb7b37e4813e8f79f00268d" dependencies = [ - "bitflags 2.5.0", + "bitflags 2.6.0", "libc", ] @@ -2289,9 +2289,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.21" +version = "0.4.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c" +checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" [[package]] name = "lz4_flex" @@ -2331,9 +2331,9 @@ checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" [[package]] name = "mimalloc" -version = "0.1.42" +version = "0.1.43" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e9186d86b79b52f4a77af65604b51225e8db1d6ee7e3f41aec1e40829c71a176" +checksum = "68914350ae34959d83f732418d51e2427a794055d0b9529f48259ac07af65633" dependencies = [ "libmimalloc-sys", ] @@ -2406,9 +2406,9 @@ dependencies = [ [[package]] name = "num-bigint" -version = "0.4.5" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c165a9ab64cf766f73521c0dd2cfdff64f488b8f0b3e621face3462d3db536d7" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" dependencies = [ "num-integer", "num-traits", @@ -2482,9 +2482,9 @@ dependencies = [ [[package]] name = "object" -version = "0.36.0" +version = "0.36.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "576dfe1fc8f9df304abb159d767a29d0476f7750fbf8aa7ad07816004a207434" +checksum = "081b846d1d56ddfc18fdf1a922e4f6e07a11768ea1b92dec44e42b72712ccfce" dependencies = [ "memchr", ] @@ -2698,7 +2698,7 @@ checksum = "2f38a4412a78282e09a2cf38d195ea5420d15ba0602cb375210efbc877243965" dependencies = [ "proc-macro2", "quote", - "syn 2.0.67", + "syn 2.0.68", ] [[package]] @@ -2912,7 +2912,7 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c82cf8cff14456045f55ec4241383baeff27af886adb72ffb2162f99911de0fd" dependencies = [ - "bitflags 2.5.0", + "bitflags 2.6.0", ] [[package]] @@ -3095,7 +3095,7 @@ version = "0.38.34" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "70dc5ec042f7a43c4a73241207cecc9873a06d45debb38b329f8541d85c2730f" dependencies = [ - "bitflags 2.5.0", + "bitflags 2.6.0", "errno", "libc", "linux-raw-sys", @@ -3264,7 +3264,7 @@ version = "2.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c627723fd09706bacdb5cf41499e95098555af3c3c29d014dc3c458ef6be11c0" dependencies = [ - "bitflags 2.5.0", + "bitflags 2.6.0", "core-foundation", "core-foundation-sys", "libc", @@ -3310,14 +3310,14 @@ checksum = "500cbc0ebeb6f46627f50f3f5811ccf6bf00643be300b4c3eabc0ef55dc5b5ba" dependencies = [ "proc-macro2", "quote", - "syn 2.0.67", + "syn 2.0.68", ] [[package]] name = "serde_json" -version = "1.0.117" +version = "1.0.119" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "455182ea6142b14f93f4bc5320a2b31c1f266b66a4a5c858b013302a5d8cbfc3" +checksum = "e8eddb61f0697cc3989c5d64b452f5488e2b8a60fd7d5076a3045076ffef8cb0" dependencies = [ "itoa", "ryu", @@ -3445,7 +3445,7 @@ checksum = "01b2e185515564f15375f593fb966b5718bc624ba77fe49fa4616ad619690554" dependencies = [ "proc-macro2", "quote", - "syn 2.0.67", + "syn 2.0.68", ] [[package]] @@ -3474,9 +3474,9 @@ checksum = "290d54ea6f91c969195bdbcd7442c8c2a2ba87da8bf60a7ee86a235d4bc1e125" [[package]] name = "strum" -version = "0.26.2" +version = "0.26.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d8cec3501a5194c432b2b7976db6b7d10ec95c253208b45f83f7136aa985e29" +checksum = "8fec0f0aef304996cf250b31b5a10dee7980c85da9d759361292b8bca5a18f06" dependencies = [ "strum_macros 0.26.4", ] @@ -3491,7 +3491,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.67", + "syn 2.0.68", ] [[package]] @@ -3504,14 +3504,14 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.67", + "syn 2.0.68", ] [[package]] name = "subtle" -version = "2.6.0" +version = "2.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0d0208408ba0c3df17ed26eb06992cb1a1268d41b2c0e12e65203fbe3972cee5" +checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "syn" @@ -3526,9 +3526,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.67" +version = "2.0.68" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff8655ed1d86f3af4ee3fd3263786bc14245ad17c4c7e85ba7187fb3ae028c90" +checksum = "901fa70d88b9d6c98022e23b4136f9f3e54e4662c3bc1bd1d84a42a9a0f0c1e9" dependencies = [ "proc-macro2", "quote", @@ -3591,7 +3591,7 @@ checksum = "46c3384250002a6d5af4d114f2845d37b57521033f30d5c3f46c4d70e1197533" dependencies = [ "proc-macro2", "quote", - "syn 2.0.67", + "syn 2.0.68", ] [[package]] @@ -3646,9 +3646,9 @@ dependencies = [ [[package]] name = "tinyvec" -version = "1.6.0" +version = "1.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87cc5ceb3875bb20c2890005a4e226a4651264a5c75edb2421b52861a0a0cb50" +checksum = "c55115c6fbe2d2bef26eb09ad74bde02d8255476fc0c7b515ef09fbb35742d82" dependencies = [ "tinyvec_macros", ] @@ -3686,7 +3686,7 @@ checksum = "5f5ae998a069d4b5aba8ee9dad856af7d520c3699e6159b185c2acd48155d39a" dependencies = [ "proc-macro2", "quote", - "syn 2.0.67", + "syn 2.0.68", ] [[package]] @@ -3783,7 +3783,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.67", + "syn 2.0.68", ] [[package]] @@ -3828,7 +3828,7 @@ checksum = "f03ca4cb38206e2bef0700092660bb74d696f808514dae47fa1467cbfe26e96e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.67", + "syn 2.0.68", ] [[package]] @@ -3907,9 +3907,9 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "uuid" -version = "1.8.0" +version = "1.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a183cf7feeba97b4dd1c0d46788634f6221d87fa961b305bed08c851829efcc0" +checksum = "5de17fd2f7da591098415cff336e12965a28061ddace43b59cb3c430179c9439" dependencies = [ "getrandom", "serde", @@ -3982,7 +3982,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.67", + "syn 2.0.68", "wasm-bindgen-shared", ] @@ -4016,7 +4016,7 @@ checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.67", + "syn 2.0.68", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -4281,7 +4281,7 @@ checksum = "15e934569e47891f7d9411f1a451d947a60e000ab3bd24fbb970f000387d1b3b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.67", + "syn 2.0.68", ] [[package]] diff --git a/datafusion-examples/examples/custom_file_format.rs b/datafusion-examples/examples/custom_file_format.rs new file mode 100644 index 000000000000..fe936418bce4 --- /dev/null +++ b/datafusion-examples/examples/custom_file_format.rs @@ -0,0 +1,234 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{any::Any, sync::Arc}; + +use arrow::{ + array::{AsArray, RecordBatch, StringArray, UInt8Array}, + datatypes::UInt64Type, +}; +use arrow_schema::{DataType, Field, Schema, SchemaRef}; +use datafusion::{ + datasource::{ + file_format::{ + csv::CsvFormatFactory, file_compression_type::FileCompressionType, + FileFormat, FileFormatFactory, + }, + physical_plan::{FileScanConfig, FileSinkConfig}, + MemTable, + }, + error::Result, + execution::{context::SessionState, runtime_env::RuntimeEnv}, + physical_plan::ExecutionPlan, + prelude::{SessionConfig, SessionContext}, +}; +use datafusion_common::{GetExt, Statistics}; +use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement}; +use object_store::{ObjectMeta, ObjectStore}; +use tempfile::tempdir; + +/// Example of a custom file format that reads and writes TSV files. +/// +/// TSVFileFormatFactory is responsible for creating instances of TSVFileFormat. +/// The former, once registered with the SessionState, will then be used +/// to facilitate SQL operations on TSV files, such as `COPY TO` shown here. + +#[derive(Debug)] +/// Custom file format that reads and writes TSV files +/// +/// This file format is a wrapper around the CSV file format +/// for demonstration purposes. +struct TSVFileFormat { + csv_file_format: Arc, +} + +impl TSVFileFormat { + pub fn new(csv_file_format: Arc) -> Self { + Self { csv_file_format } + } +} + +#[async_trait::async_trait] +impl FileFormat for TSVFileFormat { + fn as_any(&self) -> &dyn Any { + self + } + + fn get_ext(&self) -> String { + "tsv".to_string() + } + + fn get_ext_with_compression( + &self, + c: &FileCompressionType, + ) -> datafusion::error::Result { + if c == &FileCompressionType::UNCOMPRESSED { + Ok("tsv".to_string()) + } else { + todo!("Compression not supported") + } + } + + async fn infer_schema( + &self, + state: &SessionState, + store: &Arc, + objects: &[ObjectMeta], + ) -> Result { + self.csv_file_format + .infer_schema(state, store, objects) + .await + } + + async fn infer_stats( + &self, + state: &SessionState, + store: &Arc, + table_schema: SchemaRef, + object: &ObjectMeta, + ) -> Result { + self.csv_file_format + .infer_stats(state, store, table_schema, object) + .await + } + + async fn create_physical_plan( + &self, + state: &SessionState, + conf: FileScanConfig, + filters: Option<&Arc>, + ) -> Result> { + self.csv_file_format + .create_physical_plan(state, conf, filters) + .await + } + + async fn create_writer_physical_plan( + &self, + input: Arc, + state: &SessionState, + conf: FileSinkConfig, + order_requirements: Option>, + ) -> Result> { + self.csv_file_format + .create_writer_physical_plan(input, state, conf, order_requirements) + .await + } +} + +#[derive(Default)] +/// Factory for creating TSV file formats +/// +/// This factory is a wrapper around the CSV file format factory +/// for demonstration purposes. +pub struct TSVFileFactory { + csv_file_factory: CsvFormatFactory, +} + +impl TSVFileFactory { + pub fn new() -> Self { + Self { + csv_file_factory: CsvFormatFactory::new(), + } + } +} + +impl FileFormatFactory for TSVFileFactory { + fn create( + &self, + state: &SessionState, + format_options: &std::collections::HashMap, + ) -> Result> { + let mut new_options = format_options.clone(); + new_options.insert("format.delimiter".to_string(), "\t".to_string()); + + let csv_file_format = self.csv_file_factory.create(state, &new_options)?; + let tsv_file_format = Arc::new(TSVFileFormat::new(csv_file_format)); + + Ok(tsv_file_format) + } + + fn default(&self) -> std::sync::Arc { + todo!() + } +} + +impl GetExt for TSVFileFactory { + fn get_ext(&self) -> String { + "tsv".to_string() + } +} + +#[tokio::main] +async fn main() -> Result<()> { + // Create a new context with the default configuration + let config = SessionConfig::new(); + let runtime = RuntimeEnv::default(); + let mut state = SessionState::new_with_config_rt(config, Arc::new(runtime)); + + // Register the custom file format + let file_format = Arc::new(TSVFileFactory::new()); + state.register_file_format(file_format, true).unwrap(); + + // Create a new context with the custom file format + let ctx = SessionContext::new_with_state(state); + + let mem_table = create_mem_table(); + ctx.register_table("mem_table", mem_table).unwrap(); + + let temp_dir = tempdir().unwrap(); + let table_save_path = temp_dir.path().join("mem_table.tsv"); + + let d = ctx + .sql(&format!( + "COPY mem_table TO '{}' STORED AS TSV;", + table_save_path.display(), + )) + .await?; + + let results = d.collect().await?; + println!( + "Number of inserted rows: {:?}", + (results[0] + .column_by_name("count") + .unwrap() + .as_primitive::() + .value(0)) + ); + + Ok(()) +} + +// create a simple mem table +fn create_mem_table() -> Arc { + let fields = vec![ + Field::new("id", DataType::UInt8, false), + Field::new("data", DataType::Utf8, false), + ]; + let schema = Arc::new(Schema::new(fields)); + + let partitions = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(UInt8Array::from(vec![1, 2])), + Arc::new(StringArray::from(vec!["foo", "bar"])), + ], + ) + .unwrap(); + + Arc::new(MemTable::try_new(schema, vec![vec![partitions]]).unwrap()) +} diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index 1ecdb0efd2c2..1d2a9589adfc 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -613,6 +613,9 @@ config_namespace! { /// When set to true, the explain statement will print the partition sizes pub show_sizes: bool, default = true + + /// When set to true, the explain statement will print schema information + pub show_schema: bool, default = false } } diff --git a/datafusion/common/src/display/mod.rs b/datafusion/common/src/display/mod.rs index 4d1d48bf9fcc..2345c0e4c4fc 100644 --- a/datafusion/common/src/display/mod.rs +++ b/datafusion/common/src/display/mod.rs @@ -49,15 +49,19 @@ pub enum PlanType { InitialPhysicalPlan, /// The initial physical plan with stats, prepared for execution InitialPhysicalPlanWithStats, + /// The initial physical plan with schema, prepared for execution + InitialPhysicalPlanWithSchema, /// The ExecutionPlan which results from applying an optimizer pass OptimizedPhysicalPlan { /// The name of the optimizer which produced this plan optimizer_name: String, }, - /// The final, fully optimized physical which would be executed + /// The final, fully optimized physical plan which would be executed FinalPhysicalPlan, - /// The final with stats, fully optimized physical which would be executed + /// The final with stats, fully optimized physical plan which would be executed FinalPhysicalPlanWithStats, + /// The final with schema, fully optimized physical plan which would be executed + FinalPhysicalPlanWithSchema, } impl Display for PlanType { @@ -76,11 +80,17 @@ impl Display for PlanType { PlanType::InitialPhysicalPlanWithStats => { write!(f, "initial_physical_plan_with_stats") } + PlanType::InitialPhysicalPlanWithSchema => { + write!(f, "initial_physical_plan_with_schema") + } PlanType::OptimizedPhysicalPlan { optimizer_name } => { write!(f, "physical_plan after {optimizer_name}") } PlanType::FinalPhysicalPlan => write!(f, "physical_plan"), PlanType::FinalPhysicalPlanWithStats => write!(f, "physical_plan_with_stats"), + PlanType::FinalPhysicalPlanWithSchema => { + write!(f, "physical_plan_with_schema") + } } } } diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index 5b9c4a223de6..bd2265c85003 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -982,6 +982,7 @@ impl ScalarValue { DataType::UInt16 => ScalarValue::UInt16(Some(0)), DataType::UInt32 => ScalarValue::UInt32(Some(0)), DataType::UInt64 => ScalarValue::UInt64(Some(0)), + DataType::Float16 => ScalarValue::Float16(Some(f16::from_f32(0.0))), DataType::Float32 => ScalarValue::Float32(Some(0.0)), DataType::Float64 => ScalarValue::Float64(Some(0.0)), DataType::Timestamp(TimeUnit::Second, tz) => { @@ -1035,6 +1036,7 @@ impl ScalarValue { DataType::UInt16 => ScalarValue::UInt16(Some(1)), DataType::UInt32 => ScalarValue::UInt32(Some(1)), DataType::UInt64 => ScalarValue::UInt64(Some(1)), + DataType::Float16 => ScalarValue::Float16(Some(f16::from_f32(1.0))), DataType::Float32 => ScalarValue::Float32(Some(1.0)), DataType::Float64 => ScalarValue::Float64(Some(1.0)), _ => { @@ -1053,6 +1055,7 @@ impl ScalarValue { DataType::Int16 | DataType::UInt16 => ScalarValue::Int16(Some(-1)), DataType::Int32 | DataType::UInt32 => ScalarValue::Int32(Some(-1)), DataType::Int64 | DataType::UInt64 => ScalarValue::Int64(Some(-1)), + DataType::Float16 => ScalarValue::Float16(Some(f16::from_f32(-1.0))), DataType::Float32 => ScalarValue::Float32(Some(-1.0)), DataType::Float64 => ScalarValue::Float64(Some(-1.0)), _ => { @@ -1074,6 +1077,7 @@ impl ScalarValue { DataType::UInt16 => ScalarValue::UInt16(Some(10)), DataType::UInt32 => ScalarValue::UInt32(Some(10)), DataType::UInt64 => ScalarValue::UInt64(Some(10)), + DataType::Float16 => ScalarValue::Float16(Some(f16::from_f32(10.0))), DataType::Float32 => ScalarValue::Float32(Some(10.0)), DataType::Float64 => ScalarValue::Float64(Some(10.0)), _ => { @@ -1181,8 +1185,12 @@ impl ScalarValue { | ScalarValue::Int16(None) | ScalarValue::Int32(None) | ScalarValue::Int64(None) + | ScalarValue::Float16(None) | ScalarValue::Float32(None) | ScalarValue::Float64(None) => Ok(self.clone()), + ScalarValue::Float16(Some(v)) => { + Ok(ScalarValue::Float16(Some(f16::from_f32(-v.to_f32())))) + } ScalarValue::Float64(Some(v)) => Ok(ScalarValue::Float64(Some(-v))), ScalarValue::Float32(Some(v)) => Ok(ScalarValue::Float32(Some(-v))), ScalarValue::Int8(Some(v)) => Ok(ScalarValue::Int8(Some(v.neg_checked()?))), @@ -1435,6 +1443,9 @@ impl ScalarValue { (Self::UInt32(Some(l)), Self::UInt32(Some(r))) => Some(l.abs_diff(*r) as _), (Self::UInt64(Some(l)), Self::UInt64(Some(r))) => Some(l.abs_diff(*r) as _), // TODO: we might want to look into supporting ceil/floor here for floats. + (Self::Float16(Some(l)), Self::Float16(Some(r))) => { + Some((f16::to_f32(*l) - f16::to_f32(*r)).abs().round() as _) + } (Self::Float32(Some(l)), Self::Float32(Some(r))) => { Some((l - r).abs().round() as _) } @@ -2452,6 +2463,7 @@ impl ScalarValue { DataType::Boolean => typed_cast!(array, index, BooleanArray, Boolean)?, DataType::Float64 => typed_cast!(array, index, Float64Array, Float64)?, DataType::Float32 => typed_cast!(array, index, Float32Array, Float32)?, + DataType::Float16 => typed_cast!(array, index, Float16Array, Float16)?, DataType::UInt64 => typed_cast!(array, index, UInt64Array, UInt64)?, DataType::UInt32 => typed_cast!(array, index, UInt32Array, UInt32)?, DataType::UInt16 => typed_cast!(array, index, UInt16Array, UInt16)?, @@ -5635,7 +5647,6 @@ mod tests { } #[test] - #[should_panic(expected = "Can not run arithmetic negative on scalar value Float16")] fn f16_test_overflow() { // TODO: if negate supports f16, add these cases to `test_scalar_negative_overflows` test case let cases = [ @@ -5805,6 +5816,21 @@ mod tests { ScalarValue::UInt64(Some(10)), 5, ), + ( + ScalarValue::Float16(Some(f16::from_f32(1.1))), + ScalarValue::Float16(Some(f16::from_f32(1.9))), + 1, + ), + ( + ScalarValue::Float16(Some(f16::from_f32(-5.3))), + ScalarValue::Float16(Some(f16::from_f32(-9.2))), + 4, + ), + ( + ScalarValue::Float16(Some(f16::from_f32(-5.3))), + ScalarValue::Float16(Some(f16::from_f32(-9.7))), + 4, + ), ( ScalarValue::Float32(Some(1.0)), ScalarValue::Float32(Some(2.0)), @@ -5877,6 +5903,14 @@ mod tests { // Different type (ScalarValue::Int8(Some(1)), ScalarValue::Int16(Some(1))), (ScalarValue::Int8(Some(1)), ScalarValue::Float32(Some(1.0))), + ( + ScalarValue::Float16(Some(f16::from_f32(1.0))), + ScalarValue::Float32(Some(1.0)), + ), + ( + ScalarValue::Float16(Some(f16::from_f32(1.0))), + ScalarValue::Int32(Some(1)), + ), ( ScalarValue::Float64(Some(1.1)), ScalarValue::Float32(Some(2.2)), diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index 89e69caffcef..5921d8a797ac 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -1152,7 +1152,8 @@ mod tests { use crate::physical_plan::metrics::MetricValue; use crate::prelude::{SessionConfig, SessionContext}; use arrow::array::{Array, ArrayRef, StringArray}; - use arrow_array::Int64Array; + use arrow_array::types::Int32Type; + use arrow_array::{DictionaryArray, Int32Array, Int64Array}; use arrow_schema::{DataType, Field}; use async_trait::async_trait; use datafusion_common::cast::{ @@ -1161,6 +1162,7 @@ mod tests { }; use datafusion_common::config::ParquetOptions; use datafusion_common::ScalarValue; + use datafusion_common::ScalarValue::Utf8; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_physical_plan::stream::RecordBatchStreamAdapter; @@ -1442,6 +1444,48 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_statistics_from_parquet_metadata_dictionary() -> Result<()> { + // Data for column c_dic: ["a", "b", "c", "d"] + let values = StringArray::from_iter_values(["a", "b", "c", "d"]); + let keys = Int32Array::from_iter_values([0, 1, 2, 3]); + let dic_array = + DictionaryArray::::try_new(keys, Arc::new(values)).unwrap(); + let c_dic: ArrayRef = Arc::new(dic_array); + + let batch1 = RecordBatch::try_from_iter(vec![("c_dic", c_dic)]).unwrap(); + + // Use store_parquet to write each batch to its own file + // . batch1 written into first file and includes: + // - column c_dic that has 4 rows with no null. Stats min and max of dictionary column is available. + let store = Arc::new(LocalFileSystem::new()) as _; + let (files, _file_names) = store_parquet(vec![batch1], false).await?; + + let state = SessionContext::new().state(); + let format = ParquetFormat::default(); + let schema = format.infer_schema(&state, &store, &files).await.unwrap(); + + // Fetch statistics for first file + let pq_meta = fetch_parquet_metadata(store.as_ref(), &files[0], None).await?; + let stats = statistics_from_parquet_meta(&pq_meta, schema.clone()).await?; + assert_eq!(stats.num_rows, Precision::Exact(4)); + + // column c_dic + let c_dic_stats = &stats.column_statistics[0]; + + assert_eq!(c_dic_stats.null_count, Precision::Exact(0)); + assert_eq!( + c_dic_stats.max_value, + Precision::Exact(Utf8(Some("d".into()))) + ); + assert_eq!( + c_dic_stats.min_value, + Precision::Exact(Utf8(Some("a".into()))) + ); + + Ok(()) + } + #[tokio::test] async fn test_statistics_from_parquet_metadata() -> Result<()> { // Data for column c1: ["Foo", null, "bar"] diff --git a/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs b/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs index 67c517ddbc4f..fcecae27a52f 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs @@ -22,13 +22,13 @@ use arrow::datatypes::i256; use arrow::{array::ArrayRef, datatypes::DataType}; use arrow_array::{ - new_null_array, BinaryArray, BooleanArray, Date32Array, Date64Array, Decimal128Array, - Decimal256Array, FixedSizeBinaryArray, Float16Array, Float32Array, Float64Array, - Int16Array, Int32Array, Int64Array, Int8Array, LargeBinaryArray, LargeStringArray, - StringArray, Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray, - Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, - TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array, - UInt64Array, UInt8Array, + new_empty_array, new_null_array, BinaryArray, BooleanArray, Date32Array, Date64Array, + Decimal128Array, Decimal256Array, FixedSizeBinaryArray, Float16Array, Float32Array, + Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, LargeBinaryArray, + LargeStringArray, StringArray, Time32MillisecondArray, Time32SecondArray, + Time64MicrosecondArray, Time64NanosecondArray, TimestampMicrosecondArray, + TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, + UInt16Array, UInt32Array, UInt64Array, UInt8Array, }; use arrow_schema::{Field, FieldRef, Schema, TimeUnit}; use datafusion_common::{internal_datafusion_err, internal_err, plan_err, Result}; @@ -846,6 +846,9 @@ macro_rules! get_data_page_statistics { }) }).flatten().collect::>(), ))), + Some(DataType::Dictionary(_, value_type)) => { + [<$stat_type_prefix:lower _ page_statistics>](Some(value_type), $iterator) + }, Some(DataType::Timestamp(unit, timezone)) => { let iter = [<$stat_type_prefix Int64DataPageStatsIterator>]::new($iterator).flatten(); Ok(match unit { @@ -873,6 +876,34 @@ macro_rules! get_data_page_statistics { Decimal128Array::from_iter([<$stat_type_prefix Decimal128DataPageStatsIterator>]::new($iterator).flatten()).with_precision_and_scale(*precision, *scale)?)), Some(DataType::Decimal256(precision, scale)) => Ok(Arc::new( Decimal256Array::from_iter([<$stat_type_prefix Decimal256DataPageStatsIterator>]::new($iterator).flatten()).with_precision_and_scale(*precision, *scale)?)), + Some(DataType::Time32(unit)) => { + Ok(match unit { + TimeUnit::Second => Arc::new(Time32SecondArray::from_iter( + [<$stat_type_prefix Int32DataPageStatsIterator>]::new($iterator).flatten(), + )), + TimeUnit::Millisecond => Arc::new(Time32MillisecondArray::from_iter( + [<$stat_type_prefix Int32DataPageStatsIterator>]::new($iterator).flatten(), + )), + _ => { + // don't know how to extract statistics, so return an empty array + new_empty_array(&DataType::Time32(unit.clone())) + } + }) + } + Some(DataType::Time64(unit)) => { + Ok(match unit { + TimeUnit::Microsecond => Arc::new(Time64MicrosecondArray::from_iter( + [<$stat_type_prefix Int64DataPageStatsIterator>]::new($iterator).flatten(), + )), + TimeUnit::Nanosecond => Arc::new(Time64NanosecondArray::from_iter( + [<$stat_type_prefix Int64DataPageStatsIterator>]::new($iterator).flatten(), + )), + _ => { + // don't know how to extract statistics, so return an empty array + new_empty_array(&DataType::Time64(unit.clone())) + } + }) + } _ => unimplemented!() } } diff --git a/datafusion/core/src/datasource/statistics.rs b/datafusion/core/src/datasource/statistics.rs index c67227f966a2..a243a1c3558f 100644 --- a/datafusion/core/src/datasource/statistics.rs +++ b/datafusion/core/src/datasource/statistics.rs @@ -20,6 +20,7 @@ use crate::arrow::datatypes::{Schema, SchemaRef}; use crate::error::Result; use crate::physical_plan::expressions::{MaxAccumulator, MinAccumulator}; use crate::physical_plan::{Accumulator, ColumnStatistics, Statistics}; +use arrow_schema::DataType; use datafusion_common::stats::Precision; use datafusion_common::ScalarValue; @@ -156,12 +157,16 @@ pub(crate) fn create_max_min_accs( let max_values: Vec> = schema .fields() .iter() - .map(|field| MaxAccumulator::try_new(field.data_type()).ok()) + .map(|field| { + MaxAccumulator::try_new(min_max_aggregate_data_type(field.data_type())).ok() + }) .collect(); let min_values: Vec> = schema .fields() .iter() - .map(|field| MinAccumulator::try_new(field.data_type()).ok()) + .map(|field| { + MinAccumulator::try_new(min_max_aggregate_data_type(field.data_type())).ok() + }) .collect(); (max_values, min_values) } @@ -218,6 +223,18 @@ pub(crate) fn get_col_stats( .collect() } +// Min/max aggregation can take Dictionary encode input but always produces unpacked +// (aka non Dictionary) output. We need to adjust the output data type to reflect this. +// The reason min/max aggregate produces unpacked output because there is only one +// min/max value per group; there is no needs to keep them Dictionary encode +fn min_max_aggregate_data_type(input_type: &DataType) -> &DataType { + if let DataType::Dictionary(_, value_type) = input_type { + value_type.as_ref() + } else { + input_type + } +} + /// If the given value is numerically greater than the original maximum value, /// return the new maximum value with appropriate exactness information. fn set_max_if_greater( diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 9ec0148d9122..012717a007c2 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -1490,13 +1490,13 @@ impl SQLOptions { Default::default() } - /// Should DML data modification commands (e.g. `INSERT and COPY`) be run? Defaults to `true`. + /// Should DDL data definition commands (e.g. `CREATE TABLE`) be run? Defaults to `true`. pub fn with_allow_ddl(mut self, allow: bool) -> Self { self.allow_ddl = allow; self } - /// Should DML data modification commands (e.g. `INSERT and COPY`) be run? Defaults to `true` + /// Should DML data modification commands (e.g. `INSERT` and `COPY`) be run? Defaults to `true` pub fn with_allow_dml(mut self, allow: bool) -> Self { self.allow_dml = allow; self diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index 56b51e792bae..cf45fef5169b 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -607,7 +607,7 @@ impl SessionState { } } - let query = SqlToRel::new_with_options(&provider, self.get_parser_options()); + let query = self.build_sql_query_planner(&provider); query.statement_to_plan(statement) } @@ -660,8 +660,7 @@ impl SessionState { tables: HashMap::new(), }; - let query = SqlToRel::new_with_options(&provider, self.get_parser_options()); - + let query = self.build_sql_query_planner(&provider); query.sql_to_expr(sql_expr, df_schema, &mut PlannerContext::new()) } @@ -945,6 +944,31 @@ impl SessionState { let udtf = self.table_functions.remove(name); Ok(udtf.map(|x| x.function().clone())) } + + fn build_sql_query_planner<'a, S>(&self, provider: &'a S) -> SqlToRel<'a, S> + where + S: ContextProvider, + { + let query = SqlToRel::new_with_options(provider, self.get_parser_options()); + + // register crate of array expressions (if enabled) + #[cfg(feature = "array_expressions")] + { + let array_planner = + Arc::new(functions_array::planner::ArrayFunctionPlanner::default()) as _; + + let field_access_planner = + Arc::new(functions_array::planner::FieldAccessPlanner::default()) as _; + + query + .with_user_defined_planner(array_planner) + .with_user_defined_planner(field_access_planner) + } + #[cfg(not(feature = "array_expressions"))] + { + query + } + } } struct SessionContextProvider<'a> { diff --git a/datafusion/core/src/physical_optimizer/mod.rs b/datafusion/core/src/physical_optimizer/mod.rs index 7cc9a0fb75d4..e1bde36bd6fe 100644 --- a/datafusion/core/src/physical_optimizer/mod.rs +++ b/datafusion/core/src/physical_optimizer/mod.rs @@ -31,12 +31,12 @@ pub mod limited_distinct_aggregation; pub mod optimizer; pub mod output_requirements; pub mod pipeline_checker; -mod projection_pushdown; +pub mod projection_pushdown; pub mod pruning; pub mod replace_with_order_preserving_variants; mod sort_pushdown; pub mod topk_aggregation; -mod update_aggr_exprs; +pub mod update_aggr_exprs; mod utils; #[cfg(test)] diff --git a/datafusion/core/src/physical_optimizer/update_aggr_exprs.rs b/datafusion/core/src/physical_optimizer/update_aggr_exprs.rs index 6a6ca815c510..1ad4179cefd8 100644 --- a/datafusion/core/src/physical_optimizer/update_aggr_exprs.rs +++ b/datafusion/core/src/physical_optimizer/update_aggr_exprs.rs @@ -52,6 +52,7 @@ use datafusion_physical_plan::{ pub struct OptimizeAggregateOrder {} impl OptimizeAggregateOrder { + #[allow(missing_docs)] pub fn new() -> Self { Self::default() } diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 47efaac858ac..423eb8023f31 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -1983,23 +1983,37 @@ impl DefaultPhysicalPlanner { .await { Ok(input) => { - // This plan will includes statistics if show_statistics is on + // Include statistics / schema if enabled stringified_plans.push( displayable(input.as_ref()) .set_show_statistics(config.show_statistics) + .set_show_schema(config.show_schema) .to_stringified(e.verbose, InitialPhysicalPlan), ); - // If the show_statisitcs is off, add another line to show statsitics in the case of explain verbose - if e.verbose && !config.show_statistics { - stringified_plans.push( - displayable(input.as_ref()) - .set_show_statistics(true) - .to_stringified( - e.verbose, - InitialPhysicalPlanWithStats, - ), - ); + // Show statistics + schema in verbose output even if not + // explicitly requested + if e.verbose { + if !config.show_statistics { + stringified_plans.push( + displayable(input.as_ref()) + .set_show_statistics(true) + .to_stringified( + e.verbose, + InitialPhysicalPlanWithStats, + ), + ); + } + if !config.show_schema { + stringified_plans.push( + displayable(input.as_ref()) + .set_show_schema(true) + .to_stringified( + e.verbose, + InitialPhysicalPlanWithSchema, + ), + ); + } } let optimized_plan = self.optimize_internal( @@ -2011,6 +2025,7 @@ impl DefaultPhysicalPlanner { stringified_plans.push( displayable(plan) .set_show_statistics(config.show_statistics) + .set_show_schema(config.show_schema) .to_stringified(e.verbose, plan_type), ); }, @@ -2021,19 +2036,33 @@ impl DefaultPhysicalPlanner { stringified_plans.push( displayable(input.as_ref()) .set_show_statistics(config.show_statistics) + .set_show_schema(config.show_schema) .to_stringified(e.verbose, FinalPhysicalPlan), ); - // If the show_statisitcs is off, add another line to show statsitics in the case of explain verbose - if e.verbose && !config.show_statistics { - stringified_plans.push( - displayable(input.as_ref()) - .set_show_statistics(true) - .to_stringified( - e.verbose, - FinalPhysicalPlanWithStats, - ), - ); + // Show statistics + schema in verbose output even if not + // explicitly requested + if e.verbose { + if !config.show_statistics { + stringified_plans.push( + displayable(input.as_ref()) + .set_show_statistics(true) + .to_stringified( + e.verbose, + FinalPhysicalPlanWithStats, + ), + ); + } + if !config.show_schema { + stringified_plans.push( + displayable(input.as_ref()) + .set_show_schema(true) + .to_stringified( + e.verbose, + FinalPhysicalPlanWithSchema, + ), + ); + } } } Err(DataFusionError::Context(optimizer_name, e)) => { diff --git a/datafusion/core/tests/parquet/arrow_statistics.rs b/datafusion/core/tests/parquet/arrow_statistics.rs index 47f079063d3c..ea83c1fa788d 100644 --- a/datafusion/core/tests/parquet/arrow_statistics.rs +++ b/datafusion/core/tests/parquet/arrow_statistics.rs @@ -1204,7 +1204,7 @@ async fn test_time32_second_diff_rg_sizes() { expected_null_counts: UInt64Array::from(vec![0, 0, 0, 0]), // Assuming 1 null per row group for simplicity expected_row_counts: Some(UInt64Array::from(vec![4, 4, 4, 4])), column_name: "second", - check: Check::RowGroup, + check: Check::Both, } .run(); } @@ -1231,7 +1231,7 @@ async fn test_time32_millisecond_diff_rg_sizes() { expected_null_counts: UInt64Array::from(vec![0, 0, 0, 0]), // Assuming 1 null per row group for simplicity expected_row_counts: Some(UInt64Array::from(vec![4, 4, 4, 4])), column_name: "millisecond", - check: Check::RowGroup, + check: Check::Both, } .run(); } @@ -1264,7 +1264,7 @@ async fn test_time64_microsecond_diff_rg_sizes() { expected_null_counts: UInt64Array::from(vec![0, 0, 0, 0]), // Assuming 1 null per row group for simplicity expected_row_counts: Some(UInt64Array::from(vec![4, 4, 4, 4])), column_name: "microsecond", - check: Check::RowGroup, + check: Check::Both, } .run(); } @@ -1297,7 +1297,7 @@ async fn test_time64_nanosecond_diff_rg_sizes() { expected_null_counts: UInt64Array::from(vec![0, 0, 0, 0]), // Assuming 1 null per row group for simplicity expected_row_counts: Some(UInt64Array::from(vec![4, 4, 4, 4])), column_name: "nanosecond", - check: Check::RowGroup, + check: Check::Both, } .run(); } @@ -1752,7 +1752,7 @@ async fn test_dictionary() { expected_null_counts: UInt64Array::from(vec![1, 0]), expected_row_counts: Some(UInt64Array::from(vec![5, 2])), column_name: "string_dict_i32", - check: Check::RowGroup, + check: Check::Both, } .run(); diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 3893ee750f59..c7e94d8e0b83 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -1200,32 +1200,53 @@ impl Expr { } } - /// Recursively potentially multiple aliases from an expression. + /// Recursively removed potentially multiple aliases from an expression. /// - /// If the expression is not an alias, the expression is returned unchanged. - /// This method removes directly nested aliases, but not other nested - /// aliases. + /// This method removes nested aliases and returns [`Transformed`] + /// to signal if the expression was changed. /// /// # Example /// ``` /// # use datafusion_expr::col; /// // `foo as "bar"` is unaliased to `foo` /// let expr = col("foo").alias("bar"); - /// assert_eq!(expr.unalias_nested(), col("foo")); + /// assert_eq!(expr.unalias_nested().data, col("foo")); /// - /// // `foo as "bar" + baz` is not unaliased + /// // `foo as "bar" + baz` is unaliased /// let expr = col("foo").alias("bar") + col("baz"); - /// assert_eq!(expr.clone().unalias_nested(), expr); + /// assert_eq!(expr.clone().unalias_nested().data, col("foo") + col("baz")); /// /// // `foo as "bar" as "baz" is unalaised to foo /// let expr = col("foo").alias("bar").alias("baz"); - /// assert_eq!(expr.unalias_nested(), col("foo")); + /// assert_eq!(expr.unalias_nested().data, col("foo")); /// ``` - pub fn unalias_nested(self) -> Expr { - match self { - Expr::Alias(alias) => alias.expr.unalias_nested(), - _ => self, - } + pub fn unalias_nested(self) -> Transformed { + self.transform_down_up( + |expr| { + // f_down: skip subqueries. Check in f_down to avoid recursing into them + let recursion = if matches!( + expr, + Expr::Exists { .. } | Expr::ScalarSubquery(_) | Expr::InSubquery(_) + ) { + // subqueries could contain aliases so don't recurse into those + TreeNodeRecursion::Jump + } else { + TreeNodeRecursion::Continue + }; + Ok(Transformed::new(expr, false, recursion)) + }, + |expr| { + // f_up: unalias on up so we can remove nested aliases like + // `(x as foo) as bar` + if let Expr::Alias(Alias { expr, .. }) = expr { + Ok(Transformed::yes(*expr)) + } else { + Ok(Transformed::no(expr)) + } + }, + ) + // unreachable code: internal closure doesn't return err + .unwrap() } /// Return `self IN ` if `negated` is false, otherwise diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index 217477fa1010..7a3c556b3802 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -43,6 +43,7 @@ pub use order_by::rewrite_sort_cols_by_aggs; /// For example, concatenating arrays `a || b` is represented as /// `Operator::ArrowAt`, but can be implemented by calling a function /// `array_concat` from the `functions-array` crate. +// This is not used in datafusion internally, but it is still helpful for downstream project so don't remove it. pub trait FunctionRewrite { /// Return a human readable name for this rewrite fn name(&self) -> &str; diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 89ee94f9f845..5f1d3c9d5c6b 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -48,6 +48,7 @@ pub mod function; pub mod groups_accumulator; pub mod interval_arithmetic; pub mod logical_plan; +pub mod planner; pub mod registry; pub mod simplify; pub mod sort_properties; @@ -81,6 +82,7 @@ pub use partition_evaluator::PartitionEvaluator; pub use signature::{ ArrayFunctionSignature, Signature, TypeSignature, Volatility, TIMEZONE_WILDCARD, }; +pub use sqlparser; pub use table_source::{TableProviderFilterPushDown, TableSource, TableType}; pub use udaf::{AggregateExt, AggregateUDF, AggregateUDFImpl, ReversedUDAF}; pub use udf::{ScalarUDF, ScalarUDFImpl}; diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index d19d0f4a6621..a10897a2d6a2 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -882,8 +882,7 @@ impl LogicalPlan { } LogicalPlan::Filter { .. } => { assert_eq!(1, expr.len()); - let predicate = expr.pop().unwrap(); - let predicate = Filter::remove_aliases(predicate)?.data; + let predicate = expr.pop().unwrap().unalias_nested().data; Filter::try_new(predicate, Arc::new(inputs.swap_remove(0))) .map(LogicalPlan::Filter) @@ -2213,38 +2212,6 @@ impl Filter { } false } - - /// Remove aliases from a predicate for use in a `Filter` - /// - /// filter predicates should not contain aliased expressions so we remove - /// any aliases. - /// - /// before this logic was added we would have aliases within filters such as - /// for benchmark q6: - /// - /// ```sql - /// lineitem.l_shipdate >= Date32(\"8766\") - /// AND lineitem.l_shipdate < Date32(\"9131\") - /// AND CAST(lineitem.l_discount AS Decimal128(30, 15)) AS lineitem.l_discount >= - /// Decimal128(Some(49999999999999),30,15) - /// AND CAST(lineitem.l_discount AS Decimal128(30, 15)) AS lineitem.l_discount <= - /// Decimal128(Some(69999999999999),30,15) - /// AND lineitem.l_quantity < Decimal128(Some(2400),15,2) - /// ``` - pub fn remove_aliases(predicate: Expr) -> Result> { - predicate.transform_down(|expr| { - match expr { - Expr::Exists { .. } | Expr::ScalarSubquery(_) | Expr::InSubquery(_) => { - // subqueries could contain aliases so we don't recurse into those - Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump)) - } - Expr::Alias(Alias { expr, .. }) => { - Ok(Transformed::new(*expr, true, TreeNodeRecursion::Jump)) - } - _ => Ok(Transformed::no(expr)), - } - }) - } } /// Window its input based on a set of window spec and window function (e.g. SUM or RANK) diff --git a/datafusion/expr/src/planner.rs b/datafusion/expr/src/planner.rs new file mode 100644 index 000000000000..b8c0378f46aa --- /dev/null +++ b/datafusion/expr/src/planner.rs @@ -0,0 +1,146 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ContextProvider`] and [`UserDefinedSQLPlanner`] APIs to customize SQL query planning + +use std::sync::Arc; + +use arrow::datatypes::SchemaRef; +use datafusion_common::{ + config::ConfigOptions, file_options::file_type::FileType, not_impl_err, DFSchema, + Result, TableReference, +}; +use datafusion_common::logical_type::TypeRelation; +use crate::{AggregateUDF, Expr, GetFieldAccess, ScalarUDF, TableSource, WindowUDF}; + +/// Provides the `SQL` query planner meta-data about tables and +/// functions referenced in SQL statements, without a direct dependency on other +/// DataFusion structures +pub trait ContextProvider { + /// Getter for a datasource + fn get_table_source(&self, name: TableReference) -> Result>; + + fn get_file_type(&self, _ext: &str) -> Result> { + not_impl_err!("Registered file types are not supported") + } + + /// Getter for a table function + fn get_table_function_source( + &self, + _name: &str, + _args: Vec, + ) -> Result> { + not_impl_err!("Table Functions are not supported") + } + + /// This provides a worktable (an intermediate table that is used to store the results of a CTE during execution) + /// We don't directly implement this in the logical plan's ['SqlToRel`] + /// because the sql code needs access to a table that contains execution-related types that can't be a direct dependency + /// of the sql crate (namely, the `CteWorktable`). + /// The [`ContextProvider`] provides a way to "hide" this dependency. + fn create_cte_work_table( + &self, + _name: &str, + _schema: SchemaRef, + ) -> Result> { + not_impl_err!("Recursive CTE is not implemented") + } + + /// Getter for a UDF description + fn get_function_meta(&self, name: &str) -> Option>; + /// Getter for a UDAF description + fn get_aggregate_meta(&self, name: &str) -> Option>; + /// Getter for a UDWF + fn get_window_meta(&self, name: &str) -> Option>; + /// Getter for system/user-defined variable type + fn get_variable_type(&self, variable_names: &[String]) -> Option; + + /// Get configuration options + fn options(&self) -> &ConfigOptions; + + /// Get all user defined scalar function names + fn udf_names(&self) -> Vec; + + /// Get all user defined aggregate function names + fn udaf_names(&self) -> Vec; + + /// Get all user defined window function names + fn udwf_names(&self) -> Vec; +} + +/// This trait allows users to customize the behavior of the SQL planner +pub trait UserDefinedSQLPlanner { + /// Plan the binary operation between two expressions, returns OriginalBinaryExpr if not possible + fn plan_binary_op( + &self, + expr: RawBinaryExpr, + _schema: &DFSchema, + ) -> Result> { + Ok(PlannerResult::Original(expr)) + } + + /// Plan the field access expression, returns OriginalFieldAccessExpr if not possible + fn plan_field_access( + &self, + expr: RawFieldAccessExpr, + _schema: &DFSchema, + ) -> Result> { + Ok(PlannerResult::Original(expr)) + } + + // Plan the array literal, returns OriginalArray if not possible + fn plan_array_literal( + &self, + exprs: Vec, + _schema: &DFSchema, + ) -> Result>> { + Ok(PlannerResult::Original(exprs)) + } +} + +/// An operator with two arguments to plan +/// +/// Note `left` and `right` are DataFusion [`Expr`]s but the `op` is the SQL AST +/// operator. +/// +/// This structure is used by [`UserDefinedSQLPlanner`] to plan operators with +/// custom expressions. +#[derive(Debug, Clone)] +pub struct RawBinaryExpr { + pub op: sqlparser::ast::BinaryOperator, + pub left: Expr, + pub right: Expr, +} + +/// An expression with GetFieldAccess to plan +/// +/// This structure is used by [`UserDefinedSQLPlanner`] to plan operators with +/// custom expressions. +#[derive(Debug, Clone)] +pub struct RawFieldAccessExpr { + pub field_access: GetFieldAccess, + pub expr: Expr, +} + +/// Result of planning a raw expr with [`UserDefinedSQLPlanner`] +#[derive(Debug, Clone)] +pub enum PlannerResult { + /// The raw expression was successfully planned as a new [`Expr`] + Planned(Expr), + /// The raw expression could not be planned, and is returned unmodified + Original(T), +} diff --git a/datafusion/functions-array/src/lib.rs b/datafusion/functions-array/src/lib.rs index b2fcb5717b3a..814127be806b 100644 --- a/datafusion/functions-array/src/lib.rs +++ b/datafusion/functions-array/src/lib.rs @@ -39,6 +39,7 @@ pub mod extract; pub mod flatten; pub mod length; pub mod make_array; +pub mod planner; pub mod position; pub mod range; pub mod remove; @@ -46,12 +47,10 @@ pub mod repeat; pub mod replace; pub mod resize; pub mod reverse; -pub mod rewrite; pub mod set_ops; pub mod sort; pub mod string; pub mod utils; - use datafusion_common::Result; use datafusion_execution::FunctionRegistry; use datafusion_expr::ScalarUDF; @@ -152,7 +151,6 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { } Ok(()) as Result<()> })?; - registry.register_function_rewrite(Arc::new(rewrite::ArrayFunctionRewriter {}))?; Ok(()) } diff --git a/datafusion/functions-array/src/planner.rs b/datafusion/functions-array/src/planner.rs new file mode 100644 index 000000000000..5c464cc82844 --- /dev/null +++ b/datafusion/functions-array/src/planner.rs @@ -0,0 +1,162 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! SQL planning extensions like [`ArrayFunctionPlanner`] and [`FieldAccessPlanner`] + +use datafusion_common::{utils::list_ndims, DFSchema, Result}; +use datafusion_common::logical_type::ExtensionType; +use datafusion_expr::{ + planner::{PlannerResult, RawBinaryExpr, RawFieldAccessExpr, UserDefinedSQLPlanner}, + sqlparser, AggregateFunction, Expr, ExprSchemable, GetFieldAccess, +}; +use datafusion_functions::expr_fn::get_field; + +use crate::{ + array_has::array_has_all, + expr_fn::{array_append, array_concat, array_prepend}, + extract::{array_element, array_slice}, + make_array::make_array, +}; + +#[derive(Default)] +pub struct ArrayFunctionPlanner {} + +impl UserDefinedSQLPlanner for ArrayFunctionPlanner { + fn plan_binary_op( + &self, + expr: RawBinaryExpr, + schema: &DFSchema, + ) -> Result> { + let RawBinaryExpr { op, left, right } = expr; + + if op == sqlparser::ast::BinaryOperator::StringConcat { + let left_type = left.get_type(schema)?; + let right_type = right.get_type(schema)?; + let left_list_ndims = list_ndims(left_type.physical()); + let right_list_ndims = list_ndims(right_type.physical()); + + // Rewrite string concat operator to function based on types + // if we get list || list then we rewrite it to array_concat() + // if we get list || non-list then we rewrite it to array_append() + // if we get non-list || list then we rewrite it to array_prepend() + // if we get string || string then we rewrite it to concat() + + // We determine the target function to rewrite based on the list n-dimension, the check is not exact but sufficient. + // The exact validity check is handled in the actual function, so even if there is 3d list appended with 1d list, it is also fine to rewrite. + if left_list_ndims + right_list_ndims == 0 { + // TODO: concat function ignore null, but string concat takes null into consideration + // we can rewrite it to concat if we can configure the behaviour of concat function to the one like `string concat operator` + } else if left_list_ndims == right_list_ndims { + return Ok(PlannerResult::Planned(array_concat(vec![left, right]))); + } else if left_list_ndims > right_list_ndims { + return Ok(PlannerResult::Planned(array_append(left, right))); + } else if left_list_ndims < right_list_ndims { + return Ok(PlannerResult::Planned(array_prepend(left, right))); + } + } else if matches!( + op, + sqlparser::ast::BinaryOperator::AtArrow + | sqlparser::ast::BinaryOperator::ArrowAt + ) { + let left_type = left.get_type(schema)?; + let right_type = right.get_type(schema)?; + let left_list_ndims = list_ndims(left_type.physical()); + let right_list_ndims = list_ndims(right_type.physical()); + // if both are list + if left_list_ndims > 0 && right_list_ndims > 0 { + if op == sqlparser::ast::BinaryOperator::AtArrow { + // array1 @> array2 -> array_has_all(array1, array2) + return Ok(PlannerResult::Planned(array_has_all(left, right))); + } else { + // array1 <@ array2 -> array_has_all(array2, array1) + return Ok(PlannerResult::Planned(array_has_all(right, left))); + } + } + } + + Ok(PlannerResult::Original(RawBinaryExpr { op, left, right })) + } + + fn plan_array_literal( + &self, + exprs: Vec, + _schema: &DFSchema, + ) -> Result>> { + Ok(PlannerResult::Planned(make_array(exprs))) + } +} + +#[derive(Default)] +pub struct FieldAccessPlanner {} + +impl UserDefinedSQLPlanner for FieldAccessPlanner { + fn plan_field_access( + &self, + expr: RawFieldAccessExpr, + _schema: &DFSchema, + ) -> Result> { + let RawFieldAccessExpr { expr, field_access } = expr; + + match field_access { + // expr["field"] => get_field(expr, "field") + GetFieldAccess::NamedStructField { name } => { + Ok(PlannerResult::Planned(get_field(expr, name))) + } + // expr[idx] ==> array_element(expr, idx) + GetFieldAccess::ListIndex { key: index } => { + match expr { + // Special case for array_agg(expr)[index] to NTH_VALUE(expr, index) + Expr::AggregateFunction(agg_func) if is_array_agg(&agg_func) => { + Ok(PlannerResult::Planned(Expr::AggregateFunction( + datafusion_expr::expr::AggregateFunction::new( + AggregateFunction::NthValue, + agg_func + .args + .into_iter() + .chain(std::iter::once(*index)) + .collect(), + agg_func.distinct, + agg_func.filter, + agg_func.order_by, + agg_func.null_treatment, + ), + ))) + } + _ => Ok(PlannerResult::Planned(array_element(expr, *index))), + } + } + // expr[start, stop, stride] ==> array_slice(expr, start, stop, stride) + GetFieldAccess::ListRange { + start, + stop, + stride, + } => Ok(PlannerResult::Planned(array_slice( + expr, + *start, + *stop, + Some(*stride), + ))), + } + } +} + +fn is_array_agg(agg_func: &datafusion_expr::expr::AggregateFunction) -> bool { + agg_func.func_def + == datafusion_expr::expr::AggregateFunctionDefinition::BuiltIn( + AggregateFunction::ArrayAgg, + ) +} diff --git a/datafusion/functions-array/src/rewrite.rs b/datafusion/functions-array/src/rewrite.rs deleted file mode 100644 index 28bc2d5e4373..000000000000 --- a/datafusion/functions-array/src/rewrite.rs +++ /dev/null @@ -1,76 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Rewrites for using Array Functions - -use crate::array_has::array_has_all; -use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::Transformed; -use datafusion_common::DFSchema; -use datafusion_common::Result; -use datafusion_expr::expr::ScalarFunction; -use datafusion_expr::expr_rewriter::FunctionRewrite; -use datafusion_expr::{BinaryExpr, Expr, Operator}; - -/// Rewrites expressions into function calls to array functions -pub(crate) struct ArrayFunctionRewriter {} - -impl FunctionRewrite for ArrayFunctionRewriter { - fn name(&self) -> &str { - "ArrayFunctionRewriter" - } - - fn rewrite( - &self, - expr: Expr, - _schema: &DFSchema, - _config: &ConfigOptions, - ) -> Result> { - let transformed = match expr { - // array1 @> array2 -> array_has_all(array1, array2) - Expr::BinaryExpr(BinaryExpr { left, op, right }) - if op == Operator::AtArrow - && is_func(&left, "make_array") - && is_func(&right, "make_array") => - { - Transformed::yes(array_has_all(*left, *right)) - } - - // array1 <@ array2 -> array_has_all(array2, array1) - Expr::BinaryExpr(BinaryExpr { left, op, right }) - if op == Operator::ArrowAt - && is_func(&left, "make_array") - && is_func(&right, "make_array") => - { - Transformed::yes(array_has_all(*right, *left)) - } - - _ => Transformed::no(expr), - }; - Ok(transformed) - } -} - -/// Returns true if expr is a function call to the specified named function. -/// Returns false otherwise. -fn is_func(expr: &Expr, func_name: &str) -> bool { - let Expr::ScalarFunction(ScalarFunction { func, args: _ }) = expr else { - return false; - }; - - func.name() == func_name -} diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 7e9dbebf1ebc..5b183b41cd20 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -345,9 +345,12 @@ impl CommonSubexprEliminate { self.try_unary_plan(expr, input, config)? .transform_data(|(mut new_expr, new_input)| { assert_eq!(new_expr.len(), 1); // passed in vec![predicate] - let new_predicate = new_expr.pop().unwrap(); - Ok(Filter::remove_aliases(new_predicate)? - .update_data(|new_predicate| (new_predicate, new_input))) + let new_predicate = new_expr + .pop() + .unwrap() + .unalias_nested() + .update_data(|new_predicate| (new_predicate, new_input)); + Ok(new_predicate) })? .map_data(|(new_predicate, new_input)| { Filter::try_new(new_predicate, Arc::new(new_input)) diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index 503a2fc54acf..9f5a5ced741e 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -600,7 +600,7 @@ fn rewrite_expr(expr: Expr, input: &Projection) -> Result> { // * the current column is an expression "f" // // return the expression `d + e` (not `d + e` as f) - let input_expr = input.expr[idx].clone().unalias_nested(); + let input_expr = input.expr[idx].clone().unalias_nested().data; Ok(Transformed::yes(input_expr)) } // Unsupported type for consecutive projection merge analysis. diff --git a/datafusion/physical-expr-common/src/datum.rs b/datafusion/physical-expr-common/src/datum.rs index 96c903180ed9..790e742c4221 100644 --- a/datafusion/physical-expr-common/src/datum.rs +++ b/datafusion/physical-expr-common/src/datum.rs @@ -145,36 +145,3 @@ pub fn compare_op_for_nested( Ok(BooleanArray::new(values, nulls)) } } - -#[cfg(test)] -mod tests { - use arrow::{ - array::{make_comparator, Array, BooleanArray, ListArray}, - buffer::NullBuffer, - compute::SortOptions, - datatypes::Int32Type, - }; - - #[test] - fn test123() { - let data = vec![ - Some(vec![Some(0), Some(1), Some(2)]), - None, - Some(vec![Some(3), None, Some(5)]), - Some(vec![Some(6), Some(7)]), - ]; - let a = ListArray::from_iter_primitive::(data); - let data = vec![ - Some(vec![Some(0), Some(1), Some(2)]), - None, - Some(vec![Some(3), None, Some(5)]), - Some(vec![Some(6), Some(7)]), - ]; - let b = ListArray::from_iter_primitive::(data); - let cmp = make_comparator(&a, &b, SortOptions::default()).unwrap(); - let len = a.len().min(b.len()); - let values = (0..len).map(|i| cmp(i, i).is_eq()).collect(); - let nulls = NullBuffer::union(a.nulls(), b.nulls()); - println!("res: {:?}", BooleanArray::new(values, nulls)); - } -} diff --git a/datafusion/physical-plan/src/display.rs b/datafusion/physical-plan/src/display.rs index ed85c80251d6..7f4ae5797d97 100644 --- a/datafusion/physical-plan/src/display.rs +++ b/datafusion/physical-plan/src/display.rs @@ -21,12 +21,14 @@ use std::fmt; use std::fmt::Formatter; -use super::{accept, ExecutionPlan, ExecutionPlanVisitor}; - use arrow_schema::SchemaRef; + use datafusion_common::display::{GraphvizBuilder, PlanType, StringifiedPlan}; +use datafusion_expr::display_schema; use datafusion_physical_expr::{LexOrdering, PhysicalSortExpr}; +use super::{accept, ExecutionPlan, ExecutionPlanVisitor}; + /// Options for controlling how each [`ExecutionPlan`] should format itself #[derive(Debug, Clone, Copy)] pub enum DisplayFormatType { @@ -37,12 +39,15 @@ pub enum DisplayFormatType { } /// Wraps an `ExecutionPlan` with various ways to display this plan +#[derive(Debug, Clone)] pub struct DisplayableExecutionPlan<'a> { inner: &'a dyn ExecutionPlan, /// How to show metrics show_metrics: ShowMetrics, /// If statistics should be displayed show_statistics: bool, + /// If schema should be displayed. See [`Self::set_show_schema`] + show_schema: bool, } impl<'a> DisplayableExecutionPlan<'a> { @@ -53,6 +58,7 @@ impl<'a> DisplayableExecutionPlan<'a> { inner, show_metrics: ShowMetrics::None, show_statistics: false, + show_schema: false, } } @@ -64,6 +70,7 @@ impl<'a> DisplayableExecutionPlan<'a> { inner, show_metrics: ShowMetrics::Aggregated, show_statistics: false, + show_schema: false, } } @@ -75,9 +82,19 @@ impl<'a> DisplayableExecutionPlan<'a> { inner, show_metrics: ShowMetrics::Full, show_statistics: false, + show_schema: false, } } + /// Enable display of schema + /// + /// If true, plans will be displayed with schema information at the end + /// of each line. The format is `schema=[[a:Int32;N, b:Int32;N, c:Int32;N]]` + pub fn set_show_schema(mut self, show_schema: bool) -> Self { + self.show_schema = show_schema; + self + } + /// Enable display of statistics pub fn set_show_statistics(mut self, show_statistics: bool) -> Self { self.show_statistics = show_statistics; @@ -105,6 +122,7 @@ impl<'a> DisplayableExecutionPlan<'a> { plan: &'a dyn ExecutionPlan, show_metrics: ShowMetrics, show_statistics: bool, + show_schema: bool, } impl<'a> fmt::Display for Wrapper<'a> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { @@ -114,6 +132,7 @@ impl<'a> DisplayableExecutionPlan<'a> { indent: 0, show_metrics: self.show_metrics, show_statistics: self.show_statistics, + show_schema: self.show_schema, }; accept(self.plan, &mut visitor) } @@ -123,6 +142,7 @@ impl<'a> DisplayableExecutionPlan<'a> { plan: self.inner, show_metrics: self.show_metrics, show_statistics: self.show_statistics, + show_schema: self.show_schema, } } @@ -179,6 +199,7 @@ impl<'a> DisplayableExecutionPlan<'a> { plan: &'a dyn ExecutionPlan, show_metrics: ShowMetrics, show_statistics: bool, + show_schema: bool, } impl<'a> fmt::Display for Wrapper<'a> { @@ -189,6 +210,7 @@ impl<'a> DisplayableExecutionPlan<'a> { indent: 0, show_metrics: self.show_metrics, show_statistics: self.show_statistics, + show_schema: self.show_schema, }; visitor.pre_visit(self.plan)?; Ok(()) @@ -199,6 +221,7 @@ impl<'a> DisplayableExecutionPlan<'a> { plan: self.inner, show_metrics: self.show_metrics, show_statistics: self.show_statistics, + show_schema: self.show_schema, } } @@ -221,6 +244,14 @@ enum ShowMetrics { } /// Formats plans with a single line per node. +/// +/// # Example +/// +/// ```text +/// ProjectionExec: expr=[column1@0 + 2 as column1 + Int64(2)] +/// FilterExec: column1@0 = 5 +/// ValuesExec +/// ``` struct IndentVisitor<'a, 'b> { /// How to format each node t: DisplayFormatType, @@ -232,6 +263,8 @@ struct IndentVisitor<'a, 'b> { show_metrics: ShowMetrics, /// If statistics should be displayed show_statistics: bool, + /// If schema should be displayed + show_schema: bool, } impl<'a, 'b> ExecutionPlanVisitor for IndentVisitor<'a, 'b> { @@ -265,6 +298,13 @@ impl<'a, 'b> ExecutionPlanVisitor for IndentVisitor<'a, 'b> { let stats = plan.statistics().map_err(|_e| fmt::Error)?; write!(self.f, ", statistics=[{}]", stats)?; } + if self.show_schema { + write!( + self.f, + ", schema={}", + display_schema(plan.schema().as_ref()) + )?; + } writeln!(self.f)?; self.indent += 1; Ok(true) @@ -465,12 +505,13 @@ mod tests { use std::fmt::Write; use std::sync::Arc; - use super::DisplayableExecutionPlan; - use crate::{DisplayAs, ExecutionPlan, PlanProperties}; - use datafusion_common::{DataFusionError, Result, Statistics}; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; + use crate::{DisplayAs, ExecutionPlan, PlanProperties}; + + use super::DisplayableExecutionPlan; + #[derive(Debug, Clone, Copy)] enum TestStatsExecPlan { Panic, diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index 47901591c8c9..0bf66bc6e522 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -586,7 +586,7 @@ impl Debug for ExternalSorter { } } -pub(crate) fn sort_batch( +pub fn sort_batch( batch: &RecordBatch, expressions: &[PhysicalSortExpr], fetch: Option, diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 004d7320e21b..7f4d6b9d927e 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -669,9 +669,11 @@ message PlanType { datafusion_common.EmptyMessage FinalLogicalPlan = 3; datafusion_common.EmptyMessage InitialPhysicalPlan = 4; datafusion_common.EmptyMessage InitialPhysicalPlanWithStats = 9; + datafusion_common.EmptyMessage InitialPhysicalPlanWithSchema = 11; OptimizedPhysicalPlanType OptimizedPhysicalPlan = 5; datafusion_common.EmptyMessage FinalPhysicalPlan = 6; datafusion_common.EmptyMessage FinalPhysicalPlanWithStats = 10; + datafusion_common.EmptyMessage FinalPhysicalPlanWithSchema = 12; } } diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index ebfa783f8561..33cd634c4aad 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -16140,6 +16140,9 @@ impl serde::Serialize for PlanType { plan_type::PlanTypeEnum::InitialPhysicalPlanWithStats(v) => { struct_ser.serialize_field("InitialPhysicalPlanWithStats", v)?; } + plan_type::PlanTypeEnum::InitialPhysicalPlanWithSchema(v) => { + struct_ser.serialize_field("InitialPhysicalPlanWithSchema", v)?; + } plan_type::PlanTypeEnum::OptimizedPhysicalPlan(v) => { struct_ser.serialize_field("OptimizedPhysicalPlan", v)?; } @@ -16149,6 +16152,9 @@ impl serde::Serialize for PlanType { plan_type::PlanTypeEnum::FinalPhysicalPlanWithStats(v) => { struct_ser.serialize_field("FinalPhysicalPlanWithStats", v)?; } + plan_type::PlanTypeEnum::FinalPhysicalPlanWithSchema(v) => { + struct_ser.serialize_field("FinalPhysicalPlanWithSchema", v)?; + } } } struct_ser.end() @@ -16168,9 +16174,11 @@ impl<'de> serde::Deserialize<'de> for PlanType { "FinalLogicalPlan", "InitialPhysicalPlan", "InitialPhysicalPlanWithStats", + "InitialPhysicalPlanWithSchema", "OptimizedPhysicalPlan", "FinalPhysicalPlan", "FinalPhysicalPlanWithStats", + "FinalPhysicalPlanWithSchema", ]; #[allow(clippy::enum_variant_names)] @@ -16182,9 +16190,11 @@ impl<'de> serde::Deserialize<'de> for PlanType { FinalLogicalPlan, InitialPhysicalPlan, InitialPhysicalPlanWithStats, + InitialPhysicalPlanWithSchema, OptimizedPhysicalPlan, FinalPhysicalPlan, FinalPhysicalPlanWithStats, + FinalPhysicalPlanWithSchema, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -16213,9 +16223,11 @@ impl<'de> serde::Deserialize<'de> for PlanType { "FinalLogicalPlan" => Ok(GeneratedField::FinalLogicalPlan), "InitialPhysicalPlan" => Ok(GeneratedField::InitialPhysicalPlan), "InitialPhysicalPlanWithStats" => Ok(GeneratedField::InitialPhysicalPlanWithStats), + "InitialPhysicalPlanWithSchema" => Ok(GeneratedField::InitialPhysicalPlanWithSchema), "OptimizedPhysicalPlan" => Ok(GeneratedField::OptimizedPhysicalPlan), "FinalPhysicalPlan" => Ok(GeneratedField::FinalPhysicalPlan), "FinalPhysicalPlanWithStats" => Ok(GeneratedField::FinalPhysicalPlanWithStats), + "FinalPhysicalPlanWithSchema" => Ok(GeneratedField::FinalPhysicalPlanWithSchema), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -16285,6 +16297,13 @@ impl<'de> serde::Deserialize<'de> for PlanType { return Err(serde::de::Error::duplicate_field("InitialPhysicalPlanWithStats")); } plan_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(plan_type::PlanTypeEnum::InitialPhysicalPlanWithStats) +; + } + GeneratedField::InitialPhysicalPlanWithSchema => { + if plan_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("InitialPhysicalPlanWithSchema")); + } + plan_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(plan_type::PlanTypeEnum::InitialPhysicalPlanWithSchema) ; } GeneratedField::OptimizedPhysicalPlan => { @@ -16306,6 +16325,13 @@ impl<'de> serde::Deserialize<'de> for PlanType { return Err(serde::de::Error::duplicate_field("FinalPhysicalPlanWithStats")); } plan_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(plan_type::PlanTypeEnum::FinalPhysicalPlanWithStats) +; + } + GeneratedField::FinalPhysicalPlanWithSchema => { + if plan_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("FinalPhysicalPlanWithSchema")); + } + plan_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(plan_type::PlanTypeEnum::FinalPhysicalPlanWithSchema) ; } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 1a3514dbd4f7..83b8b738c4f4 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -947,7 +947,10 @@ pub struct OptimizedPhysicalPlanType { #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct PlanType { - #[prost(oneof = "plan_type::PlanTypeEnum", tags = "1, 7, 8, 2, 3, 4, 9, 5, 6, 10")] + #[prost( + oneof = "plan_type::PlanTypeEnum", + tags = "1, 7, 8, 2, 3, 4, 9, 11, 5, 6, 10, 12" + )] pub plan_type_enum: ::core::option::Option, } /// Nested message and enum types in `PlanType`. @@ -969,12 +972,16 @@ pub mod plan_type { InitialPhysicalPlan(super::super::datafusion_common::EmptyMessage), #[prost(message, tag = "9")] InitialPhysicalPlanWithStats(super::super::datafusion_common::EmptyMessage), + #[prost(message, tag = "11")] + InitialPhysicalPlanWithSchema(super::super::datafusion_common::EmptyMessage), #[prost(message, tag = "5")] OptimizedPhysicalPlan(super::OptimizedPhysicalPlanType), #[prost(message, tag = "6")] FinalPhysicalPlan(super::super::datafusion_common::EmptyMessage), #[prost(message, tag = "10")] FinalPhysicalPlanWithStats(super::super::datafusion_common::EmptyMessage), + #[prost(message, tag = "12")] + FinalPhysicalPlanWithSchema(super::super::datafusion_common::EmptyMessage), } } #[allow(clippy::derive_partial_eq_without_eq)] diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 2a06c0e0c877..0842b3f4cbc9 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -37,6 +37,9 @@ use datafusion_expr::{ }; use datafusion_proto_common::{from_proto::FromOptionalField, FromProtoError as Error}; +use crate::protobuf::plan_type::PlanTypeEnum::{ + FinalPhysicalPlanWithSchema, InitialPhysicalPlanWithSchema, +}; use crate::protobuf::{ self, plan_type::PlanTypeEnum::{ @@ -122,6 +125,7 @@ impl From<&protobuf::StringifiedPlan> for StringifiedPlan { FinalLogicalPlan(_) => PlanType::FinalLogicalPlan, InitialPhysicalPlan(_) => PlanType::InitialPhysicalPlan, InitialPhysicalPlanWithStats(_) => PlanType::InitialPhysicalPlanWithStats, + InitialPhysicalPlanWithSchema(_) => PlanType::InitialPhysicalPlanWithSchema, OptimizedPhysicalPlan(OptimizedPhysicalPlanType { optimizer_name }) => { PlanType::OptimizedPhysicalPlan { optimizer_name: optimizer_name.clone(), @@ -129,6 +133,7 @@ impl From<&protobuf::StringifiedPlan> for StringifiedPlan { } FinalPhysicalPlan(_) => PlanType::FinalPhysicalPlan, FinalPhysicalPlanWithStats(_) => PlanType::FinalPhysicalPlanWithStats, + FinalPhysicalPlanWithSchema(_) => PlanType::FinalPhysicalPlanWithSchema, }, plan: Arc::new(stringified_plan.plan.clone()), } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 3a1db1defdd9..ccc64119c8a1 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -34,9 +34,9 @@ use crate::protobuf::{ self, plan_type::PlanTypeEnum::{ AnalyzedLogicalPlan, FinalAnalyzedLogicalPlan, FinalLogicalPlan, - FinalPhysicalPlan, FinalPhysicalPlanWithStats, InitialLogicalPlan, - InitialPhysicalPlan, InitialPhysicalPlanWithStats, OptimizedLogicalPlan, - OptimizedPhysicalPlan, + FinalPhysicalPlan, FinalPhysicalPlanWithSchema, FinalPhysicalPlanWithStats, + InitialLogicalPlan, InitialPhysicalPlan, InitialPhysicalPlanWithSchema, + InitialPhysicalPlanWithStats, OptimizedLogicalPlan, OptimizedPhysicalPlan, }, AnalyzedLogicalPlanType, CubeNode, EmptyMessage, GroupingSetNode, LogicalExprList, OptimizedLogicalPlanType, OptimizedPhysicalPlanType, PlaceholderNode, RollupNode, @@ -96,9 +96,15 @@ impl From<&StringifiedPlan> for protobuf::StringifiedPlan { PlanType::InitialPhysicalPlanWithStats => Some(protobuf::PlanType { plan_type_enum: Some(InitialPhysicalPlanWithStats(EmptyMessage {})), }), + PlanType::InitialPhysicalPlanWithSchema => Some(protobuf::PlanType { + plan_type_enum: Some(InitialPhysicalPlanWithSchema(EmptyMessage {})), + }), PlanType::FinalPhysicalPlanWithStats => Some(protobuf::PlanType { plan_type_enum: Some(FinalPhysicalPlanWithStats(EmptyMessage {})), }), + PlanType::FinalPhysicalPlanWithSchema => Some(protobuf::PlanType { + plan_type_enum: Some(FinalPhysicalPlanWithSchema(EmptyMessage {})), + }), }, plan: stringified_plan.plan.to_string(), } diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index b9a8f7d35245..a564cb12d924 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -17,7 +17,8 @@ use arrow_schema::{DataType, TimeUnit}; use datafusion_common::logical_type::signature::LogicalType; -use datafusion_common::utils::list_ndims; +use datafusion_expr::planner::PlannerResult; +use datafusion_expr::planner::RawFieldAccessExpr; use sqlparser::ast::{CastKind, Expr as SQLExpr, Subscript, TrimWhereField, Value}; use datafusion_common::{ @@ -28,8 +29,8 @@ use datafusion_common::logical_type::{TypeRelation, ExtensionType}; use datafusion_expr::expr::InList; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::{ - lit, AggregateFunction, Between, BinaryExpr, Cast, Expr, ExprSchemable, - GetFieldAccess, Like, Literal, Operator, TryCast, + lit, Between, BinaryExpr, Cast, Expr, ExprSchemable, GetFieldAccess, Like, Literal, + Operator, TryCast, }; use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; @@ -53,7 +54,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ) -> Result { enum StackEntry { SQLExpr(Box), - Operator(Operator), + Operator(sqlparser::ast::BinaryOperator), } // Virtual stack machine to convert SQLExpr to Expr @@ -70,7 +71,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { SQLExpr::BinaryOp { left, op, right } => { // Note the order that we push the entries to the stack // is important. We want to visit the left node first. - let op = self.parse_sql_binary_op(op)?; stack.push(StackEntry::Operator(op)); stack.push(StackEntry::SQLExpr(right)); stack.push(StackEntry::SQLExpr(left)); @@ -101,63 +101,28 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { fn build_logical_expr( &self, - op: Operator, + op: sqlparser::ast::BinaryOperator, left: Expr, right: Expr, schema: &DFSchema, ) -> Result { - // Rewrite string concat operator to function based on types - // if we get list || list then we rewrite it to array_concat() - // if we get list || non-list then we rewrite it to array_append() - // if we get non-list || list then we rewrite it to array_prepend() - // if we get string || string then we rewrite it to concat() - if op == Operator::StringConcat { - let left_type = left.get_type(schema)?; - let right_type = right.get_type(schema)?; - let left_list_ndims = list_ndims(&left_type.physical()); - let right_list_ndims = list_ndims(&right_type.physical()); - - // We determine the target function to rewrite based on the list n-dimension, the check is not exact but sufficient. - // The exact validity check is handled in the actual function, so even if there is 3d list appended with 1d list, it is also fine to rewrite. - if left_list_ndims + right_list_ndims == 0 { - // TODO: concat function ignore null, but string concat takes null into consideration - // we can rewrite it to concat if we can configure the behaviour of concat function to the one like `string concat operator` - } else if left_list_ndims == right_list_ndims { - if let Some(udf) = self.context_provider.get_function_meta("array_concat") - { - return Ok(Expr::ScalarFunction(ScalarFunction::new_udf( - udf, - vec![left, right], - ))); - } else { - return internal_err!("array_concat not found"); - } - } else if left_list_ndims > right_list_ndims { - if let Some(udf) = self.context_provider.get_function_meta("array_append") - { - return Ok(Expr::ScalarFunction(ScalarFunction::new_udf( - udf, - vec![left, right], - ))); - } else { - return internal_err!("array_append not found"); + // try extension planers + let mut binary_expr = datafusion_expr::planner::RawBinaryExpr { op, left, right }; + for planner in self.planners.iter() { + match planner.plan_binary_op(binary_expr, schema)? { + PlannerResult::Planned(expr) => { + return Ok(expr); } - } else if left_list_ndims < right_list_ndims { - if let Some(udf) = - self.context_provider.get_function_meta("array_prepend") - { - return Ok(Expr::ScalarFunction(ScalarFunction::new_udf( - udf, - vec![left, right], - ))); - } else { - return internal_err!("array_append not found"); + PlannerResult::Original(expr) => { + binary_expr = expr; } } } + + let datafusion_expr::planner::RawBinaryExpr { op, left, right } = binary_expr; Ok(Expr::BinaryExpr(BinaryExpr::new( Box::new(left), - op, + self.parse_sql_binary_op(op)?, Box::new(right), ))) } @@ -245,7 +210,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let expr = self.sql_expr_to_logical_expr(*expr, schema, planner_context)?; - let get_field_access = match *subscript { + let field_access = match *subscript { Subscript::Index { index } => { // index can be a name, in which case it is a named field access match index { @@ -314,7 +279,17 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } }; - self.plan_field_access(expr, get_field_access) + let mut field_access_expr = RawFieldAccessExpr { expr, field_access }; + for planner in self.planners.iter() { + match planner.plan_field_access(field_access_expr, schema)? { + PlannerResult::Planned(expr) => return Ok(expr), + PlannerResult::Original(expr) => { + field_access_expr = expr; + } + } + } + + not_impl_err!("GetFieldAccess not supported by UserDefinedExtensionPlanners: {field_access_expr:?}") } SQLExpr::CompoundIdentifier(ids) => { @@ -649,36 +624,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } - /// Simplifies an expression like `ARRAY_AGG(expr)[index]` to `NTH_VALUE(expr, index)` - /// - /// returns Some(Expr) if the expression was simplified, otherwise None - /// TODO: this should likely be done in ArrayAgg::simplify when it is moved to a UDAF - fn simplify_array_index_expr(expr: &Expr, index: &Expr) -> Option { - fn is_array_agg(agg_func: &datafusion_expr::expr::AggregateFunction) -> bool { - agg_func.func_def - == datafusion_expr::expr::AggregateFunctionDefinition::BuiltIn( - AggregateFunction::ArrayAgg, - ) - } - match expr { - Expr::AggregateFunction(agg_func) if is_array_agg(agg_func) => { - let mut new_args = agg_func.args.clone(); - new_args.push(index.clone()); - Some(Expr::AggregateFunction( - datafusion_expr::expr::AggregateFunction::new( - AggregateFunction::NthValue, - new_args, - agg_func.distinct, - agg_func.filter.clone(), - agg_func.order_by.clone(), - agg_func.null_treatment, - ), - )) - } - _ => None, - } - } - /// Parses a struct(..) expression fn parse_struct( &self, @@ -964,58 +909,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let args = vec![fullstr, substr]; Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fun, args))) } - - /// Given an expression and the field to access, creates a new expression for accessing that field - fn plan_field_access( - &self, - expr: Expr, - get_field_access: GetFieldAccess, - ) -> Result { - match get_field_access { - GetFieldAccess::NamedStructField { name } => { - if let Some(udf) = self.context_provider.get_function_meta("get_field") { - Ok(Expr::ScalarFunction(ScalarFunction::new_udf( - udf, - vec![expr, lit(name)], - ))) - } else { - internal_err!("get_field not found") - } - } - // expr[idx] ==> array_element(expr, idx) - GetFieldAccess::ListIndex { key } => { - // Special case for array_agg(expr)[index] to NTH_VALUE(expr, index) - if let Some(simplified) = Self::simplify_array_index_expr(&expr, &key) { - Ok(simplified) - } else if let Some(udf) = - self.context_provider.get_function_meta("array_element") - { - Ok(Expr::ScalarFunction(ScalarFunction::new_udf( - udf, - vec![expr, *key], - ))) - } else { - internal_err!("get_field not found") - } - } - // expr[start, stop, stride] ==> array_slice(expr, start, stop, stride) - GetFieldAccess::ListRange { - start, - stop, - stride, - } => { - if let Some(udf) = self.context_provider.get_function_meta("array_slice") - { - Ok(Expr::ScalarFunction(ScalarFunction::new_udf( - udf, - vec![expr, *start, *stop, *stride], - ))) - } else { - internal_err!("array_slice not found") - } - } - } - } } #[cfg(test)] diff --git a/datafusion/sql/src/expr/value.rs b/datafusion/sql/src/expr/value.rs index 1b5196bb4c95..cfae655117ee 100644 --- a/datafusion/sql/src/expr/value.rs +++ b/datafusion/sql/src/expr/value.rs @@ -19,9 +19,10 @@ use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use arrow::compute::kernels::cast_utils::parse_interval_month_day_nano; use arrow::datatypes::DECIMAL128_MAX_PRECISION; use datafusion_common::{ - not_impl_err, plan_err, DFSchema, DataFusionError, Result, ScalarValue, + internal_err, not_impl_err, plan_err, DFSchema, DataFusionError, Result, ScalarValue, }; -use datafusion_expr::expr::{BinaryExpr, Placeholder, ScalarFunction}; +use datafusion_expr::expr::{BinaryExpr, Placeholder}; +use datafusion_expr::planner::PlannerResult; use datafusion_expr::{lit, Expr, Operator}; use log::debug; use sqlparser::ast::{BinaryOperator, Expr as SQLExpr, Interval, Value}; @@ -130,6 +131,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ))) } + // IMPORTANT: Keep sql_array_literal's function body small to prevent stack overflow + // This function is recursively called, potentially leading to deep call stacks. pub(super) fn sql_array_literal( &self, elements: Vec, @@ -142,13 +145,25 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { }) .collect::>>()?; - if let Some(udf) = self.context_provider.get_function_meta("make_array") { - Ok(Expr::ScalarFunction(ScalarFunction::new_udf(udf, values))) - } else { - not_impl_err!( - "array_expression featrue is disable, So should implement make_array UDF by yourself" - ) + self.try_plan_array_literal(values, schema) + } + + fn try_plan_array_literal( + &self, + values: Vec, + schema: &DFSchema, + ) -> Result { + let mut exprs = values; + for planner in self.planners.iter() { + match planner.plan_array_literal(exprs, schema)? { + PlannerResult::Planned(expr) => { + return Ok(expr); + } + PlannerResult::Original(values) => exprs = values, + } } + + internal_err!("Expected a simplified result, but none was found") } /// Convert a SQL interval expression to a DataFusion logical plan diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 46c8818edcb9..74237b55b71a 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -21,17 +21,15 @@ use std::sync::Arc; use std::vec; use arrow_schema::*; -use datafusion_common::file_options::file_type::FileType; use datafusion_common::{ field_not_found, internal_err, plan_datafusion_err, DFSchemaRef, SchemaError, }; -use datafusion_expr::WindowUDF; +use datafusion_expr::planner::UserDefinedSQLPlanner; use sqlparser::ast::TimezoneInfo; use sqlparser::ast::{ArrayElemTypeDef, ExactNumberInfo}; use sqlparser::ast::{ColumnDef as SQLColumnDef, ColumnOption}; use sqlparser::ast::{DataType as SQLDataType, Ident, ObjectName, TableAlias}; -use datafusion_common::config::ConfigOptions; use datafusion_common::TableReference; use datafusion_common::{ not_impl_err, plan_err, unqualified_field_not_found, DFSchema, DataFusionError, @@ -43,64 +41,11 @@ use datafusion_common::logical_type::TypeRelation; use datafusion_common::logical_type::schema::LogicalSchema; use datafusion_expr::logical_plan::{LogicalPlan, LogicalPlanBuilder}; use datafusion_expr::utils::find_column_exprs; -use datafusion_expr::TableSource; -use datafusion_expr::{col, AggregateUDF, Expr, ScalarUDF}; +use datafusion_expr::{col, Expr}; use crate::utils::make_decimal_type; -/// The ContextProvider trait allows the query planner to obtain meta-data about tables and -/// functions referenced in SQL statements -pub trait ContextProvider { - /// Getter for a datasource - fn get_table_source(&self, name: TableReference) -> Result>; - - fn get_file_type(&self, _ext: &str) -> Result> { - not_impl_err!("Registered file types are not supported") - } - - /// Getter for a table function - fn get_table_function_source( - &self, - _name: &str, - _args: Vec, - ) -> Result> { - not_impl_err!("Table Functions are not supported") - } - - /// This provides a worktable (an intermediate table that is used to store the results of a CTE during execution) - /// We don't directly implement this in the logical plan's ['SqlToRel`] - /// because the sql code needs access to a table that contains execution-related types that can't be a direct dependency - /// of the sql crate (namely, the `CteWorktable`). - /// The [`ContextProvider`] provides a way to "hide" this dependency. - fn create_cte_work_table( - &self, - _name: &str, - _schema: SchemaRef, - ) -> Result> { - not_impl_err!("Recursive CTE is not implemented") - } - - /// Getter for a UDF description - fn get_function_meta(&self, name: &str) -> Option>; - /// Getter for a UDAF description - fn get_aggregate_meta(&self, name: &str) -> Option>; - /// Getter for a UDWF - fn get_window_meta(&self, name: &str) -> Option>; - /// Getter for system/user-defined variable type - fn get_variable_type(&self, variable_names: &[String]) -> Option; - - /// Get configuration options - fn options(&self) -> &ConfigOptions; - - /// Get all user defined scalar function names - fn udf_names(&self) -> Vec; - - /// Get all user defined aggregate function names - fn udaf_names(&self) -> Vec; - - /// Get all user defined window function names - fn udwf_names(&self) -> Vec; -} +pub use datafusion_expr::planner::ContextProvider; /// SQL parser options #[derive(Debug)] @@ -245,6 +190,8 @@ pub struct SqlToRel<'a, S: ContextProvider> { pub(crate) context_provider: &'a S, pub(crate) options: ParserOptions, pub(crate) normalizer: IdentNormalizer, + /// user defined planner extensions + pub(crate) planners: Vec>, } impl<'a, S: ContextProvider> SqlToRel<'a, S> { @@ -253,13 +200,24 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Self::new_with_options(context_provider, ParserOptions::default()) } + /// add an user defined planner + pub fn with_user_defined_planner( + mut self, + planner: Arc, + ) -> Self { + self.planners.push(planner); + self + } + /// Create a new query planner pub fn new_with_options(context_provider: &'a S, options: ParserOptions) -> Self { let normalize = options.enable_ident_normalization; + SqlToRel { context_provider, options, normalizer: IdentNormalizer::new(normalize), + planners: vec![], } } diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 5ac9d9241071..e85614ec555b 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -1228,22 +1228,6 @@ fn select_binary_expr_nested() { quick_test(sql, expected); } -#[test] -fn select_at_arrow_operator() { - let sql = "SELECT left @> right from array"; - let expected = "Projection: array.left @> array.right\ - \n TableScan: array"; - quick_test(sql, expected); -} - -#[test] -fn select_arrow_at_operator() { - let sql = "SELECT left <@ right from array"; - let expected = "Projection: array.left <@ array.right\ - \n TableScan: array"; - quick_test(sql, expected); -} - #[test] fn select_wildcard_with_groupby() { quick_test( diff --git a/datafusion/sqllogictest/Cargo.toml b/datafusion/sqllogictest/Cargo.toml index 3b1f0dfd6d89..28ef6fe9adb6 100644 --- a/datafusion/sqllogictest/Cargo.toml +++ b/datafusion/sqllogictest/Cargo.toml @@ -51,7 +51,7 @@ object_store = { workspace = true } postgres-protocol = { version = "0.6.4", optional = true } postgres-types = { version = "0.2.4", optional = true } rust_decimal = { version = "1.27.0" } -sqllogictest = "0.20.0" +sqllogictest = "0.21.0" sqlparser = { workspace = true } tempfile = { workspace = true } thiserror = { workspace = true } diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 77d1a9da1f55..7917f1d78da8 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -6076,6 +6076,17 @@ select make_array(1,2,3) @> make_array(1,3), ---- true false true false false false true +# Make sure it is rewritten to function array_has_all() +query TT +explain select [1,2,3] @> [1,3]; +---- +logical_plan +01)Projection: Boolean(true) AS array_has_all(make_array(Int64(1),Int64(2),Int64(3)),make_array(Int64(1),Int64(3))) +02)--EmptyRelation +physical_plan +01)ProjectionExec: expr=[true as array_has_all(make_array(Int64(1),Int64(2),Int64(3)),make_array(Int64(1),Int64(3)))] +02)--PlaceholderRowExec + # array containment operator with scalars #2 (arrow at) query BBBBBBB select make_array(1,3) <@ make_array(1,2,3), @@ -6088,6 +6099,17 @@ select make_array(1,3) <@ make_array(1,2,3), ---- true false true false false false true +# Make sure it is rewritten to function array_has_all() +query TT +explain select [1,3] <@ [1,2,3]; +---- +logical_plan +01)Projection: Boolean(true) AS array_has_all(make_array(Int64(1),Int64(2),Int64(3)),make_array(Int64(1),Int64(3))) +02)--EmptyRelation +physical_plan +01)ProjectionExec: expr=[true as array_has_all(make_array(Int64(1),Int64(2),Int64(3)),make_array(Int64(1),Int64(3)))] +02)--PlaceholderRowExec + ### Array casting tests diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt index 96e73a591678..b850760b8734 100644 --- a/datafusion/sqllogictest/test_files/explain.slt +++ b/datafusion/sqllogictest/test_files/explain.slt @@ -241,6 +241,7 @@ logical_plan after optimize_projections SAME TEXT AS ABOVE logical_plan TableScan: simple_explain_test projection=[a, b, c] initial_physical_plan CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], has_header=true initial_physical_plan_with_stats CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], has_header=true, statistics=[Rows=Absent, Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:)]] +initial_physical_plan_with_schema CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], has_header=true, schema=[a:Int32;N, b:Int32;N, c:Int32;N] physical_plan after OutputRequirements 01)OutputRequirementExec 02)--CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], has_header=true @@ -259,6 +260,23 @@ physical_plan after ProjectionPushdown SAME TEXT AS ABOVE physical_plan after PipelineChecker SAME TEXT AS ABOVE physical_plan CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], has_header=true physical_plan_with_stats CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], has_header=true, statistics=[Rows=Absent, Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:)]] +physical_plan_with_schema CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], has_header=true, schema=[a:Int32;N, b:Int32;N, c:Int32;N] + +### tests for EXPLAIN with display schema enabled + +statement ok +set datafusion.explain.show_schema = true; + +# test EXPLAIN VERBOSE +query TT +EXPLAIN SELECT a, b, c FROM simple_explain_test; +---- +logical_plan TableScan: simple_explain_test projection=[a, b, c] +physical_plan CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], has_header=true, schema=[a:Int32;N, b:Int32;N, c:Int32;N] + + +statement ok +set datafusion.explain.show_schema = false; ### tests for EXPLAIN with display statistics enabled @@ -297,6 +315,9 @@ EXPLAIN VERBOSE SELECT * FROM alltypes_plain limit 10; initial_physical_plan 01)GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] 02)--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] +initial_physical_plan_with_schema +01)GlobalLimitExec: skip=0, fetch=10, schema=[id:Int32;N, bool_col:Boolean;N, tinyint_col:Int32;N, smallint_col:Int32;N, int_col:Int32;N, bigint_col:Int64;N, float_col:Float32;N, double_col:Float64;N, date_string_col:Binary;N, string_col:Binary;N, timestamp_col:Timestamp(Nanosecond, None);N] +02)--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, schema=[id:Int32;N, bool_col:Boolean;N, tinyint_col:Int32;N, smallint_col:Int32;N, int_col:Int32;N, bigint_col:Int64;N, float_col:Float32;N, double_col:Float64;N, date_string_col:Binary;N, string_col:Binary;N, timestamp_col:Timestamp(Nanosecond, None);N] physical_plan after OutputRequirements 01)OutputRequirementExec, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] 02)--GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] @@ -319,6 +340,9 @@ physical_plan after PipelineChecker SAME TEXT AS ABOVE physical_plan 01)GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] 02)--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] +physical_plan_with_schema +01)GlobalLimitExec: skip=0, fetch=10, schema=[id:Int32;N, bool_col:Boolean;N, tinyint_col:Int32;N, smallint_col:Int32;N, int_col:Int32;N, bigint_col:Int64;N, float_col:Float32;N, double_col:Float64;N, date_string_col:Binary;N, string_col:Binary;N, timestamp_col:Timestamp(Nanosecond, None);N] +02)--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, schema=[id:Int32;N, bool_col:Boolean;N, tinyint_col:Int32;N, smallint_col:Int32;N, int_col:Int32;N, bigint_col:Int64;N, float_col:Float32;N, double_col:Float64;N, date_string_col:Binary;N, string_col:Binary;N, timestamp_col:Timestamp(Nanosecond, None);N] statement ok @@ -334,6 +358,9 @@ initial_physical_plan initial_physical_plan_with_stats 01)GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] 02)--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] +initial_physical_plan_with_schema +01)GlobalLimitExec: skip=0, fetch=10, schema=[id:Int32;N, bool_col:Boolean;N, tinyint_col:Int32;N, smallint_col:Int32;N, int_col:Int32;N, bigint_col:Int64;N, float_col:Float32;N, double_col:Float64;N, date_string_col:Binary;N, string_col:Binary;N, timestamp_col:Timestamp(Nanosecond, None);N] +02)--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, schema=[id:Int32;N, bool_col:Boolean;N, tinyint_col:Int32;N, smallint_col:Int32;N, int_col:Int32;N, bigint_col:Int64;N, float_col:Float32;N, double_col:Float64;N, date_string_col:Binary;N, string_col:Binary;N, timestamp_col:Timestamp(Nanosecond, None);N] physical_plan after OutputRequirements 01)OutputRequirementExec 02)--GlobalLimitExec: skip=0, fetch=10 @@ -359,6 +386,9 @@ physical_plan physical_plan_with_stats 01)GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] 02)--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] +physical_plan_with_schema +01)GlobalLimitExec: skip=0, fetch=10, schema=[id:Int32;N, bool_col:Boolean;N, tinyint_col:Int32;N, smallint_col:Int32;N, int_col:Int32;N, bigint_col:Int64;N, float_col:Float32;N, double_col:Float64;N, date_string_col:Binary;N, string_col:Binary;N, timestamp_col:Timestamp(Nanosecond, None);N] +02)--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, schema=[id:Int32;N, bool_col:Boolean;N, tinyint_col:Int32;N, smallint_col:Int32;N, int_col:Int32;N, bigint_col:Int64;N, float_col:Float32;N, double_col:Float64;N, date_string_col:Binary;N, string_col:Binary;N, timestamp_col:Timestamp(Nanosecond, None);N] statement ok diff --git a/datafusion/sqllogictest/test_files/information_schema.slt b/datafusion/sqllogictest/test_files/information_schema.slt index ee64f772917c..acd465a0c021 100644 --- a/datafusion/sqllogictest/test_files/information_schema.slt +++ b/datafusion/sqllogictest/test_files/information_schema.slt @@ -213,6 +213,7 @@ datafusion.execution.target_partitions 7 datafusion.execution.time_zone +00:00 datafusion.explain.logical_plan_only false datafusion.explain.physical_plan_only false +datafusion.explain.show_schema false datafusion.explain.show_sizes true datafusion.explain.show_statistics false datafusion.optimizer.allow_symmetric_joins_without_pruning true @@ -296,6 +297,7 @@ datafusion.execution.target_partitions 7 Number of partitions for query executio datafusion.execution.time_zone +00:00 The default time zone Some functions, e.g. `EXTRACT(HOUR from SOME_TIME)`, shift the underlying datetime according to this time zone, and then extract the hour datafusion.explain.logical_plan_only false When set to true, the explain statement will only print logical plans datafusion.explain.physical_plan_only false When set to true, the explain statement will only print physical plans +datafusion.explain.show_schema false When set to true, the explain statement will print schema information datafusion.explain.show_sizes true When set to true, the explain statement will print the partition sizes datafusion.explain.show_statistics false When set to true, the explain statement will print operator statistics for physical plans datafusion.optimizer.allow_symmetric_joins_without_pruning true Should DataFusion allow symmetric hash joins for unbounded data sources even when its inputs do not have any ordering or filtering If the flag is not enabled, the SymmetricHashJoin operator will be unable to prune its internal buffers, resulting in certain join types - such as Full, Left, LeftAnti, LeftSemi, Right, RightAnti, and RightSemi - being produced only at the end of the execution. This is not typical in stream processing. Additionally, without proper design for long runner execution, all types of joins may encounter out-of-memory errors. diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 079d9d950536..258d4dd6a7d3 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -23,7 +23,8 @@ use datafusion::common::logical_type::field::LogicalField; use datafusion::common::logical_type::schema::LogicalSchema; use datafusion::common::logical_type::{TypeRelation, ExtensionType}; use datafusion::common::{ - not_impl_err, substrait_datafusion_err, substrait_err, DFSchema, DFSchemaRef, + not_impl_datafusion_err, not_impl_err, plan_datafusion_err, plan_err, + substrait_datafusion_err, substrait_err, DFSchema, DFSchemaRef, }; use substrait::proto::expression::literal::IntervalDayToSecond; use substrait::proto::read_rel::local_files::file_or_files::PathType::UriFile; @@ -33,8 +34,7 @@ use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano}; use datafusion::execution::FunctionRegistry; use datafusion::logical_expr::{ aggregate_function, expr::find_df_window_func, Aggregate, BinaryExpr, Case, - EmptyRelation, Expr, ExprSchemable, LogicalPlan, Operator, Projection, ScalarUDF, - Values, + EmptyRelation, Expr, ExprSchemable, LogicalPlan, Operator, Projection, Values, }; use datafusion::logical_expr::{ @@ -60,7 +60,7 @@ use substrait::proto::{ reference_segment::ReferenceType::StructField, window_function::bound as SubstraitBound, window_function::bound::Kind as BoundKind, window_function::Bound, - MaskExpression, RexType, + window_function::BoundsType, MaskExpression, RexType, }, extensions::simple_extension_declaration::MappingType, function_argument::ArgType, @@ -74,7 +74,6 @@ use substrait::proto::{ use substrait::proto::{FunctionArgument, SortField}; use datafusion::arrow::array::GenericListArray; -use datafusion::common::plan_err; use datafusion::common::scalar::ScalarStructBuilder; use datafusion::logical_expr::expr::{InList, InSubquery, Sort}; use std::collections::HashMap; @@ -92,12 +91,6 @@ use crate::variation_const::{ UNSIGNED_INTEGER_TYPE_VARIATION_REF, }; -enum ScalarFunctionType { - Op(Operator), - Expr(BuiltinExprBuilder), - Udf(Arc), -} - pub fn name_to_op(name: &str) -> Result { match name { "equal" => Ok(Operator::Eq), @@ -131,28 +124,6 @@ pub fn name_to_op(name: &str) -> Result { } } -fn scalar_function_type_from_str( - ctx: &SessionContext, - name: &str, -) -> Result { - let s = ctx.state(); - let name = substrait_fun_name(name); - - if let Some(func) = s.scalar_functions().get(name) { - return Ok(ScalarFunctionType::Udf(func.to_owned())); - } - - if let Ok(op) = name_to_op(name) { - return Ok(ScalarFunctionType::Op(op)); - } - - if let Some(builder) = BuiltinExprBuilder::try_from_name(name) { - return Ok(ScalarFunctionType::Expr(builder)); - } - - not_impl_err!("Unsupported function name: {name:?}") -} - pub fn substrait_fun_name(name: &str) -> &str { let name = match name.rsplit_once(':') { // Since 0.32.0, Substrait requires the function names to be in a compound format @@ -975,7 +946,7 @@ pub async fn from_substrait_rex_vec( } /// Convert Substrait FunctionArguments to DataFusion Exprs -pub async fn from_substriat_func_args( +pub async fn from_substrait_func_args( ctx: &SessionContext, arguments: &Vec, input_schema: &DFSchema, @@ -987,9 +958,7 @@ pub async fn from_substriat_func_args( Some(ArgType::Value(e)) => { from_substrait_rex(ctx, e, input_schema, extensions).await } - _ => { - not_impl_err!("Aggregated function argument non-Value type not supported") - } + _ => not_impl_err!("Function argument non-Value type not supported"), }; args.push(arg_expr?.as_ref().clone()); } @@ -1006,18 +975,8 @@ pub async fn from_substrait_agg_func( order_by: Option>, distinct: bool, ) -> Result> { - let mut args: Vec = vec![]; - for arg in &f.arguments { - let arg_expr = match &arg.arg_type { - Some(ArgType::Value(e)) => { - from_substrait_rex(ctx, e, input_schema, extensions).await - } - _ => { - not_impl_err!("Aggregated function argument non-Value type not supported") - } - }; - args.push(arg_expr?.as_ref().clone()); - } + let args = + from_substrait_func_args(ctx, &f.arguments, input_schema, extensions).await?; let Some(function_name) = extensions.get(&f.function_reference) else { return plan_err!( @@ -1025,14 +984,16 @@ pub async fn from_substrait_agg_func( f.function_reference ); }; - // function_name.split(':').next().unwrap_or(function_name); + let function_name = substrait_fun_name((**function_name).as_str()); // try udaf first, then built-in aggr fn. if let Ok(fun) = ctx.udaf(function_name) { // deal with situation that count(*) got no arguments - if fun.name() == "count" && args.is_empty() { - args.push(Expr::Literal(ScalarValue::Int64(Some(1)))); - } + let args = if fun.name() == "count" && args.is_empty() { + vec![Expr::Literal(ScalarValue::Int64(Some(1)))] + } else { + args + }; Ok(Arc::new(Expr::AggregateFunction( expr::AggregateFunction::new_udf(fun, args, distinct, filter, order_by, None), @@ -1044,7 +1005,7 @@ pub async fn from_substrait_agg_func( ))) } else { not_impl_err!( - "Aggregated function {} is not supported: function anchor = {:?}", + "Aggregate function {} is not supported: function anchor = {:?}", function_name, f.function_reference ) @@ -1148,84 +1109,40 @@ pub async fn from_substrait_rex( }))) } Some(RexType::ScalarFunction(f)) => { - let fn_name = extensions.get(&f.function_reference).ok_or_else(|| { - DataFusionError::NotImplemented(format!( - "Aggregated function not found: function reference = {:?}", + let Some(fn_name) = extensions.get(&f.function_reference) else { + return plan_err!( + "Scalar function not found: function reference = {:?}", f.function_reference - )) - })?; - - // Convert function arguments from Substrait to DataFusion - async fn decode_arguments( - ctx: &SessionContext, - input_schema: &DFSchema, - extensions: &HashMap, - function_args: &[FunctionArgument], - ) -> Result> { - let mut args = Vec::with_capacity(function_args.len()); - for arg in function_args { - let arg_expr = match &arg.arg_type { - Some(ArgType::Value(e)) => { - from_substrait_rex(ctx, e, input_schema, extensions).await - } - _ => not_impl_err!( - "Aggregated function argument non-Value type not supported" - ), - }?; - args.push(arg_expr.as_ref().clone()); - } - Ok(args) - } + ); + }; + let fn_name = substrait_fun_name(fn_name); - let fn_type = scalar_function_type_from_str(ctx, fn_name)?; - match fn_type { - ScalarFunctionType::Udf(fun) => { - let args = decode_arguments( - ctx, - input_schema, - extensions, - f.arguments.as_slice(), - ) + let args = + from_substrait_func_args(ctx, &f.arguments, input_schema, extensions) .await?; - Ok(Arc::new(Expr::ScalarFunction( - expr::ScalarFunction::new_udf(fun, args), - ))) - } - ScalarFunctionType::Op(op) => { - if f.arguments.len() != 2 { - return not_impl_err!( - "Expect two arguments for binary operator {op:?}" - ); - } - let lhs = &f.arguments[0].arg_type; - let rhs = &f.arguments[1].arg_type; - - match (lhs, rhs) { - (Some(ArgType::Value(l)), Some(ArgType::Value(r))) => { - Ok(Arc::new(Expr::BinaryExpr(BinaryExpr { - left: Box::new( - from_substrait_rex(ctx, l, input_schema, extensions) - .await? - .as_ref() - .clone(), - ), - op, - right: Box::new( - from_substrait_rex(ctx, r, input_schema, extensions) - .await? - .as_ref() - .clone(), - ), - }))) - } - (l, r) => not_impl_err!( - "Invalid arguments for binary expression: {l:?} and {r:?}" - ), - } - } - ScalarFunctionType::Expr(builder) => { - builder.build(ctx, f, input_schema, extensions).await + + // try to first match the requested function into registered udfs, then built-in ops + // and finally built-in expressions + if let Some(func) = ctx.state().scalar_functions().get(fn_name) { + Ok(Arc::new(Expr::ScalarFunction( + expr::ScalarFunction::new_udf(func.to_owned(), args), + ))) + } else if let Ok(op) = name_to_op(fn_name) { + if args.len() != 2 { + return not_impl_err!( + "Expect two arguments for binary operator {op:?}" + ); } + + Ok(Arc::new(Expr::BinaryExpr(BinaryExpr { + left: Box::new(args[0].to_owned()), + op, + right: Box::new(args[1].to_owned()), + }))) + } else if let Some(builder) = BuiltinExprBuilder::try_from_name(fn_name) { + builder.build(ctx, f, input_schema, extensions).await + } else { + not_impl_err!("Unsupported function name: {fn_name:?}") } } Some(RexType::Literal(lit)) => { @@ -1250,36 +1167,50 @@ pub async fn from_substrait_rex( None => substrait_err!("Cast expression without output type is not allowed"), }, Some(RexType::WindowFunction(window)) => { - let fun = match extensions.get(&window.function_reference) { - Some(function_name) => { - // check udaf - match ctx.udaf(function_name) { - Ok(udaf) => { - Ok(Some(WindowFunctionDefinition::AggregateUDF(udaf))) - } - Err(_) => Ok(find_df_window_func(function_name)), - } - } - None => not_impl_err!( - "Window function not found: function anchor = {:?}", - &window.function_reference - ), + let Some(fn_name) = extensions.get(&window.function_reference) else { + return plan_err!( + "Window function not found: function reference = {:?}", + window.function_reference + ); }; + let fn_name = substrait_fun_name(fn_name); + + // check udaf first, then built-in functions + let fun = match ctx.udaf(fn_name) { + Ok(udaf) => Ok(WindowFunctionDefinition::AggregateUDF(udaf)), + Err(_) => find_df_window_func(fn_name).ok_or_else(|| { + not_impl_datafusion_err!( + "Window function {} is not supported: function anchor = {:?}", + fn_name, + window.function_reference + ) + }), + }?; + let order_by = from_substrait_sorts(ctx, &window.sorts, input_schema, extensions) .await?; - // Substrait does not encode WindowFrameUnits so we're using a simple logic to determine the units - // If there is no `ORDER BY`, then by default, the frame counts each row from the lower up to upper boundary - // If there is `ORDER BY`, then by default, each frame is a range starting from unbounded preceding to current row - // TODO: Consider the cases where window frame is specified in query and is different from default - let units = if order_by.is_empty() { - WindowFrameUnits::Rows - } else { - WindowFrameUnits::Range - }; + + let bound_units = + match BoundsType::try_from(window.bounds_type).map_err(|e| { + plan_datafusion_err!("Invalid bound type {}: {e}", window.bounds_type) + })? { + BoundsType::Rows => WindowFrameUnits::Rows, + BoundsType::Range => WindowFrameUnits::Range, + BoundsType::Unspecified => { + // If the plan does not specify the bounds type, then we use a simple logic to determine the units + // If there is no `ORDER BY`, then by default, the frame counts each row from the lower up to upper boundary + // If there is `ORDER BY`, then by default, each frame is a range starting from unbounded preceding to current row + if order_by.is_empty() { + WindowFrameUnits::Rows + } else { + WindowFrameUnits::Range + } + } + }; Ok(Arc::new(Expr::WindowFunction(expr::WindowFunction { - fun: fun?.unwrap(), - args: from_substriat_func_args( + fun, + args: from_substrait_func_args( ctx, &window.arguments, input_schema, @@ -1295,7 +1226,7 @@ pub async fn from_substrait_rex( .await?, order_by, window_frame: datafusion::logical_expr::WindowFrame::new_bounds( - units, + bound_units, from_substrait_bound(&window.lower_bound, true)?, from_substrait_bound(&window.upper_bound, false)?, ), @@ -1532,7 +1463,7 @@ fn from_substrait_struct_type( let field = Field::new( next_struct_field_name(i, dfs_names, name_idx)?, from_substrait_type(f, dfs_names, name_idx)?.physical().clone(), - is_substrait_type_nullable(f)?, + true, // We assume everything to be nullable since that's easier than ensuring it matches ); fields.push(field); } @@ -1576,47 +1507,6 @@ fn from_substrait_named_struct(base_schema: &NamedStruct) -> Result Ok(DFSchemaRef::new(DFSchema::try_from(Schema::new(fields?))?)) } -fn is_substrait_type_nullable(dtype: &Type) -> Result { - fn is_nullable(nullability: i32) -> bool { - nullability != substrait::proto::r#type::Nullability::Required as i32 - } - - let nullable = match dtype - .kind - .as_ref() - .ok_or_else(|| substrait_datafusion_err!("Type must contain Kind"))? - { - r#type::Kind::Bool(val) => is_nullable(val.nullability), - r#type::Kind::I8(val) => is_nullable(val.nullability), - r#type::Kind::I16(val) => is_nullable(val.nullability), - r#type::Kind::I32(val) => is_nullable(val.nullability), - r#type::Kind::I64(val) => is_nullable(val.nullability), - r#type::Kind::Fp32(val) => is_nullable(val.nullability), - r#type::Kind::Fp64(val) => is_nullable(val.nullability), - r#type::Kind::String(val) => is_nullable(val.nullability), - r#type::Kind::Binary(val) => is_nullable(val.nullability), - r#type::Kind::Timestamp(val) => is_nullable(val.nullability), - r#type::Kind::Date(val) => is_nullable(val.nullability), - r#type::Kind::Time(val) => is_nullable(val.nullability), - r#type::Kind::IntervalYear(val) => is_nullable(val.nullability), - r#type::Kind::IntervalDay(val) => is_nullable(val.nullability), - r#type::Kind::TimestampTz(val) => is_nullable(val.nullability), - r#type::Kind::Uuid(val) => is_nullable(val.nullability), - r#type::Kind::FixedChar(val) => is_nullable(val.nullability), - r#type::Kind::Varchar(val) => is_nullable(val.nullability), - r#type::Kind::FixedBinary(val) => is_nullable(val.nullability), - r#type::Kind::Decimal(val) => is_nullable(val.nullability), - r#type::Kind::PrecisionTimestamp(val) => is_nullable(val.nullability), - r#type::Kind::PrecisionTimestampTz(val) => is_nullable(val.nullability), - r#type::Kind::Struct(val) => is_nullable(val.nullability), - r#type::Kind::List(val) => is_nullable(val.nullability), - r#type::Kind::Map(val) => is_nullable(val.nullability), - r#type::Kind::UserDefined(val) => is_nullable(val.nullability), - r#type::Kind::UserDefinedTypeReference(_) => true, // not implemented, assume nullable - }; - Ok(nullable) -} - fn from_substrait_bound( bound: &Option, is_lower: bool, @@ -1796,8 +1686,9 @@ fn from_substrait_literal( for (i, field) in s.fields.iter().enumerate() { let name = next_struct_field_name(i, dfs_names, name_idx)?; let sv = from_substrait_literal(field, dfs_names, name_idx)?; - builder = builder - .with_scalar(Field::new(name, sv.data_type(), field.nullable), sv); + // We assume everything to be nullable, since Arrow's strict about things matching + // and it's hard to match otherwise. + builder = builder.with_scalar(Field::new(name, sv.data_type(), true), sv); } builder.build()? } diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 1a728dc1efbc..940f25234c02 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -2289,8 +2289,8 @@ mod test { ), )))?; - let c0 = Field::new("c0", DataType::Boolean, false); - let c1 = Field::new("c1", DataType::Int32, false); + let c0 = Field::new("c0", DataType::Boolean, true); + let c1 = Field::new("c1", DataType::Int32, true); let c2 = Field::new("c2", DataType::Utf8, true); round_trip_literal( ScalarStructBuilder::new() @@ -2363,7 +2363,7 @@ mod test { round_trip_type(DataType::Struct( vec![ Field::new("c0", DataType::Int32, true), - Field::new("c1", DataType::Utf8, false), + Field::new("c1", DataType::Utf8, true), ] .into(), ))?; diff --git a/datafusion/substrait/tests/cases/logical_plans.rs b/datafusion/substrait/tests/cases/logical_plans.rs index 94572e098b2c..6492febc938e 100644 --- a/datafusion/substrait/tests/cases/logical_plans.rs +++ b/datafusion/substrait/tests/cases/logical_plans.rs @@ -28,7 +28,7 @@ mod tests { use substrait::proto::Plan; #[tokio::test] - async fn function_compound_signature() -> Result<()> { + async fn scalar_function_compound_signature() -> Result<()> { // DataFusion currently produces Substrait that refers to functions only by their name. // However, the Substrait spec requires that functions be identified by their compound signature. // This test confirms that DataFusion is able to consume plans following the spec, even though @@ -39,7 +39,7 @@ mod tests { // File generated with substrait-java's Isthmus: // ./isthmus-cli/build/graal/isthmus "select not d from data" -c "create table data (d boolean)" - let proto = read_json("tests/testdata/select_not_bool.substrait.json"); + let proto = read_json("tests/testdata/test_plans/select_not_bool.substrait.json"); let plan = from_substrait_plan(&ctx, &proto).await?; @@ -51,13 +51,41 @@ mod tests { Ok(()) } + // Aggregate function compound signature is tested through TPCH plans + + #[tokio::test] + async fn window_function_compound_signature() -> Result<()> { + // DataFusion currently produces Substrait that refers to functions only by their name. + // However, the Substrait spec requires that functions be identified by their compound signature. + // This test confirms that DataFusion is able to consume plans following the spec, even though + // we don't yet produce such plans. + // Once we start producing plans with compound signatures, this test can be replaced by the roundtrip tests. + + let ctx = create_context().await?; + + // File generated with substrait-java's Isthmus: + // ./isthmus-cli/build/graal/isthmus "select sum(d) OVER (PARTITION BY part ORDER BY ord ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING) AS lead_expr from data" -c "create table data (d int, part int, ord int)" + let proto = read_json("tests/testdata/test_plans/select_window.substrait.json"); + + let plan = from_substrait_plan(&ctx, &proto).await?; + + assert_eq!( + format!("{:?}", plan), + "Projection: sum(DATA.a) PARTITION BY [DATA.b] ORDER BY [DATA.c ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING AS LEAD_EXPR\ + \n WindowAggr: windowExpr=[[sum(DATA.a) PARTITION BY [DATA.b] ORDER BY [DATA.c ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING]]\ + \n TableScan: DATA projection=[a, b, c, d, e, f]" + ); + Ok(()) + } + #[tokio::test] async fn non_nullable_lists() -> Result<()> { // DataFusion's Substrait consumer treats all lists as nullable, even if the Substrait plan specifies them as non-nullable. // That's because implementing the non-nullability consistently is non-trivial. // This test confirms that reading a plan with non-nullable lists works as expected. let ctx = create_context().await?; - let proto = read_json("tests/testdata/non_nullable_lists.substrait.json"); + let proto = + read_json("tests/testdata/test_plans/non_nullable_lists.substrait.json"); let plan = from_substrait_plan(&ctx, &proto).await?; diff --git a/datafusion/substrait/tests/testdata/non_nullable_lists.substrait.json b/datafusion/substrait/tests/testdata/test_plans/non_nullable_lists.substrait.json similarity index 100% rename from datafusion/substrait/tests/testdata/non_nullable_lists.substrait.json rename to datafusion/substrait/tests/testdata/test_plans/non_nullable_lists.substrait.json diff --git a/datafusion/substrait/tests/testdata/select_not_bool.substrait.json b/datafusion/substrait/tests/testdata/test_plans/select_not_bool.substrait.json similarity index 100% rename from datafusion/substrait/tests/testdata/select_not_bool.substrait.json rename to datafusion/substrait/tests/testdata/test_plans/select_not_bool.substrait.json diff --git a/datafusion/substrait/tests/testdata/test_plans/select_window.substrait.json b/datafusion/substrait/tests/testdata/test_plans/select_window.substrait.json new file mode 100644 index 000000000000..3082c4258f83 --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/select_window.substrait.json @@ -0,0 +1,153 @@ +{ + "extensionUris": [ + { + "extensionUriAnchor": 1, + "uri": "/functions_arithmetic.yaml" + } + ], + "extensions": [ + { + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 0, + "name": "sum:i32" + } + } + ], + "relations": [ + { + "root": { + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [ + 3 + ] + } + }, + "input": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": [ + "D", + "PART", + "ORD" + ], + "struct": { + "types": [ + { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": [ + "DATA" + ] + } + } + }, + "expressions": [ + { + "windowFunction": { + "functionReference": 0, + "partitions": [ + { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + } + ], + "sorts": [ + { + "expr": { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + }, + "direction": "SORT_DIRECTION_ASC_NULLS_LAST" + } + ], + "upperBound": { + "unbounded": { + } + }, + "lowerBound": { + "preceding": { + "offset": "1" + } + }, + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "args": [], + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + } + } + ], + "invocation": "AGGREGATION_INVOCATION_ALL", + "options": [], + "boundsType": "BOUNDS_TYPE_ROWS" + } + } + ] + } + }, + "names": [ + "LEAD_EXPR" + ] + } + } + ], + "expectedTypeUrls": [] +} diff --git a/docs/source/user-guide/configs.md b/docs/source/user-guide/configs.md index 0f0aa8460448..303caef57700 100644 --- a/docs/source/user-guide/configs.md +++ b/docs/source/user-guide/configs.md @@ -111,6 +111,7 @@ Environment variables are read during `SessionConfig` initialisation so they mus | datafusion.explain.physical_plan_only | false | When set to true, the explain statement will only print physical plans | | datafusion.explain.show_statistics | false | When set to true, the explain statement will print operator statistics for physical plans | | datafusion.explain.show_sizes | true | When set to true, the explain statement will print the partition sizes | +| datafusion.explain.show_schema | false | When set to true, the explain statement will print schema information | | datafusion.sql_parser.parse_float_as_decimal | false | When set to true, SQL parser will parse float as decimal type | | datafusion.sql_parser.enable_ident_normalization | true | When set to true, SQL parser will normalize ident (convert ident to lowercase when not quoted) | | datafusion.sql_parser.dialect | generic | Configure the SQL dialect used by DataFusion's parser; supported values include: Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, MsSQL, ClickHouse, BigQuery, and Ansi. |