diff --git a/.asf.yaml b/.asf.yaml index 5fe94dc04af5..aab8c1e6df2d 100644 --- a/.asf.yaml +++ b/.asf.yaml @@ -50,6 +50,9 @@ github: main: required_pull_request_reviews: required_approving_review_count: 1 + pull_requests: + # enable updating head branches of pull requests + allow_update_branch: true # publishes the content of the `asf-site` branch to # https://datafusion.apache.org/ diff --git a/.github/workflows/docs_pr.yaml b/.github/workflows/docs_pr.yaml index d3c901c5b71b..8d11cdf9d39b 100644 --- a/.github/workflows/docs_pr.yaml +++ b/.github/workflows/docs_pr.yaml @@ -34,26 +34,7 @@ on: workflow_dispatch: jobs: - # Run doc tests - linux-test-doc: - name: cargo doctest (amd64) - runs-on: ubuntu-latest - container: - image: amd64/rust - steps: - - uses: actions/checkout@v4 - with: - submodules: true - fetch-depth: 1 - - name: Setup Rust toolchain - uses: ./.github/actions/setup-builder - with: - rust-version: stable - - name: Run doctests (embedded rust examples) - run: cargo test --doc --features avro,json - - name: Verify Working Directory Clean - run: git diff --exit-code - + # Test doc build linux-test-doc-build: name: Test doc build diff --git a/.github/workflows/extended.yml b/.github/workflows/extended.yml index d80fdb75d932..0ccecfc44fd6 100644 --- a/.github/workflows/extended.yml +++ b/.github/workflows/extended.yml @@ -101,7 +101,17 @@ jobs: - name: Run tests (excluding doctests) env: RUST_BACKTRACE: 1 - run: cargo test --profile ci --exclude datafusion-examples --exclude datafusion-benchmarks --workspace --lib --tests --bins --features avro,json,backtrace,extended_tests,recursive_protection + run: | + cargo test \ + --profile ci \ + --exclude datafusion-examples \ + --exclude datafusion-benchmarks \ + --exclude datafusion-cli \ + --workspace \ + --lib \ + --tests \ + --bins \ + --features avro,json,backtrace,extended_tests,recursive_protection - name: Verify Working Directory Clean run: git diff --exit-code - name: Cleanup diff --git a/.github/workflows/dev_pr.yml b/.github/workflows/labeler.yml similarity index 96% rename from .github/workflows/dev_pr.yml rename to .github/workflows/labeler.yml index 11c14c5c2fee..8b251552d3b2 100644 --- a/.github/workflows/dev_pr.yml +++ b/.github/workflows/labeler.yml @@ -49,7 +49,7 @@ jobs: uses: actions/labeler@v5.0.0 with: repo-token: ${{ secrets.GITHUB_TOKEN }} - configuration-path: .github/workflows/dev_pr/labeler.yml + configuration-path: .github/workflows/labeler/labeler-config.yml sync-labels: true # TODO: Enable this when eps1lon/actions-label-merge-conflict is available. diff --git a/.github/workflows/dev_pr/labeler.yml b/.github/workflows/labeler/labeler-config.yml similarity index 95% rename from .github/workflows/dev_pr/labeler.yml rename to .github/workflows/labeler/labeler-config.yml index da93e6541855..e40813072521 100644 --- a/.github/workflows/dev_pr/labeler.yml +++ b/.github/workflows/labeler/labeler-config.yml @@ -41,7 +41,7 @@ physical-expr: physical-plan: - changed-files: - - any-glob-to-any-file: [datafusion/physical-plan/**/*'] + - any-glob-to-any-file: ['datafusion/physical-plan/**/*'] catalog: @@ -77,6 +77,10 @@ proto: - changed-files: - any-glob-to-any-file: ['datafusion/proto/**/*', 'datafusion/proto-common/**/*'] +spark: +- changed-files: + - any-glob-to-any-file: ['datafusion/spark/**/*'] + substrait: - changed-files: - any-glob-to-any-file: ['datafusion/substrait/**/*'] diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 3fa8ce080474..0734431a387b 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -401,8 +401,8 @@ jobs: - name: Run tests with headless mode working-directory: ./datafusion/wasmtest run: | - wasm-pack test --headless --firefox - wasm-pack test --headless --chrome --chromedriver $CHROMEWEBDRIVER/chromedriver + RUSTFLAGS='--cfg getrandom_backend="wasm_js"' wasm-pack test --headless --firefox + RUSTFLAGS='--cfg getrandom_backend="wasm_js"' wasm-pack test --headless --chrome --chromedriver $CHROMEWEBDRIVER/chromedriver # verify that the benchmark queries return the correct results verify-benchmark-results: @@ -476,6 +476,28 @@ jobs: POSTGRES_HOST: postgres POSTGRES_PORT: ${{ job.services.postgres.ports[5432] }} + sqllogictest-substrait: + name: "Run sqllogictest in Substrait round-trip mode" + needs: linux-build-lib + runs-on: ubuntu-latest + container: + image: amd64/rust + steps: + - uses: actions/checkout@v4 + with: + submodules: true + fetch-depth: 1 + - name: Setup Rust toolchain + uses: ./.github/actions/setup-builder + with: + rust-version: stable + - name: Run sqllogictest + # TODO: Right now several tests are failing in Substrait round-trip mode, so this + # command cannot be run for all the .slt files. Run it for just one that works (limit.slt) + # until most of the tickets in https://github.com/apache/datafusion/issues/16248 are addressed + # and this command can be run without filters. + run: cargo test --test sqllogictests -- --substrait-round-trip limit.slt + # Temporarily commenting out the Windows flow, the reason is enormously slow running build # Waiting for new Windows 2025 github runner # Details: https://github.com/apache/datafusion/issues/13726 @@ -693,6 +715,11 @@ jobs: # If you encounter an error, run './dev/update_function_docs.sh' and commit ./dev/update_function_docs.sh git diff --exit-code + - name: Check if runtime_configs.md has been modified + run: | + # If you encounter an error, run './dev/update_runtime_config_docs.sh' and commit + ./dev/update_runtime_config_docs.sh + git diff --exit-code # Verify MSRV for the crates which are directly used by other projects: # - datafusion diff --git a/Cargo.lock b/Cargo.lock index 0c526e69fdcd..92bfd48c5142 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -77,23 +77,23 @@ version = "0.7.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "891477e0c6a8957309ee5c45a6368af3ae14bb510732d2684ffa19af310920f9" dependencies = [ - "getrandom 0.2.15", + "getrandom 0.2.16", "once_cell", "version_check", ] [[package]] name = "ahash" -version = "0.8.11" +version = "0.8.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" +checksum = "5a15f179cd60c4584b8a8c596927aadc462e27f2ca70c04e0071964a73ba7a75" dependencies = [ "cfg-if", "const-random", - "getrandom 0.2.15", + "getrandom 0.3.3", "once_cell", "version_check", - "zerocopy 0.7.35", + "zerocopy", ] [[package]] @@ -199,9 +199,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.95" +version = "1.0.98" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34ac096ce696dc2fcabef30516bb13c0a68a11d30131d3df6f04711467681b04" +checksum = "e16d2d3311acee920a9eb8d33b8cbc1787ce4a264e85f964c2404b969bdcd487" [[package]] name = "apache-avro" @@ -223,8 +223,8 @@ dependencies = [ "serde_bytes", "serde_json", "snap", - "strum 0.26.3", - "strum_macros 0.26.4", + "strum", + "strum_macros", "thiserror 1.0.69", "typed-builder", "uuid", @@ -246,9 +246,9 @@ checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" [[package]] name = "arrow" -version = "55.0.0" +version = "55.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3095aaf545942ff5abd46654534f15b03a90fba78299d661e045e5d587222f0d" +checksum = "b1bb018b6960c87fd9d025009820406f74e83281185a8bdcb44880d2aa5c9a87" dependencies = [ "arrow-arith", "arrow-array", @@ -265,14 +265,14 @@ dependencies = [ "arrow-string", "half", "pyo3", - "rand 0.9.0", + "rand 0.9.1", ] [[package]] name = "arrow-arith" -version = "55.0.0" +version = "55.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "00752064ff47cee746e816ddb8450520c3a52cbad1e256f6fa861a35f86c45e7" +checksum = "44de76b51473aa888ecd6ad93ceb262fb8d40d1f1154a4df2f069b3590aa7575" dependencies = [ "arrow-array", "arrow-buffer", @@ -284,26 +284,26 @@ dependencies = [ [[package]] name = "arrow-array" -version = "55.0.0" +version = "55.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cebfe926794fbc1f49ddd0cdaf898956ca9f6e79541efce62dabccfd81380472" +checksum = "29ed77e22744475a9a53d00026cf8e166fe73cf42d89c4c4ae63607ee1cfcc3f" dependencies = [ - "ahash 0.8.11", + "ahash 0.8.12", "arrow-buffer", "arrow-data", "arrow-schema", "chrono", "chrono-tz", "half", - "hashbrown 0.15.2", + "hashbrown 0.15.3", "num", ] [[package]] name = "arrow-buffer" -version = "55.0.0" +version = "55.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0303c7ec4cf1a2c60310fc4d6bbc3350cd051a17bf9e9c0a8e47b4db79277824" +checksum = "b0391c96eb58bf7389171d1e103112d3fc3e5625ca6b372d606f2688f1ea4cce" dependencies = [ "bytes", "half", @@ -312,9 +312,9 @@ dependencies = [ [[package]] name = "arrow-cast" -version = "55.0.0" +version = "55.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "335f769c5a218ea823d3760a743feba1ef7857cba114c01399a891c2fff34285" +checksum = "f39e1d774ece9292697fcbe06b5584401b26bd34be1bec25c33edae65c2420ff" dependencies = [ "arrow-array", "arrow-buffer", @@ -333,9 +333,9 @@ dependencies = [ [[package]] name = "arrow-csv" -version = "55.0.0" +version = "55.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "510db7dfbb4d5761826516cc611d97b3a68835d0ece95b034a052601109c0b1b" +checksum = "9055c972a07bf12c2a827debfd34f88d3b93da1941d36e1d9fee85eebe38a12a" dependencies = [ "arrow-array", "arrow-cast", @@ -349,9 +349,9 @@ dependencies = [ [[package]] name = "arrow-data" -version = "55.0.0" +version = "55.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8affacf3351a24039ea24adab06f316ded523b6f8c3dbe28fbac5f18743451b" +checksum = "cf75ac27a08c7f48b88e5c923f267e980f27070147ab74615ad85b5c5f90473d" dependencies = [ "arrow-buffer", "arrow-schema", @@ -361,9 +361,9 @@ dependencies = [ [[package]] name = "arrow-flight" -version = "55.0.0" +version = "55.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2e0fad280f41a918d53ba48288a246ff04202d463b3b380fbc0edecdcb52cfd" +checksum = "91efc67a4f5a438833dd76ef674745c80f6f6b9a428a3b440cbfbf74e32867e6" dependencies = [ "arrow-arith", "arrow-array", @@ -388,9 +388,9 @@ dependencies = [ [[package]] name = "arrow-ipc" -version = "55.0.0" +version = "55.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69880a9e6934d9cba2b8630dd08a3463a91db8693b16b499d54026b6137af284" +checksum = "a222f0d93772bd058d1268f4c28ea421a603d66f7979479048c429292fac7b2e" dependencies = [ "arrow-array", "arrow-buffer", @@ -402,9 +402,9 @@ dependencies = [ [[package]] name = "arrow-json" -version = "55.0.0" +version = "55.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d8dafd17a05449e31e0114d740530e0ada7379d7cb9c338fd65b09a8130960b0" +checksum = "9085342bbca0f75e8cb70513c0807cc7351f1fbf5cb98192a67d5e3044acb033" dependencies = [ "arrow-array", "arrow-buffer", @@ -424,9 +424,9 @@ dependencies = [ [[package]] name = "arrow-ord" -version = "55.0.0" +version = "55.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "895644523af4e17502d42c3cb6b27cb820f0cb77954c22d75c23a85247c849e1" +checksum = "ab2f1065a5cad7b9efa9e22ce5747ce826aa3855766755d4904535123ef431e7" dependencies = [ "arrow-array", "arrow-buffer", @@ -437,9 +437,9 @@ dependencies = [ [[package]] name = "arrow-row" -version = "55.0.0" +version = "55.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9be8a2a4e5e7d9c822b2b8095ecd77010576d824f654d347817640acfc97d229" +checksum = "3703a0e3e92d23c3f756df73d2dc9476873f873a76ae63ef9d3de17fda83b2d8" dependencies = [ "arrow-array", "arrow-buffer", @@ -450,21 +450,22 @@ dependencies = [ [[package]] name = "arrow-schema" -version = "55.0.0" +version = "55.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7450c76ab7c5a6805be3440dc2e2096010da58f7cab301fdc996a4ee3ee74e49" +checksum = "73a47aa0c771b5381de2b7f16998d351a6f4eb839f1e13d48353e17e873d969b" dependencies = [ - "bitflags 2.8.0", + "bitflags 2.9.1", "serde", + "serde_json", ] [[package]] name = "arrow-select" -version = "55.0.0" +version = "55.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aa5f5a93c75f46ef48e4001535e7b6c922eeb0aa20b73cf58d09e13d057490d8" +checksum = "24b7b85575702b23b85272b01bc1c25a01c9b9852305e5d0078c79ba25d995d4" dependencies = [ - "ahash 0.8.11", + "ahash 0.8.12", "arrow-array", "arrow-buffer", "arrow-data", @@ -474,9 +475,9 @@ dependencies = [ [[package]] name = "arrow-string" -version = "55.0.0" +version = "55.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e7005d858d84b56428ba2a98a107fe88c0132c61793cf6b8232a1f9bfc0452b" +checksum = "9260fddf1cdf2799ace2b4c2fc0356a9789fa7551e0953e35435536fecefebbd" dependencies = [ "arrow-array", "arrow-buffer", @@ -503,9 +504,9 @@ dependencies = [ [[package]] name = "assert_cmd" -version = "2.0.16" +version = "2.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc1835b7f27878de8525dc71410b5a31cdcc5f230aed5ba5df968e09c201b23d" +checksum = "2bd389a4b2970a01282ee455294913c0a43724daedcd1a24c3eb0ec1c1320b66" dependencies = [ "anstyle", "bstr", @@ -551,7 +552,7 @@ checksum = "3b43422f69d8ff38f95f1b2bb76517c91589a924d1559a0e935d7c8ce0274c11" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] @@ -573,7 +574,7 @@ checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] @@ -584,7 +585,7 @@ checksum = "e539d3fca749fcee5236ab05e93a52867dd549cc157c8cb7f99595f3cedffdb5" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] @@ -610,9 +611,9 @@ checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" [[package]] name = "aws-config" -version = "1.6.1" +version = "1.6.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c39646d1a6b51240a1a23bb57ea4eebede7e16fbc237fdc876980233dcecb4f" +checksum = "02a18fd934af6ae7ca52410d4548b98eb895aab0f1ea417d168d85db1434a141" dependencies = [ "aws-credential-types", "aws-runtime", @@ -629,7 +630,7 @@ dependencies = [ "bytes", "fastrand", "hex", - "http 1.2.0", + "http 1.3.1", "ring", "time", "tokio", @@ -640,9 +641,9 @@ dependencies = [ [[package]] name = "aws-credential-types" -version = "1.2.2" +version = "1.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4471bef4c22a06d2c7a1b6492493d3fdf24a805323109d6874f9c94d5906ac14" +checksum = "687bc16bc431a8533fe0097c7f0182874767f920989d7260950172ae8e3c4465" dependencies = [ "aws-smithy-async", "aws-smithy-runtime-api", @@ -652,9 +653,9 @@ dependencies = [ [[package]] name = "aws-lc-rs" -version = "1.12.6" +version = "1.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dabb68eb3a7aa08b46fddfd59a3d55c978243557a90ab804769f7e20e67d2b01" +checksum = "93fcc8f365936c834db5514fc45aee5b1202d677e6b40e48468aaaa8183ca8c7" dependencies = [ "aws-lc-sys", "zeroize", @@ -662,9 +663,9 @@ dependencies = [ [[package]] name = "aws-lc-sys" -version = "0.27.0" +version = "0.29.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6bbe221bbf523b625a4dd8585c7f38166e31167ec2ca98051dbcb4c3b6e825d2" +checksum = "61b1d86e7705efe1be1b569bab41d4fa1e14e220b60a160f78de2db687add079" dependencies = [ "bindgen", "cc", @@ -675,9 +676,9 @@ dependencies = [ [[package]] name = "aws-runtime" -version = "1.5.6" +version = "1.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0aff45ffe35196e593ea3b9dd65b320e51e2dda95aff4390bc459e461d09c6ad" +checksum = "6c4063282c69991e57faab9e5cb21ae557e59f5b0fb285c196335243df8dc25c" dependencies = [ "aws-credential-types", "aws-sigv4", @@ -691,7 +692,6 @@ dependencies = [ "fastrand", "http 0.2.12", "http-body 0.4.6", - "once_cell", "percent-encoding", "pin-project-lite", "tracing", @@ -700,9 +700,9 @@ dependencies = [ [[package]] name = "aws-sdk-sso" -version = "1.63.0" +version = "1.70.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1cb45b83b53b5cd55ee33fd9fd8a70750255a3f286e4dca20e882052f2b256f" +checksum = "83447efb7179d8e2ad2afb15ceb9c113debbc2ecdf109150e338e2e28b86190b" dependencies = [ "aws-credential-types", "aws-runtime", @@ -716,16 +716,15 @@ dependencies = [ "bytes", "fastrand", "http 0.2.12", - "once_cell", "regex-lite", "tracing", ] [[package]] name = "aws-sdk-ssooidc" -version = "1.64.0" +version = "1.71.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8d4d9bc075ea6238778ed3951b65d3cde8c3864282d64fdcd19f2a90c0609f1" +checksum = "c5f9bfbbda5e2b9fe330de098f14558ee8b38346408efe9f2e9cee82dc1636a4" dependencies = [ "aws-credential-types", "aws-runtime", @@ -739,16 +738,15 @@ dependencies = [ "bytes", "fastrand", "http 0.2.12", - "once_cell", "regex-lite", "tracing", ] [[package]] name = "aws-sdk-sts" -version = "1.64.0" +version = "1.71.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "819ccba087f403890fee4825eeab460e64c59345667d2b83a12cf544b581e3a7" +checksum = "e17b984a66491ec08b4f4097af8911251db79296b3e4a763060b45805746264f" dependencies = [ "aws-credential-types", "aws-runtime", @@ -763,16 +761,15 @@ dependencies = [ "aws-types", "fastrand", "http 0.2.12", - "once_cell", "regex-lite", "tracing", ] [[package]] name = "aws-sigv4" -version = "1.3.0" +version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69d03c3c05ff80d54ff860fe38c726f6f494c639ae975203a101335f223386db" +checksum = "3734aecf9ff79aa401a6ca099d076535ab465ff76b46440cf567c8e70b65dc13" dependencies = [ "aws-credential-types", "aws-smithy-http", @@ -783,8 +780,7 @@ dependencies = [ "hex", "hmac", "http 0.2.12", - "http 1.2.0", - "once_cell", + "http 1.3.1", "percent-encoding", "sha2", "time", @@ -804,9 +800,9 @@ dependencies = [ [[package]] name = "aws-smithy-http" -version = "0.62.0" +version = "0.62.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c5949124d11e538ca21142d1fba61ab0a2a2c1bc3ed323cdb3e4b878bfb83166" +checksum = "99335bec6cdc50a346fda1437f9fefe33abf8c99060739a546a16457f2862ca9" dependencies = [ "aws-smithy-runtime-api", "aws-smithy-types", @@ -814,9 +810,8 @@ dependencies = [ "bytes-utils", "futures-core", "http 0.2.12", - "http 1.2.0", + "http 1.3.1", "http-body 0.4.6", - "once_cell", "percent-encoding", "pin-project-lite", "pin-utils", @@ -825,15 +820,15 @@ dependencies = [ [[package]] name = "aws-smithy-http-client" -version = "1.0.0" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0497ef5d53065b7cd6a35e9c1654bd1fefeae5c52900d91d1b188b0af0f29324" +checksum = "7e44697a9bded898dcd0b1cb997430d949b87f4f8940d91023ae9062bf218250" dependencies = [ "aws-smithy-async", "aws-smithy-runtime-api", "aws-smithy-types", "h2", - "http 1.2.0", + "http 1.3.1", "hyper", "hyper-rustls", "hyper-util", @@ -857,12 +852,11 @@ dependencies = [ [[package]] name = "aws-smithy-observability" -version = "0.1.2" +version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "445d065e76bc1ef54963db400319f1dd3ebb3e0a74af20f7f7630625b0cc7cc0" +checksum = "9364d5989ac4dd918e5cc4c4bdcc61c9be17dcd2586ea7f69e348fc7c6cab393" dependencies = [ "aws-smithy-runtime-api", - "once_cell", ] [[package]] @@ -877,9 +871,9 @@ dependencies = [ [[package]] name = "aws-smithy-runtime" -version = "1.8.1" +version = "1.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0152749e17ce4d1b47c7747bdfec09dac1ccafdcbc741ebf9daa2a373356730f" +checksum = "14302f06d1d5b7d333fd819943075b13d27c7700b414f574c3c35859bfb55d5e" dependencies = [ "aws-smithy-async", "aws-smithy-http", @@ -890,10 +884,9 @@ dependencies = [ "bytes", "fastrand", "http 0.2.12", - "http 1.2.0", + "http 1.3.1", "http-body 0.4.6", "http-body 1.0.1", - "once_cell", "pin-project-lite", "pin-utils", "tokio", @@ -902,15 +895,15 @@ dependencies = [ [[package]] name = "aws-smithy-runtime-api" -version = "1.7.4" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3da37cf5d57011cb1753456518ec76e31691f1f474b73934a284eb2a1c76510f" +checksum = "a1e5d9e3a80a18afa109391fb5ad09c3daf887b516c6fd805a157c6ea7994a57" dependencies = [ "aws-smithy-async", "aws-smithy-types", "bytes", "http 0.2.12", - "http 1.2.0", + "http 1.3.1", "pin-project-lite", "tokio", "tracing", @@ -919,15 +912,15 @@ dependencies = [ [[package]] name = "aws-smithy-types" -version = "1.3.0" +version = "1.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "836155caafba616c0ff9b07944324785de2ab016141c3550bd1c07882f8cee8f" +checksum = "40076bd09fadbc12d5e026ae080d0930defa606856186e31d83ccc6a255eeaf3" dependencies = [ "base64-simd", "bytes", "bytes-utils", "http 0.2.12", - "http 1.2.0", + "http 1.3.1", "http-body 0.4.6", "http-body 1.0.1", "http-body-util", @@ -951,9 +944,9 @@ dependencies = [ [[package]] name = "aws-types" -version = "1.3.6" +version = "1.3.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3873f8deed8927ce8d04487630dc9ff73193bab64742a61d050e57a68dec4125" +checksum = "8a322fec39e4df22777ed3ad8ea868ac2f94cd15e1a55f6ee8d8d6305057689a" dependencies = [ "aws-credential-types", "aws-smithy-async", @@ -973,7 +966,7 @@ dependencies = [ "axum-core", "bytes", "futures-util", - "http 1.2.0", + "http 1.3.1", "http-body 1.0.1", "http-body-util", "itoa", @@ -999,7 +992,7 @@ dependencies = [ "async-trait", "bytes", "futures-util", - "http 1.2.0", + "http 1.3.1", "http-body 1.0.1", "http-body-util", "mime", @@ -1012,9 +1005,9 @@ dependencies = [ [[package]] name = "backtrace" -version = "0.3.74" +version = "0.3.75" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d82cb332cdfaed17ae235a638438ac4d4839913cc2af585c3c6746e8f8bee1a" +checksum = "6806a6321ec58106fea15becdad98371e28d92ccbc7c8f1b3b6dd724fe8f1002" dependencies = [ "addr2line", "cfg-if", @@ -1067,10 +1060,10 @@ version = "0.69.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "271383c67ccabffb7381723dea0672a673f292304fcb45c01cc648c7a8d58088" dependencies = [ - "bitflags 2.8.0", + "bitflags 2.9.1", "cexpr", "clang-sys", - "itertools 0.10.5", + "itertools 0.12.1", "lazy_static", "lazycell", "log", @@ -1080,7 +1073,7 @@ dependencies = [ "regex", "rustc-hash 1.1.0", "shlex", - "syn 2.0.100", + "syn 2.0.101", "which", ] @@ -1092,9 +1085,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.8.0" +version = "2.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f68f53c83ab957f72c32642f3868eec03eb974d1fb82e453128456482613d36" +checksum = "1b8e56985ec62d17e9c1001dc89c88ecd7dc08e47eba5ec7c29c7b5eeecde967" [[package]] name = "bitvec" @@ -1119,9 +1112,9 @@ dependencies = [ [[package]] name = "blake3" -version = "1.8.1" +version = "1.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "389a099b34312839e16420d499a9cad9650541715937ffbdd40d36f49e77eeb3" +checksum = "3888aaa89e4b2a40fca9848e400f6a658a5a3978de7be858e209cafa8be9a4a0" dependencies = [ "arrayref", "arrayvec", @@ -1152,7 +1145,7 @@ dependencies = [ "futures-util", "hex", "home", - "http 1.2.0", + "http 1.3.1", "http-body-util", "hyper", "hyper-named-pipe", @@ -1191,9 +1184,9 @@ dependencies = [ [[package]] name = "borsh" -version = "1.5.5" +version = "1.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5430e3be710b68d984d1391c854eb431a9d548640711faa54eecb1df93db91cc" +checksum = "ad8646f98db542e39fc66e68a20b2144f6a732636df7c2354e74645faaa433ce" dependencies = [ "borsh-derive", "cfg_aliases", @@ -1201,22 +1194,22 @@ dependencies = [ [[package]] name = "borsh-derive" -version = "1.5.5" +version = "1.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8b668d39970baad5356d7c83a86fee3a539e6f93bf6764c97368243e17a0487" +checksum = "fdd1d3c0c2f5833f22386f252fe8ed005c7f59fdcddeef025c01b4c3b9fd9ac3" dependencies = [ "once_cell", "proc-macro-crate", "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] name = "brotli" -version = "7.0.0" +version = "8.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc97b8f16f944bba54f0433f07e30be199b6dc2bd25937444bbad560bcea29bd" +checksum = "9991eea70ea4f293524138648e41ee89b0b2b12ddef3b255effa43c8056e0e0d" dependencies = [ "alloc-no-stdlib", "alloc-stdlib", @@ -1225,9 +1218,9 @@ dependencies = [ [[package]] name = "brotli-decompressor" -version = "4.0.2" +version = "5.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "74fa05ad7d803d413eb8380983b092cbbaf9a85f151b871360e7b00cd7060b37" +checksum = "874bb8112abecc98cbd6d81ea4fa7e94fb9449648c93cc89aa40c81c24d7de03" dependencies = [ "alloc-no-stdlib", "alloc-stdlib", @@ -1235,9 +1228,9 @@ dependencies = [ [[package]] name = "bstr" -version = "1.11.3" +version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "531a9155a481e2ee699d4f98f43c0ca4ff8ee1bfd55c31e9e98fb29d2b176fe0" +checksum = "234113d19d0d7d613b40e86fb654acf958910802bcceab913a4f9e7cda03b1a4" dependencies = [ "memchr", "regex-automata", @@ -1331,9 +1324,9 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cc" -version = "1.2.14" +version = "1.2.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c3d1b2e905a3a7b00a6141adb0e4c0bb941d11caf55349d863942a1cc44e3c9" +checksum = "5f4ac86a9e5bc1e2b3449ab9d7d3a6a405e3d1bb28d7b9be8614f55846ae3766" dependencies = [ "jobserver", "libc", @@ -1363,9 +1356,9 @@ checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" [[package]] name = "chrono" -version = "0.4.40" +version = "0.4.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a7964611d71df112cb1730f2ee67324fcf4d0fc6606acbbe9bfe06df124637c" +checksum = "c469d952047f47f91b68d1cba3f10d63c11d73e4636f24f08daf0278abf01c4d" dependencies = [ "android-tzdata", "iana-time-zone", @@ -1389,9 +1382,9 @@ dependencies = [ [[package]] name = "chrono-tz-build" -version = "0.4.0" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e94fea34d77a245229e7746bd2beb786cd2a896f306ff491fb8cecb3074b10a7" +checksum = "8f10f8c9340e31fc120ff885fcdb54a0b48e474bbd77cab557f0c30a3e569402" dependencies = [ "parse-zoneinfo", "phf_codegen", @@ -1432,7 +1425,7 @@ checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4" dependencies = [ "glob", "libc", - "libloading 0.8.6", + "libloading 0.8.7", ] [[package]] @@ -1448,9 +1441,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.35" +version = "4.5.39" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d8aa86934b44c19c50f87cc2790e19f54f7a67aedb64101c2e1a2e5ecfb73944" +checksum = "fd60e63e9be68e5fb56422e397cf9baddded06dae1d2e523401542383bc72a9f" dependencies = [ "clap_builder", "clap_derive", @@ -1458,9 +1451,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.35" +version = "4.5.39" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2414dbb2dd0695280da6ea9261e327479e9d37b0630f6b53ba2a11c60c679fd9" +checksum = "89cc6392a1f72bbeb820d71f32108f61fdaf18bc526e1d23954168a67759ef51" dependencies = [ "anstream", "anstyle", @@ -1477,7 +1470,7 @@ dependencies = [ "heck 0.5.0", "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] @@ -1522,9 +1515,9 @@ dependencies = [ [[package]] name = "console" -version = "0.15.10" +version = "0.15.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ea3c6ecd8059b57859df5c69830340ed3c41d30e3da0c1cbed90a96ac853041b" +checksum = "054ccb5b10f9f2cbf51eb355ca1d05c2d279ce1804688d0db74b4733a5aeafd8" dependencies = [ "encode_unicode", "libc", @@ -1558,7 +1551,7 @@ version = "0.1.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f9d839f2a20b0aee515dc581a6172f2321f96cab76c1a38a4c584a194955390e" dependencies = [ - "getrandom 0.2.15", + "getrandom 0.2.16", "once_cell", "tiny-keccak", ] @@ -1642,7 +1635,7 @@ dependencies = [ "anes", "cast", "ciborium", - "clap 4.5.35", + "clap 4.5.39", "criterion-plot", "futures", "is-terminal", @@ -1744,19 +1737,25 @@ dependencies = [ [[package]] name = "ctor" -version = "0.2.9" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32a2785755761f3ddc1492979ce1e48d2c00d09311c39e4466429188f3dd6501" +checksum = "a4735f265ba6a1188052ca32d461028a7d1125868be18e287e756019da7607b5" dependencies = [ - "quote", - "syn 2.0.100", + "ctor-proc-macro", + "dtor", ] +[[package]] +name = "ctor-proc-macro" +version = "0.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4f211af61d8efdd104f96e57adf5e426ba1bc3ed7a4ead616e15e5881fd79c4d" + [[package]] name = "darling" -version = "0.20.10" +version = "0.20.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f63b86c8a8826a49b8c21f08a2d07338eec8d900540f8630dc76284be802989" +checksum = "fc7f46116c46ff9ab3eb1597a45688b6715c6e628b5c133e288e709a29bcb4ee" dependencies = [ "darling_core", "darling_macro", @@ -1764,27 +1763,27 @@ dependencies = [ [[package]] name = "darling_core" -version = "0.20.10" +version = "0.20.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95133861a8032aaea082871032f5815eb9e98cef03fa916ab4500513994df9e5" +checksum = "0d00b9596d185e565c2207a0b01f8bd1a135483d02d9b7b0a54b11da8d53412e" dependencies = [ "fnv", "ident_case", "proc-macro2", "quote", "strsim", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] name = "darling_macro" -version = "0.20.10" +version = "0.20.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d336a2a514f6ccccaa3e09b02d41d35330c07ddf03a62165fcec10bb561c7806" +checksum = "fc34b93ccb385b40dc71c6fceac4b2ad23662c7eeb248cf10d529b7e055b6ead" dependencies = [ "darling_core", "quote", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] @@ -1809,7 +1808,7 @@ dependencies = [ [[package]] name = "datafusion" -version = "47.0.0" +version = "48.0.1" dependencies = [ "arrow", "arrow-ipc", @@ -1860,7 +1859,7 @@ dependencies = [ "parking_lot", "parquet", "paste", - "rand 0.8.5", + "rand 0.9.1", "rand_distr", "regex", "rstest", @@ -1879,7 +1878,7 @@ dependencies = [ [[package]] name = "datafusion-benchmarks" -version = "47.0.0" +version = "48.0.1" dependencies = [ "arrow", "datafusion", @@ -1891,7 +1890,7 @@ dependencies = [ "mimalloc", "object_store", "parquet", - "rand 0.8.5", + "rand 0.9.1", "serde", "serde_json", "snmalloc-rs", @@ -1903,7 +1902,7 @@ dependencies = [ [[package]] name = "datafusion-catalog" -version = "47.0.0" +version = "48.0.1" dependencies = [ "arrow", "async-trait", @@ -1927,7 +1926,7 @@ dependencies = [ [[package]] name = "datafusion-catalog-listing" -version = "47.0.0" +version = "48.0.1" dependencies = [ "arrow", "async-trait", @@ -1950,14 +1949,14 @@ dependencies = [ [[package]] name = "datafusion-cli" -version = "47.0.0" +version = "48.0.1" dependencies = [ "arrow", "assert_cmd", "async-trait", "aws-config", "aws-credential-types", - "clap 4.5.35", + "clap 4.5.39", "ctor", "datafusion", "dirs", @@ -1979,9 +1978,9 @@ dependencies = [ [[package]] name = "datafusion-common" -version = "47.0.0" +version = "48.0.1" dependencies = [ - "ahash 0.8.11", + "ahash 0.8.12", "apache-avro", "arrow", "arrow-ipc", @@ -1997,7 +1996,7 @@ dependencies = [ "parquet", "paste", "pyo3", - "rand 0.8.5", + "rand 0.9.1", "recursive", "sqlparser", "tokio", @@ -2006,7 +2005,7 @@ dependencies = [ [[package]] name = "datafusion-common-runtime" -version = "47.0.0" +version = "48.0.1" dependencies = [ "futures", "log", @@ -2015,7 +2014,7 @@ dependencies = [ [[package]] name = "datafusion-datasource" -version = "47.0.0" +version = "48.0.1" dependencies = [ "arrow", "async-compression", @@ -2039,7 +2038,7 @@ dependencies = [ "log", "object_store", "parquet", - "rand 0.8.5", + "rand 0.9.1", "tempfile", "tokio", "tokio-util", @@ -2050,7 +2049,7 @@ dependencies = [ [[package]] name = "datafusion-datasource-avro" -version = "47.0.0" +version = "48.0.1" dependencies = [ "apache-avro", "arrow", @@ -2075,7 +2074,7 @@ dependencies = [ [[package]] name = "datafusion-datasource-csv" -version = "47.0.0" +version = "48.0.1" dependencies = [ "arrow", "async-trait", @@ -2098,7 +2097,7 @@ dependencies = [ [[package]] name = "datafusion-datasource-json" -version = "47.0.0" +version = "48.0.1" dependencies = [ "arrow", "async-trait", @@ -2121,7 +2120,7 @@ dependencies = [ [[package]] name = "datafusion-datasource-parquet" -version = "47.0.0" +version = "48.0.1" dependencies = [ "arrow", "async-trait", @@ -2145,17 +2144,17 @@ dependencies = [ "object_store", "parking_lot", "parquet", - "rand 0.8.5", + "rand 0.9.1", "tokio", ] [[package]] name = "datafusion-doc" -version = "47.0.0" +version = "48.0.1" [[package]] name = "datafusion-examples" -version = "47.0.0" +version = "48.0.1" dependencies = [ "arrow", "arrow-flight", @@ -2185,7 +2184,7 @@ dependencies = [ [[package]] name = "datafusion-execution" -version = "47.0.0" +version = "48.0.1" dependencies = [ "arrow", "chrono", @@ -2193,17 +2192,18 @@ dependencies = [ "datafusion-common", "datafusion-expr", "futures", + "insta", "log", "object_store", "parking_lot", - "rand 0.8.5", + "rand 0.9.1", "tempfile", "url", ] [[package]] name = "datafusion-expr" -version = "47.0.0" +version = "48.0.1" dependencies = [ "arrow", "chrono", @@ -2216,6 +2216,7 @@ dependencies = [ "datafusion-physical-expr-common", "env_logger", "indexmap 2.9.0", + "insta", "paste", "recursive", "serde_json", @@ -2224,7 +2225,7 @@ dependencies = [ [[package]] name = "datafusion-expr-common" -version = "47.0.0" +version = "48.0.1" dependencies = [ "arrow", "datafusion-common", @@ -2235,7 +2236,7 @@ dependencies = [ [[package]] name = "datafusion-ffi" -version = "47.0.0" +version = "48.0.1" dependencies = [ "abi_stable", "arrow", @@ -2243,7 +2244,9 @@ dependencies = [ "async-ffi", "async-trait", "datafusion", + "datafusion-functions-aggregate-common", "datafusion-proto", + "datafusion-proto-common", "doc-comment", "futures", "log", @@ -2254,7 +2257,7 @@ dependencies = [ [[package]] name = "datafusion-functions" -version = "47.0.0" +version = "48.0.1" dependencies = [ "arrow", "arrow-buffer", @@ -2273,7 +2276,7 @@ dependencies = [ "itertools 0.14.0", "log", "md-5", - "rand 0.8.5", + "rand 0.9.1", "regex", "sha2", "tokio", @@ -2283,9 +2286,9 @@ dependencies = [ [[package]] name = "datafusion-functions-aggregate" -version = "47.0.0" +version = "48.0.1" dependencies = [ - "ahash 0.8.11", + "ahash 0.8.12", "arrow", "criterion", "datafusion-common", @@ -2299,25 +2302,25 @@ dependencies = [ "half", "log", "paste", - "rand 0.8.5", + "rand 0.9.1", ] [[package]] name = "datafusion-functions-aggregate-common" -version = "47.0.0" +version = "48.0.1" dependencies = [ - "ahash 0.8.11", + "ahash 0.8.12", "arrow", "criterion", "datafusion-common", "datafusion-expr-common", "datafusion-physical-expr-common", - "rand 0.8.5", + "rand 0.9.1", ] [[package]] name = "datafusion-functions-nested" -version = "47.0.0" +version = "48.0.1" dependencies = [ "arrow", "arrow-ord", @@ -2333,12 +2336,12 @@ dependencies = [ "itertools 0.14.0", "log", "paste", - "rand 0.8.5", + "rand 0.9.1", ] [[package]] name = "datafusion-functions-table" -version = "47.0.0" +version = "48.0.1" dependencies = [ "arrow", "async-trait", @@ -2352,7 +2355,7 @@ dependencies = [ [[package]] name = "datafusion-functions-window" -version = "47.0.0" +version = "48.0.1" dependencies = [ "arrow", "datafusion-common", @@ -2368,7 +2371,7 @@ dependencies = [ [[package]] name = "datafusion-functions-window-common" -version = "47.0.0" +version = "48.0.1" dependencies = [ "datafusion-common", "datafusion-physical-expr-common", @@ -2376,20 +2379,21 @@ dependencies = [ [[package]] name = "datafusion-macros" -version = "47.0.0" +version = "48.0.1" dependencies = [ "datafusion-expr", "quote", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] name = "datafusion-optimizer" -version = "47.0.0" +version = "48.0.1" dependencies = [ "arrow", "async-trait", "chrono", + "criterion", "ctor", "datafusion-common", "datafusion-expr", @@ -2410,9 +2414,9 @@ dependencies = [ [[package]] name = "datafusion-physical-expr" -version = "47.0.0" +version = "48.0.1" dependencies = [ - "ahash 0.8.11", + "ahash 0.8.12", "arrow", "criterion", "datafusion-common", @@ -2428,16 +2432,16 @@ dependencies = [ "itertools 0.14.0", "log", "paste", - "petgraph", - "rand 0.8.5", + "petgraph 0.8.1", + "rand 0.9.1", "rstest", ] [[package]] name = "datafusion-physical-expr-common" -version = "47.0.0" +version = "48.0.1" dependencies = [ - "ahash 0.8.11", + "ahash 0.8.12", "arrow", "datafusion-common", "datafusion-expr-common", @@ -2447,7 +2451,7 @@ dependencies = [ [[package]] name = "datafusion-physical-optimizer" -version = "47.0.0" +version = "48.0.1" dependencies = [ "arrow", "datafusion-common", @@ -2466,9 +2470,9 @@ dependencies = [ [[package]] name = "datafusion-physical-plan" -version = "47.0.0" +version = "48.0.1" dependencies = [ - "ahash 0.8.11", + "ahash 0.8.12", "arrow", "arrow-ord", "arrow-schema", @@ -2493,7 +2497,7 @@ dependencies = [ "log", "parking_lot", "pin-project-lite", - "rand 0.8.5", + "rand 0.9.1", "rstest", "rstest_reuse", "tempfile", @@ -2502,7 +2506,7 @@ dependencies = [ [[package]] name = "datafusion-proto" -version = "47.0.0" +version = "48.0.1" dependencies = [ "arrow", "chrono", @@ -2519,13 +2523,12 @@ dependencies = [ "prost", "serde", "serde_json", - "strum 0.27.1", "tokio", ] [[package]] name = "datafusion-proto-common" -version = "47.0.0" +version = "48.0.1" dependencies = [ "arrow", "datafusion-common", @@ -2538,7 +2541,7 @@ dependencies = [ [[package]] name = "datafusion-session" -version = "47.0.0" +version = "48.0.1" dependencies = [ "arrow", "async-trait", @@ -2558,9 +2561,23 @@ dependencies = [ "tokio", ] +[[package]] +name = "datafusion-spark" +version = "48.0.1" +dependencies = [ + "arrow", + "datafusion-catalog", + "datafusion-common", + "datafusion-execution", + "datafusion-expr", + "datafusion-functions", + "datafusion-macros", + "log", +] + [[package]] name = "datafusion-sql" -version = "47.0.0" +version = "48.0.1" dependencies = [ "arrow", "bigdecimal", @@ -2584,15 +2601,17 @@ dependencies = [ [[package]] name = "datafusion-sqllogictest" -version = "47.0.0" +version = "48.0.1" dependencies = [ "arrow", "async-trait", "bigdecimal", "bytes", "chrono", - "clap 4.5.35", + "clap 4.5.39", "datafusion", + "datafusion-spark", + "datafusion-substrait", "env_logger", "futures", "half", @@ -2615,7 +2634,7 @@ dependencies = [ [[package]] name = "datafusion-substrait" -version = "47.0.0" +version = "48.0.1" dependencies = [ "async-recursion", "async-trait", @@ -2635,7 +2654,7 @@ dependencies = [ [[package]] name = "datafusion-wasmtest" -version = "47.0.0" +version = "48.0.1" dependencies = [ "chrono", "console_error_panic_hook", @@ -2646,7 +2665,7 @@ dependencies = [ "datafusion-optimizer", "datafusion-physical-plan", "datafusion-sql", - "getrandom 0.2.15", + "getrandom 0.3.3", "insta", "object_store", "tokio", @@ -2657,9 +2676,9 @@ dependencies = [ [[package]] name = "deranged" -version = "0.3.11" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b42b6fa04a440b495c8b04d0e71b707c585f83cb9cb28cf8cd0d976c315e31b4" +checksum = "9c9e6a11ca8224451684bc0d7d5a7adbf8f2fd6887261a1cfc3c0432f9d4068e" dependencies = [ "powerfmt", "serde", @@ -2711,7 +2730,7 @@ checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] @@ -2722,15 +2741,30 @@ checksum = "fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10" [[package]] name = "docker_credential" -version = "1.3.1" +version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "31951f49556e34d90ed28342e1df7e1cb7a229c4cab0aecc627b5d91edd41d07" +checksum = "1d89dfcba45b4afad7450a99b39e751590463e45c04728cf555d36bb66940de8" dependencies = [ "base64 0.21.7", "serde", "serde_json", ] +[[package]] +name = "dtor" +version = "0.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97cbdf2ad6846025e8e25df05171abfb30e3ababa12ee0a0e44b9bbe570633a8" +dependencies = [ + "dtor-proc-macro", +] + +[[package]] +name = "dtor-proc-macro" +version = "0.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7454e41ff9012c00d53cf7f475c5e3afa3b91b7c90568495495e8d9bf47a1055" + [[package]] name = "dunce" version = "1.0.5" @@ -2739,9 +2773,9 @@ checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813" [[package]] name = "dyn-clone" -version = "1.0.18" +version = "1.0.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "feeef44e73baff3a26d371801df019877a9866a8c493d315ab00177843314f35" +checksum = "1c7a8fb8a9fbf66c1f703fe16184d10ca0ee9d23be5b4436400408ba54a95005" [[package]] name = "educe" @@ -2752,14 +2786,14 @@ dependencies = [ "enum-ordinalize", "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] name = "either" -version = "1.13.0" +version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" +checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" [[package]] name = "encode_unicode" @@ -2790,7 +2824,7 @@ checksum = "0d28318a75d4aead5c4db25382e8ef717932d0346600cacae6357eb5941bc5ff" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] @@ -2805,9 +2839,9 @@ dependencies = [ [[package]] name = "env_logger" -version = "0.11.7" +version = "0.11.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3716d7a920fb4fac5d84e9d4bce8ceb321e9414b4409da61b07b75c1e3d0697" +checksum = "13c863f0904021b108aa8b2f55046443e6b1ebde8fd4a15c399893aae4fa069f" dependencies = [ "anstream", "anstyle", @@ -2824,9 +2858,9 @@ checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" [[package]] name = "errno" -version = "0.3.10" +version = "0.3.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33d852cb9b869c2a9b3df2f71a3074817f01e1844f839a144f5fcef059a4eb5d" +checksum = "cea14ef9355e3beab063703aa9dab15afd25f0667c341310c1e5274bb1d0da18" dependencies = [ "libc", "windows-sys 0.59.0", @@ -2834,9 +2868,9 @@ dependencies = [ [[package]] name = "error-code" -version = "3.3.1" +version = "3.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a5d9305ccc6942a704f4335694ecd3de2ea531b114ac2d51f5f843750787a92f" +checksum = "dea2df4cf52843e0452895c455a1a2cfbb842a1e7329671acf418fdc53ed4c59" [[package]] name = "escape8259" @@ -2846,13 +2880,13 @@ checksum = "5692dd7b5a1978a5aeb0ce83b7655c58ca8efdcb79d21036ea249da95afec2c6" [[package]] name = "etcetera" -version = "0.8.0" +version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "136d1b5283a1ab77bd9257427ffd09d8667ced0570b6f938942bc7568ed5b943" +checksum = "26c7b13d0780cb82722fd59f6f57f925e143427e4a75313a6c77243bf5326ae6" dependencies = [ "cfg-if", "home", - "windows-sys 0.48.0", + "windows-sys 0.59.0", ] [[package]] @@ -2869,13 +2903,13 @@ checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" [[package]] name = "fd-lock" -version = "4.0.2" +version = "4.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e5768da2206272c81ef0b5e951a41862938a6070da63bcea197899942d3b947" +checksum = "0ce92ff622d6dadf7349484f42c93271a0d49b7cc4d466a936405bacbe10aa78" dependencies = [ "cfg-if", - "rustix 0.38.44", - "windows-sys 0.52.0", + "rustix 1.0.7", + "windows-sys 0.59.0", ] [[package]] @@ -2932,7 +2966,7 @@ version = "25.2.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1045398c1bfd89168b5fd3f1fc11f6e70b34f6f66300c87d44d3de849463abf1" dependencies = [ - "bitflags 2.8.0", + "bitflags 2.9.1", "rustc_version", ] @@ -2964,9 +2998,9 @@ checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" [[package]] name = "foldhash" -version = "0.1.4" +version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a0d2fde1f7b3d48b8395d5f2de76c18a528bd6a9cdde438df747bfcba3e05d6f" +checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" [[package]] name = "form_urlencoded" @@ -3054,7 +3088,7 @@ checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] @@ -3130,9 +3164,9 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.2.15" +version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" +checksum = "335ff9f135e4384c8150d6f27c6daed433577f86b4750418338c01a1a2528592" dependencies = [ "cfg-if", "js-sys", @@ -3143,14 +3177,16 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.3.1" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43a49c392881ce6d5c3b8cb70f98717b7c07aabbdff06687b9030dbfbe2725f8" +checksum = "26145e563e54f2cadc477553f1ec5ee650b00862f0a58bcd12cbdc5f0ea2d2f4" dependencies = [ "cfg-if", + "js-sys", "libc", - "wasi 0.13.3+wasi-0.2.2", - "windows-targets 0.52.6", + "r-efi", + "wasi 0.14.2+wasi-0.2.4", + "wasm-bindgen", ] [[package]] @@ -3180,16 +3216,16 @@ dependencies = [ [[package]] name = "h2" -version = "0.4.8" +version = "0.4.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5017294ff4bb30944501348f6f8e42e6ad28f42c8bbef7a74029aff064a4e3c2" +checksum = "a9421a676d1b147b16b82c9225157dc629087ef8ec4d5e2960f9437a90dac0a5" dependencies = [ "atomic-waker", "bytes", "fnv", "futures-core", "futures-sink", - "http 1.2.0", + "http 1.3.1", "indexmap 2.9.0", "slab", "tokio", @@ -3199,9 +3235,9 @@ dependencies = [ [[package]] name = "half" -version = "2.5.0" +version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7db2ff139bba50379da6aa0766b52fdcb62cb5b263009b09ed58ba604e14bbd1" +checksum = "459196ed295495a68f7d7fe1d84f6c4b7ff0e21fe3017b2f283c6fac3ad803c9" dependencies = [ "cfg-if", "crunchy", @@ -3223,15 +3259,15 @@ version = "0.14.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" dependencies = [ - "ahash 0.8.11", + "ahash 0.8.12", "allocator-api2", ] [[package]] name = "hashbrown" -version = "0.15.2" +version = "0.15.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" +checksum = "84b26c544d002229e640969970a2e74021aadf6e2f96372b9c58eff97de08eb3" dependencies = [ "allocator-api2", "equivalent", @@ -3255,9 +3291,9 @@ checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" [[package]] name = "hermit-abi" -version = "0.4.0" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fbf6a919d6cf397374f7dfeeea91d974c7c0a7221d0d0f4f20d859d329e53fcc" +checksum = "f154ce46856750ed433c8649605bf7ed2de3bc35fd9d2a9f30cddd873c80cb08" [[package]] name = "hex" @@ -3296,9 +3332,9 @@ dependencies = [ [[package]] name = "http" -version = "1.2.0" +version = "1.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f16ca2af56261c99fba8bac40a10251ce8188205a4c448fbb745a2e4daa76fea" +checksum = "f4a85d31aea989eead29a3aaf9e1115a180df8282431156e533de47660892565" dependencies = [ "bytes", "fnv", @@ -3323,27 +3359,27 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" dependencies = [ "bytes", - "http 1.2.0", + "http 1.3.1", ] [[package]] name = "http-body-util" -version = "0.1.2" +version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "793429d76616a256bcb62c2a2ec2bed781c8307e797e2598c50010f2bee2544f" +checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a" dependencies = [ "bytes", - "futures-util", - "http 1.2.0", + "futures-core", + "http 1.3.1", "http-body 1.0.1", "pin-project-lite", ] [[package]] name = "httparse" -version = "1.10.0" +version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2d708df4e7140240a16cd6ab0ab65c972d7433ab77819ea693fde9c43811e2a" +checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" [[package]] name = "httpdate" @@ -3353,9 +3389,9 @@ checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" [[package]] name = "humantime" -version = "2.1.0" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" +checksum = "9b112acc8b3adf4b107a8ec20977da0273a8c386765a3ec0229bd500a1443f9f" [[package]] name = "hyper" @@ -3367,7 +3403,7 @@ dependencies = [ "futures-channel", "futures-util", "h2", - "http 1.2.0", + "http 1.3.1", "http-body 1.0.1", "httparse", "httpdate", @@ -3400,7 +3436,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2d191583f3da1305256f22463b9bb0471acad48a4e534a5218b9963e9c1f59b2" dependencies = [ "futures-util", - "http 1.2.0", + "http 1.3.1", "hyper", "hyper-util", "rustls", @@ -3426,16 +3462,17 @@ dependencies = [ [[package]] name = "hyper-util" -version = "0.1.10" +version = "0.1.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df2dcfbe0677734ab2f3ffa7fa7bfd4706bfdc1ef393f2ee30184aed67e631b4" +checksum = "cf9f1e950e0d9d1d3c47184416723cf29c0d1f93bd8cccf37e4beb6b44f31710" dependencies = [ "bytes", "futures-channel", "futures-util", - "http 1.2.0", + "http 1.3.1", "http-body 1.0.1", "hyper", + "libc", "pin-project-lite", "socket2", "tokio", @@ -3460,16 +3497,17 @@ dependencies = [ [[package]] name = "iana-time-zone" -version = "0.1.61" +version = "0.1.63" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "235e081f3925a06703c2d0117ea8b91f042756fd6e7a6e5d901e8ca1a996b220" +checksum = "b0c919e5debc312ad217002b8048a17b7d83f80703865bbfcfebb0458b0b27d8" dependencies = [ "android_system_properties", "core-foundation-sys", "iana-time-zone-haiku", "js-sys", + "log", "wasm-bindgen", - "windows-core 0.52.0", + "windows-core", ] [[package]] @@ -3483,21 +3521,22 @@ dependencies = [ [[package]] name = "icu_collections" -version = "1.5.0" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db2fa452206ebee18c4b5c2274dbf1de17008e874b4dc4f0aea9d01ca79e4526" +checksum = "200072f5d0e3614556f94a9930d5dc3e0662a652823904c3a75dc3b0af7fee47" dependencies = [ "displaydoc", + "potential_utf", "yoke", "zerofrom", "zerovec", ] [[package]] -name = "icu_locid" -version = "1.5.0" +name = "icu_locale_core" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13acbb8371917fc971be86fc8057c41a64b521c184808a698c02acc242dbf637" +checksum = "0cde2700ccaed3872079a65fb1a78f6c0a36c91570f28755dda67bc8f7d9f00a" dependencies = [ "displaydoc", "litemap", @@ -3506,31 +3545,11 @@ dependencies = [ "zerovec", ] -[[package]] -name = "icu_locid_transform" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01d11ac35de8e40fdeda00d9e1e9d92525f3f9d887cdd7aa81d727596788b54e" -dependencies = [ - "displaydoc", - "icu_locid", - "icu_locid_transform_data", - "icu_provider", - "tinystr", - "zerovec", -] - -[[package]] -name = "icu_locid_transform_data" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fdc8ff3388f852bede6b579ad4e978ab004f139284d7b28715f773507b946f6e" - [[package]] name = "icu_normalizer" -version = "1.5.0" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "19ce3e0da2ec68599d193c93d088142efd7f9c5d6fc9b803774855747dc6a84f" +checksum = "436880e8e18df4d7bbc06d58432329d6458cc84531f7ac5f024e93deadb37979" dependencies = [ "displaydoc", "icu_collections", @@ -3538,67 +3557,54 @@ dependencies = [ "icu_properties", "icu_provider", "smallvec", - "utf16_iter", - "utf8_iter", - "write16", "zerovec", ] [[package]] name = "icu_normalizer_data" -version = "1.5.0" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8cafbf7aa791e9b22bec55a167906f9e1215fd475cd22adfcf660e03e989516" +checksum = "00210d6893afc98edb752b664b8890f0ef174c8adbb8d0be9710fa66fbbf72d3" [[package]] name = "icu_properties" -version = "1.5.1" +version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93d6020766cfc6302c15dbbc9c8778c37e62c14427cb7f6e601d849e092aeef5" +checksum = "016c619c1eeb94efb86809b015c58f479963de65bdb6253345c1a1276f22e32b" dependencies = [ "displaydoc", "icu_collections", - "icu_locid_transform", + "icu_locale_core", "icu_properties_data", "icu_provider", - "tinystr", + "potential_utf", + "zerotrie", "zerovec", ] [[package]] name = "icu_properties_data" -version = "1.5.0" +version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67a8effbc3dd3e4ba1afa8ad918d5684b8868b3b26500753effea8d2eed19569" +checksum = "298459143998310acd25ffe6810ed544932242d3f07083eee1084d83a71bd632" [[package]] name = "icu_provider" -version = "1.5.0" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ed421c8a8ef78d3e2dbc98a973be2f3770cb42b606e3ab18d6237c4dfde68d9" +checksum = "03c80da27b5f4187909049ee2d72f276f0d9f99a42c306bd0131ecfe04d8e5af" dependencies = [ "displaydoc", - "icu_locid", - "icu_provider_macros", + "icu_locale_core", "stable_deref_trait", "tinystr", "writeable", "yoke", "zerofrom", + "zerotrie", "zerovec", ] -[[package]] -name = "icu_provider_macros" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.100", -] - [[package]] name = "ident_case" version = "1.0.1" @@ -3618,9 +3624,9 @@ dependencies = [ [[package]] name = "idna_adapter" -version = "1.2.0" +version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "daca1df1c957320b2cf139ac61e7bd64fed304c5040df000a745aa1de3b4ef71" +checksum = "3acae9609540aa318d1bc588455225fb2085b9ed0c4f6bd0d9d5bcd86f1a0344" dependencies = [ "icu_normalizer", "icu_properties", @@ -3644,7 +3650,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cea70ddb795996207ad57735b50c5982d8844f38ba9ee5f1aedcfb708a2aa11e" dependencies = [ "equivalent", - "hashbrown 0.15.2", + "hashbrown 0.15.3", "serde", ] @@ -3663,21 +3669,19 @@ dependencies = [ [[package]] name = "indoc" -version = "2.0.5" +version = "2.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b248f5224d1d606005e02c97f5aa4e88eeb230488bcc03bc9ca4d7991399f2b5" +checksum = "f4c7245a08504955605670dbf141fceab975f15ca21570696aebe9d2e71576bd" [[package]] name = "insta" -version = "1.42.2" +version = "1.43.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50259abbaa67d11d2bcafc7ba1d094ed7a0c70e3ce893f0d0997f73558cb3084" +checksum = "154934ea70c58054b556dd430b99a98c2a7ff5309ac9891597e339b5c28f4371" dependencies = [ "console", "globset", - "linked-hash-map", "once_cell", - "pin-project", "regex", "serde", "similar", @@ -3709,9 +3713,9 @@ checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130" [[package]] name = "is-terminal" -version = "0.4.15" +version = "0.4.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e19b23d53f35ce9f56aebc7d1bb4e6ac1e9c0db7ac85c8d1760c04379edced37" +checksum = "e04d7f318608d35d4b61ddd75cbdaee86b023ebe2bd5a66ee0915f0bf93095a9" dependencies = [ "hermit-abi", "libc", @@ -3733,6 +3737,15 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" +dependencies = [ + "either", +] + [[package]] name = "itertools" version = "0.13.0" @@ -3753,15 +3766,15 @@ dependencies = [ [[package]] name = "itoa" -version = "1.0.14" +version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d75a2a4b1b190afb6f5425f10f6a8f959d2ea0b9c2b1d79553551850539e4674" +checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" [[package]] name = "jiff" -version = "0.2.4" +version = "0.2.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d699bc6dfc879fb1bf9bdff0d4c56f0884fc6f0d0eb0fba397a6d00cd9a6b85e" +checksum = "a194df1107f33c79f4f93d02c80798520551949d59dfad22b6157048a88cca93" dependencies = [ "jiff-static", "log", @@ -3772,21 +3785,22 @@ dependencies = [ [[package]] name = "jiff-static" -version = "0.2.4" +version = "0.2.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d16e75759ee0aa64c57a56acbf43916987b20c77373cb7e808979e02b93c9f9" +checksum = "6c6e1db7ed32c6c71b759497fae34bf7933636f75a251b9e736555da426f6442" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] name = "jobserver" -version = "0.1.32" +version = "0.1.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48d1dbcbbeb6a7fec7e059840aa538bd62aaccf972c7346c4d9d2059312853d0" +checksum = "38f262f097c174adebe41eb73d66ae9c06b2844fb0da69969647bbddd9b0538a" dependencies = [ + "getrandom 0.3.3", "libc", ] @@ -3878,9 +3892,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.171" +version = "0.2.172" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c19937216e9d3aa9956d9bb8dfc0b0c8beb6058fc4f7a4dc4d850edf86a237d6" +checksum = "d750af042f7ef4f724306de029d18836c26c1765a54a6a3f094cbd23a7267ffa" [[package]] name = "libflate" @@ -3918,19 +3932,19 @@ dependencies = [ [[package]] name = "libloading" -version = "0.8.6" +version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34" +checksum = "6a793df0d7afeac54f95b471d3af7f0d4fb975699f972341a4b76988d49cdf0c" dependencies = [ "cfg-if", - "windows-targets 0.52.6", + "windows-targets 0.53.0", ] [[package]] name = "libm" -version = "0.2.11" +version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8355be11b20d696c8f18f6cc018c4e372165b1fa8126cef092399c9951984ffa" +checksum = "f9fbbcab51052fe104eb5e5d351cf728d30a5be1fe14d9be8a3b097481fb97de" [[package]] name = "libmimalloc-sys" @@ -3948,9 +3962,9 @@ version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c0ff37bd590ca25063e35af745c343cb7a0271906fb7b37e4813e8f79f00268d" dependencies = [ - "bitflags 2.8.0", + "bitflags 2.9.1", "libc", - "redox_syscall 0.5.8", + "redox_syscall 0.5.12", ] [[package]] @@ -3961,7 +3975,7 @@ checksum = "5297962ef19edda4ce33aaa484386e0a5b3d7f2f4e037cbeee00503ef6b29d33" dependencies = [ "anstream", "anstyle", - "clap 4.5.35", + "clap 4.5.39", "escape8259", ] @@ -3974,12 +3988,6 @@ dependencies = [ "zlib-rs", ] -[[package]] -name = "linked-hash-map" -version = "0.5.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f" - [[package]] name = "linux-raw-sys" version = "0.4.15" @@ -3988,15 +3996,15 @@ checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab" [[package]] name = "linux-raw-sys" -version = "0.9.2" +version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6db9c683daf087dc577b7506e9695b3d556a9f3849903fa28186283afd6809e9" +checksum = "cd945864f07fe9f5371a27ad7b52a172b4b499999f1d97574c9fa68373937e12" [[package]] name = "litemap" -version = "0.7.4" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ee93343901ab17bd981295f2cf0026d4ad018c7c31ba84549a4ddbb47a45104" +checksum = "241eaef5fd12c88705a01fc1066c48c4b36e0dd4377dcdc7ec3942cea7a69956" [[package]] name = "lock_api" @@ -4014,6 +4022,12 @@ version = "0.4.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "13dc2df351e3202783a1fe0d44375f7295ffb4049267b0f3018346dc122a1d94" +[[package]] +name = "lru-slab" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154" + [[package]] name = "lz4_flex" version = "0.11.3" @@ -4118,9 +4132,9 @@ dependencies = [ [[package]] name = "multimap" -version = "0.10.0" +version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "defc4c55412d89136f966bbb339008b474350e5e6e78d2714439c386b3137a03" +checksum = "1d87ecb2933e8aeadb3e3a02b828fed80a7528047e68b4f424523a0981a3a084" [[package]] name = "nibble_vec" @@ -4133,11 +4147,11 @@ dependencies = [ [[package]] name = "nix" -version = "0.29.0" +version = "0.30.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "71e2746dc3a24dd78b3cfcb7be93368c6de9963d30f43a6a73998a9cf4b17b46" +checksum = "74523f3a35e05aba87a1d978330aef40f67b0304ac79c1c00b294c9830543db6" dependencies = [ - "bitflags 2.8.0", + "bitflags 2.9.1", "cfg-if", "cfg_aliases", "libc", @@ -4267,11 +4281,21 @@ checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" [[package]] name = "objc2-core-foundation" -version = "0.3.0" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "daeaf60f25471d26948a1c2f840e3f7d86f4109e3af4e8e4b5cd70c39690d925" +checksum = "1c10c2894a6fed806ade6027bcd50662746363a9589d3ec9d9bef30a4e4bc166" dependencies = [ - "bitflags 2.8.0", + "bitflags 2.9.1", +] + +[[package]] +name = "objc2-io-kit" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "71c1c64d6120e51cd86033f67176b1cb66780c2efe34dec55176f77befd93c0a" +dependencies = [ + "libc", + "objc2-core-foundation", ] [[package]] @@ -4285,9 +4309,9 @@ dependencies = [ [[package]] name = "object_store" -version = "0.12.0" +version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e9ce831b09395f933addbc56d894d889e4b226eba304d4e7adbab591e26daf1e" +checksum = "d94ac16b433c0ccf75326388c893d2835ab7457ea35ab8ba5d745c053ef5fa16" dependencies = [ "async-trait", "base64 0.22.1", @@ -4295,7 +4319,7 @@ dependencies = [ "chrono", "form_urlencoded", "futures", - "http 1.2.0", + "http 1.3.1", "http-body-util", "humantime", "hyper", @@ -4304,7 +4328,7 @@ dependencies = [ "parking_lot", "percent-encoding", "quick-xml", - "rand 0.8.5", + "rand 0.9.1", "reqwest", "ring", "rustls-pemfile", @@ -4316,19 +4340,21 @@ dependencies = [ "tracing", "url", "walkdir", + "wasm-bindgen-futures", + "web-time", ] [[package]] name = "once_cell" -version = "1.20.3" +version = "1.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "945462a4b81e43c4e3ba96bd7b49d834c6f61198356aa858733bc4acf3cbe62e" +checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" [[package]] name = "oorandom" -version = "11.1.4" +version = "11.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b410bbe7e14ab526a0e86877eb47c6996a2bd7746f027ba551028c925390e4e9" +checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" [[package]] name = "openssl-probe" @@ -4365,9 +4391,9 @@ checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" [[package]] name = "owo-colors" -version = "4.1.0" +version = "4.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fb37767f6569cd834a413442455e0f066d0d522de8630436e2a1761d9726ba56" +checksum = "26995317201fa17f3656c36716aed4a7c81743a9634ac4c99c0eeda495db0cec" [[package]] name = "parking_lot" @@ -4387,18 +4413,18 @@ checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8" dependencies = [ "cfg-if", "libc", - "redox_syscall 0.5.8", + "redox_syscall 0.5.12", "smallvec", "windows-targets 0.52.6", ] [[package]] name = "parquet" -version = "55.0.0" +version = "55.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cd31a8290ac5b19f09ad77ee7a1e6a541f1be7674ad410547d5f1eef6eef4a9c" +checksum = "be7b2d778f6b841d37083ebdf32e33a524acde1266b5884a8ca29bf00dfa1231" dependencies = [ - "ahash 0.8.11", + "ahash 0.8.12", "arrow-array", "arrow-buffer", "arrow-cast", @@ -4413,7 +4439,7 @@ dependencies = [ "flate2", "futures", "half", - "hashbrown 0.15.2", + "hashbrown 0.15.3", "lz4_flex", "num", "num-bigint", @@ -4450,7 +4476,7 @@ dependencies = [ "regex", "regex-syntax", "structmeta", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] @@ -4521,6 +4547,18 @@ dependencies = [ "indexmap 2.9.0", ] +[[package]] +name = "petgraph" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a98c6720655620a521dcc722d0ad66cd8afd5d86e34a89ef691c50b7b24de06" +dependencies = [ + "fixedbitset", + "hashbrown 0.15.3", + "indexmap 2.9.0", + "serde", +] + [[package]] name = "phf" version = "0.11.3" @@ -4561,22 +4599,22 @@ dependencies = [ [[package]] name = "pin-project" -version = "1.1.9" +version = "1.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dfe2e71e1471fe07709406bf725f710b02927c9c54b2b5b2ec0e8087d97c327d" +checksum = "677f1add503faace112b9f1373e43e9e054bfdd22ff1a63c1bc485eaec6a6a8a" dependencies = [ "pin-project-internal", ] [[package]] name = "pin-project-internal" -version = "1.1.9" +version = "1.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f6e859e6e5bd50440ab63c47e3ebabc90f26251f7c73c3d3e837b74a1cc3fa67" +checksum = "6e918e4ff8c4549eb882f14b3a4bc8c8bc93de829416eacf579f1207a8fbf861" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] @@ -4593,9 +4631,9 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" [[package]] name = "pkg-config" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "953ec861398dccce10c670dfeaf3ec4911ca479e9c02154b3a215178c5f566f2" +checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" [[package]] name = "plotters" @@ -4627,9 +4665,9 @@ dependencies = [ [[package]] name = "portable-atomic" -version = "1.10.0" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "280dc24453071f1b63954171985a0b0d30058d287960968b9b2aca264c8d4ee6" +checksum = "350e9b48cbc6b0e028b0473b114454c6316e57336ee184ceab6e53f72c178b3e" [[package]] name = "portable-atomic-util" @@ -4649,7 +4687,7 @@ dependencies = [ "heck 0.5.0", "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] @@ -4665,7 +4703,7 @@ dependencies = [ "hmac", "md-5", "memchr", - "rand 0.9.0", + "rand 0.9.1", "sha2", "stringprep", ] @@ -4683,6 +4721,15 @@ dependencies = [ "postgres-protocol", ] +[[package]] +name = "potential_utf" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5a7c30837279ca13e7c867e9e40053bc68740f988cb07f7ca6df43cc734b585" +dependencies = [ + "zerovec", +] + [[package]] name = "powerfmt" version = "0.2.0" @@ -4691,11 +4738,11 @@ checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" [[package]] name = "ppv-lite86" -version = "0.2.20" +version = "0.2.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04" +checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9" dependencies = [ - "zerocopy 0.7.35", + "zerocopy", ] [[package]] @@ -4730,19 +4777,19 @@ dependencies = [ [[package]] name = "prettyplease" -version = "0.2.31" +version = "0.2.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5316f57387668042f561aae71480de936257848f9c43ce528e311d89a07cadeb" +checksum = "664ec5419c51e34154eec046ebcba56312d5a2fc3b09a06da188e1ad21afadf6" dependencies = [ "proc-macro2", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] name = "proc-macro-crate" -version = "3.2.0" +version = "3.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ecf48c7ca261d60b74ab1a7b20da18bede46776b2e55535cb958eb595c5fa7b" +checksum = "edce586971a4dfaa28950c6f18ed55e0406c1ab88bbce2c6f6293a7aaba73d35" dependencies = [ "toml_edit", ] @@ -4773,9 +4820,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.93" +version = "1.0.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60946a68e5f9d28b0dc1c21bb8a97ee7d018a8b322fa57838ba31cc878e22d99" +checksum = "02b3e5e68a3a1a02aad3ec490a98007cbc13c37cbe84a3cd7b8e406d76e7f778" dependencies = [ "unicode-ident", ] @@ -4801,12 +4848,12 @@ dependencies = [ "log", "multimap", "once_cell", - "petgraph", + "petgraph 0.7.1", "prettyplease", "prost", "prost-types", "regex", - "syn 2.0.100", + "syn 2.0.101", "tempfile", ] @@ -4820,7 +4867,7 @@ dependencies = [ "itertools 0.14.0", "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] @@ -4843,9 +4890,9 @@ dependencies = [ [[package]] name = "psm" -version = "0.1.25" +version = "0.1.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f58e5423e24c18cc840e1c98370b3993c6649cd1678b4d24318bcf0a083cbe88" +checksum = "6e944464ec8536cd1beb0bbfd96987eb5e3b72f2ecdafdc5c769a37f1fa2ae1f" dependencies = [ "cc", ] @@ -4872,9 +4919,9 @@ dependencies = [ [[package]] name = "pyo3" -version = "0.24.1" +version = "0.24.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17da310086b068fbdcefbba30aeb3721d5bb9af8db4987d6735b2183ca567229" +checksum = "e5203598f366b11a02b13aa20cab591229ff0a89fd121a308a5df751d5fc9219" dependencies = [ "cfg-if", "indoc", @@ -4890,9 +4937,9 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.24.1" +version = "0.24.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e27165889bd793000a098bb966adc4300c312497ea25cf7a690a9f0ac5aa5fc1" +checksum = "99636d423fa2ca130fa5acde3059308006d46f98caac629418e53f7ebb1e9999" dependencies = [ "once_cell", "target-lexicon", @@ -4900,9 +4947,9 @@ dependencies = [ [[package]] name = "pyo3-ffi" -version = "0.24.1" +version = "0.24.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05280526e1dbf6b420062f3ef228b78c0c54ba94e157f5cb724a609d0f2faabc" +checksum = "78f9cf92ba9c409279bc3305b5409d90db2d2c22392d443a87df3a1adad59e33" dependencies = [ "libc", "pyo3-build-config", @@ -4910,27 +4957,27 @@ dependencies = [ [[package]] name = "pyo3-macros" -version = "0.24.1" +version = "0.24.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c3ce5686aa4d3f63359a5100c62a127c9f15e8398e5fdeb5deef1fed5cd5f44" +checksum = "0b999cb1a6ce21f9a6b147dcf1be9ffedf02e0043aec74dc390f3007047cecd9" dependencies = [ "proc-macro2", "pyo3-macros-backend", "quote", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] name = "pyo3-macros-backend" -version = "0.24.1" +version = "0.24.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f4cf6faa0cbfb0ed08e89beb8103ae9724eb4750e3a78084ba4017cbe94f3855" +checksum = "822ece1c7e1012745607d5cf0bcb2874769f0f7cb34c4cde03b9358eb9ef911a" dependencies = [ "heck 0.5.0", "proc-macro2", "pyo3-build-config", "quote", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] @@ -4941,9 +4988,9 @@ checksum = "5a651516ddc9168ebd67b24afd085a718be02f8858fe406591b013d101ce2f40" [[package]] name = "quick-xml" -version = "0.37.2" +version = "0.37.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "165859e9e55f79d67b96c5d96f4e88b6f2695a1972849c15a6a3f5c59fc2c003" +checksum = "331e97a1af0bf59823e6eadffe373d7b27f485be8748f71471c662c1f269b7fb" dependencies = [ "memchr", "serde", @@ -4951,11 +4998,12 @@ dependencies = [ [[package]] name = "quinn" -version = "0.11.6" +version = "0.11.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62e96808277ec6f97351a2380e6c25114bc9e67037775464979f3037c92d05ef" +checksum = "626214629cda6781b6dc1d316ba307189c85ba657213ce642d9c77670f8202c8" dependencies = [ "bytes", + "cfg_aliases", "pin-project-lite", "quinn-proto", "quinn-udp", @@ -4965,17 +5013,19 @@ dependencies = [ "thiserror 2.0.12", "tokio", "tracing", + "web-time", ] [[package]] name = "quinn-proto" -version = "0.11.9" +version = "0.11.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a2fe5ef3495d7d2e377ff17b1a8ce2ee2ec2a18cde8b6ad6619d65d0701c135d" +checksum = "49df843a9161c85bb8aae55f101bc0bac8bcafd637a620d9122fd7e0b2f7422e" dependencies = [ "bytes", - "getrandom 0.2.15", - "rand 0.8.5", + "getrandom 0.3.3", + "lru-slab", + "rand 0.9.1", "ring", "rustc-hash 2.1.1", "rustls", @@ -4989,9 +5039,9 @@ dependencies = [ [[package]] name = "quinn-udp" -version = "0.5.10" +version = "0.5.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e46f3055866785f6b92bc6164b76be02ca8f2eb4b002c0354b28cf4c119e5944" +checksum = "ee4e529991f949c5e25755532370b8af5d114acae52326361d68d47af64aa842" dependencies = [ "cfg_aliases", "libc", @@ -5010,6 +5060,12 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "r-efi" +version = "5.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74765f6d916ee2faa39bc8e68e4f3ed8949b48cccdac59983d287a7cb71ce9c5" + [[package]] name = "radium" version = "0.7.0" @@ -5039,13 +5095,12 @@ dependencies = [ [[package]] name = "rand" -version = "0.9.0" +version = "0.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3779b94aeb87e8bd4e834cee3650289ee9e0d5677f976ecdb6d219e5f4f6cd94" +checksum = "9fbfd9d094a40bf3ae768db9361049ace4c0e04a4fd6b359518bd7b73a73dd97" dependencies = [ "rand_chacha 0.9.0", - "rand_core 0.9.1", - "zerocopy 0.8.18", + "rand_core 0.9.3", ] [[package]] @@ -5065,7 +5120,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" dependencies = [ "ppv-lite86", - "rand_core 0.9.1", + "rand_core 0.9.3", ] [[package]] @@ -5074,27 +5129,26 @@ version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" dependencies = [ - "getrandom 0.2.15", + "getrandom 0.2.16", ] [[package]] name = "rand_core" -version = "0.9.1" +version = "0.9.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a88e0da7a2c97baa202165137c158d0a2e824ac465d13d81046727b34cb247d3" +checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38" dependencies = [ - "getrandom 0.3.1", - "zerocopy 0.8.18", + "getrandom 0.3.3", ] [[package]] name = "rand_distr" -version = "0.4.3" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32cb0b9bc82b0a0876c2dd994a7e7a2683d3e7390ca40e6886785ef0c7e3ee31" +checksum = "6a8615d50dcf34fa31f7ab52692afec947c4dd0ab803cc87cb3b0b4570ff7463" dependencies = [ "num-traits", - "rand 0.8.5", + "rand 0.9.1", ] [[package]] @@ -5134,7 +5188,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "76009fbe0614077fc1a2ce255e3a1881a2e3a3527097d5dc6d8212c585e7e38b" dependencies = [ "quote", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] @@ -5148,11 +5202,11 @@ dependencies = [ [[package]] name = "redox_syscall" -version = "0.5.8" +version = "0.5.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03a862b389f93e68874fbf580b9de08dd02facb9a788ebadaf4a3fd33cf58834" +checksum = "928fca9cf2aa042393a8325b9ead81d2f0df4cb12e1e24cef072922ccd99c5af" dependencies = [ - "bitflags 2.8.0", + "bitflags 2.9.1", ] [[package]] @@ -5161,7 +5215,7 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dd6f9d3d47bdd2ad6945c5015a226ec6155d0bcdfd8f7cd29f86b71f8de99d2b" dependencies = [ - "getrandom 0.2.15", + "getrandom 0.2.16", "libredox", "thiserror 2.0.12", ] @@ -5207,7 +5261,7 @@ version = "0.10.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78ef7fa9ed0256d64a688a3747d0fef7a88851c18a5e1d57f115f38ec2e09366" dependencies = [ - "hashbrown 0.15.2", + "hashbrown 0.15.3", "memchr", ] @@ -5237,16 +5291,16 @@ dependencies = [ [[package]] name = "reqwest" -version = "0.12.12" +version = "0.12.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43e734407157c3c2034e0258f5e4473ddb361b1e85f95a66690d67264d7cd1da" +checksum = "d19c46a6fdd48bc4dab94b6103fccc55d34c67cc0ad04653aad4ea2a07cd7bbb" dependencies = [ "base64 0.22.1", "bytes", "futures-core", "futures-util", "h2", - "http 1.2.0", + "http 1.3.1", "http-body 1.0.1", "http-body-util", "hyper", @@ -5283,13 +5337,13 @@ dependencies = [ [[package]] name = "ring" -version = "0.17.13" +version = "0.17.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "70ac5d832aa16abd7d1def883a8545280c20a60f523a370aa3a9617c2b8550ee" +checksum = "a4689e6c2294d81e88dc6261c768b63bc4fcdb852be6d1352498b114f61383b7" dependencies = [ "cc", "cfg-if", - "getrandom 0.2.15", + "getrandom 0.2.16", "libc", "untrusted", "windows-sys 0.52.0", @@ -5332,9 +5386,9 @@ checksum = "3582f63211428f83597b51b2ddb88e2a91a9d52d12831f9d08f5e624e8977422" [[package]] name = "rstest" -version = "0.24.0" +version = "0.25.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03e905296805ab93e13c1ec3a03f4b6c4f35e9498a3d5fa96dc626d22c03cd89" +checksum = "6fc39292f8613e913f7df8fa892b8944ceb47c247b78e1b1ae2f09e019be789d" dependencies = [ "futures-timer", "futures-util", @@ -5344,9 +5398,9 @@ dependencies = [ [[package]] name = "rstest_macros" -version = "0.24.0" +version = "0.25.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef0053bbffce09062bee4bcc499b0fbe7a57b879f1efe088d6d8d4c7adcdef9b" +checksum = "1f168d99749d307be9de54d23fd226628d99768225ef08f6ffb52e0182a27746" dependencies = [ "cfg-if", "glob", @@ -5356,7 +5410,7 @@ dependencies = [ "regex", "relative-path", "rustc_version", - "syn 2.0.100", + "syn 2.0.101", "unicode-ident", ] @@ -5368,7 +5422,7 @@ checksum = "b3a8fb4672e840a587a66fc577a5491375df51ddb88f2a2c2a792598c326fe14" dependencies = [ "quote", "rand 0.8.5", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] @@ -5421,7 +5475,7 @@ version = "0.38.44" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154" dependencies = [ - "bitflags 2.8.0", + "bitflags 2.9.1", "errno", "libc", "linux-raw-sys 0.4.15", @@ -5430,22 +5484,22 @@ dependencies = [ [[package]] name = "rustix" -version = "1.0.2" +version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7178faa4b75a30e269c71e61c353ce2748cf3d76f0c44c393f4e60abf49b825" +checksum = "c71e83d6afe7ff64890ec6b71d6a69bb8a610ab78ce364b3352876bb4c801266" dependencies = [ - "bitflags 2.8.0", + "bitflags 2.9.1", "errno", "libc", - "linux-raw-sys 0.9.2", + "linux-raw-sys 0.9.4", "windows-sys 0.59.0", ] [[package]] name = "rustls" -version = "0.23.23" +version = "0.23.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "47796c98c480fce5406ef69d1c76378375492c3b0a0de587be0c1d9feb12f395" +checksum = "730944ca083c1c233a75c09f199e973ca499344a2b7ba9e755c457e86fb4a321" dependencies = [ "aws-lc-rs", "once_cell", @@ -5479,18 +5533,19 @@ dependencies = [ [[package]] name = "rustls-pki-types" -version = "1.11.0" +version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "917ce264624a4b4db1c364dcc35bfca9ded014d0a958cd47ad3e960e988ea51c" +checksum = "229a4a4c221013e7e1f1a043678c5cc39fe5171437c88fb47151a21e6f5b5c79" dependencies = [ "web-time", + "zeroize", ] [[package]] name = "rustls-webpki" -version = "0.102.8" +version = "0.103.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64ca1bc8749bd4cf37b5ce386cc146580777b4e8572c7b97baf22c83f444bee9" +checksum = "e4a72fe2bcf7a6ac6fd7d0b9e5cb68aeb7d4c0a0271730218b3e92d43b4eb435" dependencies = [ "aws-lc-rs", "ring", @@ -5500,17 +5555,17 @@ dependencies = [ [[package]] name = "rustversion" -version = "1.0.19" +version = "1.0.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7c45b9784283f1b2e7fb61b42047c2fd678ef0960d4f6f1eba131594cc369d4" +checksum = "eded382c5f5f786b989652c49544c4877d9f015cc22e145a5ea8ea66c2921cd2" [[package]] name = "rustyline" -version = "15.0.0" +version = "16.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2ee1e066dc922e513bda599c6ccb5f3bb2b0ea5870a579448f2622993f0a9a2f" +checksum = "62fd9ca5ebc709e8535e8ef7c658eb51457987e48c98ead2be482172accc408d" dependencies = [ - "bitflags 2.8.0", + "bitflags 2.9.1", "cfg-if", "clipboard-win", "fd-lock", @@ -5528,9 +5583,9 @@ dependencies = [ [[package]] name = "ryu" -version = "1.0.19" +version = "1.0.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ea1a2d0a644769cc99faa24c3ad26b379b786fe7c36fd3c546254801650e6dd" +checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" [[package]] name = "same-file" @@ -5571,7 +5626,7 @@ dependencies = [ "proc-macro2", "quote", "serde_derive_internals", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] @@ -5592,7 +5647,7 @@ version = "3.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "271720403f46ca04f7ba6f55d438f8bd878d6b8ca0a1046e8228c4145bcbb316" dependencies = [ - "bitflags 2.8.0", + "bitflags 2.9.1", "core-foundation", "core-foundation-sys", "libc", @@ -5620,9 +5675,9 @@ dependencies = [ [[package]] name = "seq-macro" -version = "0.3.5" +version = "0.3.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4" +checksum = "1bc711410fbe7399f390ca1c3b60ad0f53f80e95c5eb935e52268a0e2cd49acc" [[package]] name = "serde" @@ -5635,9 +5690,9 @@ dependencies = [ [[package]] name = "serde_bytes" -version = "0.11.15" +version = "0.11.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "387cc504cb06bb40a96c8e04e951fe01854cf6bc921053c954e4a606d9675c6a" +checksum = "8437fd221bde2d4ca316d61b90e337e9e702b3820b87d63caa9ba6c02bd06d96" dependencies = [ "serde", ] @@ -5650,7 +5705,7 @@ checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] @@ -5661,7 +5716,7 @@ checksum = "18d26a20a969b9e3fdf2fc2d9f21eda6c40e2de84c9408bb5d3b05d499aae711" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] @@ -5678,13 +5733,13 @@ dependencies = [ [[package]] name = "serde_repr" -version = "0.1.19" +version = "0.1.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c64451ba24fc7a6a2d60fc75dd9c83c90903b19028d4eff35e88fc1e86564e9" +checksum = "175ee3e80ae9982737ca543e96133087cbd9a485eecc3bc4de9c1a37b47ea59c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] @@ -5696,7 +5751,7 @@ dependencies = [ "proc-macro2", "quote", "serde", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] @@ -5738,7 +5793,7 @@ dependencies = [ "darling", "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] @@ -5756,9 +5811,9 @@ dependencies = [ [[package]] name = "sha2" -version = "0.10.8" +version = "0.10.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8" +checksum = "a7507d819769d01a365ab707794a4084392c824f54a7a6a7862f8c3d0892b283" dependencies = [ "cfg-if", "cpufeatures", @@ -5782,9 +5837,9 @@ checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" [[package]] name = "signal-hook-registry" -version = "1.4.2" +version = "1.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a9e9e0b4211b72e7b8b6e85c807d36c212bdb33ea8587f7569562a84df5465b1" +checksum = "9203b8055f63a2a00e2f593bb0510367fe707d7ff1e5c872de2f537b339e5410" dependencies = [ "libc", ] @@ -5818,9 +5873,9 @@ dependencies = [ [[package]] name = "smallvec" -version = "1.14.0" +version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7fcf8323ef1faaee30a44a340193b1ac6814fd9b7b4e88e9d4519a3e4abe1cfd" +checksum = "8917285742e9f3e1683f0a9c4e6b57960b7314d0b08d30d1ecd426713ee2eee9" [[package]] name = "snap" @@ -5848,9 +5903,9 @@ dependencies = [ [[package]] name = "socket2" -version = "0.5.8" +version = "0.5.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c970269d99b64e60ec3bd6ad27270092a5394c4e309314b18ae3fe575695fbe8" +checksum = "4f5fd57c80058a56cf5c777ab8a126398ece8e442983605d280a44ce79d0edef" dependencies = [ "libc", "windows-sys 0.52.0", @@ -5858,9 +5913,9 @@ dependencies = [ [[package]] name = "sqllogictest" -version = "0.28.0" +version = "0.28.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17b2f0b80fc250ed3fdd82fc88c0ada5ad62ee1ed5314ac5474acfa52082f518" +checksum = "94181af64007792bd1ab6d22023fbe86c2ccc50c1031b5bac554b5d057597e7b" dependencies = [ "async-trait", "educe", @@ -5900,7 +5955,7 @@ checksum = "da5fc6819faabb412da764b99d3b713bb55083c11e7e0c00144d386cd6a1939c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] @@ -5911,9 +5966,9 @@ checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" [[package]] name = "stacker" -version = "0.1.18" +version = "0.1.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d08feb8f695b465baed819b03c128dc23f57a694510ab1f06c77f763975685e" +checksum = "cddb07e32ddb770749da91081d8d0ac3a16f1a569a18b20348cd371f5dead06b" dependencies = [ "cc", "cfg-if", @@ -5954,7 +6009,7 @@ dependencies = [ "proc-macro2", "quote", "structmeta-derive", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] @@ -5965,7 +6020,7 @@ checksum = "152a0b65a590ff6c3da95cabe2353ee04e6167c896b28e3b14478c2636c922fc" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] @@ -5998,15 +6053,6 @@ version = "0.26.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8fec0f0aef304996cf250b31b5a10dee7980c85da9d759361292b8bca5a18f06" -[[package]] -name = "strum" -version = "0.27.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f64def088c51c9510a8579e3c5d67c65349dcf755e5479ad3d010aa6454e2c32" -dependencies = [ - "strum_macros 0.27.1", -] - [[package]] name = "strum_macros" version = "0.26.4" @@ -6017,27 +6063,14 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.100", -] - -[[package]] -name = "strum_macros" -version = "0.27.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c77a8c5abcaf0f9ce05d62342b7d298c346515365c36b673df4ebe3ced01fde8" -dependencies = [ - "heck 0.5.0", - "proc-macro2", - "quote", - "rustversion", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] name = "subst" -version = "0.3.7" +version = "0.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33e7942675ea19db01ef8cf15a1e6443007208e6c74568bd64162da26d40160d" +checksum = "0a9a86e5144f63c2d18334698269a8bfae6eece345c70b64821ea5b35054ec99" dependencies = [ "memchr", "unicode-width 0.1.14", @@ -6045,9 +6078,9 @@ dependencies = [ [[package]] name = "substrait" -version = "0.55.0" +version = "0.56.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3a359aeb711c1e1944c0c4178bbb2d679d39237ac5bfe28f7e0506e522e5ce6" +checksum = "13de2e20128f2a018dab1cfa30be83ae069219a65968c6f89df66ad124de2397" dependencies = [ "heck 0.5.0", "pbjson", @@ -6064,7 +6097,7 @@ dependencies = [ "serde", "serde_json", "serde_yaml", - "syn 2.0.100", + "syn 2.0.101", "typify", "walkdir", ] @@ -6088,9 +6121,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.100" +version = "2.0.101" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b09a44accad81e1ba1cd74a32461ba89dee89095ba17b32f5d03683b1b1fc2a0" +checksum = "8ce2b7fc941b3a24138a0a7cf8e858bfc6a992e7978a068a5c760deb0ed43caf" dependencies = [ "proc-macro2", "quote", @@ -6108,25 +6141,26 @@ dependencies = [ [[package]] name = "synstructure" -version = "0.13.1" +version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971" +checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] name = "sysinfo" -version = "0.34.2" +version = "0.35.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4b93974b3d3aeaa036504b8eefd4c039dced109171c1ae973f1dc63b2c7e4b2" +checksum = "3c3ffa3e4ff2b324a57f7aeb3c349656c7b127c3c189520251a648102a92496e" dependencies = [ "libc", "memchr", "ntapi", "objc2-core-foundation", + "objc2-io-kit", "windows", ] @@ -6144,14 +6178,14 @@ checksum = "e502f78cdbb8ba4718f566c418c52bc729126ffd16baee5baa718cf25dd5a69a" [[package]] name = "tempfile" -version = "3.19.1" +version = "3.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7437ac7763b9b123ccf33c338a5cc1bac6f69b45a136c19bdd8a65e3916435bf" +checksum = "e8a64e3985349f2441a1a9ef0b853f869006c3855f2cda6862a94d26ebb9d6a1" dependencies = [ "fastrand", - "getrandom 0.3.1", + "getrandom 0.3.3", "once_cell", - "rustix 1.0.2", + "rustix 1.0.7", "windows-sys 0.59.0", ] @@ -6169,14 +6203,14 @@ dependencies = [ "chrono-tz", "datafusion-common", "env_logger", - "rand 0.8.5", + "rand 0.9.1", ] [[package]] name = "testcontainers" -version = "0.23.3" +version = "0.24.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "59a4f01f39bb10fc2a5ab23eb0d888b1e2bb168c157f61a1b98e6c501c639c74" +checksum = "23bb7577dca13ad86a78e8271ef5d322f37229ec83b8d98da6d996c588a1ddb1" dependencies = [ "async-trait", "bollard", @@ -6203,9 +6237,9 @@ dependencies = [ [[package]] name = "testcontainers-modules" -version = "0.11.6" +version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4d43ed4e8f58424c3a2c6c56dbea6643c3c23e8666a34df13c54f0a184e6c707" +checksum = "eac95cde96549fc19c6bf19ef34cc42bd56e264c1cb97e700e21555be0ecf9e2" dependencies = [ "testcontainers", ] @@ -6245,7 +6279,7 @@ checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] @@ -6256,7 +6290,7 @@ checksum = "7f7cf42b4507d8ea322120659672cf1b9dbb93f8f2d4ecfd6e51350ff5b17a1d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] @@ -6282,9 +6316,9 @@ dependencies = [ [[package]] name = "time" -version = "0.3.37" +version = "0.3.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "35e7868883861bd0e56d9ac6efcaaca0d6d5d82a2a7ec8209ff492c07cf37b21" +checksum = "8a7619e19bc266e0f9c5e6686659d394bc57973859340060a69221e57dbc0c40" dependencies = [ "deranged", "itoa", @@ -6297,15 +6331,15 @@ dependencies = [ [[package]] name = "time-core" -version = "0.1.2" +version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3" +checksum = "c9e9a38711f559d9e3ce1cdb06dd7c5b8ea546bc90052da6d06bb76da74bb07c" [[package]] name = "time-macros" -version = "0.2.19" +version = "0.2.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2834e6017e3e5e4b9834939793b282bc03b37a3336245fa820e35e233e2a85de" +checksum = "3526739392ec93fd8b359c8e98514cb3e8e021beb4e5f597b00a0221f8ed8a49" dependencies = [ "num-conv", "time-core", @@ -6322,9 +6356,9 @@ dependencies = [ [[package]] name = "tinystr" -version = "0.7.6" +version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9117f5d4db391c1cf6927e7bea3db74b9a1c1add8f7eda9ffd5364f40f57b82f" +checksum = "5d4f6d1145dcb577acf783d4e601bc1d76a13337bb54e6233add580b07344c8b" dependencies = [ "displaydoc", "zerovec", @@ -6342,9 +6376,9 @@ dependencies = [ [[package]] name = "tinyvec" -version = "1.8.1" +version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "022db8904dfa342efe721985167e9fcd16c29b226db4397ed752a761cfce81e8" +checksum = "09b3661f17e86524eccd4371ab0429194e0d7c008abb45f7a7495b1719463c71" dependencies = [ "tinyvec_macros", ] @@ -6357,9 +6391,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.44.1" +version = "1.45.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f382da615b842244d4b8738c82ed1275e6c5dd90c459a30941cd07080b06c91a" +checksum = "75ef51a33ef1da925cea3e4eb122833cb377c61439ca401b770f54902b806779" dependencies = [ "backtrace", "bytes", @@ -6381,7 +6415,7 @@ checksum = "6e06d43f1345a3bcd39f6a56dbb7dcab2ba47e68e8ac134855e7e2bdbaf8cab8" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] @@ -6403,7 +6437,7 @@ dependencies = [ "pin-project-lite", "postgres-protocol", "postgres-types", - "rand 0.9.0", + "rand 0.9.1", "socket2", "tokio", "tokio-util", @@ -6412,9 +6446,9 @@ dependencies = [ [[package]] name = "tokio-rustls" -version = "0.26.1" +version = "0.26.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f6d0975eaace0cf0fcadee4e4aaa5da15b5c079146f2cffb67c113be122bf37" +checksum = "8e727b36a1a0e8b74c376ac2211e40c2c8af09fb4013c60d910495810f008e9b" dependencies = [ "rustls", "tokio", @@ -6448,9 +6482,9 @@ dependencies = [ [[package]] name = "tokio-util" -version = "0.7.14" +version = "0.7.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6b9590b93e6fcc1739458317cccd391ad3955e2bde8913edf6f95f9e65a8f034" +checksum = "66a539a9ad6d5d281510d5bd368c973d636c02dbf8a67300bfb6b950696ad7df" dependencies = [ "bytes", "futures-core", @@ -6461,15 +6495,15 @@ dependencies = [ [[package]] name = "toml_datetime" -version = "0.6.8" +version = "0.6.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0dd7358ecb8fc2f8d014bf86f6f638ce72ba252a2c3a2572f2a795f1d23efb41" +checksum = "3da5db5a963e24bc68be8b17b6fa82814bb22ee8660f192bb182771d498f09a3" [[package]] name = "toml_edit" -version = "0.22.24" +version = "0.22.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17b4795ff5edd201c7cd6dca065ae59972ce77d1b80fa0a84d94950ece7d1474" +checksum = "310068873db2c5b3e7659d2cc35d21855dbafa50d1ce336397c666e3cb08137e" dependencies = [ "indexmap 2.9.0", "toml_datetime", @@ -6488,7 +6522,7 @@ dependencies = [ "base64 0.22.1", "bytes", "h2", - "http 1.2.0", + "http 1.3.1", "http-body 1.0.1", "http-body-util", "hyper", @@ -6572,7 +6606,7 @@ checksum = "395ae124c09f9e6918a2310af6038fba074bcf474ac352496d5910dd59a2226d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] @@ -6670,7 +6704,7 @@ checksum = "f9534daa9fd3ed0bd911d462a37f172228077e7abf18c18a5f67199d959205f8" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] @@ -6681,9 +6715,9 @@ checksum = "1dccffe3ce07af9386bfd29e80c0ab1a8205a2fc34e4bcd40364df902cfa8f3f" [[package]] name = "typify" -version = "0.3.0" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e03ba3643450cfd95a1aca2e1938fef63c1c1994489337998aff4ad771f21ef8" +checksum = "fcc5bec3cdff70fd542e579aa2e52967833e543a25fae0d14579043d2e868a50" dependencies = [ "typify-impl", "typify-macro", @@ -6691,9 +6725,9 @@ dependencies = [ [[package]] name = "typify-impl" -version = "0.3.0" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bce48219a2f3154aaa2c56cbf027728b24a3c8fe0a47ed6399781de2b3f3eeaf" +checksum = "b52a67305054e1da6f3d99ad94875dcd0c7c49adbd17b4b64f0eefb7ae5bf8ab" dependencies = [ "heck 0.5.0", "log", @@ -6704,16 +6738,16 @@ dependencies = [ "semver", "serde", "serde_json", - "syn 2.0.100", + "syn 2.0.101", "thiserror 2.0.12", "unicode-ident", ] [[package]] name = "typify-macro" -version = "0.3.0" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "68b5780d745920ed73c5b7447496a9b5c42ed2681a9b70859377aec423ecf02b" +checksum = "0ff5799be156e4f635c348c6051d165e1c59997827155133351a8c4d333d9841" dependencies = [ "proc-macro2", "quote", @@ -6722,7 +6756,7 @@ dependencies = [ "serde", "serde_json", "serde_tokenstream", - "syn 2.0.100", + "syn 2.0.101", "typify-impl", ] @@ -6734,9 +6768,9 @@ checksum = "5c1cb5db39152898a79168971543b1cb5020dff7fe43c8dc468b0885f5e29df5" [[package]] name = "unicode-ident" -version = "1.0.16" +version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a210d160f08b701c8721ba1c726c11662f877ea6b7094007e1ca9a1041945034" +checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512" [[package]] name = "unicode-normalization" @@ -6773,9 +6807,9 @@ checksum = "1fc81956842c57dac11422a97c3b8195a1ff727f06e85c84ed2e8aa277c9a0fd" [[package]] name = "unindent" -version = "0.2.3" +version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce" +checksum = "7264e107f553ccae879d21fbea1d6724ac785e8c3bfc762137959b5802826ef3" [[package]] name = "unsafe-libyaml" @@ -6807,12 +6841,6 @@ version = "2.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" -[[package]] -name = "utf16_iter" -version = "1.0.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8232dd3cdaed5356e0f716d285e4b40b932ac434100fe9b7e0e8e935b9e6246" - [[package]] name = "utf8_iter" version = "1.0.4" @@ -6827,11 +6855,11 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "uuid" -version = "1.16.0" +version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "458f7a779bf54acc9f347480ac654f68407d3aab21269a6e3c9f922acd9e2da9" +checksum = "3cf4199d1e5d15ddd86a694e4d0dffa9c323ce759fea589f00fef9d81cc1931d" dependencies = [ - "getrandom 0.3.1", + "getrandom 0.3.3", "js-sys", "serde", "wasm-bindgen", @@ -6891,9 +6919,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasi" -version = "0.13.3+wasi-0.2.2" +version = "0.14.2+wasi-0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26816d2e1a4a36a2940b96c5296ce403917633dff8f3440e9b236ed6f6bacad2" +checksum = "9683f9a5a998d873c0d21fcbe3c083009670149a8fab228644b8bd36b2c48cb3" dependencies = [ "wit-bindgen-rt", ] @@ -6926,7 +6954,7 @@ dependencies = [ "log", "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.101", "wasm-bindgen-shared", ] @@ -6961,7 +6989,7 @@ checksum = "8ae87ea40c9f689fc23f209965b6fb8a99ad69aeeb0231408be24920604395de" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.101", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -6996,7 +7024,7 @@ checksum = "17d5042cc5fa009658f9a7333ef24291b1291a25b6382dd68862a7f3b969f69b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] @@ -7046,11 +7074,11 @@ dependencies = [ [[package]] name = "whoami" -version = "1.5.2" +version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "372d5b87f58ec45c384ba03563b03544dc5fadc3983e434b286913f5b4a9bb6d" +checksum = "6994d13118ab492c3c80c1f81928718159254c53c472bf9ce36f8dae4add02a7" dependencies = [ - "redox_syscall 0.5.8", + "redox_syscall 0.5.12", "wasite", "web-sys", ] @@ -7088,55 +7116,70 @@ checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" [[package]] name = "windows" -version = "0.57.0" +version = "0.61.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "12342cb4d8e3b046f3d80effd474a7a02447231330ef77d71daa6fbc40681143" +checksum = "c5ee8f3d025738cb02bad7868bbb5f8a6327501e870bf51f1b455b0a2454a419" dependencies = [ - "windows-core 0.57.0", - "windows-targets 0.52.6", + "windows-collections", + "windows-core", + "windows-future", + "windows-link", + "windows-numerics", ] [[package]] -name = "windows-core" -version = "0.52.0" +name = "windows-collections" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" +checksum = "3beeceb5e5cfd9eb1d76b381630e82c4241ccd0d27f1a39ed41b2760b255c5e8" dependencies = [ - "windows-targets 0.52.6", + "windows-core", ] [[package]] name = "windows-core" -version = "0.57.0" +version = "0.61.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2ed2439a290666cd67ecce2b0ffaad89c2a56b976b736e6ece670297897832d" +checksum = "c0fdd3ddb90610c7638aa2b3a3ab2904fb9e5cdbecc643ddb3647212781c4ae3" dependencies = [ "windows-implement", "windows-interface", - "windows-result 0.1.2", - "windows-targets 0.52.6", + "windows-link", + "windows-result", + "windows-strings 0.4.2", +] + +[[package]] +name = "windows-future" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc6a41e98427b19fe4b73c550f060b59fa592d7d686537eebf9385621bfbad8e" +dependencies = [ + "windows-core", + "windows-link", + "windows-threading", ] [[package]] name = "windows-implement" -version = "0.57.0" +version = "0.60.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9107ddc059d5b6fbfbffdfa7a7fe3e22a226def0b2608f72e9d552763d3e1ad7" +checksum = "a47fddd13af08290e67f4acabf4b459f647552718f683a7b415d290ac744a836" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] name = "windows-interface" -version = "0.57.0" +version = "0.59.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "29bee4b38ea3cde66011baa44dba677c432a78593e202392d1e9070cf2a7fca7" +checksum = "bd9211b69f8dcdfa817bfd14bf1c97c9188afa36f4750130fcdf3f400eca9fa8" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] @@ -7146,51 +7189,51 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "76840935b766e1b0a05c0066835fb9ec80071d4c09a16f6bd5f7e655e3c14c38" [[package]] -name = "windows-registry" +name = "windows-numerics" version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e400001bb720a623c1c69032f8e3e4cf09984deec740f007dd2b03ec864804b0" +checksum = "9150af68066c4c5c07ddc0ce30421554771e528bde427614c61038bc2c92c2b1" dependencies = [ - "windows-result 0.2.0", - "windows-strings", - "windows-targets 0.52.6", + "windows-core", + "windows-link", ] [[package]] -name = "windows-result" -version = "0.1.2" +name = "windows-registry" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e383302e8ec8515204254685643de10811af0ed97ea37210dc26fb0032647f8" +checksum = "4286ad90ddb45071efd1a66dfa43eb02dd0dfbae1545ad6cc3c51cf34d7e8ba3" dependencies = [ - "windows-targets 0.52.6", + "windows-result", + "windows-strings 0.3.1", + "windows-targets 0.53.0", ] [[package]] name = "windows-result" -version = "0.2.0" +version = "0.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d1043d8214f791817bab27572aaa8af63732e11bf84aa21a45a78d6c317ae0e" +checksum = "56f42bd332cc6c8eac5af113fc0c1fd6a8fd2aa08a0119358686e5160d0586c6" dependencies = [ - "windows-targets 0.52.6", + "windows-link", ] [[package]] name = "windows-strings" -version = "0.1.0" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4cd9b125c486025df0eabcb585e62173c6c9eddcec5d117d3b6e8c30e2ee4d10" +checksum = "87fa48cc5d406560701792be122a10132491cff9d0aeb23583cc2dcafc847319" dependencies = [ - "windows-result 0.2.0", - "windows-targets 0.52.6", + "windows-link", ] [[package]] -name = "windows-sys" -version = "0.48.0" +name = "windows-strings" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" +checksum = "56e6c93f3a0c3b36176cb1327a4958a0353d5d166c2a35cb268ace15e91d3b57" dependencies = [ - "windows-targets 0.48.5", + "windows-link", ] [[package]] @@ -7211,21 +7254,6 @@ dependencies = [ "windows-targets 0.52.6", ] -[[package]] -name = "windows-targets" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" -dependencies = [ - "windows_aarch64_gnullvm 0.48.5", - "windows_aarch64_msvc 0.48.5", - "windows_i686_gnu 0.48.5", - "windows_i686_msvc 0.48.5", - "windows_x86_64_gnu 0.48.5", - "windows_x86_64_gnullvm 0.48.5", - "windows_x86_64_msvc 0.48.5", -] - [[package]] name = "windows-targets" version = "0.52.6" @@ -7235,7 +7263,7 @@ dependencies = [ "windows_aarch64_gnullvm 0.52.6", "windows_aarch64_msvc 0.52.6", "windows_i686_gnu 0.52.6", - "windows_i686_gnullvm", + "windows_i686_gnullvm 0.52.6", "windows_i686_msvc 0.52.6", "windows_x86_64_gnu 0.52.6", "windows_x86_64_gnullvm 0.52.6", @@ -7243,10 +7271,29 @@ dependencies = [ ] [[package]] -name = "windows_aarch64_gnullvm" -version = "0.48.5" +name = "windows-targets" +version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" +checksum = "b1e4c7e8ceaaf9cb7d7507c974735728ab453b67ef8f18febdd7c11fe59dca8b" +dependencies = [ + "windows_aarch64_gnullvm 0.53.0", + "windows_aarch64_msvc 0.53.0", + "windows_i686_gnu 0.53.0", + "windows_i686_gnullvm 0.53.0", + "windows_i686_msvc 0.53.0", + "windows_x86_64_gnu 0.53.0", + "windows_x86_64_gnullvm 0.53.0", + "windows_x86_64_msvc 0.53.0", +] + +[[package]] +name = "windows-threading" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b66463ad2e0ea3bbf808b7f1d371311c80e115c0b71d60efc142cafbcfb057a6" +dependencies = [ + "windows-link", +] [[package]] name = "windows_aarch64_gnullvm" @@ -7255,10 +7302,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" [[package]] -name = "windows_aarch64_msvc" -version = "0.48.5" +name = "windows_aarch64_gnullvm" +version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" +checksum = "86b8d5f90ddd19cb4a147a5fa63ca848db3df085e25fee3cc10b39b6eebae764" [[package]] name = "windows_aarch64_msvc" @@ -7267,10 +7314,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" [[package]] -name = "windows_i686_gnu" -version = "0.48.5" +name = "windows_aarch64_msvc" +version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" +checksum = "c7651a1f62a11b8cbd5e0d42526e55f2c99886c77e007179efff86c2b137e66c" [[package]] name = "windows_i686_gnu" @@ -7278,6 +7325,12 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" +[[package]] +name = "windows_i686_gnu" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1dc67659d35f387f5f6c479dc4e28f1d4bb90ddd1a5d3da2e5d97b42d6272c3" + [[package]] name = "windows_i686_gnullvm" version = "0.52.6" @@ -7285,10 +7338,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" [[package]] -name = "windows_i686_msvc" -version = "0.48.5" +name = "windows_i686_gnullvm" +version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" +checksum = "9ce6ccbdedbf6d6354471319e781c0dfef054c81fbc7cf83f338a4296c0cae11" [[package]] name = "windows_i686_msvc" @@ -7297,10 +7350,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" [[package]] -name = "windows_x86_64_gnu" -version = "0.48.5" +name = "windows_i686_msvc" +version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" +checksum = "581fee95406bb13382d2f65cd4a908ca7b1e4c2f1917f143ba16efe98a589b5d" [[package]] name = "windows_x86_64_gnu" @@ -7309,10 +7362,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" [[package]] -name = "windows_x86_64_gnullvm" -version = "0.48.5" +name = "windows_x86_64_gnu" +version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" +checksum = "2e55b5ac9ea33f2fc1716d1742db15574fd6fc8dadc51caab1c16a3d3b4190ba" [[package]] name = "windows_x86_64_gnullvm" @@ -7321,10 +7374,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" [[package]] -name = "windows_x86_64_msvc" -version = "0.48.5" +name = "windows_x86_64_gnullvm" +version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" +checksum = "0a6e035dd0599267ce1ee132e51c27dd29437f63325753051e71dd9e42406c57" [[package]] name = "windows_x86_64_msvc" @@ -7332,35 +7385,35 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" +[[package]] +name = "windows_x86_64_msvc" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "271414315aff87387382ec3d271b52d7ae78726f5d44ac98b4f4030c91880486" + [[package]] name = "winnow" -version = "0.7.2" +version = "0.7.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "59690dea168f2198d1a3b0cac23b8063efcd11012f10ae4698f284808c8ef603" +checksum = "c06928c8748d81b05c9be96aad92e1b6ff01833332f281e8cfca3be4b35fc9ec" dependencies = [ "memchr", ] [[package]] name = "wit-bindgen-rt" -version = "0.33.0" +version = "0.39.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3268f3d866458b787f390cf61f4bbb563b922d091359f9608842999eaee3943c" +checksum = "6f42320e61fe2cfd34354ecb597f86f413484a798ba44a8ca1165c58d42da6c1" dependencies = [ - "bitflags 2.8.0", + "bitflags 2.9.1", ] -[[package]] -name = "write16" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d1890f4022759daae28ed4fe62859b1236caebfc61ede2f63ed4e695f3f6d936" - [[package]] name = "writeable" -version = "0.5.5" +version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e9df38ee2d2c3c5948ea468a8406ff0db0b29ae1ffde1bcf20ef305bcc95c51" +checksum = "ea2f10b9bb0928dfb1b42b65e1f9e36f7f54dbdf08457afefb38afcdec4fa2bb" [[package]] name = "wyz" @@ -7373,13 +7426,12 @@ dependencies = [ [[package]] name = "xattr" -version = "1.4.0" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e105d177a3871454f754b33bb0ee637ecaaac997446375fd3e5d43a2ed00c909" +checksum = "0d65cbf2f12c15564212d48f4e3dfb87923d25d611f2aed18f4cb23f0413d89e" dependencies = [ "libc", - "linux-raw-sys 0.4.15", - "rustix 0.38.44", + "rustix 1.0.7", ] [[package]] @@ -7399,9 +7451,9 @@ dependencies = [ [[package]] name = "yoke" -version = "0.7.5" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "120e6aef9aa629e3d4f52dc8cc43a015c7724194c97dfaf45180d2daf2b77f40" +checksum = "5f41bb01b8226ef4bfd589436a297c53d118f65921786300e427be8d487695cc" dependencies = [ "serde", "stable_deref_trait", @@ -7411,75 +7463,54 @@ dependencies = [ [[package]] name = "yoke-derive" -version = "0.7.5" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2380878cad4ac9aac1e2435f3eb4020e8374b5f13c296cb75b4620ff8e229154" +checksum = "38da3c9736e16c5d3c8c597a9aaa5d1fa565d0532ae05e27c24aa62fb32c0ab6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.101", "synstructure", ] [[package]] name = "zerocopy" -version = "0.7.35" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" -dependencies = [ - "byteorder", - "zerocopy-derive 0.7.35", -] - -[[package]] -name = "zerocopy" -version = "0.8.18" +version = "0.8.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "79386d31a42a4996e3336b0919ddb90f81112af416270cff95b5f5af22b839c2" +checksum = "a1702d9583232ddb9174e01bb7c15a2ab8fb1bc6f227aa1233858c351a3ba0cb" dependencies = [ - "zerocopy-derive 0.8.18", + "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.7.35" +version = "0.8.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" +checksum = "28a6e20d751156648aa063f3800b706ee209a32c0b4d9f24be3d980b01be55ef" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", -] - -[[package]] -name = "zerocopy-derive" -version = "0.8.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76331675d372f91bf8d17e13afbd5fe639200b73d01f0fc748bb059f9cca2db7" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] name = "zerofrom" -version = "0.1.5" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cff3ee08c995dee1859d998dea82f7374f2826091dd9cd47def953cae446cd2e" +checksum = "50cc42e0333e05660c3587f3bf9d0478688e15d870fab3346451ce7f8c9fbea5" dependencies = [ "zerofrom-derive", ] [[package]] name = "zerofrom-derive" -version = "0.1.5" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "595eed982f7d355beb85837f651fa22e90b3c044842dc7f2c2842c086f295808" +checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.101", "synstructure", ] @@ -7489,11 +7520,22 @@ version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" +[[package]] +name = "zerotrie" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "36f0bbd478583f79edad978b407914f61b2972f5af6fa089686016be8f9af595" +dependencies = [ + "displaydoc", + "yoke", + "zerofrom", +] + [[package]] name = "zerovec" -version = "0.10.4" +version = "0.11.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aa2b893d79df23bfb12d5461018d408ea19dfafe76c2c7ef6d4eba614f8ff079" +checksum = "4a05eb080e015ba39cc9e23bbe5e7fb04d5fb040350f99f34e338d5fdd294428" dependencies = [ "yoke", "zerofrom", @@ -7502,13 +7544,13 @@ dependencies = [ [[package]] name = "zerovec-derive" -version = "0.10.3" +version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6eafa6dfb17584ea3e2bd6e76e0cc15ad7af12b09abdd1ca55961bed9b1063c6" +checksum = "5b96237efa0c878c64bd89c436f661be4e46b2f3eff1ebb976f7ef2321d2f58f" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] @@ -7528,18 +7570,18 @@ dependencies = [ [[package]] name = "zstd-safe" -version = "7.2.1" +version = "7.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "54a3ab4db68cea366acc5c897c7b4d4d1b8994a9cd6e6f841f8964566a419059" +checksum = "8f49c4d5f0abb602a93fb8736af2a4f4dd9512e36f7f570d66e65ff867ed3b9d" dependencies = [ "zstd-sys", ] [[package]] name = "zstd-sys" -version = "2.0.13+zstd.1.5.6" +version = "2.0.15+zstd.1.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38ff0f21cfee8f97d94cef41359e0c89aa6113028ab0291aa8ca0038995a95aa" +checksum = "eb81183ddd97d0c74cedf1d50d85c8d08c1b8b68ee863bdee9e706eedba1a237" dependencies = [ "cc", "pkg-config", diff --git a/Cargo.toml b/Cargo.toml index 4f8cfa8baa87..366701bdae2e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -48,6 +48,7 @@ members = [ "datafusion/proto-common", "datafusion/proto-common/gen", "datafusion/session", + "datafusion/spark", "datafusion/sql", "datafusion/sqllogictest", "datafusion/substrait", @@ -75,7 +76,7 @@ repository = "https://github.com/apache/datafusion" # Define Minimum Supported Rust Version (MSRV) rust-version = "1.82.0" # Define DataFusion version -version = "47.0.0" +version = "48.0.1" [workspace.dependencies] # We turn off default-features for some dependencies here so the workspaces which inherit them can @@ -87,12 +88,12 @@ ahash = { version = "0.8", default-features = false, features = [ "runtime-rng", ] } apache-avro = { version = "0.17", default-features = false } -arrow = { version = "55.0.0", features = [ +arrow = { version = "55.1.0", features = [ "prettyprint", "chrono-tz", ] } arrow-buffer = { version = "55.0.0", default-features = false } -arrow-flight = { version = "55.0.0", features = [ +arrow-flight = { version = "55.1.0", features = [ "flight-sql-experimental", ] } arrow-ipc = { version = "55.0.0", default-features = false, features = [ @@ -103,53 +104,54 @@ arrow-schema = { version = "55.0.0", default-features = false } async-trait = "0.1.88" bigdecimal = "0.4.8" bytes = "1.10" -chrono = { version = "0.4.38", default-features = false } +chrono = { version = "0.4.41", default-features = false } criterion = "0.5.1" -ctor = "0.2.9" +ctor = "0.4.0" dashmap = "6.0.1" -datafusion = { path = "datafusion/core", version = "47.0.0", default-features = false } -datafusion-catalog = { path = "datafusion/catalog", version = "47.0.0" } -datafusion-catalog-listing = { path = "datafusion/catalog-listing", version = "47.0.0" } -datafusion-common = { path = "datafusion/common", version = "47.0.0", default-features = false } -datafusion-common-runtime = { path = "datafusion/common-runtime", version = "47.0.0" } -datafusion-datasource = { path = "datafusion/datasource", version = "47.0.0", default-features = false } -datafusion-datasource-avro = { path = "datafusion/datasource-avro", version = "47.0.0", default-features = false } -datafusion-datasource-csv = { path = "datafusion/datasource-csv", version = "47.0.0", default-features = false } -datafusion-datasource-json = { path = "datafusion/datasource-json", version = "47.0.0", default-features = false } -datafusion-datasource-parquet = { path = "datafusion/datasource-parquet", version = "47.0.0", default-features = false } -datafusion-doc = { path = "datafusion/doc", version = "47.0.0" } -datafusion-execution = { path = "datafusion/execution", version = "47.0.0" } -datafusion-expr = { path = "datafusion/expr", version = "47.0.0" } -datafusion-expr-common = { path = "datafusion/expr-common", version = "47.0.0" } -datafusion-ffi = { path = "datafusion/ffi", version = "47.0.0" } -datafusion-functions = { path = "datafusion/functions", version = "47.0.0" } -datafusion-functions-aggregate = { path = "datafusion/functions-aggregate", version = "47.0.0" } -datafusion-functions-aggregate-common = { path = "datafusion/functions-aggregate-common", version = "47.0.0" } -datafusion-functions-nested = { path = "datafusion/functions-nested", version = "47.0.0" } -datafusion-functions-table = { path = "datafusion/functions-table", version = "47.0.0" } -datafusion-functions-window = { path = "datafusion/functions-window", version = "47.0.0" } -datafusion-functions-window-common = { path = "datafusion/functions-window-common", version = "47.0.0" } -datafusion-macros = { path = "datafusion/macros", version = "47.0.0" } -datafusion-optimizer = { path = "datafusion/optimizer", version = "47.0.0", default-features = false } -datafusion-physical-expr = { path = "datafusion/physical-expr", version = "47.0.0", default-features = false } -datafusion-physical-expr-common = { path = "datafusion/physical-expr-common", version = "47.0.0", default-features = false } -datafusion-physical-optimizer = { path = "datafusion/physical-optimizer", version = "47.0.0" } -datafusion-physical-plan = { path = "datafusion/physical-plan", version = "47.0.0" } -datafusion-proto = { path = "datafusion/proto", version = "47.0.0" } -datafusion-proto-common = { path = "datafusion/proto-common", version = "47.0.0" } -datafusion-session = { path = "datafusion/session", version = "47.0.0" } -datafusion-sql = { path = "datafusion/sql", version = "47.0.0" } +datafusion = { path = "datafusion/core", version = "48.0.1", default-features = false } +datafusion-catalog = { path = "datafusion/catalog", version = "48.0.1" } +datafusion-catalog-listing = { path = "datafusion/catalog-listing", version = "48.0.1" } +datafusion-common = { path = "datafusion/common", version = "48.0.1", default-features = false } +datafusion-common-runtime = { path = "datafusion/common-runtime", version = "48.0.1" } +datafusion-datasource = { path = "datafusion/datasource", version = "48.0.1", default-features = false } +datafusion-datasource-avro = { path = "datafusion/datasource-avro", version = "48.0.1", default-features = false } +datafusion-datasource-csv = { path = "datafusion/datasource-csv", version = "48.0.1", default-features = false } +datafusion-datasource-json = { path = "datafusion/datasource-json", version = "48.0.1", default-features = false } +datafusion-datasource-parquet = { path = "datafusion/datasource-parquet", version = "48.0.1", default-features = false } +datafusion-doc = { path = "datafusion/doc", version = "48.0.1" } +datafusion-execution = { path = "datafusion/execution", version = "48.0.1" } +datafusion-expr = { path = "datafusion/expr", version = "48.0.1" } +datafusion-expr-common = { path = "datafusion/expr-common", version = "48.0.1" } +datafusion-ffi = { path = "datafusion/ffi", version = "48.0.1" } +datafusion-functions = { path = "datafusion/functions", version = "48.0.1" } +datafusion-functions-aggregate = { path = "datafusion/functions-aggregate", version = "48.0.1" } +datafusion-functions-aggregate-common = { path = "datafusion/functions-aggregate-common", version = "48.0.1" } +datafusion-functions-nested = { path = "datafusion/functions-nested", version = "48.0.1" } +datafusion-functions-table = { path = "datafusion/functions-table", version = "48.0.1" } +datafusion-functions-window = { path = "datafusion/functions-window", version = "48.0.1" } +datafusion-functions-window-common = { path = "datafusion/functions-window-common", version = "48.0.1" } +datafusion-macros = { path = "datafusion/macros", version = "48.0.1" } +datafusion-optimizer = { path = "datafusion/optimizer", version = "48.0.1", default-features = false } +datafusion-physical-expr = { path = "datafusion/physical-expr", version = "48.0.1", default-features = false } +datafusion-physical-expr-common = { path = "datafusion/physical-expr-common", version = "48.0.1", default-features = false } +datafusion-physical-optimizer = { path = "datafusion/physical-optimizer", version = "48.0.1" } +datafusion-physical-plan = { path = "datafusion/physical-plan", version = "48.0.1" } +datafusion-proto = { path = "datafusion/proto", version = "48.0.1" } +datafusion-proto-common = { path = "datafusion/proto-common", version = "48.0.1" } +datafusion-session = { path = "datafusion/session", version = "48.0.1" } +datafusion-spark = { path = "datafusion/spark", version = "48.0.1" } +datafusion-sql = { path = "datafusion/sql", version = "48.0.1" } doc-comment = "0.3" env_logger = "0.11" futures = "0.3" -half = { version = "2.5.0", default-features = false } +half = { version = "2.6.0", default-features = false } hashbrown = { version = "0.14.5", features = ["raw"] } indexmap = "2.9.0" itertools = "0.14" log = "^0.4" object_store = { version = "0.12.0", default-features = false } parking_lot = "0.12" -parquet = { version = "55.0.0", default-features = false, features = [ +parquet = { version = "55.1.0", default-features = false, features = [ "arrow", "async", "object_store", @@ -157,16 +159,16 @@ parquet = { version = "55.0.0", default-features = false, features = [ pbjson = { version = "0.7.0" } pbjson-types = "0.7" # Should match arrow-flight's version of prost. -insta = { version = "1.41.1", features = ["glob", "filters"] } +insta = { version = "1.43.1", features = ["glob", "filters"] } prost = "0.13.1" -rand = "0.8.5" +rand = "0.9" recursive = "0.1.1" regex = "1.8" -rstest = "0.24.0" +rstest = "0.25.0" serde_json = "1" sqlparser = { version = "0.55.0", features = ["visitor"] } tempfile = "3" -tokio = { version = "1.44", features = ["macros", "rt", "sync"] } +tokio = { version = "1.45", features = ["macros", "rt", "sync"] } url = "2.5.4" [profile.release] @@ -209,7 +211,10 @@ strip = false # Detects large stack-allocated futures that may cause stack overflow crashes (see threshold in clippy.toml) large_futures = "warn" used_underscore_binding = "warn" +or_fun_call = "warn" +unnecessary_lazy_evaluations = "warn" +uninlined_format_args = "warn" [workspace.lints.rust] -unexpected_cfgs = { level = "warn", check-cfg = ["cfg(tarpaulin)"] } +unexpected_cfgs = { level = "warn", check-cfg = ["cfg(tarpaulin)", "cfg(tarpaulin_include)"] } unused_qualifications = "deny" diff --git a/NOTICE.txt b/NOTICE.txt index 21be1a20d554..7f3c80d606c0 100644 --- a/NOTICE.txt +++ b/NOTICE.txt @@ -1,5 +1,5 @@ Apache DataFusion -Copyright 2019-2024 The Apache Software Foundation +Copyright 2019-2025 The Apache Software Foundation This product includes software developed at -The Apache Software Foundation (http://www.apache.org/). \ No newline at end of file +The Apache Software Foundation (http://www.apache.org/). diff --git a/README.md b/README.md index 158033d40599..c142d8f366b2 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,7 @@ [![Open Issues][open-issues-badge]][open-issues-url] [![Discord chat][discord-badge]][discord-url] [![Linkedin][linkedin-badge]][linkedin-url] +![Crates.io MSRV][msrv-badge] [crates-badge]: https://img.shields.io/crates/v/datafusion.svg [crates-url]: https://crates.io/crates/datafusion @@ -40,6 +41,7 @@ [open-issues-url]: https://github.com/apache/datafusion/issues [linkedin-badge]: https://img.shields.io/badge/Follow-Linkedin-blue [linkedin-url]: https://www.linkedin.com/company/apache-datafusion/ +[msrv-badge]: https://img.shields.io/crates/msrv/datafusion?label=Min%20Rust%20Version [Website](https://datafusion.apache.org/) | [API Docs](https://docs.rs/datafusion/latest/datafusion/) | @@ -133,20 +135,6 @@ Optional features: [apache avro]: https://avro.apache.org/ [apache parquet]: https://parquet.apache.org/ -## Rust Version Compatibility Policy - -The Rust toolchain releases are tracked at [Rust Versions](https://releases.rs) and follow -[semantic versioning](https://semver.org/). A Rust toolchain release can be identified -by a version string like `1.80.0`, or more generally `major.minor.patch`. - -DataFusion's supports the last 4 stable Rust minor versions released and any such versions released within the last 4 months. - -For example, given the releases `1.78.0`, `1.79.0`, `1.80.0`, `1.80.1` and `1.81.0` DataFusion will support 1.78.0, which is 3 minor versions prior to the most minor recent `1.81`. - -Note: If a Rust hotfix is released for the current MSRV, the MSRV will be updated to the specific minor version that includes all applicable hotfixes preceding other policies. - -DataFusion enforces MSRV policy using a [MSRV CI Check](https://github.com/search?q=repo%3Aapache%2Fdatafusion+rust-version+language%3ATOML+path%3A%2F%5ECargo.toml%2F&type=code) - ## DataFusion API Evolution and Deprecation Guidelines Public methods in Apache DataFusion evolve over time: while we try to maintain a diff --git a/benchmarks/Cargo.toml b/benchmarks/Cargo.toml index 063f4dac22d8..f9c198597b74 100644 --- a/benchmarks/Cargo.toml +++ b/benchmarks/Cargo.toml @@ -51,7 +51,7 @@ snmalloc-rs = { version = "0.3", optional = true } structopt = { version = "0.3", default-features = false } test-utils = { path = "../test-utils/", version = "0.1.0" } tokio = { workspace = true, features = ["rt-multi-thread", "parking_lot"] } -tokio-util = { version = "0.7.14" } +tokio-util = { version = "0.7.15" } [dev-dependencies] datafusion-proto = { workspace = true } diff --git a/benchmarks/README.md b/benchmarks/README.md index 86b2e1b3b958..b19b3385afc8 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -23,7 +23,6 @@ This crate contains benchmarks based on popular public data sets and open source benchmark suites, to help with performance and scalability testing of DataFusion. - ## Other engines The benchmarks measure changes to DataFusion itself, rather than @@ -31,11 +30,11 @@ its performance against other engines. For competitive benchmarking, DataFusion is included in the benchmark setups for several popular benchmarks that compare performance with other engines. For example: -* [ClickBench] scripts are in the [ClickBench repo](https://github.com/ClickHouse/ClickBench/tree/main/datafusion) -* [H2o.ai `db-benchmark`] scripts are in [db-benchmark](https://github.com/apache/datafusion/tree/main/benchmarks/src/h2o.rs) +- [ClickBench] scripts are in the [ClickBench repo](https://github.com/ClickHouse/ClickBench/tree/main/datafusion) +- [H2o.ai `db-benchmark`] scripts are in [db-benchmark](https://github.com/apache/datafusion/tree/main/benchmarks/src/h2o.rs) -[ClickBench]: https://github.com/ClickHouse/ClickBench/tree/main -[H2o.ai `db-benchmark`]: https://github.com/h2oai/db-benchmark +[clickbench]: https://github.com/ClickHouse/ClickBench/tree/main +[h2o.ai `db-benchmark`]: https://github.com/h2oai/db-benchmark # Running the benchmarks @@ -65,31 +64,54 @@ Create / download a specific dataset (TPCH) ```shell ./bench.sh data tpch ``` + Data is placed in the `data` subdirectory. ## Running benchmarks Run benchmark for TPC-H dataset + ```shell ./bench.sh run tpch ``` + or for TPC-H dataset scale 10 + ```shell ./bench.sh run tpch10 ``` To run for specific query, for example Q21 + ```shell ./bench.sh run tpch10 21 ``` -## Select join algorithm +## Benchmark with modified configurations + +### Select join algorithm + The benchmark runs with `prefer_hash_join == true` by default, which enforces HASH join algorithm. To run TPCH benchmarks with join other than HASH: + ```shell PREFER_HASH_JOIN=false ./bench.sh run tpch ``` +### Configure with environment variables + +Any [datafusion options](https://datafusion.apache.org/user-guide/configs.html) that are provided environment variables are +also considered by the benchmarks. +The following configuration runs the TPCH benchmark with datafusion configured to _not_ repartition join keys. + +```shell +DATAFUSION_OPTIMIZER_REPARTITION_JOINS=false ./bench.sh run tpch +``` + +You might want to adjust the results location to avoid overwriting previous results. +Environment configuration that was picked up by datafusion is logged at `info` level. +To verify that datafusion picked up your configuration, run the benchmarks with `RUST_LOG=info` or higher. + ## Comparing performance of main and a branch ```shell @@ -407,7 +429,7 @@ logs. Example -dfbench parquet-filter --path ./data --scale-factor 1.0 +dfbench parquet-filter --path ./data --scale-factor 1.0 generates the synthetic dataset at `./data/logs.parquet`. The size of the dataset can be controlled through the `size_factor` @@ -439,6 +461,7 @@ Iteration 2 returned 1781686 rows in 1947 ms ``` ## Sort + Test performance of sorting large datasets This test sorts a a synthetic dataset generated during the @@ -462,22 +485,27 @@ Additionally, an optional `--limit` flag is available for the sort benchmark. Wh See [`sort_tpch.rs`](src/sort_tpch.rs) for more details. ### Sort TPCH Benchmark Example Runs + 1. Run all queries with default setting: + ```bash cargo run --release --bin dfbench -- sort-tpch -p './datafusion/benchmarks/data/tpch_sf1' -o '/tmp/sort_tpch.json' ``` 2. Run a specific query: + ```bash cargo run --release --bin dfbench -- sort-tpch -p './datafusion/benchmarks/data/tpch_sf1' -o '/tmp/sort_tpch.json' --query 2 ``` 3. Run all queries as TopK queries on presorted data: + ```bash cargo run --release --bin dfbench -- sort-tpch --sorted --limit 10 -p './datafusion/benchmarks/data/tpch_sf1' -o '/tmp/sort_tpch.json' ``` 4. Run all queries with `bench.sh` script: + ```bash ./bench.sh run sort_tpch ``` @@ -515,59 +543,78 @@ External aggregation benchmarks run several aggregation queries with different m This benchmark is inspired by [DuckDB's external aggregation paper](https://hannes.muehleisen.org/publications/icde2024-out-of-core-kuiper-boncz-muehleisen.pdf), specifically Section VI. ### External Aggregation Example Runs + 1. Run all queries with predefined memory limits: + ```bash # Under 'benchmarks/' directory cargo run --release --bin external_aggr -- benchmark -n 4 --iterations 3 -p '....../data/tpch_sf1' -o '/tmp/aggr.json' ``` 2. Run a query with specific memory limit: + ```bash cargo run --release --bin external_aggr -- benchmark -n 4 --iterations 3 -p '....../data/tpch_sf1' -o '/tmp/aggr.json' --query 1 --memory-limit 30M ``` 3. Run all queries with `bench.sh` script: + ```bash ./bench.sh data external_aggr ./bench.sh run external_aggr ``` +## h2o.ai benchmarks + +The h2o.ai benchmarks are a set of performance tests for groupby and join operations. Beyond the standard h2o benchmark, there is also an extended benchmark for window functions. These benchmarks use synthetic data with configurable sizes (small: 1e7 rows, medium: 1e8 rows, big: 1e9 rows) to evaluate DataFusion's performance across different data scales. + +Reference: + +- [H2O AI Benchmark](https://duckdb.org/2023/04/14/h2oai.html) +- [Extended window benchmark](https://duckdb.org/2024/06/26/benchmarks-over-time.html#window-functions-benchmark) -## h2o benchmarks for groupby +### h2o benchmarks for groupby + +#### Generate data for h2o benchmarks -### Generate data for h2o benchmarks There are three options for generating data for h2o benchmarks: `small`, `medium`, and `big`. The data is generated in the `data` directory. 1. Generate small data (1e7 rows) + ```bash ./bench.sh data h2o_small ``` - 2. Generate medium data (1e8 rows) + ```bash ./bench.sh data h2o_medium ``` - 3. Generate large data (1e9 rows) + ```bash ./bench.sh data h2o_big ``` -### Run h2o benchmarks +#### Run h2o benchmarks + There are three options for running h2o benchmarks: `small`, `medium`, and `big`. + 1. Run small data benchmark + ```bash ./bench.sh run h2o_small ``` 2. Run medium data benchmark + ```bash ./bench.sh run h2o_medium ``` 3. Run large data benchmark + ```bash ./bench.sh run h2o_big ``` @@ -575,53 +622,53 @@ There are three options for running h2o benchmarks: `small`, `medium`, and `big` 4. Run a specific query with a specific data path For example, to run query 1 with the small data generated above: + ```bash cargo run --release --bin dfbench -- h2o --path ./benchmarks/data/h2o/G1_1e7_1e7_100_0.csv --query 1 ``` -## h2o benchmarks for join +### h2o benchmarks for join -### Generate data for h2o benchmarks There are three options for generating data for h2o benchmarks: `small`, `medium`, and `big`. The data is generated in the `data` directory. -1. Generate small data (4 table files, the largest is 1e7 rows) +Here is a example to generate `small` dataset and run the benchmark. To run other +dataset size configuration, change the command similar to the previous example. + ```bash +# Generate small data (4 table files, the largest is 1e7 rows) ./bench.sh data h2o_small_join + +# Run the benchmark +./bench.sh run h2o_small_join ``` +To run a specific query with a specific join data paths, the data paths are including 4 table files. -2. Generate medium data (4 table files, the largest is 1e8 rows) -```bash -./bench.sh data h2o_medium_join -``` +For example, to run query 1 with the small data generated above: -3. Generate large data (4 table files, the largest is 1e9 rows) ```bash -./bench.sh data h2o_big_join +cargo run --release --bin dfbench -- h2o --join-paths ./benchmarks/data/h2o/J1_1e7_NA_0.csv,./benchmarks/data/h2o/J1_1e7_1e1_0.csv,./benchmarks/data/h2o/J1_1e7_1e4_0.csv,./benchmarks/data/h2o/J1_1e7_1e7_NA.csv --queries-path ./benchmarks/queries/h2o/join.sql --query 1 ``` -### Run h2o benchmarks -There are three options for running h2o benchmarks: `small`, `medium`, and `big`. -1. Run small data benchmark -```bash -./bench.sh run h2o_small_join -``` +### Extended h2o benchmarks for window -2. Run medium data benchmark -```bash -./bench.sh run h2o_medium_join -``` +This benchmark extends the h2o benchmark suite to evaluate window function performance. H2o window benchmark uses the same dataset as the h2o join benchmark. There are three options for generating data for h2o benchmarks: `small`, `medium`, and `big`. + +Here is a example to generate `small` dataset and run the benchmark. To run other +dataset size configuration, change the command similar to the previous example. -3. Run large data benchmark ```bash -./bench.sh run h2o_big_join +# Generate small data +./bench.sh data h2o_small_window + +# Run the benchmark +./bench.sh run h2o_small_window ``` -4. Run a specific query with a specific join data paths, the data paths are including 4 table files. +To run a specific query with a specific window data paths, the data paths are including 4 table files (the same as h2o-join dataset) For example, to run query 1 with the small data generated above: + ```bash -cargo run --release --bin dfbench -- h2o --join-paths ./benchmarks/data/h2o/J1_1e7_NA_0.csv,./benchmarks/data/h2o/J1_1e7_1e1_0.csv,./benchmarks/data/h2o/J1_1e7_1e4_0.csv,./benchmarks/data/h2o/J1_1e7_1e7_NA.csv --queries-path ./benchmarks/queries/h2o/join.sql --query 1 +cargo run --release --bin dfbench -- h2o --join-paths ./benchmarks/data/h2o/J1_1e7_NA_0.csv,./benchmarks/data/h2o/J1_1e7_1e1_0.csv,./benchmarks/data/h2o/J1_1e7_1e4_0.csv,./benchmarks/data/h2o/J1_1e7_1e7_NA.csv --queries-path ./benchmarks/queries/h2o/window.sql --query 1 ``` -[1]: http://www.tpc.org/tpch/ -[2]: https://www1.nyc.gov/site/tlc/about/tlc-trip-record-data.page diff --git a/benchmarks/bench.sh b/benchmarks/bench.sh index 5d3ad3446ddb..7fafb751b65a 100755 --- a/benchmarks/bench.sh +++ b/benchmarks/bench.sh @@ -28,6 +28,12 @@ set -e # https://stackoverflow.com/questions/59895/how-do-i-get-the-directory-where-a-bash-script-is-located-from-within-the-script SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +# Execute command and also print it, for debugging purposes +debug_run() { + set -x + "$@" + set +x +} # Set Defaults COMMAND= @@ -87,6 +93,9 @@ h2o_big: h2oai benchmark with large dataset (1e9 rows) for groupb h2o_small_join: h2oai benchmark with small dataset (1e7 rows) for join, default file format is csv h2o_medium_join: h2oai benchmark with medium dataset (1e8 rows) for join, default file format is csv h2o_big_join: h2oai benchmark with large dataset (1e9 rows) for join, default file format is csv +h2o_small_window: Extended h2oai benchmark with small dataset (1e7 rows) for window, default file format is csv +h2o_medium_window: Extended h2oai benchmark with medium dataset (1e8 rows) for window, default file format is csv +h2o_big_window: Extended h2oai benchmark with large dataset (1e9 rows) for window, default file format is csv imdb: Join Order Benchmark (JOB) using the IMDB dataset converted to parquet ********** @@ -98,6 +107,7 @@ DATAFUSION_DIR directory to use (default $DATAFUSION_DIR) RESULTS_NAME folder where the benchmark files are stored PREFER_HASH_JOIN Prefer hash join algorithm (default true) VENV_PATH Python venv to use for compare and venv commands (default ./venv, override by /bin/activate) +DATAFUSION_* Set the given datafusion configuration " exit 1 } @@ -204,6 +214,16 @@ main() { h2o_big_join) data_h2o_join "BIG" "CSV" ;; + # h2o window benchmark uses the same data as the h2o join + h2o_small_window) + data_h2o_join "SMALL" "CSV" + ;; + h2o_medium_window) + data_h2o_join "MEDIUM" "CSV" + ;; + h2o_big_window) + data_h2o_join "BIG" "CSV" + ;; external_aggr) # same data as for tpch data_tpch "1" @@ -314,6 +334,15 @@ main() { h2o_big_join) run_h2o_join "BIG" "CSV" "join" ;; + h2o_small_window) + run_h2o_window "SMALL" "CSV" "window" + ;; + h2o_medium_window) + run_h2o_window "MEDIUM" "CSV" "window" + ;; + h2o_big_window) + run_h2o_window "BIG" "CSV" "window" + ;; external_aggr) run_external_aggr ;; @@ -412,10 +441,7 @@ run_tpch() { echo "Running tpch benchmark..." # Optional query filter to run specific query QUERY=$([ -n "$ARG3" ] && echo "--query $ARG3" || echo "") - # debug the target command - set -x - $CARGO_COMMAND --bin tpch -- benchmark datafusion --iterations 5 --path "${TPCH_DIR}" --prefer_hash_join "${PREFER_HASH_JOIN}" --format parquet -o "${RESULTS_FILE}" $QUERY - set +x + debug_run $CARGO_COMMAND --bin tpch -- benchmark datafusion --iterations 5 --path "${TPCH_DIR}" --prefer_hash_join "${PREFER_HASH_JOIN}" --format parquet -o "${RESULTS_FILE}" $QUERY } # Runs the tpch in memory @@ -432,11 +458,8 @@ run_tpch_mem() { echo "Running tpch_mem benchmark..." # Optional query filter to run specific query QUERY=$([ -n "$ARG3" ] && echo "--query $ARG3" || echo "") - # debug the target command - set -x # -m means in memory - $CARGO_COMMAND --bin tpch -- benchmark datafusion --iterations 5 --path "${TPCH_DIR}" --prefer_hash_join "${PREFER_HASH_JOIN}" -m --format parquet -o "${RESULTS_FILE}" $QUERY - set +x + debug_run $CARGO_COMMAND --bin tpch -- benchmark datafusion --iterations 5 --path "${TPCH_DIR}" --prefer_hash_join "${PREFER_HASH_JOIN}" -m --format parquet -o "${RESULTS_FILE}" $QUERY } # Runs the cancellation benchmark @@ -444,7 +467,7 @@ run_cancellation() { RESULTS_FILE="${RESULTS_DIR}/cancellation.json" echo "RESULTS_FILE: ${RESULTS_FILE}" echo "Running cancellation benchmark..." - $CARGO_COMMAND --bin dfbench -- cancellation --iterations 5 --path "${DATA_DIR}/cancellation" -o "${RESULTS_FILE}" + debug_run $CARGO_COMMAND --bin dfbench -- cancellation --iterations 5 --path "${DATA_DIR}/cancellation" -o "${RESULTS_FILE}" } # Runs the parquet filter benchmark @@ -452,7 +475,7 @@ run_parquet() { RESULTS_FILE="${RESULTS_DIR}/parquet.json" echo "RESULTS_FILE: ${RESULTS_FILE}" echo "Running parquet filter benchmark..." - $CARGO_COMMAND --bin parquet -- filter --path "${DATA_DIR}" --scale-factor 1.0 --iterations 5 -o "${RESULTS_FILE}" + debug_run $CARGO_COMMAND --bin parquet -- filter --path "${DATA_DIR}" --scale-factor 1.0 --iterations 5 -o "${RESULTS_FILE}" } # Runs the sort benchmark @@ -460,7 +483,7 @@ run_sort() { RESULTS_FILE="${RESULTS_DIR}/sort.json" echo "RESULTS_FILE: ${RESULTS_FILE}" echo "Running sort benchmark..." - $CARGO_COMMAND --bin parquet -- sort --path "${DATA_DIR}" --scale-factor 1.0 --iterations 5 -o "${RESULTS_FILE}" + debug_run $CARGO_COMMAND --bin parquet -- sort --path "${DATA_DIR}" --scale-factor 1.0 --iterations 5 -o "${RESULTS_FILE}" } @@ -514,7 +537,7 @@ run_clickbench_1() { RESULTS_FILE="${RESULTS_DIR}/clickbench_1.json" echo "RESULTS_FILE: ${RESULTS_FILE}" echo "Running clickbench (1 file) benchmark..." - $CARGO_COMMAND --bin dfbench -- clickbench --iterations 5 --path "${DATA_DIR}/hits.parquet" --queries-path "${SCRIPT_DIR}/queries/clickbench/queries.sql" -o "${RESULTS_FILE}" + debug_run $CARGO_COMMAND --bin dfbench -- clickbench --iterations 5 --path "${DATA_DIR}/hits.parquet" --queries-path "${SCRIPT_DIR}/queries/clickbench/queries.sql" -o "${RESULTS_FILE}" } # Runs the clickbench benchmark with the partitioned parquet files @@ -522,7 +545,7 @@ run_clickbench_partitioned() { RESULTS_FILE="${RESULTS_DIR}/clickbench_partitioned.json" echo "RESULTS_FILE: ${RESULTS_FILE}" echo "Running clickbench (partitioned, 100 files) benchmark..." - $CARGO_COMMAND --bin dfbench -- clickbench --iterations 5 --path "${DATA_DIR}/hits_partitioned" --queries-path "${SCRIPT_DIR}/queries/clickbench/queries.sql" -o "${RESULTS_FILE}" + debug_run $CARGO_COMMAND --bin dfbench -- clickbench --iterations 5 --path "${DATA_DIR}/hits_partitioned" --queries-path "${SCRIPT_DIR}/queries/clickbench/queries.sql" -o "${RESULTS_FILE}" } # Runs the clickbench "extended" benchmark with a single large parquet file @@ -530,7 +553,7 @@ run_clickbench_extended() { RESULTS_FILE="${RESULTS_DIR}/clickbench_extended.json" echo "RESULTS_FILE: ${RESULTS_FILE}" echo "Running clickbench (1 file) extended benchmark..." - $CARGO_COMMAND --bin dfbench -- clickbench --iterations 5 --path "${DATA_DIR}/hits.parquet" --queries-path "${SCRIPT_DIR}/queries/clickbench/extended.sql" -o "${RESULTS_FILE}" + debug_run $CARGO_COMMAND --bin dfbench -- clickbench --iterations 5 --path "${DATA_DIR}/hits.parquet" --queries-path "${SCRIPT_DIR}/queries/clickbench/extended.sql" -o "${RESULTS_FILE}" } # Downloads the csv.gz files IMDB datasets from Peter Boncz's homepage(one of the JOB paper authors) @@ -645,7 +668,7 @@ run_imdb() { RESULTS_FILE="${RESULTS_DIR}/imdb.json" echo "RESULTS_FILE: ${RESULTS_FILE}" echo "Running imdb benchmark..." - $CARGO_COMMAND --bin imdb -- benchmark datafusion --iterations 5 --path "${IMDB_DIR}" --prefer_hash_join "${PREFER_HASH_JOIN}" --format parquet -o "${RESULTS_FILE}" + debug_run $CARGO_COMMAND --bin imdb -- benchmark datafusion --iterations 5 --path "${IMDB_DIR}" --prefer_hash_join "${PREFER_HASH_JOIN}" --format parquet -o "${RESULTS_FILE}" } data_h2o() { @@ -800,6 +823,7 @@ data_h2o_join() { deactivate } +# Runner for h2o groupby benchmark run_h2o() { # Default values for size and data format SIZE=${1:-"SMALL"} @@ -835,14 +859,15 @@ run_h2o() { QUERY_FILE="${SCRIPT_DIR}/queries/h2o/${RUN_Type}.sql" # Run the benchmark using the dynamically constructed file path and query file - $CARGO_COMMAND --bin dfbench -- h2o \ + debug_run $CARGO_COMMAND --bin dfbench -- h2o \ --iterations 3 \ --path "${H2O_DIR}/${FILE_NAME}" \ --queries-path "${QUERY_FILE}" \ -o "${RESULTS_FILE}" } -run_h2o_join() { +# Utility function to run h2o join/window benchmark +h2o_runner() { # Default values for size and data format SIZE=${1:-"SMALL"} DATA_FORMAT=${2:-"CSV"} @@ -851,10 +876,10 @@ run_h2o_join() { # Data directory and results file path H2O_DIR="${DATA_DIR}/h2o" - RESULTS_FILE="${RESULTS_DIR}/h2o_join.json" + RESULTS_FILE="${RESULTS_DIR}/h2o_${RUN_Type}.json" echo "RESULTS_FILE: ${RESULTS_FILE}" - echo "Running h2o join benchmark..." + echo "Running h2o ${RUN_Type} benchmark..." # Set the file name based on the size case "$SIZE" in @@ -882,16 +907,26 @@ run_h2o_join() { ;; esac - # Set the query file name based on the RUN_Type + # Set the query file name based on the RUN_Type QUERY_FILE="${SCRIPT_DIR}/queries/h2o/${RUN_Type}.sql" - $CARGO_COMMAND --bin dfbench -- h2o \ + debug_run $CARGO_COMMAND --bin dfbench -- h2o \ --iterations 3 \ --join-paths "${H2O_DIR}/${X_TABLE_FILE_NAME},${H2O_DIR}/${SMALL_TABLE_FILE_NAME},${H2O_DIR}/${MEDIUM_TABLE_FILE_NAME},${H2O_DIR}/${LARGE_TABLE_FILE_NAME}" \ --queries-path "${QUERY_FILE}" \ -o "${RESULTS_FILE}" } +# Runners for h2o join benchmark +run_h2o_join() { + h2o_runner "$1" "$2" "join" +} + +# Runners for h2o join benchmark +run_h2o_window() { + h2o_runner "$1" "$2" "window" +} + # Runs the external aggregation benchmark run_external_aggr() { # Use TPC-H SF1 dataset @@ -905,7 +940,7 @@ run_external_aggr() { # number-of-partitions), and by default `--partitions` is set to number of # CPU cores, we set a constant number of partitions to prevent this # benchmark to fail on some machines. - $CARGO_COMMAND --bin external_aggr -- benchmark --partitions 4 --iterations 5 --path "${TPCH_DIR}" -o "${RESULTS_FILE}" + debug_run $CARGO_COMMAND --bin external_aggr -- benchmark --partitions 4 --iterations 5 --path "${TPCH_DIR}" -o "${RESULTS_FILE}" } # Runs the sort integration benchmark @@ -915,7 +950,7 @@ run_sort_tpch() { echo "RESULTS_FILE: ${RESULTS_FILE}" echo "Running sort tpch benchmark..." - $CARGO_COMMAND --bin dfbench -- sort-tpch --iterations 5 --path "${TPCH_DIR}" -o "${RESULTS_FILE}" + debug_run $CARGO_COMMAND --bin dfbench -- sort-tpch --iterations 5 --path "${TPCH_DIR}" -o "${RESULTS_FILE}" } diff --git a/benchmarks/compare.py b/benchmarks/compare.py index 4b609c744d50..0dd067ca9c34 100755 --- a/benchmarks/compare.py +++ b/benchmarks/compare.py @@ -47,6 +47,7 @@ class QueryRun: query: int iterations: List[QueryResult] start_time: int + success: bool = True @classmethod def load_from(cls, data: Dict[str, Any]) -> QueryRun: @@ -54,6 +55,7 @@ def load_from(cls, data: Dict[str, Any]) -> QueryRun: query=data["query"], iterations=[QueryResult(**iteration) for iteration in data["iterations"]], start_time=data["start_time"], + success=data["success"], ) @property @@ -125,11 +127,26 @@ def compare( faster_count = 0 slower_count = 0 no_change_count = 0 + failure_count = 0 total_baseline_time = 0 total_comparison_time = 0 for baseline_result, comparison_result in zip(baseline.queries, comparison.queries): assert baseline_result.query == comparison_result.query + + base_failed = not baseline_result.success + comp_failed = not comparison_result.success + # If a query fails, its execution time is excluded from the performance comparison + if base_failed or comp_failed: + change_text = "incomparable" + failure_count += 1 + table.add_row( + f"Q{baseline_result.query}", + "FAIL" if base_failed else f"{baseline_result.execution_time:.2f}ms", + "FAIL" if comp_failed else f"{comparison_result.execution_time:.2f}ms", + change_text, + ) + continue total_baseline_time += baseline_result.execution_time total_comparison_time += comparison_result.execution_time @@ -156,8 +173,12 @@ def compare( console.print(table) # Calculate averages - avg_baseline_time = total_baseline_time / len(baseline.queries) - avg_comparison_time = total_comparison_time / len(comparison.queries) + avg_baseline_time = 0.0 + avg_comparison_time = 0.0 + if len(baseline.queries) - failure_count > 0: + avg_baseline_time = total_baseline_time / (len(baseline.queries) - failure_count) + if len(comparison.queries) - failure_count > 0: + avg_comparison_time = total_comparison_time / (len(comparison.queries) - failure_count) # Summary table summary_table = Table(show_header=True, header_style="bold magenta") @@ -171,6 +192,7 @@ def compare( summary_table.add_row("Queries Faster", str(faster_count)) summary_table.add_row("Queries Slower", str(slower_count)) summary_table.add_row("Queries with No Change", str(no_change_count)) + summary_table.add_row("Queries with Failure", str(failure_count)) console.print(summary_table) diff --git a/benchmarks/queries/clickbench/README.md b/benchmarks/queries/clickbench/README.md index 2032427e1ef2..e5acd8f348a4 100644 --- a/benchmarks/queries/clickbench/README.md +++ b/benchmarks/queries/clickbench/README.md @@ -5,12 +5,13 @@ This directory contains queries for the ClickBench benchmark https://benchmark.c ClickBench is focused on aggregation and filtering performance (though it has no Joins) ## Files: -* `queries.sql` - Actual ClickBench queries, downloaded from the [ClickBench repository] -* `extended.sql` - "Extended" DataFusion specific queries. -[ClickBench repository]: https://github.com/ClickHouse/ClickBench/blob/main/datafusion/queries.sql +- `queries.sql` - Actual ClickBench queries, downloaded from the [ClickBench repository] +- `extended.sql` - "Extended" DataFusion specific queries. -## "Extended" Queries +[clickbench repository]: https://github.com/ClickHouse/ClickBench/blob/main/datafusion/queries.sql + +## "Extended" Queries The "extended" queries are not part of the official ClickBench benchmark. Instead they are used to test other DataFusion features that are not covered by @@ -25,7 +26,7 @@ the standard benchmark. Each description below is for the corresponding line in distinct string columns. ```sql -SELECT COUNT(DISTINCT "SearchPhrase"), COUNT(DISTINCT "MobilePhone"), COUNT(DISTINCT "MobilePhoneModel") +SELECT COUNT(DISTINCT "SearchPhrase"), COUNT(DISTINCT "MobilePhone"), COUNT(DISTINCT "MobilePhoneModel") FROM hits; ``` @@ -35,7 +36,6 @@ FROM hits; **Important Query Properties**: multiple `COUNT DISTINCT`s. All three are small strings (length either 1 or 2). - ```sql SELECT COUNT(DISTINCT "HitColor"), COUNT(DISTINCT "BrowserCountry"), COUNT(DISTINCT "BrowserLanguage") FROM hits; @@ -43,21 +43,20 @@ FROM hits; ### Q2: Top 10 analysis -**Question**: "Find the top 10 "browser country" by number of distinct "social network"s, -including the distinct counts of "hit color", "browser language", +**Question**: "Find the top 10 "browser country" by number of distinct "social network"s, +including the distinct counts of "hit color", "browser language", and "social action"." **Important Query Properties**: GROUP BY short, string, multiple `COUNT DISTINCT`s. There are several small strings (length either 1 or 2). ```sql SELECT "BrowserCountry", COUNT(DISTINCT "SocialNetwork"), COUNT(DISTINCT "HitColor"), COUNT(DISTINCT "BrowserLanguage"), COUNT(DISTINCT "SocialAction") -FROM hits -GROUP BY 1 -ORDER BY 2 DESC +FROM hits +GROUP BY 1 +ORDER BY 2 DESC LIMIT 10; ``` - ### Q3: What is the income distribution for users in specific regions **Question**: "What regions and social networks have the highest variance of parameter price?" @@ -65,17 +64,17 @@ LIMIT 10; **Important Query Properties**: STDDEV and VAR aggregation functions, GROUP BY multiple small ints ```sql -SELECT "SocialSourceNetworkID", "RegionID", COUNT(*), AVG("Age"), AVG("ParamPrice"), STDDEV("ParamPrice") as s, VAR("ParamPrice") -FROM 'hits.parquet' -GROUP BY "SocialSourceNetworkID", "RegionID" +SELECT "SocialSourceNetworkID", "RegionID", COUNT(*), AVG("Age"), AVG("ParamPrice"), STDDEV("ParamPrice") as s, VAR("ParamPrice") +FROM 'hits.parquet' +GROUP BY "SocialSourceNetworkID", "RegionID" HAVING s IS NOT NULL -ORDER BY s DESC +ORDER BY s DESC LIMIT 10; ``` ### Q4: Response start time distribution analysis (median) -**Question**: Find the WatchIDs with the highest median "ResponseStartTiming" without Java enabled +**Question**: Find the WatchIDs with the highest median "ResponseStartTiming" without Java enabled **Important Query Properties**: MEDIAN, functions, high cardinality grouping that skips intermediate aggregation @@ -102,17 +101,16 @@ Results look like +-------------+---------------------+---+------+------+------+ ``` - ### Q5: Response start time distribution analysis (p95) -**Question**: Find the WatchIDs with the highest p95 "ResponseStartTiming" without Java enabled +**Question**: Find the WatchIDs with the highest p95 "ResponseStartTiming" without Java enabled **Important Query Properties**: APPROX_PERCENTILE_CONT, functions, high cardinality grouping that skips intermediate aggregation Note this query is somewhat synthetic as "WatchID" is almost unique (there are a few duplicates) ```sql -SELECT "ClientIP", "WatchID", COUNT(*) c, MIN("ResponseStartTiming") tmin, APPROX_PERCENTILE_CONT("ResponseStartTiming", 0.95) tp95, MAX("ResponseStartTiming") tmax +SELECT "ClientIP", "WatchID", COUNT(*) c, MIN("ResponseStartTiming") tmin, APPROX_PERCENTILE_CONT(0.95) WITHIN GROUP (ORDER BY "ResponseStartTiming") tp95, MAX("ResponseStartTiming") tmax FROM 'hits.parquet' WHERE "JavaEnable" = 0 -- filters to 32M of 100M rows GROUP BY "ClientIP", "WatchID" @@ -122,6 +120,7 @@ LIMIT 10; ``` Results look like + ``` +-------------+---------------------+---+------+------+------+ | ClientIP | WatchID | c | tmin | tp95 | tmax | @@ -132,6 +131,7 @@ Results look like ``` ### Q6: How many social shares meet complex multi-stage filtering criteria? + **Question**: What is the count of sharing actions from iPhone mobile users on specific social networks, within common timezones, participating in seasonal campaigns, with high screen resolutions and closely matched UTM parameters? **Important Query Properties**: Simple filter with high-selectivity, Costly string matching, A large number of filters with high overhead are positioned relatively later in the process @@ -150,20 +150,89 @@ WHERE -- Stage 3: Heavy computations (expensive) AND regexp_match("Referer", '\/campaign\/(spring|summer)_promo') IS NOT NULL -- Find campaign-specific referrers - AND CASE - WHEN split_part(split_part("URL", 'resolution=', 2), '&', 1) ~ '^\d+$' - THEN split_part(split_part("URL", 'resolution=', 2), '&', 1)::INT - ELSE 0 + AND CASE + WHEN split_part(split_part("URL", 'resolution=', 2), '&', 1) ~ '^\d+$' + THEN split_part(split_part("URL", 'resolution=', 2), '&', 1)::INT + ELSE 0 END > 1920 -- Extract and validate resolution parameter - AND levenshtein("UTMSource", "UTMCampaign") < 3 -- Verify UTM parameter similarity + AND levenshtein(CAST("UTMSource" AS STRING), CAST("UTMCampaign" AS STRING)) < 3 -- Verify UTM parameter similarity ``` + Result is empty,Since it has already been filtered by `"SocialAction" = 'share'`. +### Q7: Device Resolution and Refresh Behavior Analysis + +**Question**: Identify the top 10 WatchIDs with the highest resolution range (min/max "ResolutionWidth") and total refresh count ("IsRefresh") in descending WatchID order + +**Important Query Properties**: Primitive aggregation functions, group by single primitive column, high cardinality grouping + +```sql +SELECT "WatchID", MIN("ResolutionWidth") as wmin, MAX("ResolutionWidth") as wmax, SUM("IsRefresh") as srefresh +FROM hits +GROUP BY "WatchID" +ORDER BY "WatchID" DESC +LIMIT 10; +``` + +Results look like + +``` ++---------------------+------+------+----------+ +| WatchID | wmin | wmax | srefresh | ++---------------------+------+------+----------+ +| 9223372033328793741 | 1368 | 1368 | 0 | +| 9223371941779979288 | 1479 | 1479 | 0 | +| 9223371906781104763 | 1638 | 1638 | 0 | +| 9223371803397398692 | 1990 | 1990 | 0 | +| 9223371799215233959 | 1638 | 1638 | 0 | +| 9223371785975219972 | 0 | 0 | 0 | +| 9223371776706839366 | 1368 | 1368 | 0 | +| 9223371740707848038 | 1750 | 1750 | 0 | +| 9223371715190479830 | 1368 | 1368 | 0 | +| 9223371620124912624 | 1828 | 1828 | 0 | ++---------------------+------+------+----------+ +``` + +### Q8: Average Latency and Response Time Analysis + +**Question**: Which combinations of operating system, region, and user agent exhibit the highest average latency? For each of these combinations, also report the average response time. + +**Important Query Properties**: Multiple average of Duration, high cardinality grouping + +```sql +SELECT "RegionID", "UserAgent", "OS", AVG(to_timestamp("ResponseEndTiming")-to_timestamp("ResponseStartTiming")) as avg_response_time, AVG(to_timestamp("ResponseEndTiming")-to_timestamp("ConnectTiming")) as avg_latency +FROM hits +GROUP BY "RegionID", "UserAgent", "OS" +ORDER BY avg_latency DESC +LIMIT 10; +``` + +Results look like + +``` ++----------+-----------+-----+------------------------------------------+------------------------------------------+ +| RegionID | UserAgent | OS | avg_response_time | avg_latency | ++----------+-----------+-----+------------------------------------------+------------------------------------------+ +| 22934 | 5 | 126 | 0 days 8 hours 20 mins 0.000000000 secs | 0 days 8 hours 20 mins 0.000000000 secs | +| 22735 | 82 | 74 | 0 days 8 hours 20 mins 0.000000000 secs | 0 days 8 hours 20 mins 0.000000000 secs | +| 21687 | 32 | 49 | 0 days 8 hours 20 mins 0.000000000 secs | 0 days 8 hours 20 mins 0.000000000 secs | +| 18518 | 82 | 77 | 0 days 8 hours 20 mins 0.000000000 secs | 0 days 8 hours 20 mins 0.000000000 secs | +| 14006 | 7 | 126 | 0 days 7 hours 58 mins 20.000000000 secs | 0 days 8 hours 20 mins 0.000000000 secs | +| 9803 | 82 | 77 | 0 days 8 hours 20 mins 0.000000000 secs | 0 days 8 hours 20 mins 0.000000000 secs | +| 107108 | 82 | 77 | 0 days 8 hours 20 mins 0.000000000 secs | 0 days 8 hours 20 mins 0.000000000 secs | +| 111626 | 7 | 44 | 0 days 7 hours 23 mins 12.500000000 secs | 0 days 8 hours 0 mins 47.000000000 secs | +| 17716 | 56 | 44 | 0 days 6 hours 48 mins 44.500000000 secs | 0 days 7 hours 35 mins 47.000000000 secs | +| 13631 | 82 | 45 | 0 days 7 hours 23 mins 1.000000000 secs | 0 days 7 hours 23 mins 1.000000000 secs | ++----------+-----------+-----+------------------------------------------+------------------------------------------+ +10 row(s) fetched. +Elapsed 30.195 seconds. +``` ## Data Notes Here are some interesting statistics about the data used in the queries Max length of `"SearchPhrase"` is 1113 characters + ```sql > select min(length("SearchPhrase")) as "SearchPhrase_len_min", max(length("SearchPhrase")) "SearchPhrase_len_max" from 'hits.parquet' limit 10; +----------------------+----------------------+ @@ -173,8 +242,8 @@ Max length of `"SearchPhrase"` is 1113 characters +----------------------+----------------------+ ``` - Here is the schema of the data + ```sql > describe 'hits.parquet'; +-----------------------+-----------+-------------+ diff --git a/benchmarks/queries/clickbench/extended.sql b/benchmarks/queries/clickbench/extended.sql index ef3a409c9c02..93c39efe4f8e 100644 --- a/benchmarks/queries/clickbench/extended.sql +++ b/benchmarks/queries/clickbench/extended.sql @@ -3,5 +3,7 @@ SELECT COUNT(DISTINCT "HitColor"), COUNT(DISTINCT "BrowserCountry"), COUNT(DISTI SELECT "BrowserCountry", COUNT(DISTINCT "SocialNetwork"), COUNT(DISTINCT "HitColor"), COUNT(DISTINCT "BrowserLanguage"), COUNT(DISTINCT "SocialAction") FROM hits GROUP BY 1 ORDER BY 2 DESC LIMIT 10; SELECT "SocialSourceNetworkID", "RegionID", COUNT(*), AVG("Age"), AVG("ParamPrice"), STDDEV("ParamPrice") as s, VAR("ParamPrice") FROM hits GROUP BY "SocialSourceNetworkID", "RegionID" HAVING s IS NOT NULL ORDER BY s DESC LIMIT 10; SELECT "ClientIP", "WatchID", COUNT(*) c, MIN("ResponseStartTiming") tmin, MEDIAN("ResponseStartTiming") tmed, MAX("ResponseStartTiming") tmax FROM hits WHERE "JavaEnable" = 0 GROUP BY "ClientIP", "WatchID" HAVING c > 1 ORDER BY tmed DESC LIMIT 10; -SELECT "ClientIP", "WatchID", COUNT(*) c, MIN("ResponseStartTiming") tmin, APPROX_PERCENTILE_CONT("ResponseStartTiming", 0.95) tp95, MAX("ResponseStartTiming") tmax FROM 'hits' WHERE "JavaEnable" = 0 GROUP BY "ClientIP", "WatchID" HAVING c > 1 ORDER BY tp95 DESC LIMIT 10; -SELECT COUNT(*) AS ShareCount FROM hits WHERE "IsMobile" = 1 AND "MobilePhoneModel" LIKE 'iPhone%' AND "SocialAction" = 'share' AND "SocialSourceNetworkID" IN (5, 12) AND "ClientTimeZone" BETWEEN -5 AND 5 AND regexp_match("Referer", '\/campaign\/(spring|summer)_promo') IS NOT NULL AND CASE WHEN split_part(split_part("URL", 'resolution=', 2), '&', 1) ~ '^\d+$' THEN split_part(split_part("URL", 'resolution=', 2), '&', 1)::INT ELSE 0 END > 1920 AND levenshtein("UTMSource", "UTMCampaign") < 3; \ No newline at end of file +SELECT "ClientIP", "WatchID", COUNT(*) c, MIN("ResponseStartTiming") tmin, APPROX_PERCENTILE_CONT(0.95) WITHIN GROUP (ORDER BY "ResponseStartTiming") tp95, MAX("ResponseStartTiming") tmax FROM 'hits' WHERE "JavaEnable" = 0 GROUP BY "ClientIP", "WatchID" HAVING c > 1 ORDER BY tp95 DESC LIMIT 10; +SELECT COUNT(*) AS ShareCount FROM hits WHERE "IsMobile" = 1 AND "MobilePhoneModel" LIKE 'iPhone%' AND "SocialAction" = 'share' AND "SocialSourceNetworkID" IN (5, 12) AND "ClientTimeZone" BETWEEN -5 AND 5 AND regexp_match("Referer", '\/campaign\/(spring|summer)_promo') IS NOT NULL AND CASE WHEN split_part(split_part("URL", 'resolution=', 2), '&', 1) ~ '^\d+$' THEN split_part(split_part("URL", 'resolution=', 2), '&', 1)::INT ELSE 0 END > 1920 AND levenshtein(CAST("UTMSource" AS STRING), CAST("UTMCampaign" AS STRING)) < 3; +SELECT "WatchID", MIN("ResolutionWidth") as wmin, MAX("ResolutionWidth") as wmax, SUM("IsRefresh") as srefresh FROM hits GROUP BY "WatchID" ORDER BY "WatchID" DESC LIMIT 10; +SELECT "RegionID", "UserAgent", "OS", AVG(to_timestamp("ResponseEndTiming")-to_timestamp("ResponseStartTiming")) as avg_response_time, AVG(to_timestamp("ResponseEndTiming")-to_timestamp("ConnectTiming")) as avg_latency FROM hits GROUP BY "RegionID", "UserAgent", "OS" ORDER BY avg_latency DESC limit 10; \ No newline at end of file diff --git a/benchmarks/queries/h2o/groupby.sql b/benchmarks/queries/h2o/groupby.sql index c2101ef8ada2..4fae7a13810d 100644 --- a/benchmarks/queries/h2o/groupby.sql +++ b/benchmarks/queries/h2o/groupby.sql @@ -1,10 +1,19 @@ SELECT id1, SUM(v1) AS v1 FROM x GROUP BY id1; + SELECT id1, id2, SUM(v1) AS v1 FROM x GROUP BY id1, id2; + SELECT id3, SUM(v1) AS v1, AVG(v3) AS v3 FROM x GROUP BY id3; + SELECT id4, AVG(v1) AS v1, AVG(v2) AS v2, AVG(v3) AS v3 FROM x GROUP BY id4; + SELECT id6, SUM(v1) AS v1, SUM(v2) AS v2, SUM(v3) AS v3 FROM x GROUP BY id6; + SELECT id4, id5, MEDIAN(v3) AS median_v3, STDDEV(v3) AS sd_v3 FROM x GROUP BY id4, id5; + SELECT id3, MAX(v1) - MIN(v2) AS range_v1_v2 FROM x GROUP BY id3; + SELECT id6, largest2_v3 FROM (SELECT id6, v3 AS largest2_v3, ROW_NUMBER() OVER (PARTITION BY id6 ORDER BY v3 DESC) AS order_v3 FROM x WHERE v3 IS NOT NULL) sub_query WHERE order_v3 <= 2; + SELECT id2, id4, POWER(CORR(v1, v2), 2) AS r2 FROM x GROUP BY id2, id4; -SELECT id1, id2, id3, id4, id5, id6, SUM(v3) AS v3, COUNT(*) AS count FROM x GROUP BY id1, id2, id3, id4, id5, id6; + +SELECT id1, id2, id3, id4, id5, id6, SUM(v3) AS v3, COUNT(*) AS count FROM x GROUP BY id1, id2, id3, id4, id5, id6; \ No newline at end of file diff --git a/benchmarks/queries/h2o/join.sql b/benchmarks/queries/h2o/join.sql index 8546b9292dbb..84cd661fdd59 100644 --- a/benchmarks/queries/h2o/join.sql +++ b/benchmarks/queries/h2o/join.sql @@ -1,5 +1,9 @@ SELECT x.id1, x.id2, x.id3, x.id4 as xid4, small.id4 as smallid4, x.id5, x.id6, x.v1, small.v2 FROM x INNER JOIN small ON x.id1 = small.id1; + SELECT x.id1 as xid1, medium.id1 as mediumid1, x.id2, x.id3, x.id4 as xid4, medium.id4 as mediumid4, x.id5 as xid5, medium.id5 as mediumid5, x.id6, x.v1, medium.v2 FROM x INNER JOIN medium ON x.id2 = medium.id2; + SELECT x.id1 as xid1, medium.id1 as mediumid1, x.id2, x.id3, x.id4 as xid4, medium.id4 as mediumid4, x.id5 as xid5, medium.id5 as mediumid5, x.id6, x.v1, medium.v2 FROM x LEFT JOIN medium ON x.id2 = medium.id2; + SELECT x.id1 as xid1, medium.id1 as mediumid1, x.id2, x.id3, x.id4 as xid4, medium.id4 as mediumid4, x.id5 as xid5, medium.id5 as mediumid5, x.id6, x.v1, medium.v2 FROM x JOIN medium ON x.id5 = medium.id5; -SELECT x.id1 as xid1, large.id1 as largeid1, x.id2 as xid2, large.id2 as largeid2, x.id3, x.id4 as xid4, large.id4 as largeid4, x.id5 as xid5, large.id5 as largeid5, x.id6 as xid6, large.id6 as largeid6, x.v1, large.v2 FROM x JOIN large ON x.id3 = large.id3; + +SELECT x.id1 as xid1, large.id1 as largeid1, x.id2 as xid2, large.id2 as largeid2, x.id3, x.id4 as xid4, large.id4 as largeid4, x.id5 as xid5, large.id5 as largeid5, x.id6 as xid6, large.id6 as largeid6, x.v1, large.v2 FROM x JOIN large ON x.id3 = large.id3; \ No newline at end of file diff --git a/benchmarks/queries/h2o/window.sql b/benchmarks/queries/h2o/window.sql new file mode 100644 index 000000000000..071540927a4c --- /dev/null +++ b/benchmarks/queries/h2o/window.sql @@ -0,0 +1,112 @@ +-- Basic Window +SELECT + id1, + id2, + id3, + v2, + sum(v2) OVER () AS window_basic +FROM large; + +-- Sorted Window +SELECT + id1, + id2, + id3, + v2, + first_value(v2) OVER (ORDER BY id3) AS first_order_by, + row_number() OVER (ORDER BY id3) AS row_number_order_by +FROM large; + +-- PARTITION BY +SELECT + id1, + id2, + id3, + v2, + sum(v2) OVER (PARTITION BY id1) AS sum_by_id1, + sum(v2) OVER (PARTITION BY id2) AS sum_by_id2, + sum(v2) OVER (PARTITION BY id3) AS sum_by_id3 +FROM large; + +-- PARTITION BY ORDER BY +SELECT + id1, + id2, + id3, + v2, + first_value(v2) OVER (PARTITION BY id2 ORDER BY id3) AS first_by_id2_ordered_by_id3 +FROM large; + +-- Lead and Lag +SELECT + id1, + id2, + id3, + v2, + first_value(v2) OVER (ORDER BY id3 ROWS BETWEEN 1 PRECEDING AND 1 PRECEDING) AS my_lag, + first_value(v2) OVER (ORDER BY id3 ROWS BETWEEN 1 FOLLOWING AND 1 FOLLOWING) AS my_lead +FROM large; + +-- Moving Averages +SELECT + id1, + id2, + id3, + v2, + avg(v2) OVER (ORDER BY id3 ROWS BETWEEN 100 PRECEDING AND CURRENT ROW) AS my_moving_average +FROM large; + +-- Rolling Sum +SELECT + id1, + id2, + id3, + v2, + sum(v2) OVER (ORDER BY id3 ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS my_rolling_sum +FROM large; + +-- RANGE BETWEEN +SELECT + id1, + id2, + id3, + v2, + sum(v2) OVER (ORDER BY v2 RANGE BETWEEN 3 PRECEDING AND CURRENT ROW) AS my_range_between +FROM large; + +-- First PARTITION BY ROWS BETWEEN +SELECT + id1, + id2, + id3, + v2, + first_value(v2) OVER (PARTITION BY id2 ORDER BY id3 ROWS BETWEEN 1 PRECEDING AND 1 PRECEDING) AS my_lag_by_id2, + first_value(v2) OVER (PARTITION BY id2 ORDER BY id3 ROWS BETWEEN 1 FOLLOWING AND 1 FOLLOWING) AS my_lead_by_id2 +FROM large; + +-- Moving Averages PARTITION BY +SELECT + id1, + id2, + id3, + v2, + avg(v2) OVER (PARTITION BY id2 ORDER BY id3 ROWS BETWEEN 100 PRECEDING AND CURRENT ROW) AS my_moving_average_by_id2 +FROM large; + +-- Rolling Sum PARTITION BY +SELECT + id1, + id2, + id3, + v2, + sum(v2) OVER (PARTITION BY id2 ORDER BY id3 ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS my_rolling_sum_by_id2 +FROM large; + +-- RANGE BETWEEN PARTITION BY +SELECT + id1, + id2, + id3, + v2, + sum(v2) OVER (PARTITION BY id2 ORDER BY v2 RANGE BETWEEN 3 PRECEDING AND CURRENT ROW) AS my_range_between_by_id2 +FROM large; \ No newline at end of file diff --git a/benchmarks/queries/q10.sql b/benchmarks/queries/q10.sql index cf45e43485fb..8613fd496283 100644 --- a/benchmarks/queries/q10.sql +++ b/benchmarks/queries/q10.sql @@ -28,4 +28,5 @@ group by c_address, c_comment order by - revenue desc; \ No newline at end of file + revenue desc +limit 20; diff --git a/benchmarks/queries/q18.sql b/benchmarks/queries/q18.sql index 835de28a57be..ba7ee7f716cf 100644 --- a/benchmarks/queries/q18.sql +++ b/benchmarks/queries/q18.sql @@ -29,4 +29,5 @@ group by o_totalprice order by o_totalprice desc, - o_orderdate; \ No newline at end of file + o_orderdate +limit 100; diff --git a/benchmarks/queries/q2.sql b/benchmarks/queries/q2.sql index f66af210205e..68e478f65d3f 100644 --- a/benchmarks/queries/q2.sql +++ b/benchmarks/queries/q2.sql @@ -40,4 +40,5 @@ order by s_acctbal desc, n_name, s_name, - p_partkey; \ No newline at end of file + p_partkey +limit 100; diff --git a/benchmarks/queries/q21.sql b/benchmarks/queries/q21.sql index 9d2fe32cee22..b95e7b0dfca0 100644 --- a/benchmarks/queries/q21.sql +++ b/benchmarks/queries/q21.sql @@ -36,4 +36,5 @@ group by s_name order by numwait desc, - s_name; \ No newline at end of file + s_name +limit 100; diff --git a/benchmarks/queries/q3.sql b/benchmarks/queries/q3.sql index 7dbc6d9ef678..e5fa9e38664c 100644 --- a/benchmarks/queries/q3.sql +++ b/benchmarks/queries/q3.sql @@ -19,4 +19,5 @@ group by o_shippriority order by revenue desc, - o_orderdate; \ No newline at end of file + o_orderdate +limit 10; diff --git a/benchmarks/src/bin/external_aggr.rs b/benchmarks/src/bin/external_aggr.rs index 578f71f8275d..0e519367badb 100644 --- a/benchmarks/src/bin/external_aggr.rs +++ b/benchmarks/src/bin/external_aggr.rs @@ -40,7 +40,7 @@ use datafusion::execution::SessionStateBuilder; use datafusion::physical_plan::display::DisplayableExecutionPlan; use datafusion::physical_plan::{collect, displayable}; use datafusion::prelude::*; -use datafusion_benchmarks::util::{BenchmarkRun, CommonOpt}; +use datafusion_benchmarks::util::{BenchmarkRun, CommonOpt, QueryResult}; use datafusion_common::instant::Instant; use datafusion_common::utils::get_available_parallelism; use datafusion_common::{exec_err, DEFAULT_PARQUET_EXTENSION}; @@ -77,11 +77,6 @@ struct ExternalAggrConfig { output_path: Option, } -struct QueryResult { - elapsed: std::time::Duration, - row_count: usize, -} - /// Query Memory Limits /// Map query id to predefined memory limits /// @@ -189,7 +184,7 @@ impl ExternalAggrConfig { ) -> Result> { let query_name = format!("Q{query_id}({})", human_readable_size(mem_limit as usize)); - let config = self.common.config(); + let config = self.common.config()?; let memory_pool: Arc = match mem_pool_type { "fair" => Arc::new(FairSpillPool::new(mem_limit as usize)), "greedy" => Arc::new(GreedyMemoryPool::new(mem_limit as usize)), @@ -335,7 +330,7 @@ impl ExternalAggrConfig { fn partitions(&self) -> usize { self.common .partitions - .unwrap_or(get_available_parallelism()) + .unwrap_or_else(get_available_parallelism) } } diff --git a/benchmarks/src/cancellation.rs b/benchmarks/src/cancellation.rs index f5740bdc96e0..fcf03fbc5455 100644 --- a/benchmarks/src/cancellation.rs +++ b/benchmarks/src/cancellation.rs @@ -38,7 +38,7 @@ use futures::TryStreamExt; use object_store::ObjectStore; use parquet::arrow::async_writer::ParquetObjectWriter; use parquet::arrow::AsyncArrowWriter; -use rand::distributions::Alphanumeric; +use rand::distr::Alphanumeric; use rand::rngs::ThreadRng; use rand::Rng; use structopt::StructOpt; @@ -237,7 +237,7 @@ fn find_files_on_disk(data_dir: impl AsRef) -> Result> { let path = file.unwrap().path(); if path .extension() - .map(|ext| (ext == "parquet")) + .map(|ext| ext == "parquet") .unwrap_or(false) { Some(path) @@ -312,15 +312,15 @@ async fn generate_data( } fn random_data(column_type: &DataType, rows: usize) -> Arc { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let values = (0..rows).map(|_| random_value(&mut rng, column_type)); ScalarValue::iter_to_array(values).unwrap() } fn random_value(rng: &mut ThreadRng, column_type: &DataType) -> ScalarValue { match column_type { - DataType::Float64 => ScalarValue::Float64(Some(rng.gen())), - DataType::Boolean => ScalarValue::Boolean(Some(rng.gen())), + DataType::Float64 => ScalarValue::Float64(Some(rng.random())), + DataType::Boolean => ScalarValue::Boolean(Some(rng.random())), DataType::Utf8 => ScalarValue::Utf8(Some( rng.sample_iter(&Alphanumeric) .take(10) diff --git a/benchmarks/src/clickbench.rs b/benchmarks/src/clickbench.rs index 923c2bdd7cdf..57726fd181f4 100644 --- a/benchmarks/src/clickbench.rs +++ b/benchmarks/src/clickbench.rs @@ -18,7 +18,7 @@ use std::path::Path; use std::path::PathBuf; -use crate::util::{BenchmarkRun, CommonOpt}; +use crate::util::{BenchmarkRun, CommonOpt, QueryResult}; use datafusion::{ error::{DataFusionError, Result}, prelude::SessionContext, @@ -116,7 +116,7 @@ impl RunOpt { }; // configure parquet options - let mut config = self.common.config(); + let mut config = self.common.config()?; { let parquet_options = &mut config.options_mut().execution.parquet; // The hits_partitioned dataset specifies string columns @@ -128,36 +128,58 @@ impl RunOpt { let ctx = SessionContext::new_with_config_rt(config, rt_builder.build_arc()?); self.register_hits(&ctx).await?; - let iterations = self.common.iterations; let mut benchmark_run = BenchmarkRun::new(); for query_id in query_range { - let mut millis = Vec::with_capacity(iterations); benchmark_run.start_new_case(&format!("Query {query_id}")); - let sql = queries.get_query(query_id)?; - println!("Q{query_id}: {sql}"); - - for i in 0..iterations { - let start = Instant::now(); - let results = ctx.sql(sql).await?.collect().await?; - let elapsed = start.elapsed(); - let ms = elapsed.as_secs_f64() * 1000.0; - millis.push(ms); - let row_count: usize = results.iter().map(|b| b.num_rows()).sum(); - println!( - "Query {query_id} iteration {i} took {ms:.1} ms and returned {row_count} rows" - ); - benchmark_run.write_iter(elapsed, row_count); + let query_run = self.benchmark_query(&queries, query_id, &ctx).await; + match query_run { + Ok(query_results) => { + for iter in query_results { + benchmark_run.write_iter(iter.elapsed, iter.row_count); + } + } + Err(e) => { + benchmark_run.mark_failed(); + eprintln!("Query {query_id} failed: {e}"); + } } - if self.common.debug { - ctx.sql(sql).await?.explain(false, false)?.show().await?; - } - let avg = millis.iter().sum::() / millis.len() as f64; - println!("Query {query_id} avg time: {avg:.2} ms"); } benchmark_run.maybe_write_json(self.output_path.as_ref())?; + benchmark_run.maybe_print_failures(); Ok(()) } + async fn benchmark_query( + &self, + queries: &AllQueries, + query_id: usize, + ctx: &SessionContext, + ) -> Result> { + let sql = queries.get_query(query_id)?; + println!("Q{query_id}: {sql}"); + + let mut millis = Vec::with_capacity(self.iterations()); + let mut query_results = vec![]; + for i in 0..self.iterations() { + let start = Instant::now(); + let results = ctx.sql(sql).await?.collect().await?; + let elapsed = start.elapsed(); + let ms = elapsed.as_secs_f64() * 1000.0; + millis.push(ms); + let row_count: usize = results.iter().map(|b| b.num_rows()).sum(); + println!( + "Query {query_id} iteration {i} took {ms:.1} ms and returned {row_count} rows" + ); + query_results.push(QueryResult { elapsed, row_count }) + } + if self.common.debug { + ctx.sql(sql).await?.explain(false, false)?.show().await?; + } + let avg = millis.iter().sum::() / millis.len() as f64; + println!("Query {query_id} avg time: {avg:.2} ms"); + Ok(query_results) + } + /// Registers the `hits.parquet` as a table named `hits` async fn register_hits(&self, ctx: &SessionContext) -> Result<()> { let options = Default::default(); @@ -171,4 +193,8 @@ impl RunOpt { ) }) } + + fn iterations(&self) -> usize { + self.common.iterations + } } diff --git a/benchmarks/src/h2o.rs b/benchmarks/src/h2o.rs index cc463e70d74a..23dba07f426d 100644 --- a/benchmarks/src/h2o.rs +++ b/benchmarks/src/h2o.rs @@ -15,9 +15,16 @@ // specific language governing permissions and limitations // under the License. +//! H2O benchmark implementation for groupby, join and window operations +//! Reference: +//! - [H2O AI Benchmark](https://duckdb.org/2023/04/14/h2oai.html) +//! - [Extended window function benchmark](https://duckdb.org/2024/06/26/benchmarks-over-time.html#window-functions-benchmark) + use crate::util::{BenchmarkRun, CommonOpt}; use datafusion::{error::Result, prelude::SessionContext}; -use datafusion_common::{exec_datafusion_err, instant::Instant, DataFusionError}; +use datafusion_common::{ + exec_datafusion_err, instant::Instant, internal_err, DataFusionError, +}; use std::path::{Path, PathBuf}; use structopt::StructOpt; @@ -77,19 +84,28 @@ impl RunOpt { None => queries.min_query_id()..=queries.max_query_id(), }; - let config = self.common.config(); + let config = self.common.config()?; let rt_builder = self.common.runtime_env_builder()?; let ctx = SessionContext::new_with_config_rt(config, rt_builder.build_arc()?); - if self.queries_path.to_str().unwrap().contains("join") { + // Register tables depending on which h2o benchmark is being run + // (groupby/join/window) + if self.queries_path.to_str().unwrap().ends_with("groupby.sql") { + self.register_data(&ctx).await?; + } else if self.queries_path.to_str().unwrap().ends_with("join.sql") { let join_paths: Vec<&str> = self.join_paths.split(',').collect(); let table_name: Vec<&str> = vec!["x", "small", "medium", "large"]; for (i, path) in join_paths.iter().enumerate() { ctx.register_csv(table_name[i], path, Default::default()) .await?; } - } else if self.queries_path.to_str().unwrap().contains("groupby") { - self.register_data(&ctx).await?; + } else if self.queries_path.to_str().unwrap().ends_with("window.sql") { + // Only register the 'large' table in h2o-join dataset + let h2o_join_large_path = self.join_paths.split(',').nth(3).unwrap(); + ctx.register_csv("large", h2o_join_large_path, Default::default()) + .await?; + } else { + return internal_err!("Invalid query file path"); } let iterations = self.common.iterations; @@ -171,7 +187,7 @@ impl AllQueries { .map_err(|e| exec_datafusion_err!("Could not open {path:?}: {e}"))?; Ok(Self { - queries: all_queries.lines().map(|s| s.to_string()).collect(), + queries: all_queries.split("\n\n").map(|s| s.to_string()).collect(), }) } diff --git a/benchmarks/src/imdb/run.rs b/benchmarks/src/imdb/run.rs index d7d7a56d0540..0d9bdf536d10 100644 --- a/benchmarks/src/imdb/run.rs +++ b/benchmarks/src/imdb/run.rs @@ -19,7 +19,7 @@ use std::path::PathBuf; use std::sync::Arc; use super::{get_imdb_table_schema, get_query_sql, IMDB_TABLES}; -use crate::util::{BenchmarkRun, CommonOpt}; +use crate::util::{BenchmarkRun, CommonOpt, QueryResult}; use arrow::record_batch::RecordBatch; use arrow::util::pretty::{self, pretty_format_batches}; @@ -303,7 +303,7 @@ impl RunOpt { async fn benchmark_query(&self, query_id: usize) -> Result> { let mut config = self .common - .config() + .config()? .with_collect_statistics(!self.disable_statistics); config.options_mut().optimizer.prefer_hash_join = self.prefer_hash_join; let rt_builder = self.common.runtime_env_builder()?; @@ -471,15 +471,10 @@ impl RunOpt { fn partitions(&self) -> usize { self.common .partitions - .unwrap_or(get_available_parallelism()) + .unwrap_or_else(get_available_parallelism) } } -struct QueryResult { - elapsed: std::time::Duration, - row_count: usize, -} - #[cfg(test)] // Only run with "ci" mode when we have the data #[cfg(feature = "ci")] @@ -514,7 +509,7 @@ mod tests { let common = CommonOpt { iterations: 1, partitions: Some(2), - batch_size: 8192, + batch_size: Some(8192), mem_pool_type: "fair".to_string(), memory_limit: None, sort_spill_reservation_bytes: None, @@ -550,7 +545,7 @@ mod tests { let common = CommonOpt { iterations: 1, partitions: Some(2), - batch_size: 8192, + batch_size: Some(8192), mem_pool_type: "fair".to_string(), memory_limit: None, sort_spill_reservation_bytes: None, diff --git a/benchmarks/src/sort.rs b/benchmarks/src/sort.rs index 9cf09c57205a..8b2b02670449 100644 --- a/benchmarks/src/sort.rs +++ b/benchmarks/src/sort.rs @@ -149,7 +149,7 @@ impl RunOpt { let config = SessionConfig::new().with_target_partitions( self.common .partitions - .unwrap_or(get_available_parallelism()), + .unwrap_or_else(get_available_parallelism), ); let ctx = SessionContext::new_with_config(config); let (rows, elapsed) = diff --git a/benchmarks/src/sort_tpch.rs b/benchmarks/src/sort_tpch.rs index 176234eca541..21897f5bf2d7 100644 --- a/benchmarks/src/sort_tpch.rs +++ b/benchmarks/src/sort_tpch.rs @@ -40,7 +40,7 @@ use datafusion_common::instant::Instant; use datafusion_common::utils::get_available_parallelism; use datafusion_common::DEFAULT_PARQUET_EXTENSION; -use crate::util::{BenchmarkRun, CommonOpt}; +use crate::util::{BenchmarkRun, CommonOpt, QueryResult}; #[derive(Debug, StructOpt)] pub struct RunOpt { @@ -74,11 +74,6 @@ pub struct RunOpt { limit: Option, } -struct QueryResult { - elapsed: std::time::Duration, - row_count: usize, -} - impl RunOpt { const SORT_TABLES: [&'static str; 1] = ["lineitem"]; @@ -179,7 +174,7 @@ impl RunOpt { /// If query is specified from command line, run only that query. /// Otherwise, run all queries. pub async fn run(&self) -> Result<()> { - let mut benchmark_run = BenchmarkRun::new(); + let mut benchmark_run: BenchmarkRun = BenchmarkRun::new(); let query_range = match self.query { Some(query_id) => query_id..=query_id, @@ -189,20 +184,28 @@ impl RunOpt { for query_id in query_range { benchmark_run.start_new_case(&format!("{query_id}")); - let query_results = self.benchmark_query(query_id).await?; - for iter in query_results { - benchmark_run.write_iter(iter.elapsed, iter.row_count); + let query_results = self.benchmark_query(query_id).await; + match query_results { + Ok(query_results) => { + for iter in query_results { + benchmark_run.write_iter(iter.elapsed, iter.row_count); + } + } + Err(e) => { + benchmark_run.mark_failed(); + eprintln!("Query {query_id} failed: {e}"); + } } } benchmark_run.maybe_write_json(self.output_path.as_ref())?; - + benchmark_run.maybe_print_failures(); Ok(()) } /// Benchmark query `query_id` in `SORT_QUERIES` async fn benchmark_query(&self, query_id: usize) -> Result> { - let config = self.common.config(); + let config = self.common.config()?; let rt_builder = self.common.runtime_env_builder()?; let state = SessionStateBuilder::new() .with_config(config) @@ -294,7 +297,7 @@ impl RunOpt { let mut stream = execute_stream(physical_plan.clone(), state.task_ctx())?; while let Some(batch) = stream.next().await { - row_count += batch.unwrap().num_rows(); + row_count += batch?.num_rows(); } if debug { @@ -352,6 +355,6 @@ impl RunOpt { fn partitions(&self) -> usize { self.common .partitions - .unwrap_or(get_available_parallelism()) + .unwrap_or_else(get_available_parallelism) } } diff --git a/benchmarks/src/tpch/run.rs b/benchmarks/src/tpch/run.rs index 752a5a1a6ba0..c042b3418069 100644 --- a/benchmarks/src/tpch/run.rs +++ b/benchmarks/src/tpch/run.rs @@ -21,7 +21,7 @@ use std::sync::Arc; use super::{ get_query_sql, get_tbl_tpch_table_schema, get_tpch_table_schema, TPCH_TABLES, }; -use crate::util::{BenchmarkRun, CommonOpt}; +use crate::util::{BenchmarkRun, CommonOpt, QueryResult}; use arrow::record_batch::RecordBatch; use arrow::util::pretty::{self, pretty_format_batches}; @@ -109,29 +109,41 @@ impl RunOpt { }; let mut benchmark_run = BenchmarkRun::new(); - for query_id in query_range { - benchmark_run.start_new_case(&format!("Query {query_id}")); - let query_run = self.benchmark_query(query_id).await?; - for iter in query_run { - benchmark_run.write_iter(iter.elapsed, iter.row_count); - } - } - benchmark_run.maybe_write_json(self.output_path.as_ref())?; - Ok(()) - } - - async fn benchmark_query(&self, query_id: usize) -> Result> { let mut config = self .common - .config() + .config()? .with_collect_statistics(!self.disable_statistics); config.options_mut().optimizer.prefer_hash_join = self.prefer_hash_join; let rt_builder = self.common.runtime_env_builder()?; let ctx = SessionContext::new_with_config_rt(config, rt_builder.build_arc()?); - // register tables self.register_tables(&ctx).await?; + for query_id in query_range { + benchmark_run.start_new_case(&format!("Query {query_id}")); + let query_run = self.benchmark_query(query_id, &ctx).await; + match query_run { + Ok(query_results) => { + for iter in query_results { + benchmark_run.write_iter(iter.elapsed, iter.row_count); + } + } + Err(e) => { + benchmark_run.mark_failed(); + eprintln!("Query {query_id} failed: {e}"); + } + } + } + benchmark_run.maybe_write_json(self.output_path.as_ref())?; + benchmark_run.maybe_print_failures(); + Ok(()) + } + + async fn benchmark_query( + &self, + query_id: usize, + ctx: &SessionContext, + ) -> Result> { let mut millis = vec![]; // run benchmark let mut query_results = vec![]; @@ -146,14 +158,14 @@ impl RunOpt { if query_id == 15 { for (n, query) in sql.iter().enumerate() { if n == 1 { - result = self.execute_query(&ctx, query).await?; + result = self.execute_query(ctx, query).await?; } else { - self.execute_query(&ctx, query).await?; + self.execute_query(ctx, query).await?; } } } else { for query in sql { - result = self.execute_query(&ctx, query).await?; + result = self.execute_query(ctx, query).await?; } } @@ -313,15 +325,10 @@ impl RunOpt { fn partitions(&self) -> usize { self.common .partitions - .unwrap_or(get_available_parallelism()) + .unwrap_or_else(get_available_parallelism) } } -struct QueryResult { - elapsed: std::time::Duration, - row_count: usize, -} - #[cfg(test)] // Only run with "ci" mode when we have the data #[cfg(feature = "ci")] @@ -355,7 +362,7 @@ mod tests { let common = CommonOpt { iterations: 1, partitions: Some(2), - batch_size: 8192, + batch_size: Some(8192), mem_pool_type: "fair".to_string(), memory_limit: None, sort_spill_reservation_bytes: None, @@ -392,7 +399,7 @@ mod tests { let common = CommonOpt { iterations: 1, partitions: Some(2), - batch_size: 8192, + batch_size: Some(8192), mem_pool_type: "fair".to_string(), memory_limit: None, sort_spill_reservation_bytes: None, diff --git a/benchmarks/src/util/mod.rs b/benchmarks/src/util/mod.rs index 95c6e5f53d0f..420d52401c4e 100644 --- a/benchmarks/src/util/mod.rs +++ b/benchmarks/src/util/mod.rs @@ -22,4 +22,4 @@ mod run; pub use access_log::AccessLogOpt; pub use options::CommonOpt; -pub use run::{BenchQuery, BenchmarkRun}; +pub use run::{BenchQuery, BenchmarkRun, QueryResult}; diff --git a/benchmarks/src/util/options.rs b/benchmarks/src/util/options.rs index a1cf31525dd9..6627a287dfcd 100644 --- a/benchmarks/src/util/options.rs +++ b/benchmarks/src/util/options.rs @@ -19,13 +19,13 @@ use std::{num::NonZeroUsize, sync::Arc}; use datafusion::{ execution::{ - disk_manager::DiskManagerConfig, + disk_manager::DiskManagerBuilder, memory_pool::{FairSpillPool, GreedyMemoryPool, MemoryPool, TrackConsumersPool}, runtime_env::RuntimeEnvBuilder, }, prelude::SessionConfig, }; -use datafusion_common::{utils::get_available_parallelism, DataFusionError, Result}; +use datafusion_common::{DataFusionError, Result}; use structopt::StructOpt; // Common benchmark options (don't use doc comments otherwise this doc @@ -41,8 +41,8 @@ pub struct CommonOpt { pub partitions: Option, /// Batch size when reading CSV or Parquet files - #[structopt(short = "s", long = "batch-size", default_value = "8192")] - pub batch_size: usize, + #[structopt(short = "s", long = "batch-size")] + pub batch_size: Option, /// The memory pool type to use, should be one of "fair" or "greedy" #[structopt(long = "mem-pool-type", default_value = "fair")] @@ -65,21 +65,25 @@ pub struct CommonOpt { impl CommonOpt { /// Return an appropriately configured `SessionConfig` - pub fn config(&self) -> SessionConfig { - self.update_config(SessionConfig::new()) + pub fn config(&self) -> Result { + SessionConfig::from_env().map(|config| self.update_config(config)) } /// Modify the existing config appropriately - pub fn update_config(&self, config: SessionConfig) -> SessionConfig { - let mut config = config - .with_target_partitions( - self.partitions.unwrap_or(get_available_parallelism()), - ) - .with_batch_size(self.batch_size); + pub fn update_config(&self, mut config: SessionConfig) -> SessionConfig { + if let Some(batch_size) = self.batch_size { + config = config.with_batch_size(batch_size); + } + + if let Some(partitions) = self.partitions { + config = config.with_target_partitions(partitions); + } + if let Some(sort_spill_reservation_bytes) = self.sort_spill_reservation_bytes { config = config.with_sort_spill_reservation_bytes(sort_spill_reservation_bytes); } + config } @@ -106,7 +110,7 @@ impl CommonOpt { }; rt_builder = rt_builder .with_memory_pool(pool) - .with_disk_manager(DiskManagerConfig::NewOs); + .with_disk_manager_builder(DiskManagerBuilder::default()); } Ok(rt_builder) } @@ -118,15 +122,14 @@ fn parse_memory_limit(limit: &str) -> Result { let (number, unit) = limit.split_at(limit.len() - 1); let number: f64 = number .parse() - .map_err(|_| format!("Failed to parse number from memory limit '{}'", limit))?; + .map_err(|_| format!("Failed to parse number from memory limit '{limit}'"))?; match unit { "K" => Ok((number * 1024.0) as usize), "M" => Ok((number * 1024.0 * 1024.0) as usize), "G" => Ok((number * 1024.0 * 1024.0 * 1024.0) as usize), _ => Err(format!( - "Unsupported unit '{}' in memory limit '{}'", - unit, limit + "Unsupported unit '{unit}' in memory limit '{limit}'" )), } } diff --git a/benchmarks/src/util/run.rs b/benchmarks/src/util/run.rs index 13969f4d3949..764ea648ff72 100644 --- a/benchmarks/src/util/run.rs +++ b/benchmarks/src/util/run.rs @@ -90,8 +90,13 @@ pub struct BenchQuery { iterations: Vec, #[serde(serialize_with = "serialize_start_time")] start_time: SystemTime, + success: bool, +} +/// Internal representation of a single benchmark query iteration result. +pub struct QueryResult { + pub elapsed: Duration, + pub row_count: usize, } - /// collects benchmark run data and then serializes it at the end pub struct BenchmarkRun { context: RunContext, @@ -120,6 +125,7 @@ impl BenchmarkRun { query: id.to_owned(), iterations: vec![], start_time: SystemTime::now(), + success: true, }); if let Some(c) = self.current_case.as_mut() { *c += 1; @@ -138,6 +144,28 @@ impl BenchmarkRun { } } + /// Print the names of failed queries, if any + pub fn maybe_print_failures(&self) { + let failed_queries: Vec<&str> = self + .queries + .iter() + .filter_map(|q| (!q.success).then_some(q.query.as_str())) + .collect(); + + if !failed_queries.is_empty() { + println!("Failed Queries: {}", failed_queries.join(", ")); + } + } + + /// Mark current query + pub fn mark_failed(&mut self) { + if let Some(idx) = self.current_case { + self.queries[idx].success = false; + } else { + unreachable!("Cannot mark failure: no current case"); + } + } + /// Stringify data into formatted json pub fn to_json(&self) -> String { let mut output = HashMap::<&str, Value>::new(); diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml index 566aafb319bf..2eec93628b52 100644 --- a/datafusion-cli/Cargo.toml +++ b/datafusion-cli/Cargo.toml @@ -37,9 +37,9 @@ backtrace = ["datafusion/backtrace"] [dependencies] arrow = { workspace = true } async-trait = { workspace = true } -aws-config = "1.6.1" +aws-config = "1.6.2" aws-credential-types = "1.2.0" -clap = { version = "4.5.35", features = ["derive", "cargo"] } +clap = { version = "4.5.39", features = ["derive", "cargo"] } datafusion = { workspace = true, features = [ "avro", "crypto_expressions", @@ -60,7 +60,7 @@ object_store = { workspace = true, features = ["aws", "gcp", "http"] } parking_lot = { workspace = true } parquet = { workspace = true, default-features = false } regex = { workspace = true } -rustyline = "15.0" +rustyline = "16.0" tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread", "sync", "parking_lot", "signal"] } url = { workspace = true } diff --git a/datafusion-cli/src/catalog.rs b/datafusion-cli/src/catalog.rs index ceb72dbc546b..3298b7deaeba 100644 --- a/datafusion-cli/src/catalog.rs +++ b/datafusion-cli/src/catalog.rs @@ -337,8 +337,7 @@ mod tests { #[cfg(not(target_os = "windows"))] #[test] fn test_substitute_tilde() { - use std::env; - use std::path::MAIN_SEPARATOR; + use std::{env, path::PathBuf}; let original_home = home_dir(); let test_home_path = if cfg!(windows) { "C:\\Users\\user" @@ -350,17 +349,16 @@ mod tests { test_home_path, ); let input = "~/Code/datafusion/benchmarks/data/tpch_sf1/part/part-0.parquet"; - let expected = format!( - "{}{}Code{}datafusion{}benchmarks{}data{}tpch_sf1{}part{}part-0.parquet", - test_home_path, - MAIN_SEPARATOR, - MAIN_SEPARATOR, - MAIN_SEPARATOR, - MAIN_SEPARATOR, - MAIN_SEPARATOR, - MAIN_SEPARATOR, - MAIN_SEPARATOR - ); + let expected = PathBuf::from(test_home_path) + .join("Code") + .join("datafusion") + .join("benchmarks") + .join("data") + .join("tpch_sf1") + .join("part") + .join("part-0.parquet") + .to_string_lossy() + .to_string(); let actual = substitute_tilde(input.to_string()); assert_eq!(actual, expected); match original_home { diff --git a/datafusion-cli/src/command.rs b/datafusion-cli/src/command.rs index fc7d1a2617cf..77bc8d3d2000 100644 --- a/datafusion-cli/src/command.rs +++ b/datafusion-cli/src/command.rs @@ -64,21 +64,28 @@ impl Command { let command_batch = all_commands_info(); let schema = command_batch.schema(); let num_rows = command_batch.num_rows(); - print_options.print_batches(schema, &[command_batch], now, num_rows) + let task_ctx = ctx.task_ctx(); + let config = &task_ctx.session_config().options().format; + print_options.print_batches( + schema, + &[command_batch], + now, + num_rows, + config, + ) } Self::ListTables => { exec_and_print(ctx, print_options, "SHOW TABLES".into()).await } Self::DescribeTableStmt(name) => { - exec_and_print(ctx, print_options, format!("SHOW COLUMNS FROM {}", name)) + exec_and_print(ctx, print_options, format!("SHOW COLUMNS FROM {name}")) .await } Self::Include(filename) => { if let Some(filename) = filename { let file = File::open(filename).map_err(|e| { DataFusionError::Execution(format!( - "Error opening {:?} {}", - filename, e + "Error opening {filename:?} {e}" )) })?; exec_from_lines(ctx, &mut BufReader::new(file), print_options) @@ -108,7 +115,7 @@ impl Command { Self::SearchFunctions(function) => { if let Ok(func) = function.parse::() { let details = func.function_details()?; - println!("{}", details); + println!("{details}"); Ok(()) } else { exec_err!("{function} is not a supported function") diff --git a/datafusion-cli/src/exec.rs b/datafusion-cli/src/exec.rs index 0f4d70c1cca9..3c2a6e68bbe1 100644 --- a/datafusion-cli/src/exec.rs +++ b/datafusion-cli/src/exec.rs @@ -200,7 +200,7 @@ pub async fn exec_from_repl( break; } Err(err) => { - eprintln!("Unknown error happened {:?}", err); + eprintln!("Unknown error happened {err:?}"); break; } } @@ -216,7 +216,8 @@ pub(super) async fn exec_and_print( ) -> Result<()> { let now = Instant::now(); let task_ctx = ctx.task_ctx(); - let dialect = &task_ctx.session_config().options().sql_parser.dialect; + let options = task_ctx.session_config().options(); + let dialect = &options.sql_parser.dialect; let dialect = dialect_from_str(dialect).ok_or_else(|| { plan_datafusion_err!( "Unsupported SQL dialect: {dialect}. Available dialects: \ @@ -250,7 +251,9 @@ pub(super) async fn exec_and_print( // As the input stream comes, we can generate results. // However, memory safety is not guaranteed. let stream = execute_stream(physical_plan, task_ctx.clone())?; - print_options.print_stream(stream, now).await?; + print_options + .print_stream(stream, now, &options.format) + .await?; } else { // Bounded stream; collected results size is limited by the maxrows option let schema = physical_plan.schema(); @@ -273,9 +276,13 @@ pub(super) async fn exec_and_print( } row_count += curr_num_rows; } - adjusted - .into_inner() - .print_batches(schema, &results, now, row_count)?; + adjusted.into_inner().print_batches( + schema, + &results, + now, + row_count, + &options.format, + )?; reservation.free(); } } @@ -523,7 +530,7 @@ mod tests { ) })?; for location in locations { - let sql = format!("copy (values (1,2)) to '{}' STORED AS PARQUET;", location); + let sql = format!("copy (values (1,2)) to '{location}' STORED AS PARQUET;"); let statements = DFParser::parse_sql_with_dialect(&sql, dialect.as_ref())?; for statement in statements { //Should not fail diff --git a/datafusion-cli/src/functions.rs b/datafusion-cli/src/functions.rs index 13d2d5fd3547..f07dac649df9 100644 --- a/datafusion-cli/src/functions.rs +++ b/datafusion-cli/src/functions.rs @@ -205,7 +205,7 @@ pub fn display_all_functions() -> Result<()> { let array = StringArray::from( ALL_FUNCTIONS .iter() - .map(|f| format!("{}", f)) + .map(|f| format!("{f}")) .collect::>(), ); let schema = Schema::new(vec![Field::new("Function", DataType::Utf8, false)]); @@ -322,7 +322,7 @@ pub struct ParquetMetadataFunc {} impl TableFunctionImpl for ParquetMetadataFunc { fn call(&self, exprs: &[Expr]) -> Result> { let filename = match exprs.first() { - Some(Expr::Literal(ScalarValue::Utf8(Some(s)))) => s, // single quote: parquet_metadata('x.parquet') + Some(Expr::Literal(ScalarValue::Utf8(Some(s)), _)) => s, // single quote: parquet_metadata('x.parquet') Some(Expr::Column(Column { name, .. })) => name, // double quote: parquet_metadata("x.parquet") _ => { return plan_err!( diff --git a/datafusion-cli/src/main.rs b/datafusion-cli/src/main.rs index dad2d15f01a1..fdecb185e33e 100644 --- a/datafusion-cli/src/main.rs +++ b/datafusion-cli/src/main.rs @@ -17,15 +17,17 @@ use std::collections::HashMap; use std::env; +use std::num::NonZeroUsize; use std::path::Path; use std::process::ExitCode; use std::sync::{Arc, LazyLock}; use datafusion::error::{DataFusionError, Result}; use datafusion::execution::context::SessionConfig; -use datafusion::execution::memory_pool::{FairSpillPool, GreedyMemoryPool, MemoryPool}; +use datafusion::execution::memory_pool::{ + FairSpillPool, GreedyMemoryPool, MemoryPool, TrackConsumersPool, +}; use datafusion::execution::runtime_env::RuntimeEnvBuilder; -use datafusion::execution::DiskManager; use datafusion::prelude::SessionContext; use datafusion_cli::catalog::DynamicObjectStoreCatalog; use datafusion_cli::functions::ParquetMetadataFunc; @@ -40,7 +42,7 @@ use datafusion_cli::{ use clap::Parser; use datafusion::common::config_err; use datafusion::config::ConfigOptions; -use datafusion::execution::disk_manager::DiskManagerConfig; +use datafusion::execution::disk_manager::{DiskManagerBuilder, DiskManagerMode}; use mimalloc::MiMalloc; #[global_allocator] @@ -118,6 +120,13 @@ struct Args { )] mem_pool_type: PoolType, + #[clap( + long, + help = "The number of top memory consumers to display when query fails due to memory exhaustion. To disable memory consumer tracking, set this value to 0", + default_value = "3" + )] + top_memory_consumers: usize, + #[clap( long, help = "The max number of rows to display for 'Table' format\n[possible values: numbers(0/10/...), inf(no limit)]", @@ -154,7 +163,7 @@ async fn main_inner() -> Result<()> { let args = Args::parse(); if !args.quiet { - println!("DataFusion CLI v{}", DATAFUSION_CLI_VERSION); + println!("DataFusion CLI v{DATAFUSION_CLI_VERSION}"); } if let Some(ref path) = args.data_path { @@ -169,22 +178,31 @@ async fn main_inner() -> Result<()> { if let Some(memory_limit) = args.memory_limit { // set memory pool type let pool: Arc = match args.mem_pool_type { - PoolType::Fair => Arc::new(FairSpillPool::new(memory_limit)), - PoolType::Greedy => Arc::new(GreedyMemoryPool::new(memory_limit)), + PoolType::Fair if args.top_memory_consumers == 0 => { + Arc::new(FairSpillPool::new(memory_limit)) + } + PoolType::Fair => Arc::new(TrackConsumersPool::new( + FairSpillPool::new(memory_limit), + NonZeroUsize::new(args.top_memory_consumers).unwrap(), + )), + PoolType::Greedy if args.top_memory_consumers == 0 => { + Arc::new(GreedyMemoryPool::new(memory_limit)) + } + PoolType::Greedy => Arc::new(TrackConsumersPool::new( + GreedyMemoryPool::new(memory_limit), + NonZeroUsize::new(args.top_memory_consumers).unwrap(), + )), }; + rt_builder = rt_builder.with_memory_pool(pool) } // set disk limit if let Some(disk_limit) = args.disk_limit { - let disk_manager = DiskManager::try_new(DiskManagerConfig::NewOs)?; - - let disk_manager = Arc::try_unwrap(disk_manager) - .expect("DiskManager should be a single instance") - .with_max_temp_directory_size(disk_limit.try_into().unwrap())?; - - let disk_config = DiskManagerConfig::new_existing(Arc::new(disk_manager)); - rt_builder = rt_builder.with_disk_manager(disk_config); + let builder = DiskManagerBuilder::default() + .with_mode(DiskManagerMode::OsTmpDirectory) + .with_max_temp_directory_size(disk_limit.try_into().unwrap()); + rt_builder = rt_builder.with_disk_manager_builder(builder); } let runtime_env = rt_builder.build_arc()?; @@ -265,6 +283,11 @@ fn get_session_config(args: &Args) -> Result { config_options.explain.format = String::from("tree"); } + // in the CLI, we want to show NULL values rather the empty strings + if env::var_os("DATAFUSION_FORMAT_NULL").is_none() { + config_options.format.null = String::from("NULL"); + } + let session_config = SessionConfig::from(config_options).with_information_schema(true); Ok(session_config) @@ -274,7 +297,7 @@ fn parse_valid_file(dir: &str) -> Result { if Path::new(dir).is_file() { Ok(dir.to_string()) } else { - Err(format!("Invalid file '{}'", dir)) + Err(format!("Invalid file '{dir}'")) } } @@ -282,14 +305,14 @@ fn parse_valid_data_dir(dir: &str) -> Result { if Path::new(dir).is_dir() { Ok(dir.to_string()) } else { - Err(format!("Invalid data directory '{}'", dir)) + Err(format!("Invalid data directory '{dir}'")) } } fn parse_batch_size(size: &str) -> Result { match size.parse::() { Ok(size) if size > 0 => Ok(size), - _ => Err(format!("Invalid batch size '{}'", size)), + _ => Err(format!("Invalid batch size '{size}'")), } } @@ -346,20 +369,20 @@ fn parse_size_string(size: &str, label: &str) -> Result { let num_str = caps.get(1).unwrap().as_str(); let num = num_str .parse::() - .map_err(|_| format!("Invalid numeric value in {} '{}'", label, size))?; + .map_err(|_| format!("Invalid numeric value in {label} '{size}'"))?; let suffix = caps.get(2).map(|m| m.as_str()).unwrap_or("b"); let unit = BYTE_SUFFIXES .get(suffix) - .ok_or_else(|| format!("Invalid {} '{}'", label, size))?; + .ok_or_else(|| format!("Invalid {label} '{size}'"))?; let total_bytes = usize::try_from(unit.multiplier()) .ok() .and_then(|multiplier| num.checked_mul(multiplier)) - .ok_or_else(|| format!("{} '{}' is too large", label, size))?; + .ok_or_else(|| format!("{label} '{size}' is too large"))?; Ok(total_bytes) } else { - Err(format!("Invalid {} '{}'", label, size)) + Err(format!("Invalid {label} '{size}'")) } } diff --git a/datafusion-cli/src/pool_type.rs b/datafusion-cli/src/pool_type.rs index 269790b61f5a..a2164cc3c739 100644 --- a/datafusion-cli/src/pool_type.rs +++ b/datafusion-cli/src/pool_type.rs @@ -33,7 +33,7 @@ impl FromStr for PoolType { match s { "Greedy" | "greedy" => Ok(PoolType::Greedy), "Fair" | "fair" => Ok(PoolType::Fair), - _ => Err(format!("Invalid memory pool type '{}'", s)), + _ => Err(format!("Invalid memory pool type '{s}'")), } } } diff --git a/datafusion-cli/src/print_format.rs b/datafusion-cli/src/print_format.rs index 1fc949593512..1d6a8396aee7 100644 --- a/datafusion-cli/src/print_format.rs +++ b/datafusion-cli/src/print_format.rs @@ -26,7 +26,7 @@ use arrow::datatypes::SchemaRef; use arrow::json::{ArrayWriter, LineDelimitedWriter}; use arrow::record_batch::RecordBatch; use arrow::util::pretty::pretty_format_batches_with_options; -use datafusion::common::format::DEFAULT_CLI_FORMAT_OPTIONS; +use datafusion::config::FormatOptions; use datafusion::error::Result; /// Allow records to be printed in different formats @@ -110,7 +110,10 @@ fn format_batches_with_maxrows( writer: &mut W, batches: &[RecordBatch], maxrows: MaxRows, + format_options: &FormatOptions, ) -> Result<()> { + let options: arrow::util::display::FormatOptions = format_options.try_into()?; + match maxrows { MaxRows::Limited(maxrows) => { // Filter batches to meet the maxrows condition @@ -131,22 +134,19 @@ fn format_batches_with_maxrows( } } - let formatted = pretty_format_batches_with_options( - &filtered_batches, - &DEFAULT_CLI_FORMAT_OPTIONS, - )?; + let formatted = + pretty_format_batches_with_options(&filtered_batches, &options)?; if over_limit { - let mut formatted_str = format!("{}", formatted); + let mut formatted_str = format!("{formatted}"); formatted_str = keep_only_maxrows(&formatted_str, maxrows); - writeln!(writer, "{}", formatted_str)?; + writeln!(writer, "{formatted_str}")?; } else { - writeln!(writer, "{}", formatted)?; + writeln!(writer, "{formatted}")?; } } MaxRows::Unlimited => { - let formatted = - pretty_format_batches_with_options(batches, &DEFAULT_CLI_FORMAT_OPTIONS)?; - writeln!(writer, "{}", formatted)?; + let formatted = pretty_format_batches_with_options(batches, &options)?; + writeln!(writer, "{formatted}")?; } } @@ -162,6 +162,7 @@ impl PrintFormat { batches: &[RecordBatch], maxrows: MaxRows, with_header: bool, + format_options: &FormatOptions, ) -> Result<()> { // filter out any empty batches let batches: Vec<_> = batches @@ -170,7 +171,7 @@ impl PrintFormat { .cloned() .collect(); if batches.is_empty() { - return self.print_empty(writer, schema); + return self.print_empty(writer, schema, format_options); } match self { @@ -182,7 +183,7 @@ impl PrintFormat { if maxrows == MaxRows::Limited(0) { return Ok(()); } - format_batches_with_maxrows(writer, &batches, maxrows) + format_batches_with_maxrows(writer, &batches, maxrows, format_options) } Self::Json => batches_to_json!(ArrayWriter, writer, &batches), Self::NdJson => batches_to_json!(LineDelimitedWriter, writer, &batches), @@ -194,16 +195,18 @@ impl PrintFormat { &self, writer: &mut W, schema: SchemaRef, + format_options: &FormatOptions, ) -> Result<()> { match self { // Print column headers for Table format Self::Table if !schema.fields().is_empty() => { + let format_options: arrow::util::display::FormatOptions = + format_options.try_into()?; + let empty_batch = RecordBatch::new_empty(schema); - let formatted = pretty_format_batches_with_options( - &[empty_batch], - &DEFAULT_CLI_FORMAT_OPTIONS, - )?; - writeln!(writer, "{}", formatted)?; + let formatted = + pretty_format_batches_with_options(&[empty_batch], &format_options)?; + writeln!(writer, "{formatted}")?; } _ => {} } @@ -644,6 +647,7 @@ mod tests { &self.batches, self.maxrows, with_header, + &FormatOptions::default(), ) .unwrap(); String::from_utf8(buffer).unwrap() diff --git a/datafusion-cli/src/print_options.rs b/datafusion-cli/src/print_options.rs index 9557e783e8a7..56d787b0fe08 100644 --- a/datafusion-cli/src/print_options.rs +++ b/datafusion-cli/src/print_options.rs @@ -29,6 +29,7 @@ use datafusion::common::DataFusionError; use datafusion::error::Result; use datafusion::physical_plan::RecordBatchStream; +use datafusion::config::FormatOptions; use futures::StreamExt; #[derive(Debug, Clone, PartialEq, Copy)] @@ -51,7 +52,7 @@ impl FromStr for MaxRows { } else { match maxrows.parse::() { Ok(nrows) => Ok(Self::Limited(nrows)), - _ => Err(format!("Invalid maxrows {}. Valid inputs are natural numbers or \'none\', \'inf\', or \'infinite\' for no limit.", maxrows)), + _ => Err(format!("Invalid maxrows {maxrows}. Valid inputs are natural numbers or \'none\', \'inf\', or \'infinite\' for no limit.")), } } } @@ -103,12 +104,19 @@ impl PrintOptions { batches: &[RecordBatch], query_start_time: Instant, row_count: usize, + format_options: &FormatOptions, ) -> Result<()> { let stdout = std::io::stdout(); let mut writer = stdout.lock(); - self.format - .print_batches(&mut writer, schema, batches, self.maxrows, true)?; + self.format.print_batches( + &mut writer, + schema, + batches, + self.maxrows, + true, + format_options, + )?; let formatted_exec_details = get_execution_details_formatted( row_count, @@ -132,6 +140,7 @@ impl PrintOptions { &self, mut stream: Pin>, query_start_time: Instant, + format_options: &FormatOptions, ) -> Result<()> { if self.format == PrintFormat::Table { return Err(DataFusionError::External( @@ -154,6 +163,7 @@ impl PrintOptions { &[batch], MaxRows::Unlimited, with_header, + format_options, )?; with_header = false; } diff --git a/datafusion-cli/tests/cli_integration.rs b/datafusion-cli/tests/cli_integration.rs index 9ac09955512b..fb2f08157f67 100644 --- a/datafusion-cli/tests/cli_integration.rs +++ b/datafusion-cli/tests/cli_integration.rs @@ -69,6 +69,10 @@ fn init() { // can choose the old explain format too ["--command", "EXPLAIN FORMAT indent SELECT 123"], )] +#[case::change_format_version( + "change_format_version", + ["--file", "tests/sql/types_format.sql", "-q"], +)] #[test] fn cli_quick_test<'a>( #[case] snapshot_name: &'a str, @@ -118,6 +122,42 @@ fn test_cli_format<'a>(#[case] format: &'a str) { assert_cmd_snapshot!(cmd); } +#[rstest] +#[case("no_track", ["--top-memory-consumers", "0"])] +#[case("top2", ["--top-memory-consumers", "2"])] +#[case("top3_default", [])] +#[test] +fn test_cli_top_memory_consumers<'a>( + #[case] snapshot_name: &str, + #[case] top_memory_consumers: impl IntoIterator, +) { + let mut settings = make_settings(); + + settings.set_snapshot_suffix(snapshot_name); + + settings.add_filter( + r"[^\s]+\#\d+\(can spill: (true|false)\) consumed .*?B", + "Consumer(can spill: bool) consumed XB", + ); + settings.add_filter( + r"Error: Failed to allocate additional .*? for .*? with .*? already allocated for this reservation - .*? remain available for the total pool", + "Error: Failed to allocate ", + ); + settings.add_filter( + r"Resources exhausted: Failed to allocate additional .*? for .*? with .*? already allocated for this reservation - .*? remain available for the total pool", + "Resources exhausted: Failed to allocate", + ); + + let _bound = settings.bind_to_scope(); + + let mut cmd = cli(); + let sql = "select * from generate_series(1,500000) as t1(v1) order by v1;"; + cmd.args(["--memory-limit", "10M", "--command", sql]); + cmd.args(top_memory_consumers); + + assert_cmd_snapshot!(cmd); +} + #[tokio::test] async fn test_cli() { if env::var("TEST_STORAGE_INTEGRATION").is_err() { @@ -157,15 +197,14 @@ async fn test_aws_options() { STORED AS CSV LOCATION 's3://data/cars.csv' OPTIONS( - 'aws.access_key_id' '{}', - 'aws.secret_access_key' '{}', - 'aws.endpoint' '{}', + 'aws.access_key_id' '{access_key_id}', + 'aws.secret_access_key' '{secret_access_key}', + 'aws.endpoint' '{endpoint_url}', 'aws.allow_http' 'true' ); SELECT * FROM CARS limit 1; -"#, - access_key_id, secret_access_key, endpoint_url +"# ); assert_cmd_snapshot!(cli().env_clear().pass_stdin(input)); diff --git a/datafusion-cli/tests/snapshots/cli_quick_test@change_format_version.snap b/datafusion-cli/tests/snapshots/cli_quick_test@change_format_version.snap new file mode 100644 index 000000000000..74059b2a6103 --- /dev/null +++ b/datafusion-cli/tests/snapshots/cli_quick_test@change_format_version.snap @@ -0,0 +1,20 @@ +--- +source: datafusion-cli/tests/cli_integration.rs +info: + program: datafusion-cli + args: + - "--file" + - tests/sql/types_format.sql + - "-q" +--- +success: true +exit_code: 0 +----- stdout ----- ++-----------+ +| Int64(54) | +| Int64 | ++-----------+ +| 54 | ++-----------+ + +----- stderr ----- diff --git a/datafusion-cli/tests/snapshots/cli_top_memory_consumers@no_track.snap b/datafusion-cli/tests/snapshots/cli_top_memory_consumers@no_track.snap new file mode 100644 index 000000000000..89b646a531f8 --- /dev/null +++ b/datafusion-cli/tests/snapshots/cli_top_memory_consumers@no_track.snap @@ -0,0 +1,21 @@ +--- +source: datafusion-cli/tests/cli_integration.rs +info: + program: datafusion-cli + args: + - "--memory-limit" + - 10M + - "--command" + - "select * from generate_series(1,500000) as t1(v1) order by v1;" + - "--top-memory-consumers" + - "0" +--- +success: false +exit_code: 1 +----- stdout ----- +[CLI_VERSION] +Error: Not enough memory to continue external sort. Consider increasing the memory limit, or decreasing sort_spill_reservation_bytes +caused by +Resources exhausted: Failed to allocate + +----- stderr ----- diff --git a/datafusion-cli/tests/snapshots/cli_top_memory_consumers@top2.snap b/datafusion-cli/tests/snapshots/cli_top_memory_consumers@top2.snap new file mode 100644 index 000000000000..ed925a6f6461 --- /dev/null +++ b/datafusion-cli/tests/snapshots/cli_top_memory_consumers@top2.snap @@ -0,0 +1,24 @@ +--- +source: datafusion-cli/tests/cli_integration.rs +info: + program: datafusion-cli + args: + - "--memory-limit" + - 10M + - "--command" + - "select * from generate_series(1,500000) as t1(v1) order by v1;" + - "--top-memory-consumers" + - "2" +--- +success: false +exit_code: 1 +----- stdout ----- +[CLI_VERSION] +Error: Not enough memory to continue external sort. Consider increasing the memory limit, or decreasing sort_spill_reservation_bytes +caused by +Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: + Consumer(can spill: bool) consumed XB, + Consumer(can spill: bool) consumed XB. +Error: Failed to allocate + +----- stderr ----- diff --git a/datafusion-cli/tests/snapshots/cli_top_memory_consumers@top3_default.snap b/datafusion-cli/tests/snapshots/cli_top_memory_consumers@top3_default.snap new file mode 100644 index 000000000000..f35e3b117178 --- /dev/null +++ b/datafusion-cli/tests/snapshots/cli_top_memory_consumers@top3_default.snap @@ -0,0 +1,23 @@ +--- +source: datafusion-cli/tests/cli_integration.rs +info: + program: datafusion-cli + args: + - "--memory-limit" + - 10M + - "--command" + - "select * from generate_series(1,500000) as t1(v1) order by v1;" +--- +success: false +exit_code: 1 +----- stdout ----- +[CLI_VERSION] +Error: Not enough memory to continue external sort. Consider increasing the memory limit, or decreasing sort_spill_reservation_bytes +caused by +Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: + Consumer(can spill: bool) consumed XB, + Consumer(can spill: bool) consumed XB, + Consumer(can spill: bool) consumed XB. +Error: Failed to allocate + +----- stderr ----- diff --git a/datafusion-cli/tests/sql/types_format.sql b/datafusion-cli/tests/sql/types_format.sql new file mode 100644 index 000000000000..637929c980a1 --- /dev/null +++ b/datafusion-cli/tests/sql/types_format.sql @@ -0,0 +1,3 @@ +set datafusion.format.types_info to true; + +select 54 \ No newline at end of file diff --git a/datafusion-examples/Cargo.toml b/datafusion-examples/Cargo.toml index 2ba1673d97b9..b31708a5c1cc 100644 --- a/datafusion-examples/Cargo.toml +++ b/datafusion-examples/Cargo.toml @@ -77,7 +77,7 @@ tonic = "0.12.1" tracing = { version = "0.1" } tracing-subscriber = { version = "0.3" } url = { workspace = true } -uuid = "1.16" +uuid = "1.17" [target.'cfg(not(target_os = "windows"))'.dev-dependencies] -nix = { version = "0.29.0", features = ["fs"] } +nix = { version = "0.30.1", features = ["fs"] } diff --git a/datafusion-examples/examples/advanced_udaf.rs b/datafusion-examples/examples/advanced_udaf.rs index 9cda726db719..7b1d3e94b2ef 100644 --- a/datafusion-examples/examples/advanced_udaf.rs +++ b/datafusion-examples/examples/advanced_udaf.rs @@ -25,6 +25,7 @@ use arrow::array::{ }; use arrow::datatypes::{ArrowNativeTypeOp, ArrowPrimitiveType, Float64Type, UInt32Type}; use arrow::record_batch::RecordBatch; +use arrow_schema::FieldRef; use datafusion::common::{cast::as_float64_array, ScalarValue}; use datafusion::error::Result; use datafusion::logical_expr::{ @@ -92,10 +93,10 @@ impl AggregateUDFImpl for GeoMeanUdaf { } /// This is the description of the state. accumulator's state() must match the types here. - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { Ok(vec![ - Field::new("prod", args.return_type.clone(), true), - Field::new("n", DataType::UInt32, true), + Field::new("prod", args.return_type().clone(), true).into(), + Field::new("n", DataType::UInt32, true).into(), ]) } @@ -401,7 +402,7 @@ impl AggregateUDFImpl for SimplifiedGeoMeanUdaf { unimplemented!("should not be invoked") } - fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { unimplemented!("should not be invoked") } @@ -482,7 +483,7 @@ async fn main() -> Result<()> { ctx.register_udaf(udf.clone()); let sql_df = ctx - .sql(&format!("SELECT {}(a) FROM t GROUP BY b", udf_name)) + .sql(&format!("SELECT {udf_name}(a) FROM t GROUP BY b")) .await?; sql_df.show().await?; diff --git a/datafusion-examples/examples/advanced_udwf.rs b/datafusion-examples/examples/advanced_udwf.rs index 8330e783319d..f7316ddc1bec 100644 --- a/datafusion-examples/examples/advanced_udwf.rs +++ b/datafusion-examples/examples/advanced_udwf.rs @@ -23,6 +23,7 @@ use arrow::{ array::{ArrayRef, AsArray, Float64Array}, datatypes::Float64Type, }; +use arrow_schema::FieldRef; use datafusion::common::ScalarValue; use datafusion::error::Result; use datafusion::functions_aggregate::average::avg_udaf; @@ -87,8 +88,8 @@ impl WindowUDFImpl for SmoothItUdf { Ok(Box::new(MyPartitionEvaluator::new())) } - fn field(&self, field_args: WindowUDFFieldArgs) -> Result { - Ok(Field::new(field_args.name(), DataType::Float64, true)) + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + Ok(Field::new(field_args.name(), DataType::Float64, true).into()) } } @@ -190,7 +191,7 @@ impl WindowUDFImpl for SimplifySmoothItUdf { /// default implementation will not be called (left as `todo!()`) fn simplify(&self) -> Option { let simplify = |window_function: WindowFunction, _: &dyn SimplifyInfo| { - Ok(Expr::WindowFunction(WindowFunction { + Ok(Expr::from(WindowFunction { fun: WindowFunctionDefinition::AggregateUDF(avg_udaf()), params: WindowFunctionParams { args: window_function.params.args, @@ -205,8 +206,8 @@ impl WindowUDFImpl for SimplifySmoothItUdf { Some(Box::new(simplify)) } - fn field(&self, field_args: WindowUDFFieldArgs) -> Result { - Ok(Field::new(field_args.name(), DataType::Float64, true)) + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + Ok(Field::new(field_args.name(), DataType::Float64, true).into()) } } diff --git a/datafusion-examples/examples/catalog.rs b/datafusion-examples/examples/catalog.rs index 655438b78b9f..229867cdfc5b 100644 --- a/datafusion-examples/examples/catalog.rs +++ b/datafusion-examples/examples/catalog.rs @@ -309,7 +309,7 @@ fn prepare_example_data() -> Result { 3,baz"#; for i in 0..5 { - let mut file = File::create(path.join(format!("{}.csv", i)))?; + let mut file = File::create(path.join(format!("{i}.csv")))?; file.write_all(content.as_bytes())?; } diff --git a/datafusion-examples/examples/custom_file_format.rs b/datafusion-examples/examples/custom_file_format.rs index 165d82627061..5ebb2a6b4e58 100644 --- a/datafusion-examples/examples/custom_file_format.rs +++ b/datafusion-examples/examples/custom_file_format.rs @@ -21,8 +21,7 @@ use arrow::{ array::{AsArray, RecordBatch, StringArray, UInt8Array}, datatypes::{DataType, Field, Schema, SchemaRef, UInt64Type}, }; -use datafusion::physical_expr::LexRequirement; -use datafusion::physical_expr::PhysicalExpr; +use datafusion::physical_expr::{LexRequirement, PhysicalExpr}; use datafusion::{ catalog::Session, common::{GetExt, Statistics}, @@ -114,9 +113,7 @@ impl FileFormat for TSVFileFormat { conf: FileScanConfig, filters: Option<&Arc>, ) -> Result> { - self.csv_file_format - .create_physical_plan(state, conf, filters) - .await + self.csv_file_format.create_physical_plan(state, conf, filters).await } async fn create_writer_physical_plan( diff --git a/datafusion-examples/examples/dataframe.rs b/datafusion-examples/examples/dataframe.rs index 6f61c164f41d..57a28aeca0de 100644 --- a/datafusion-examples/examples/dataframe.rs +++ b/datafusion-examples/examples/dataframe.rs @@ -15,8 +15,9 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{ArrayRef, Int32Array, RecordBatch, StringArray}; +use arrow::array::{ArrayRef, Int32Array, RecordBatch, StringArray, StringViewArray}; use datafusion::arrow::datatypes::{DataType, Field, Schema}; +use datafusion::catalog::MemTable; use datafusion::common::config::CsvOptions; use datafusion::common::parsers::CompressionTypeVariant; use datafusion::common::DataFusionError; @@ -63,6 +64,7 @@ async fn main() -> Result<()> { read_parquet(&ctx).await?; read_csv(&ctx).await?; read_memory(&ctx).await?; + read_memory_macro().await?; write_out(&ctx).await?; register_aggregate_test_data("t1", &ctx).await?; register_aggregate_test_data("t2", &ctx).await?; @@ -144,7 +146,7 @@ async fn read_csv(ctx: &SessionContext) -> Result<()> { // and using the `enable_url_table` refer to local files directly let dyn_ctx = ctx.clone().enable_url_table(); let csv_df = dyn_ctx - .sql(&format!("SELECT rating, unixtime FROM '{}'", file_path)) + .sql(&format!("SELECT rating, unixtime FROM '{file_path}'")) .await?; csv_df.show().await?; @@ -173,16 +175,40 @@ async fn read_memory(ctx: &SessionContext) -> Result<()> { Ok(()) } +/// Use the DataFrame API to: +/// 1. Read in-memory data. +async fn read_memory_macro() -> Result<()> { + // create a DataFrame using macro + let df = dataframe!( + "a" => ["a", "b", "c", "d"], + "b" => [1, 10, 10, 100] + )?; + // print the results + df.show().await?; + + // create empty DataFrame using macro + let df_empty = dataframe!()?; + df_empty.show().await?; + + Ok(()) +} + /// Use the DataFrame API to: /// 1. Write out a DataFrame to a table /// 2. Write out a DataFrame to a parquet file /// 3. Write out a DataFrame to a csv file /// 4. Write out a DataFrame to a json file async fn write_out(ctx: &SessionContext) -> std::result::Result<(), DataFusionError> { - let mut df = ctx.sql("values ('a'), ('b'), ('c')").await.unwrap(); - - // Ensure the column names and types match the target table - df = df.with_column_renamed("column1", "tablecol1").unwrap(); + let array = StringViewArray::from(vec!["a", "b", "c"]); + let schema = Arc::new(Schema::new(vec![Field::new( + "tablecol1", + DataType::Utf8View, + false, + )])); + let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(array)])?; + let mem_table = MemTable::try_new(schema.clone(), vec![vec![batch]])?; + ctx.register_table("initial_data", Arc::new(mem_table))?; + let df = ctx.table("initial_data").await?; ctx.sql( "create external table diff --git a/datafusion-examples/examples/expr_api.rs b/datafusion-examples/examples/expr_api.rs index b61a350a5a9c..92cf33f4fdf6 100644 --- a/datafusion-examples/examples/expr_api.rs +++ b/datafusion-examples/examples/expr_api.rs @@ -65,7 +65,7 @@ async fn main() -> Result<()> { let expr2 = Expr::BinaryExpr(BinaryExpr::new( Box::new(col("a")), Operator::Plus, - Box::new(Expr::Literal(ScalarValue::Int32(Some(5)))), + Box::new(Expr::Literal(ScalarValue::Int32(Some(5)), None)), )); assert_eq!(expr, expr2); @@ -147,8 +147,7 @@ fn evaluate_demo() -> Result<()> { ])) as _; assert!( matches!(&result, ColumnarValue::Array(r) if r == &expected_result), - "result: {:?}", - result + "result: {result:?}" ); Ok(()) @@ -424,7 +423,7 @@ fn boundary_analysis_in_conjuctions_demo() -> Result<()> { // // But `AND` conjunctions are easier to reason with because their interval // arithmetic follows naturally from set intersection operations, let us - // now look at an example that is a tad more complicated `OR` conjunctions. + // now look at an example that is a tad more complicated `OR` disjunctions. // The expression we will look at is `age > 60 OR age <= 18`. let age_greater_than_60_less_than_18 = @@ -435,7 +434,7 @@ fn boundary_analysis_in_conjuctions_demo() -> Result<()> { // // Initial range: [14, 79] as described in our column statistics. // - // From the left-hand side and right-hand side of our `OR` conjunctions + // From the left-hand side and right-hand side of our `OR` disjunctions // we end up with two ranges, instead of just one. // // - age > 60: [61, 79] @@ -446,7 +445,8 @@ fn boundary_analysis_in_conjuctions_demo() -> Result<()> { let physical_expr = SessionContext::new() .create_physical_expr(age_greater_than_60_less_than_18, &df_schema)?; - // Since we don't handle interval arithmetic for `OR` operator this will error out. + // However, analysis only supports a single interval, so we don't yet deal + // with the multiple possibilities of the `OR` disjunctions. let analysis = analyze( &physical_expr, AnalysisContext::new(initial_boundaries), diff --git a/datafusion-examples/examples/flight/flight_sql_server.rs b/datafusion-examples/examples/flight/flight_sql_server.rs index 54e8de7177cb..5a573ed52320 100644 --- a/datafusion-examples/examples/flight/flight_sql_server.rs +++ b/datafusion-examples/examples/flight/flight_sql_server.rs @@ -115,6 +115,7 @@ impl FlightSqlServiceImpl { Ok(uuid) } + #[allow(clippy::result_large_err)] fn get_ctx(&self, req: &Request) -> Result, Status> { // get the token from the authorization header on Request let auth = req @@ -140,6 +141,7 @@ impl FlightSqlServiceImpl { } } + #[allow(clippy::result_large_err)] fn get_plan(&self, handle: &str) -> Result { if let Some(plan) = self.statements.get(handle) { Ok(plan.clone()) @@ -148,6 +150,7 @@ impl FlightSqlServiceImpl { } } + #[allow(clippy::result_large_err)] fn get_result(&self, handle: &str) -> Result, Status> { if let Some(result) = self.results.get(handle) { Ok(result.clone()) @@ -195,11 +198,13 @@ impl FlightSqlServiceImpl { .unwrap() } + #[allow(clippy::result_large_err)] fn remove_plan(&self, handle: &str) -> Result<(), Status> { self.statements.remove(&handle.to_string()); Ok(()) } + #[allow(clippy::result_large_err)] fn remove_result(&self, handle: &str) -> Result<(), Status> { self.results.remove(&handle.to_string()); Ok(()) diff --git a/datafusion-examples/examples/function_factory.rs b/datafusion-examples/examples/function_factory.rs index 06367f5c09e3..e712f4ea8eaa 100644 --- a/datafusion-examples/examples/function_factory.rs +++ b/datafusion-examples/examples/function_factory.rs @@ -189,8 +189,7 @@ impl ScalarFunctionWrapper { if let Some(value) = placeholder.strip_prefix('$') { Ok(value.parse().map(|v: usize| v - 1).map_err(|e| { DataFusionError::Execution(format!( - "Placeholder `{}` parsing error: {}!", - placeholder, e + "Placeholder `{placeholder}` parsing error: {e}!" )) })?) } else { diff --git a/datafusion-examples/examples/optimizer_rule.rs b/datafusion-examples/examples/optimizer_rule.rs index 63f17484809e..176b1a69808c 100644 --- a/datafusion-examples/examples/optimizer_rule.rs +++ b/datafusion-examples/examples/optimizer_rule.rs @@ -171,7 +171,7 @@ fn is_binary_eq(binary_expr: &BinaryExpr) -> bool { /// Return true if the expression is a literal or column reference fn is_lit_or_col(expr: &Expr) -> bool { - matches!(expr, Expr::Column(_) | Expr::Literal(_)) + matches!(expr, Expr::Column(_) | Expr::Literal(_, _)) } /// A simple user defined filter function diff --git a/datafusion-examples/examples/parquet_index.rs b/datafusion-examples/examples/parquet_index.rs index 7d6ce4d86af1..4fcbf6c67679 100644 --- a/datafusion-examples/examples/parquet_index.rs +++ b/datafusion-examples/examples/parquet_index.rs @@ -23,6 +23,7 @@ use arrow::datatypes::{Int32Type, SchemaRef}; use arrow::util::pretty::pretty_format_batches; use async_trait::async_trait; use datafusion::catalog::Session; +use datafusion::common::pruning::PruningStatistics; use datafusion::common::{ internal_datafusion_err, DFSchema, DataFusionError, Result, ScalarValue, }; @@ -39,7 +40,7 @@ use datafusion::parquet::arrow::{ arrow_reader::ParquetRecordBatchReaderBuilder, ArrowWriter, }; use datafusion::physical_expr::PhysicalExpr; -use datafusion::physical_optimizer::pruning::{PruningPredicate, PruningStatistics}; +use datafusion::physical_optimizer::pruning::PruningPredicate; use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::*; use std::any::Any; @@ -242,8 +243,7 @@ impl TableProvider for IndexTableProvider { let files = self.index.get_files(predicate.clone())?; let object_store_url = ObjectStoreUrl::parse("file://")?; - let source = - Arc::new(ParquetSource::default().with_predicate(self.schema(), predicate)); + let source = Arc::new(ParquetSource::default().with_predicate(self.schema(), predicate)); let mut file_scan_config_builder = FileScanConfigBuilder::new(object_store_url, self.schema(), source) .with_projection(projection.cloned()) diff --git a/datafusion-examples/examples/planner_api.rs b/datafusion-examples/examples/planner_api.rs index 4943e593bd0b..55aec7b0108a 100644 --- a/datafusion-examples/examples/planner_api.rs +++ b/datafusion-examples/examples/planner_api.rs @@ -96,7 +96,7 @@ async fn to_physical_plan_step_by_step_demo( ctx.state().config_options(), |_, _| (), )?; - println!("Analyzed logical plan:\n\n{:?}\n\n", analyzed_logical_plan); + println!("Analyzed logical plan:\n\n{analyzed_logical_plan:?}\n\n"); // Optimize the analyzed logical plan let optimized_logical_plan = ctx.state().optimizer().optimize( @@ -104,10 +104,7 @@ async fn to_physical_plan_step_by_step_demo( &ctx.state(), |_, _| (), )?; - println!( - "Optimized logical plan:\n\n{:?}\n\n", - optimized_logical_plan - ); + println!("Optimized logical plan:\n\n{optimized_logical_plan:?}\n\n"); // Create the physical plan let physical_plan = ctx diff --git a/datafusion-examples/examples/pruning.rs b/datafusion-examples/examples/pruning.rs index 4c802bcdbda0..b2d2fa13b7ed 100644 --- a/datafusion-examples/examples/pruning.rs +++ b/datafusion-examples/examples/pruning.rs @@ -20,10 +20,11 @@ use std::sync::Arc; use arrow::array::{ArrayRef, BooleanArray, Int32Array}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use datafusion::common::pruning::PruningStatistics; use datafusion::common::{DFSchema, ScalarValue}; use datafusion::execution::context::ExecutionProps; use datafusion::physical_expr::create_physical_expr; -use datafusion::physical_optimizer::pruning::{PruningPredicate, PruningStatistics}; +use datafusion::physical_optimizer::pruning::PruningPredicate; use datafusion::prelude::*; /// This example shows how to use DataFusion's `PruningPredicate` to prove diff --git a/datafusion-examples/examples/simple_udtf.rs b/datafusion-examples/examples/simple_udtf.rs index d2b2d1bf9655..b65ffb8d7174 100644 --- a/datafusion-examples/examples/simple_udtf.rs +++ b/datafusion-examples/examples/simple_udtf.rs @@ -133,7 +133,8 @@ struct LocalCsvTableFunc {} impl TableFunctionImpl for LocalCsvTableFunc { fn call(&self, exprs: &[Expr]) -> Result> { - let Some(Expr::Literal(ScalarValue::Utf8(Some(ref path)))) = exprs.first() else { + let Some(Expr::Literal(ScalarValue::Utf8(Some(ref path)), _)) = exprs.first() + else { return plan_err!("read_csv requires at least one string argument"); }; @@ -145,7 +146,7 @@ impl TableFunctionImpl for LocalCsvTableFunc { let info = SimplifyContext::new(&execution_props); let expr = ExprSimplifier::new(info).simplify(expr.clone())?; - if let Expr::Literal(ScalarValue::Int64(Some(limit))) = expr { + if let Expr::Literal(ScalarValue::Int64(Some(limit)), _) = expr { Ok(limit as usize) } else { plan_err!("Limit must be an integer") diff --git a/datafusion-examples/examples/sql_dialect.rs b/datafusion-examples/examples/sql_dialect.rs index 12141847ca36..20b515506f3b 100644 --- a/datafusion-examples/examples/sql_dialect.rs +++ b/datafusion-examples/examples/sql_dialect.rs @@ -17,10 +17,10 @@ use std::fmt::Display; -use datafusion::error::Result; +use datafusion::error::{DataFusionError, Result}; use datafusion::sql::{ parser::{CopyToSource, CopyToStatement, DFParser, DFParserBuilder, Statement}, - sqlparser::{keywords::Keyword, parser::ParserError, tokenizer::Token}, + sqlparser::{keywords::Keyword, tokenizer::Token}, }; /// This example demonstrates how to use the DFParser to parse a statement in a custom way @@ -34,8 +34,8 @@ async fn main() -> Result<()> { let my_statement = my_parser.parse_statement()?; match my_statement { - MyStatement::DFStatement(s) => println!("df: {}", s), - MyStatement::MyCopyTo(s) => println!("my_copy: {}", s), + MyStatement::DFStatement(s) => println!("df: {s}"), + MyStatement::MyCopyTo(s) => println!("my_copy: {s}"), } Ok(()) @@ -62,7 +62,7 @@ impl<'a> MyParser<'a> { /// This is the entry point to our parser -- it handles `COPY` statements specially /// but otherwise delegates to the existing DataFusion parser. - pub fn parse_statement(&mut self) -> Result { + pub fn parse_statement(&mut self) -> Result { if self.is_copy() { self.df_parser.parser.next_token(); // COPY let df_statement = self.df_parser.parse_copy()?; @@ -87,8 +87,8 @@ enum MyStatement { impl Display for MyStatement { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - MyStatement::DFStatement(s) => write!(f, "{}", s), - MyStatement::MyCopyTo(s) => write!(f, "{}", s), + MyStatement::DFStatement(s) => write!(f, "{s}"), + MyStatement::MyCopyTo(s) => write!(f, "{s}"), } } } diff --git a/datafusion/catalog-listing/src/helpers.rs b/datafusion/catalog-listing/src/helpers.rs index 8fd48715b3e6..7ac13b99cce6 100644 --- a/datafusion/catalog-listing/src/helpers.rs +++ b/datafusion/catalog-listing/src/helpers.rs @@ -63,7 +63,7 @@ pub fn expr_applicable_for_cols(col_names: &[&str], expr: &Expr) -> bool { Ok(TreeNodeRecursion::Stop) } } - Expr::Literal(_) + Expr::Literal(_, _) | Expr::Alias(_) | Expr::OuterReferenceColumn(_, _) | Expr::ScalarVariable(_, _) @@ -372,8 +372,8 @@ fn populate_partition_values<'a>( { match op { Operator::Eq => match (left.as_ref(), right.as_ref()) { - (Expr::Column(Column { ref name, .. }), Expr::Literal(val)) - | (Expr::Literal(val), Expr::Column(Column { ref name, .. })) => { + (Expr::Column(Column { ref name, .. }), Expr::Literal(val, _)) + | (Expr::Literal(val, _), Expr::Column(Column { ref name, .. })) => { if partition_values .insert(name, PartitionValue::Single(val.to_string())) .is_some() @@ -533,11 +533,7 @@ where Some((name, val)) if name == pn => part_values.push(val), _ => { debug!( - "Ignoring file: file_path='{}', table_path='{}', part='{}', partition_col='{}'", - file_path, - table_path, - part, - pn, + "Ignoring file: file_path='{file_path}', table_path='{table_path}', part='{part}', partition_col='{pn}'", ); return None; } @@ -1043,7 +1039,7 @@ mod tests { assert_eq!( evaluate_partition_prefix( partitions, - &[col("a").eq(Expr::Literal(ScalarValue::Date32(Some(3))))], + &[col("a").eq(Expr::Literal(ScalarValue::Date32(Some(3)), None))], ), Some(Path::from("a=1970-01-04")), ); @@ -1052,9 +1048,10 @@ mod tests { assert_eq!( evaluate_partition_prefix( partitions, - &[col("a").eq(Expr::Literal(ScalarValue::Date64(Some( - 4 * 24 * 60 * 60 * 1000 - )))),], + &[col("a").eq(Expr::Literal( + ScalarValue::Date64(Some(4 * 24 * 60 * 60 * 1000)), + None + )),], ), Some(Path::from("a=1970-01-05")), ); diff --git a/datafusion/catalog/src/information_schema.rs b/datafusion/catalog/src/information_schema.rs index 7948c0299d39..057d1a819882 100644 --- a/datafusion/catalog/src/information_schema.rs +++ b/datafusion/catalog/src/information_schema.rs @@ -30,6 +30,7 @@ use arrow::{ use async_trait::async_trait; use datafusion_common::config::{ConfigEntry, ConfigOptions}; use datafusion_common::error::Result; +use datafusion_common::types::NativeType; use datafusion_common::DataFusionError; use datafusion_execution::TaskContext; use datafusion_expr::{AggregateUDF, ScalarUDF, Signature, TypeSignature, WindowUDF}; @@ -37,7 +38,7 @@ use datafusion_expr::{TableType, Volatility}; use datafusion_physical_plan::stream::RecordBatchStreamAdapter; use datafusion_physical_plan::streaming::PartitionStream; use datafusion_physical_plan::SendableRecordBatchStream; -use std::collections::{HashMap, HashSet}; +use std::collections::{BTreeSet, HashMap, HashSet}; use std::fmt::Debug; use std::{any::Any, sync::Arc}; @@ -403,58 +404,63 @@ impl InformationSchemaConfig { /// returns a tuple of (arg_types, return_type) fn get_udf_args_and_return_types( udf: &Arc, -) -> Result, Option)>> { +) -> Result, Option)>> { let signature = udf.signature(); let arg_types = signature.type_signature.get_example_types(); if arg_types.is_empty() { - Ok(vec![(vec![], None)]) + Ok(vec![(vec![], None)].into_iter().collect::>()) } else { Ok(arg_types .into_iter() .map(|arg_types| { // only handle the function which implemented [`ScalarUDFImpl::return_type`] method - let return_type = udf.return_type(&arg_types).ok().map(|t| t.to_string()); + let return_type = udf + .return_type(&arg_types) + .map(|t| remove_native_type_prefix(NativeType::from(t))) + .ok(); let arg_types = arg_types .into_iter() - .map(|t| t.to_string()) + .map(|t| remove_native_type_prefix(NativeType::from(t))) .collect::>(); (arg_types, return_type) }) - .collect::>()) + .collect::>()) } } fn get_udaf_args_and_return_types( udaf: &Arc, -) -> Result, Option)>> { +) -> Result, Option)>> { let signature = udaf.signature(); let arg_types = signature.type_signature.get_example_types(); if arg_types.is_empty() { - Ok(vec![(vec![], None)]) + Ok(vec![(vec![], None)].into_iter().collect::>()) } else { Ok(arg_types .into_iter() .map(|arg_types| { // only handle the function which implemented [`ScalarUDFImpl::return_type`] method - let return_type = - udaf.return_type(&arg_types).ok().map(|t| t.to_string()); + let return_type = udaf + .return_type(&arg_types) + .ok() + .map(|t| remove_native_type_prefix(NativeType::from(t))); let arg_types = arg_types .into_iter() - .map(|t| t.to_string()) + .map(|t| remove_native_type_prefix(NativeType::from(t))) .collect::>(); (arg_types, return_type) }) - .collect::>()) + .collect::>()) } } fn get_udwf_args_and_return_types( udwf: &Arc, -) -> Result, Option)>> { +) -> Result, Option)>> { let signature = udwf.signature(); let arg_types = signature.type_signature.get_example_types(); if arg_types.is_empty() { - Ok(vec![(vec![], None)]) + Ok(vec![(vec![], None)].into_iter().collect::>()) } else { Ok(arg_types .into_iter() @@ -462,14 +468,19 @@ fn get_udwf_args_and_return_types( // only handle the function which implemented [`ScalarUDFImpl::return_type`] method let arg_types = arg_types .into_iter() - .map(|t| t.to_string()) + .map(|t| remove_native_type_prefix(NativeType::from(t))) .collect::>(); (arg_types, None) }) - .collect::>()) + .collect::>()) } } +#[inline] +fn remove_native_type_prefix(native_type: NativeType) -> String { + format!("{native_type:?}") +} + #[async_trait] impl SchemaProvider for InformationSchemaProvider { fn as_any(&self) -> &dyn Any { diff --git a/datafusion/common-runtime/Cargo.toml b/datafusion/common-runtime/Cargo.toml index 5e7816b669de..7ddc021e640c 100644 --- a/datafusion/common-runtime/Cargo.toml +++ b/datafusion/common-runtime/Cargo.toml @@ -43,4 +43,4 @@ log = { workspace = true } tokio = { workspace = true } [dev-dependencies] -tokio = { version = "1.44", features = ["rt", "rt-multi-thread", "time"] } +tokio = { version = "1.45", features = ["rt", "rt-multi-thread", "time"] } diff --git a/datafusion/common/Cargo.toml b/datafusion/common/Cargo.toml index 74e99163955e..d471e48be4e7 100644 --- a/datafusion/common/Cargo.toml +++ b/datafusion/common/Cargo.toml @@ -58,12 +58,12 @@ base64 = "0.22.1" half = { workspace = true } hashbrown = { workspace = true } indexmap = { workspace = true } -libc = "0.2.171" +libc = "0.2.172" log = { workspace = true } object_store = { workspace = true, optional = true } parquet = { workspace = true, optional = true, default-features = true } paste = "1.0.15" -pyo3 = { version = "0.24.0", optional = true } +pyo3 = { version = "0.24.2", optional = true } recursive = { workspace = true, optional = true } sqlparser = { workspace = true } tokio = { workspace = true } diff --git a/datafusion/common/src/column.rs b/datafusion/common/src/column.rs index 50a4e257d1c9..b3acaeee5a54 100644 --- a/datafusion/common/src/column.rs +++ b/datafusion/common/src/column.rs @@ -130,8 +130,8 @@ impl Column { /// where `"foo.BAR"` would be parsed to a reference to column named `foo.BAR` pub fn from_qualified_name(flat_name: impl Into) -> Self { let flat_name = flat_name.into(); - Self::from_idents(parse_identifiers_normalized(&flat_name, false)).unwrap_or( - Self { + Self::from_idents(parse_identifiers_normalized(&flat_name, false)).unwrap_or_else( + || Self { relation: None, name: flat_name, spans: Spans::new(), @@ -142,8 +142,8 @@ impl Column { /// Deserialize a fully qualified name string into a column preserving column text case pub fn from_qualified_name_ignore_case(flat_name: impl Into) -> Self { let flat_name = flat_name.into(); - Self::from_idents(parse_identifiers_normalized(&flat_name, true)).unwrap_or( - Self { + Self::from_idents(parse_identifiers_normalized(&flat_name, true)).unwrap_or_else( + || Self { relation: None, name: flat_name, spans: Spans::new(), diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index 1c746a4e9840..883d2b60a897 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -17,17 +17,16 @@ //! Runtime configuration, via [`ConfigOptions`] +use crate::error::_config_err; +use crate::parsers::CompressionTypeVariant; +use crate::utils::get_available_parallelism; +use crate::{DataFusionError, Result}; use std::any::Any; use std::collections::{BTreeMap, HashMap}; use std::error::Error; use std::fmt::{self, Display}; use std::str::FromStr; -use crate::error::_config_err; -use crate::parsers::CompressionTypeVariant; -use crate::utils::get_available_parallelism; -use crate::{DataFusionError, Result}; - /// A macro that wraps a configuration struct and automatically derives /// [`Default`] and [`ConfigField`] for it, allowing it to be used /// in the [`ConfigOptions`] configuration tree. @@ -263,7 +262,7 @@ config_namespace! { /// If true, `VARCHAR` is mapped to `Utf8View` during SQL planning. /// If false, `VARCHAR` is mapped to `Utf8` during SQL planning. /// Default is false. - pub map_varchar_to_utf8view: bool, default = false + pub map_varchar_to_utf8view: bool, default = true /// When set to true, the source locations relative to the original SQL /// query (i.e. [`Span`](https://docs.rs/sqlparser/latest/sqlparser/tokenizer/struct.Span.html)) will be collected @@ -293,14 +292,16 @@ config_namespace! { /// target batch size is determined by the configuration setting pub coalesce_batches: bool, default = true - /// Should DataFusion collect statistics after listing files - pub collect_statistics: bool, default = false + /// Should DataFusion collect statistics when first creating a table. + /// Has no effect after the table is created. Applies to the default + /// `ListingTableProvider` in DataFusion. Defaults to true. + pub collect_statistics: bool, default = true /// Number of partitions for query execution. Increasing partitions can increase /// concurrency. /// /// Defaults to the number of CPU cores on the system - pub target_partitions: usize, default = get_available_parallelism() + pub target_partitions: usize, transform = ExecutionOptions::normalized_parallelism, default = get_available_parallelism() /// The default time zone /// @@ -316,7 +317,7 @@ config_namespace! { /// This is mostly use to plan `UNION` children in parallel. /// /// Defaults to the number of CPU cores on the system - pub planning_concurrency: usize, default = get_available_parallelism() + pub planning_concurrency: usize, transform = ExecutionOptions::normalized_parallelism, default = get_available_parallelism() /// When set to true, skips verifying that the schema produced by /// planning the input of `LogicalPlan::Aggregate` exactly matches the @@ -405,6 +406,13 @@ config_namespace! { /// in joins can reduce memory usage when joining large /// tables with a highly-selective join filter, but is also slightly slower. pub enforce_batch_size_in_joins: bool, default = false + + /// Size (bytes) of data buffer DataFusion uses when writing output files. + /// This affects the size of the data chunks that are uploaded to remote + /// object stores (e.g. AWS S3). If very large (>= 100 GiB) output files are being + /// written, it may be necessary to increase this size to avoid errors from + /// the remote end point. + pub objectstore_writer_buffer_size: usize, default = 10 * 1024 * 1024 } } @@ -467,6 +475,9 @@ config_namespace! { /// nanosecond resolution. pub coerce_int96: Option, transform = str::to_lowercase, default = None + /// (reading) Use any available bloom filters when reading parquet files + pub bloom_filter_on_read: bool, default = true + // The following options affect writing to parquet files // and map to parquet::file::properties::WriterProperties @@ -542,9 +553,6 @@ config_namespace! { /// default parquet writer setting pub encoding: Option, transform = str::to_lowercase, default = None - /// (writing) Use any available bloom filters when reading parquet files - pub bloom_filter_on_read: bool, default = true - /// (writing) Write bloom filters for all columns when creating parquet files pub bloom_filter_on_write: bool, default = false @@ -632,13 +640,20 @@ config_namespace! { /// long runner execution, all types of joins may encounter out-of-memory errors. pub allow_symmetric_joins_without_pruning: bool, default = true - /// When set to `true`, file groups will be repartitioned to achieve maximum parallelism. - /// Currently Parquet and CSV formats are supported. + /// When set to `true`, datasource partitions will be repartitioned to achieve maximum parallelism. + /// This applies to both in-memory partitions and FileSource's file groups (1 group is 1 partition). /// - /// If set to `true`, all files will be repartitioned evenly (i.e., a single large file + /// For FileSources, only Parquet and CSV formats are currently supported. + /// + /// If set to `true` for a FileSource, all files will be repartitioned evenly (i.e., a single large file /// might be partitioned into smaller chunks) for parallel scanning. - /// If set to `false`, different files will be read in parallel, but repartitioning won't + /// If set to `false` for a FileSource, different files will be read in parallel, but repartitioning won't /// happen within a single file. + /// + /// If set to `true` for an in-memory source, all memtable's partitions will have their batches + /// repartitioned evenly to the desired number of `target_partitions`. Repartitioning can change + /// the total number of partitions and batches per partition, but does not slice the initial + /// record tables provided to the MemTable on creation. pub repartition_file_scans: bool, default = true /// Should DataFusion repartition data using the partitions keys to execute window @@ -739,6 +754,72 @@ config_namespace! { } } +impl ExecutionOptions { + /// Returns the correct parallelism based on the provided `value`. + /// If `value` is `"0"`, returns the default available parallelism, computed with + /// `get_available_parallelism`. Otherwise, returns `value`. + fn normalized_parallelism(value: &str) -> String { + if value.parse::() == Ok(0) { + get_available_parallelism().to_string() + } else { + value.to_owned() + } + } +} + +config_namespace! { + /// Options controlling the format of output when printing record batches + /// Copies [`arrow::util::display::FormatOptions`] + pub struct FormatOptions { + /// If set to `true` any formatting errors will be written to the output + /// instead of being converted into a [`std::fmt::Error`] + pub safe: bool, default = true + /// Format string for nulls + pub null: String, default = "".into() + /// Date format for date arrays + pub date_format: Option, default = Some("%Y-%m-%d".to_string()) + /// Format for DateTime arrays + pub datetime_format: Option, default = Some("%Y-%m-%dT%H:%M:%S%.f".to_string()) + /// Timestamp format for timestamp arrays + pub timestamp_format: Option, default = Some("%Y-%m-%dT%H:%M:%S%.f".to_string()) + /// Timestamp format for timestamp with timezone arrays. When `None`, ISO 8601 format is used. + pub timestamp_tz_format: Option, default = None + /// Time format for time arrays + pub time_format: Option, default = Some("%H:%M:%S%.f".to_string()) + /// Duration format. Can be either `"pretty"` or `"ISO8601"` + pub duration_format: String, transform = str::to_lowercase, default = "pretty".into() + /// Show types in visual representation batches + pub types_info: bool, default = false + } +} + +impl<'a> TryInto> for &'a FormatOptions { + type Error = DataFusionError; + fn try_into(self) -> Result> { + let duration_format = match self.duration_format.as_str() { + "pretty" => arrow::util::display::DurationFormat::Pretty, + "iso8601" => arrow::util::display::DurationFormat::ISO8601, + _ => { + return _config_err!( + "Invalid duration format: {}. Valid values are pretty or iso8601", + self.duration_format + ) + } + }; + + Ok(arrow::util::display::FormatOptions::new() + .with_display_error(self.safe) + .with_null(&self.null) + .with_date_format(self.date_format.as_deref()) + .with_datetime_format(self.datetime_format.as_deref()) + .with_timestamp_format(self.timestamp_format.as_deref()) + .with_timestamp_tz_format(self.timestamp_tz_format.as_deref()) + .with_time_format(self.time_format.as_deref()) + .with_duration_format(duration_format) + .with_types_info(self.types_info)) + } +} + /// A key value pair, with a corresponding description #[derive(Debug)] pub struct ConfigEntry { @@ -768,6 +849,8 @@ pub struct ConfigOptions { pub explain: ExplainOptions, /// Optional extensions registered using [`Extensions::insert`] pub extensions: Extensions, + /// Formatting options when printing batches + pub format: FormatOptions, } impl ConfigField for ConfigOptions { @@ -780,6 +863,7 @@ impl ConfigField for ConfigOptions { "optimizer" => self.optimizer.set(rem, value), "explain" => self.explain.set(rem, value), "sql_parser" => self.sql_parser.set(rem, value), + "format" => self.format.set(rem, value), _ => _config_err!("Config value \"{key}\" not found on ConfigOptions"), } } @@ -790,6 +874,7 @@ impl ConfigField for ConfigOptions { self.optimizer.visit(v, "datafusion.optimizer", ""); self.explain.visit(v, "datafusion.explain", ""); self.sql_parser.visit(v, "datafusion.sql_parser", ""); + self.format.visit(v, "datafusion.format", ""); } } @@ -851,7 +936,9 @@ impl ConfigOptions { for key in keys.0 { let env = key.to_uppercase().replace('.', "_"); if let Some(var) = std::env::var_os(env) { - ret.set(&key, var.to_string_lossy().as_ref())?; + let value = var.to_string_lossy(); + log::info!("Set {key} to {value} from the environment variable"); + ret.set(&key, value.as_ref())?; } } @@ -1138,8 +1225,7 @@ impl ConfigField for u8 { fn set(&mut self, key: &str, value: &str) -> Result<()> { if value.is_empty() { return Err(DataFusionError::Configuration(format!( - "Input string for {} key is empty", - key + "Input string for {key} key is empty" ))); } // Check if the string is a valid number @@ -1151,8 +1237,7 @@ impl ConfigField for u8 { // Check if the first character is ASCII (single byte) if bytes.len() > 1 || !value.chars().next().unwrap().is_ascii() { return Err(DataFusionError::Configuration(format!( - "Error parsing {} as u8. Non-ASCII string provided", - value + "Error parsing {value} as u8. Non-ASCII string provided" ))); } *self = bytes[0]; @@ -1982,11 +2067,11 @@ config_namespace! { } } -pub trait FormatOptionsExt: Display {} +pub trait OutputFormatExt: Display {} #[derive(Debug, Clone, PartialEq)] #[allow(clippy::large_enum_variant)] -pub enum FormatOptions { +pub enum OutputFormat { CSV(CsvOptions), JSON(JsonOptions), #[cfg(feature = "parquet")] @@ -1995,17 +2080,17 @@ pub enum FormatOptions { ARROW, } -impl Display for FormatOptions { +impl Display for OutputFormat { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let out = match self { - FormatOptions::CSV(_) => "csv", - FormatOptions::JSON(_) => "json", + OutputFormat::CSV(_) => "csv", + OutputFormat::JSON(_) => "json", #[cfg(feature = "parquet")] - FormatOptions::PARQUET(_) => "parquet", - FormatOptions::AVRO => "avro", - FormatOptions::ARROW => "arrow", + OutputFormat::PARQUET(_) => "parquet", + OutputFormat::AVRO => "avro", + OutputFormat::ARROW => "arrow", }; - write!(f, "{}", out) + write!(f, "{out}") } } diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs index 66a26a18c0dc..804e14bf72fb 100644 --- a/datafusion/common/src/dfschema.rs +++ b/datafusion/common/src/dfschema.rs @@ -472,7 +472,7 @@ impl DFSchema { let matches = self.qualified_fields_with_unqualified_name(name); match matches.len() { 0 => Err(unqualified_field_not_found(name, self)), - 1 => Ok((matches[0].0, (matches[0].1))), + 1 => Ok((matches[0].0, matches[0].1)), _ => { // When `matches` size > 1, it doesn't necessarily mean an `ambiguous name` problem. // Because name may generate from Alias/... . It means that it don't own qualifier. @@ -515,14 +515,6 @@ impl DFSchema { Ok(self.field(idx)) } - /// Find the field with the given qualified column - pub fn field_from_column(&self, column: &Column) -> Result<&Field> { - match &column.relation { - Some(r) => self.field_with_qualified_name(r, &column.name), - None => self.field_with_unqualified_name(&column.name), - } - } - /// Find the field with the given qualified column pub fn qualified_field_from_column( &self, @@ -969,16 +961,28 @@ impl Display for DFSchema { /// widely used in the DataFusion codebase. pub trait ExprSchema: std::fmt::Debug { /// Is this column reference nullable? - fn nullable(&self, col: &Column) -> Result; + fn nullable(&self, col: &Column) -> Result { + Ok(self.field_from_column(col)?.is_nullable()) + } /// What is the datatype of this column? - fn data_type(&self, col: &Column) -> Result<&DataType>; + fn data_type(&self, col: &Column) -> Result<&DataType> { + Ok(self.field_from_column(col)?.data_type()) + } /// Returns the column's optional metadata. - fn metadata(&self, col: &Column) -> Result<&HashMap>; + fn metadata(&self, col: &Column) -> Result<&HashMap> { + Ok(self.field_from_column(col)?.metadata()) + } /// Return the column's datatype and nullability - fn data_type_and_nullable(&self, col: &Column) -> Result<(&DataType, bool)>; + fn data_type_and_nullable(&self, col: &Column) -> Result<(&DataType, bool)> { + let field = self.field_from_column(col)?; + Ok((field.data_type(), field.is_nullable())) + } + + // Return the column's field + fn field_from_column(&self, col: &Column) -> Result<&Field>; } // Implement `ExprSchema` for `Arc` @@ -998,24 +1002,18 @@ impl + std::fmt::Debug> ExprSchema for P { fn data_type_and_nullable(&self, col: &Column) -> Result<(&DataType, bool)> { self.as_ref().data_type_and_nullable(col) } -} - -impl ExprSchema for DFSchema { - fn nullable(&self, col: &Column) -> Result { - Ok(self.field_from_column(col)?.is_nullable()) - } - - fn data_type(&self, col: &Column) -> Result<&DataType> { - Ok(self.field_from_column(col)?.data_type()) - } - fn metadata(&self, col: &Column) -> Result<&HashMap> { - Ok(self.field_from_column(col)?.metadata()) + fn field_from_column(&self, col: &Column) -> Result<&Field> { + self.as_ref().field_from_column(col) } +} - fn data_type_and_nullable(&self, col: &Column) -> Result<(&DataType, bool)> { - let field = self.field_from_column(col)?; - Ok((field.data_type(), field.is_nullable())) +impl ExprSchema for DFSchema { + fn field_from_column(&self, col: &Column) -> Result<&Field> { + match &col.relation { + Some(r) => self.field_with_qualified_name(r, &col.name), + None => self.field_with_unqualified_name(&col.name), + } } } @@ -1090,7 +1088,7 @@ impl SchemaExt for Schema { pub fn qualified_name(qualifier: Option<&TableReference>, name: &str) -> String { match qualifier { - Some(q) => format!("{}.{}", q, name), + Some(q) => format!("{q}.{name}"), None => name.to_string(), } } diff --git a/datafusion/common/src/error.rs b/datafusion/common/src/error.rs index c50ec64759d5..b4a537fdce7e 100644 --- a/datafusion/common/src/error.rs +++ b/datafusion/common/src/error.rs @@ -59,7 +59,7 @@ pub enum DataFusionError { ParquetError(ParquetError), /// Error when reading Avro data. #[cfg(feature = "avro")] - AvroError(AvroError), + AvroError(Box), /// Error when reading / writing to / from an object_store (e.g. S3 or LocalFile) #[cfg(feature = "object_store")] ObjectStore(object_store::Error), @@ -311,7 +311,7 @@ impl From for DataFusionError { #[cfg(feature = "avro")] impl From for DataFusionError { fn from(e: AvroError) -> Self { - DataFusionError::AvroError(e) + DataFusionError::AvroError(Box::new(e)) } } @@ -397,7 +397,7 @@ impl Error for DataFusionError { impl From for io::Error { fn from(e: DataFusionError) -> Self { - io::Error::new(io::ErrorKind::Other, e) + io::Error::other(e) } } @@ -526,7 +526,7 @@ impl DataFusionError { pub fn message(&self) -> Cow { match *self { DataFusionError::ArrowError(ref desc, ref backtrace) => { - let backtrace = backtrace.clone().unwrap_or("".to_owned()); + let backtrace = backtrace.clone().unwrap_or_else(|| "".to_owned()); Cow::Owned(format!("{desc}{backtrace}")) } #[cfg(feature = "parquet")] @@ -535,7 +535,8 @@ impl DataFusionError { DataFusionError::AvroError(ref desc) => Cow::Owned(desc.to_string()), DataFusionError::IoError(ref desc) => Cow::Owned(desc.to_string()), DataFusionError::SQL(ref desc, ref backtrace) => { - let backtrace: String = backtrace.clone().unwrap_or("".to_owned()); + let backtrace: String = + backtrace.clone().unwrap_or_else(|| "".to_owned()); Cow::Owned(format!("{desc:?}{backtrace}")) } DataFusionError::Configuration(ref desc) => Cow::Owned(desc.to_string()), @@ -547,7 +548,7 @@ impl DataFusionError { DataFusionError::Plan(ref desc) => Cow::Owned(desc.to_string()), DataFusionError::SchemaError(ref desc, ref backtrace) => { let backtrace: &str = - &backtrace.as_ref().clone().unwrap_or("".to_owned()); + &backtrace.as_ref().clone().unwrap_or_else(|| "".to_owned()); Cow::Owned(format!("{desc}{backtrace}")) } DataFusionError::Execution(ref desc) => Cow::Owned(desc.to_string()), @@ -759,23 +760,33 @@ macro_rules! make_error { /// Macro wraps `$ERR` to add backtrace feature #[macro_export] macro_rules! $NAME_DF_ERR { - ($d($d args:expr),*) => { - $crate::DataFusionError::$ERR( + ($d($d args:expr),* $d(; diagnostic=$d DIAG:expr)?) => {{ + let err =$crate::DataFusionError::$ERR( ::std::format!( "{}{}", ::std::format!($d($d args),*), $crate::DataFusionError::get_back_trace(), ).into() - ) + ); + $d ( + let err = err.with_diagnostic($d DIAG); + )? + err } } + } /// Macro wraps Err(`$ERR`) to add backtrace feature #[macro_export] macro_rules! $NAME_ERR { - ($d($d args:expr),*) => { - Err($crate::[<_ $NAME_DF_ERR>]!($d($d args),*)) - } + ($d($d args:expr),* $d(; diagnostic = $d DIAG:expr)?) => {{ + let err = $crate::[<_ $NAME_DF_ERR>]!($d($d args),*); + $d ( + let err = err.with_diagnostic($d DIAG); + )? + Err(err) + + }} } @@ -816,54 +827,80 @@ make_error!(resources_err, resources_datafusion_err, ResourcesExhausted); // Exposes a macro to create `DataFusionError::SQL` with optional backtrace #[macro_export] macro_rules! sql_datafusion_err { - ($ERR:expr) => { - DataFusionError::SQL($ERR, Some(DataFusionError::get_back_trace())) - }; + ($ERR:expr $(; diagnostic = $DIAG:expr)?) => {{ + let err = DataFusionError::SQL($ERR, Some(DataFusionError::get_back_trace())); + $( + let err = err.with_diagnostic($DIAG); + )? + err + }}; } // Exposes a macro to create `Err(DataFusionError::SQL)` with optional backtrace #[macro_export] macro_rules! sql_err { - ($ERR:expr) => { - Err(datafusion_common::sql_datafusion_err!($ERR)) - }; + ($ERR:expr $(; diagnostic = $DIAG:expr)?) => {{ + let err = datafusion_common::sql_datafusion_err!($ERR); + $( + let err = err.with_diagnostic($DIAG); + )? + Err(err) + }}; } // Exposes a macro to create `DataFusionError::ArrowError` with optional backtrace #[macro_export] macro_rules! arrow_datafusion_err { - ($ERR:expr) => { - DataFusionError::ArrowError($ERR, Some(DataFusionError::get_back_trace())) - }; + ($ERR:expr $(; diagnostic = $DIAG:expr)?) => {{ + let err = DataFusionError::ArrowError($ERR, Some(DataFusionError::get_back_trace())); + $( + let err = err.with_diagnostic($DIAG); + )? + err + }}; } // Exposes a macro to create `Err(DataFusionError::ArrowError)` with optional backtrace #[macro_export] macro_rules! arrow_err { - ($ERR:expr) => { - Err(datafusion_common::arrow_datafusion_err!($ERR)) - }; + ($ERR:expr $(; diagnostic = $DIAG:expr)?) => { + { + let err = datafusion_common::arrow_datafusion_err!($ERR); + $( + let err = err.with_diagnostic($DIAG); + )? + Err(err) + }}; } // Exposes a macro to create `DataFusionError::SchemaError` with optional backtrace #[macro_export] macro_rules! schema_datafusion_err { - ($ERR:expr) => { - $crate::error::DataFusionError::SchemaError( + ($ERR:expr $(; diagnostic = $DIAG:expr)?) => {{ + let err = $crate::error::DataFusionError::SchemaError( $ERR, Box::new(Some($crate::error::DataFusionError::get_back_trace())), - ) - }; + ); + $( + let err = err.with_diagnostic($DIAG); + )? + err + }}; } // Exposes a macro to create `Err(DataFusionError::SchemaError)` with optional backtrace #[macro_export] macro_rules! schema_err { - ($ERR:expr) => { - Err($crate::error::DataFusionError::SchemaError( + ($ERR:expr $(; diagnostic = $DIAG:expr)?) => {{ + let err = $crate::error::DataFusionError::SchemaError( $ERR, Box::new(Some($crate::error::DataFusionError::get_back_trace())), - )) + ); + $( + let err = err.with_diagnostic($DIAG); + )? + Err(err) + } }; } @@ -908,7 +945,7 @@ pub fn add_possible_columns_to_diag( .collect(); for name in field_names { - diagnostic.add_note(format!("possible column {}", name), None); + diagnostic.add_note(format!("possible column {name}"), None); } } @@ -1083,8 +1120,7 @@ mod test { ); // assert wrapping other Error - let generic_error: GenericError = - Box::new(std::io::Error::new(std::io::ErrorKind::Other, "io error")); + let generic_error: GenericError = Box::new(std::io::Error::other("io error")); let datafusion_error: DataFusionError = generic_error.into(); println!("{}", datafusion_error.strip_backtrace()); assert_eq!( @@ -1095,13 +1131,12 @@ mod test { #[test] fn external_error_no_recursive() { - let generic_error_1: GenericError = - Box::new(std::io::Error::new(std::io::ErrorKind::Other, "io error")); + let generic_error_1: GenericError = Box::new(std::io::Error::other("io error")); let external_error_1: DataFusionError = generic_error_1.into(); let generic_error_2: GenericError = Box::new(external_error_1); let external_error_2: DataFusionError = generic_error_2.into(); - println!("{}", external_error_2); + println!("{external_error_2}"); assert!(external_error_2 .to_string() .starts_with("External error: io error")); diff --git a/datafusion/common/src/file_options/parquet_writer.rs b/datafusion/common/src/file_options/parquet_writer.rs index 3e33466edf50..07e763f0ee6f 100644 --- a/datafusion/common/src/file_options/parquet_writer.rs +++ b/datafusion/common/src/file_options/parquet_writer.rs @@ -330,8 +330,7 @@ fn split_compression_string(str_setting: &str) -> Result<(String, Option)> let level = &rh[..rh.len() - 1].parse::().map_err(|_| { DataFusionError::Configuration(format!( "Could not parse compression string. \ - Got codec: {} and unknown level from {}", - codec, str_setting + Got codec: {codec} and unknown level from {str_setting}" )) })?; Ok((codec.to_owned(), Some(*level))) diff --git a/datafusion/common/src/format.rs b/datafusion/common/src/format.rs index 23cfb72314a3..a4ebd1753999 100644 --- a/datafusion/common/src/format.rs +++ b/datafusion/common/src/format.rs @@ -19,6 +19,7 @@ use arrow::compute::CastOptions; use arrow::util::display::{DurationFormat, FormatOptions}; /// The default [`FormatOptions`] to use within DataFusion +/// Also see [`crate::config::FormatOptions`] pub const DEFAULT_FORMAT_OPTIONS: FormatOptions<'static> = FormatOptions::new().with_duration_format(DurationFormat::Pretty); @@ -27,7 +28,3 @@ pub const DEFAULT_CAST_OPTIONS: CastOptions<'static> = CastOptions { safe: false, format_options: DEFAULT_FORMAT_OPTIONS, }; - -pub const DEFAULT_CLI_FORMAT_OPTIONS: FormatOptions<'static> = FormatOptions::new() - .with_duration_format(DurationFormat::Pretty) - .with_null("NULL"); diff --git a/datafusion/common/src/functional_dependencies.rs b/datafusion/common/src/functional_dependencies.rs index c4f2805f8285..77e00d6dcda2 100644 --- a/datafusion/common/src/functional_dependencies.rs +++ b/datafusion/common/src/functional_dependencies.rs @@ -113,7 +113,7 @@ impl Display for Constraints { let pk = self .inner .iter() - .map(|c| format!("{:?}", c)) + .map(|c| format!("{c:?}")) .collect::>(); let pk = pk.join(", "); write!(f, "constraints=[{pk}]") diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs index b137624532b9..7b2c86d3975f 100644 --- a/datafusion/common/src/lib.rs +++ b/datafusion/common/src/lib.rs @@ -47,6 +47,7 @@ pub mod format; pub mod hash_utils; pub mod instant; pub mod parsers; +pub mod pruning; pub mod rounding; pub mod scalar; pub mod spans; @@ -185,9 +186,7 @@ mod tests { let expected_prefix = expected_prefix.as_ref(); assert!( actual.starts_with(expected_prefix), - "Expected '{}' to start with '{}'", - actual, - expected_prefix + "Expected '{actual}' to start with '{expected_prefix}'" ); } } diff --git a/datafusion/common/src/parsers.rs b/datafusion/common/src/parsers.rs index c73c8a55f18c..41571ebb8576 100644 --- a/datafusion/common/src/parsers.rs +++ b/datafusion/common/src/parsers.rs @@ -64,7 +64,7 @@ impl Display for CompressionTypeVariant { Self::ZSTD => "ZSTD", Self::UNCOMPRESSED => "", }; - write!(f, "{}", str) + write!(f, "{str}") } } diff --git a/datafusion/common/src/pruning.rs b/datafusion/common/src/pruning.rs new file mode 100644 index 000000000000..48750e3c995c --- /dev/null +++ b/datafusion/common/src/pruning.rs @@ -0,0 +1,1122 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Array, NullArray, UInt64Array}; +use arrow::array::{ArrayRef, BooleanArray}; +use arrow::datatypes::{FieldRef, Schema, SchemaRef}; +use std::collections::HashSet; +use std::sync::Arc; + +use crate::error::DataFusionError; +use crate::stats::Precision; +use crate::{Column, Statistics}; +use crate::{ColumnStatistics, ScalarValue}; + +/// A source of runtime statistical information to [`PruningPredicate`]s. +/// +/// # Supported Information +/// +/// 1. Minimum and maximum values for columns +/// +/// 2. Null counts and row counts for columns +/// +/// 3. Whether the values in a column are contained in a set of literals +/// +/// # Vectorized Interface +/// +/// Information for containers / files are returned as Arrow [`ArrayRef`], so +/// the evaluation happens once on a single `RecordBatch`, which amortizes the +/// overhead of evaluating the predicate. This is important when pruning 1000s +/// of containers which often happens in analytic systems that have 1000s of +/// potential files to consider. +/// +/// For example, for the following three files with a single column `a`: +/// ```text +/// file1: column a: min=5, max=10 +/// file2: column a: No stats +/// file2: column a: min=20, max=30 +/// ``` +/// +/// PruningStatistics would return: +/// +/// ```text +/// min_values("a") -> Some([5, Null, 20]) +/// max_values("a") -> Some([10, Null, 30]) +/// min_values("X") -> None +/// ``` +/// +/// [`PruningPredicate`]: https://docs.rs/datafusion/latest/datafusion/physical_optimizer/pruning/struct.PruningPredicate.html +pub trait PruningStatistics { + /// Return the minimum values for the named column, if known. + /// + /// If the minimum value for a particular container is not known, the + /// returned array should have `null` in that row. If the minimum value is + /// not known for any row, return `None`. + /// + /// Note: the returned array must contain [`Self::num_containers`] rows + fn min_values(&self, column: &Column) -> Option; + + /// Return the maximum values for the named column, if known. + /// + /// See [`Self::min_values`] for when to return `None` and null values. + /// + /// Note: the returned array must contain [`Self::num_containers`] rows + fn max_values(&self, column: &Column) -> Option; + + /// Return the number of containers (e.g. Row Groups) being pruned with + /// these statistics. + /// + /// This value corresponds to the size of the [`ArrayRef`] returned by + /// [`Self::min_values`], [`Self::max_values`], [`Self::null_counts`], + /// and [`Self::row_counts`]. + fn num_containers(&self) -> usize; + + /// Return the number of null values for the named column as an + /// [`UInt64Array`] + /// + /// See [`Self::min_values`] for when to return `None` and null values. + /// + /// Note: the returned array must contain [`Self::num_containers`] rows + /// + /// [`UInt64Array`]: arrow::array::UInt64Array + fn null_counts(&self, column: &Column) -> Option; + + /// Return the number of rows for the named column in each container + /// as an [`UInt64Array`]. + /// + /// See [`Self::min_values`] for when to return `None` and null values. + /// + /// Note: the returned array must contain [`Self::num_containers`] rows + /// + /// [`UInt64Array`]: arrow::array::UInt64Array + fn row_counts(&self, column: &Column) -> Option; + + /// Returns [`BooleanArray`] where each row represents information known + /// about specific literal `values` in a column. + /// + /// For example, Parquet Bloom Filters implement this API to communicate + /// that `values` are known not to be present in a Row Group. + /// + /// The returned array has one row for each container, with the following + /// meanings: + /// * `true` if the values in `column` ONLY contain values from `values` + /// * `false` if the values in `column` are NOT ANY of `values` + /// * `null` if the neither of the above holds or is unknown. + /// + /// If these statistics can not determine column membership for any + /// container, return `None` (the default). + /// + /// Note: the returned array must contain [`Self::num_containers`] rows + fn contained( + &self, + column: &Column, + values: &HashSet, + ) -> Option; +} + +/// Prune files based on their partition values. +/// +/// This is used both at planning time and execution time to prune +/// files based on their partition values. +/// This feeds into [`CompositePruningStatistics`] to allow pruning +/// with filters that depend both on partition columns and data columns +/// (e.g. `WHERE partition_col = data_col`). +#[derive(Clone)] +pub struct PartitionPruningStatistics { + /// Values for each column for each container. + /// + /// The outer vectors represent the columns while the inner vectors + /// represent the containers. The order must match the order of the + /// partition columns in [`PartitionPruningStatistics::partition_schema`]. + partition_values: Vec, + /// The number of containers. + /// + /// Stored since the partition values are column-major and if + /// there are no columns we wouldn't know the number of containers. + num_containers: usize, + /// The schema of the partition columns. + /// + /// This must **not** be the schema of the entire file or table: it must + /// only be the schema of the partition columns, in the same order as the + /// values in [`PartitionPruningStatistics::partition_values`]. + partition_schema: SchemaRef, +} + +impl PartitionPruningStatistics { + /// Create a new instance of [`PartitionPruningStatistics`]. + /// + /// Args: + /// * `partition_values`: A vector of vectors of [`ScalarValue`]s. + /// The outer vector represents the containers while the inner + /// vector represents the partition values for each column. + /// Note that this is the **opposite** of the order of the + /// partition columns in `PartitionPruningStatistics::partition_schema`. + /// * `partition_schema`: The schema of the partition columns. + /// This must **not** be the schema of the entire file or table: + /// instead it must only be the schema of the partition columns, + /// in the same order as the values in `partition_values`. + pub fn try_new( + partition_values: Vec>, + partition_fields: Vec, + ) -> Result { + let num_containers = partition_values.len(); + let partition_schema = Arc::new(Schema::new(partition_fields)); + let mut partition_values_by_column = + vec![ + Vec::with_capacity(partition_values.len()); + partition_schema.fields().len() + ]; + for partition_value in partition_values { + for (i, value) in partition_value.into_iter().enumerate() { + partition_values_by_column[i].push(value); + } + } + Ok(Self { + partition_values: partition_values_by_column + .into_iter() + .map(|v| { + if v.is_empty() { + Ok(Arc::new(NullArray::new(0)) as ArrayRef) + } else { + ScalarValue::iter_to_array(v) + } + }) + .collect::, _>>()?, + num_containers, + partition_schema, + }) + } +} + +impl PruningStatistics for PartitionPruningStatistics { + fn min_values(&self, column: &Column) -> Option { + let index = self.partition_schema.index_of(column.name()).ok()?; + self.partition_values.get(index).and_then(|v| { + if v.is_empty() || v.null_count() == v.len() { + // If the array is empty or all nulls, return None + None + } else { + // Otherwise, return the array as is + Some(Arc::clone(v)) + } + }) + } + + fn max_values(&self, column: &Column) -> Option { + self.min_values(column) + } + + fn num_containers(&self) -> usize { + self.num_containers + } + + fn null_counts(&self, _column: &Column) -> Option { + None + } + + fn row_counts(&self, _column: &Column) -> Option { + None + } + + fn contained( + &self, + column: &Column, + values: &HashSet, + ) -> Option { + let index = self.partition_schema.index_of(column.name()).ok()?; + let array = self.partition_values.get(index)?; + let boolean_array = values.iter().try_fold(None, |acc, v| { + let arrow_value = v.to_scalar().ok()?; + let eq_result = arrow::compute::kernels::cmp::eq(array, &arrow_value).ok()?; + match acc { + None => Some(Some(eq_result)), + Some(acc_array) => { + arrow::compute::kernels::boolean::and(&acc_array, &eq_result) + .map(Some) + .ok() + } + } + })??; + // If the boolean array is empty or all null values, return None + if boolean_array.is_empty() || boolean_array.null_count() == boolean_array.len() { + None + } else { + Some(boolean_array) + } + } +} + +/// Prune a set of containers represented by their statistics. +/// +/// Each [`Statistics`] represents a "container" -- some collection of data +/// that has statistics of its columns. +/// +/// It is up to the caller to decide what each container represents. For +/// example, they can come from a file (e.g. [`PartitionedFile`]) or a set of of +/// files (e.g. [`FileGroup`]) +/// +/// [`PartitionedFile`]: https://docs.rs/datafusion/latest/datafusion/datasource/listing/struct.PartitionedFile.html +/// [`FileGroup`]: https://docs.rs/datafusion/latest/datafusion/datasource/physical_plan/struct.FileGroup.html +#[derive(Clone)] +pub struct PrunableStatistics { + /// Statistics for each container. + /// These are taken as a reference since they may be rather large / expensive to clone + /// and we often won't return all of them as ArrayRefs (we only return the columns the predicate requests). + statistics: Vec>, + /// The schema of the file these statistics are for. + schema: SchemaRef, +} + +impl PrunableStatistics { + /// Create a new instance of [`PrunableStatistics`]. + /// Each [`Statistics`] represents a container (e.g. a file or a partition of files). + /// The `schema` is the schema of the data in the containers and should apply to all files. + pub fn new(statistics: Vec>, schema: SchemaRef) -> Self { + Self { statistics, schema } + } + + fn get_exact_column_statistics( + &self, + column: &Column, + get_stat: impl Fn(&ColumnStatistics) -> &Precision, + ) -> Option { + let index = self.schema.index_of(column.name()).ok()?; + let mut has_value = false; + match ScalarValue::iter_to_array(self.statistics.iter().map(|s| { + s.column_statistics + .get(index) + .and_then(|stat| { + if let Precision::Exact(min) = get_stat(stat) { + has_value = true; + Some(min.clone()) + } else { + None + } + }) + .unwrap_or(ScalarValue::Null) + })) { + // If there is any non-null value and no errors, return the array + Ok(array) => has_value.then_some(array), + Err(_) => { + log::warn!( + "Failed to convert min values to array for column {}", + column.name() + ); + None + } + } + } +} + +impl PruningStatistics for PrunableStatistics { + fn min_values(&self, column: &Column) -> Option { + self.get_exact_column_statistics(column, |stat| &stat.min_value) + } + + fn max_values(&self, column: &Column) -> Option { + self.get_exact_column_statistics(column, |stat| &stat.max_value) + } + + fn num_containers(&self) -> usize { + self.statistics.len() + } + + fn null_counts(&self, column: &Column) -> Option { + let index = self.schema.index_of(column.name()).ok()?; + if self.statistics.iter().any(|s| { + s.column_statistics + .get(index) + .is_some_and(|stat| stat.null_count.is_exact().unwrap_or(false)) + }) { + Some(Arc::new( + self.statistics + .iter() + .map(|s| { + s.column_statistics.get(index).and_then(|stat| { + if let Precision::Exact(null_count) = &stat.null_count { + u64::try_from(*null_count).ok() + } else { + None + } + }) + }) + .collect::(), + )) + } else { + None + } + } + + fn row_counts(&self, column: &Column) -> Option { + // If the column does not exist in the schema, return None + if self.schema.index_of(column.name()).is_err() { + return None; + } + if self + .statistics + .iter() + .any(|s| s.num_rows.is_exact().unwrap_or(false)) + { + Some(Arc::new( + self.statistics + .iter() + .map(|s| { + if let Precision::Exact(row_count) = &s.num_rows { + u64::try_from(*row_count).ok() + } else { + None + } + }) + .collect::(), + )) + } else { + None + } + } + + fn contained( + &self, + _column: &Column, + _values: &HashSet, + ) -> Option { + None + } +} + +/// Combine multiple [`PruningStatistics`] into a single +/// [`CompositePruningStatistics`]. +/// This can be used to combine statistics from different sources, +/// for example partition values and file statistics. +/// This allows pruning with filters that depend on multiple sources of statistics, +/// such as `WHERE partition_col = data_col`. +/// This is done by iterating over the statistics and returning the first +/// one that has information for the requested column. +/// If multiple statistics have information for the same column, +/// the first one is returned without any regard for completeness or accuracy. +/// That is: if the first statistics has information for a column, even if it is incomplete, +/// that is returned even if a later statistics has more complete information. +pub struct CompositePruningStatistics { + pub statistics: Vec>, +} + +impl CompositePruningStatistics { + /// Create a new instance of [`CompositePruningStatistics`] from + /// a vector of [`PruningStatistics`]. + pub fn new(statistics: Vec>) -> Self { + assert!(!statistics.is_empty()); + // Check that all statistics have the same number of containers + let num_containers = statistics[0].num_containers(); + for stats in &statistics { + assert_eq!(num_containers, stats.num_containers()); + } + Self { statistics } + } +} + +impl PruningStatistics for CompositePruningStatistics { + fn min_values(&self, column: &Column) -> Option { + for stats in &self.statistics { + if let Some(array) = stats.min_values(column) { + return Some(array); + } + } + None + } + + fn max_values(&self, column: &Column) -> Option { + for stats in &self.statistics { + if let Some(array) = stats.max_values(column) { + return Some(array); + } + } + None + } + + fn num_containers(&self) -> usize { + self.statistics[0].num_containers() + } + + fn null_counts(&self, column: &Column) -> Option { + for stats in &self.statistics { + if let Some(array) = stats.null_counts(column) { + return Some(array); + } + } + None + } + + fn row_counts(&self, column: &Column) -> Option { + for stats in &self.statistics { + if let Some(array) = stats.row_counts(column) { + return Some(array); + } + } + None + } + + fn contained( + &self, + column: &Column, + values: &HashSet, + ) -> Option { + for stats in &self.statistics { + if let Some(array) = stats.contained(column, values) { + return Some(array); + } + } + None + } +} + +#[cfg(test)] +mod tests { + use crate::{ + cast::{as_int32_array, as_uint64_array}, + ColumnStatistics, + }; + + use super::*; + use arrow::datatypes::{DataType, Field}; + use std::sync::Arc; + + #[test] + fn test_partition_pruning_statistics() { + let partition_values = vec![ + vec![ScalarValue::from(1i32), ScalarValue::from(2i32)], + vec![ScalarValue::from(3i32), ScalarValue::from(4i32)], + ]; + let partition_fields = vec![ + Arc::new(Field::new("a", DataType::Int32, false)), + Arc::new(Field::new("b", DataType::Int32, false)), + ]; + let partition_stats = + PartitionPruningStatistics::try_new(partition_values, partition_fields) + .unwrap(); + + let column_a = Column::new_unqualified("a"); + let column_b = Column::new_unqualified("b"); + + // Partition values don't know anything about nulls or row counts + assert!(partition_stats.null_counts(&column_a).is_none()); + assert!(partition_stats.row_counts(&column_a).is_none()); + assert!(partition_stats.null_counts(&column_b).is_none()); + assert!(partition_stats.row_counts(&column_b).is_none()); + + // Min/max values are the same as the partition values + let min_values_a = + as_int32_array(&partition_stats.min_values(&column_a).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_values_a = vec![Some(1), Some(3)]; + assert_eq!(min_values_a, expected_values_a); + let max_values_a = + as_int32_array(&partition_stats.max_values(&column_a).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_values_a = vec![Some(1), Some(3)]; + assert_eq!(max_values_a, expected_values_a); + + let min_values_b = + as_int32_array(&partition_stats.min_values(&column_b).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_values_b = vec![Some(2), Some(4)]; + assert_eq!(min_values_b, expected_values_b); + let max_values_b = + as_int32_array(&partition_stats.max_values(&column_b).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_values_b = vec![Some(2), Some(4)]; + assert_eq!(max_values_b, expected_values_b); + + // Contained values are only true for the partition values + let values = HashSet::from([ScalarValue::from(1i32)]); + let contained_a = partition_stats.contained(&column_a, &values).unwrap(); + let expected_contained_a = BooleanArray::from(vec![true, false]); + assert_eq!(contained_a, expected_contained_a); + let contained_b = partition_stats.contained(&column_b, &values).unwrap(); + let expected_contained_b = BooleanArray::from(vec![false, false]); + assert_eq!(contained_b, expected_contained_b); + + // The number of containers is the length of the partition values + assert_eq!(partition_stats.num_containers(), 2); + } + + #[test] + fn test_partition_pruning_statistics_empty() { + let partition_values = vec![]; + let partition_fields = vec![ + Arc::new(Field::new("a", DataType::Int32, false)), + Arc::new(Field::new("b", DataType::Int32, false)), + ]; + let partition_stats = + PartitionPruningStatistics::try_new(partition_values, partition_fields) + .unwrap(); + + let column_a = Column::new_unqualified("a"); + let column_b = Column::new_unqualified("b"); + + // Partition values don't know anything about nulls or row counts + assert!(partition_stats.null_counts(&column_a).is_none()); + assert!(partition_stats.row_counts(&column_a).is_none()); + assert!(partition_stats.null_counts(&column_b).is_none()); + assert!(partition_stats.row_counts(&column_b).is_none()); + + // Min/max values are all missing + assert!(partition_stats.min_values(&column_a).is_none()); + assert!(partition_stats.max_values(&column_a).is_none()); + assert!(partition_stats.min_values(&column_b).is_none()); + assert!(partition_stats.max_values(&column_b).is_none()); + + // Contained values are all empty + let values = HashSet::from([ScalarValue::from(1i32)]); + assert!(partition_stats.contained(&column_a, &values).is_none()); + } + + #[test] + fn test_statistics_pruning_statistics() { + let statistics = vec![ + Arc::new( + Statistics::default() + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_min_value(Precision::Exact(ScalarValue::from(0i32))) + .with_max_value(Precision::Exact(ScalarValue::from(100i32))) + .with_null_count(Precision::Exact(0)), + ) + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_min_value(Precision::Exact(ScalarValue::from(100i32))) + .with_max_value(Precision::Exact(ScalarValue::from(200i32))) + .with_null_count(Precision::Exact(5)), + ) + .with_num_rows(Precision::Exact(100)), + ), + Arc::new( + Statistics::default() + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_min_value(Precision::Exact(ScalarValue::from(50i32))) + .with_max_value(Precision::Exact(ScalarValue::from(300i32))) + .with_null_count(Precision::Exact(10)), + ) + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_min_value(Precision::Exact(ScalarValue::from(200i32))) + .with_max_value(Precision::Exact(ScalarValue::from(400i32))) + .with_null_count(Precision::Exact(0)), + ) + .with_num_rows(Precision::Exact(200)), + ), + ]; + + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + Field::new("c", DataType::Int32, false), + ])); + let pruning_stats = PrunableStatistics::new(statistics, schema); + + let column_a = Column::new_unqualified("a"); + let column_b = Column::new_unqualified("b"); + + // Min/max values are the same as the statistics + let min_values_a = as_int32_array(&pruning_stats.min_values(&column_a).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_values_a = vec![Some(0), Some(50)]; + assert_eq!(min_values_a, expected_values_a); + let max_values_a = as_int32_array(&pruning_stats.max_values(&column_a).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_values_a = vec![Some(100), Some(300)]; + assert_eq!(max_values_a, expected_values_a); + let min_values_b = as_int32_array(&pruning_stats.min_values(&column_b).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_values_b = vec![Some(100), Some(200)]; + assert_eq!(min_values_b, expected_values_b); + let max_values_b = as_int32_array(&pruning_stats.max_values(&column_b).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_values_b = vec![Some(200), Some(400)]; + assert_eq!(max_values_b, expected_values_b); + + // Null counts are the same as the statistics + let null_counts_a = + as_uint64_array(&pruning_stats.null_counts(&column_a).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_null_counts_a = vec![Some(0), Some(10)]; + assert_eq!(null_counts_a, expected_null_counts_a); + let null_counts_b = + as_uint64_array(&pruning_stats.null_counts(&column_b).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_null_counts_b = vec![Some(5), Some(0)]; + assert_eq!(null_counts_b, expected_null_counts_b); + + // Row counts are the same as the statistics + let row_counts_a = as_uint64_array(&pruning_stats.row_counts(&column_a).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_row_counts_a = vec![Some(100), Some(200)]; + assert_eq!(row_counts_a, expected_row_counts_a); + let row_counts_b = as_uint64_array(&pruning_stats.row_counts(&column_b).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_row_counts_b = vec![Some(100), Some(200)]; + assert_eq!(row_counts_b, expected_row_counts_b); + + // Contained values are all null/missing (we can't know this just from statistics) + let values = HashSet::from([ScalarValue::from(0i32)]); + assert!(pruning_stats.contained(&column_a, &values).is_none()); + assert!(pruning_stats.contained(&column_b, &values).is_none()); + + // The number of containers is the length of the statistics + assert_eq!(pruning_stats.num_containers(), 2); + + // Test with a column that has no statistics + let column_c = Column::new_unqualified("c"); + assert!(pruning_stats.min_values(&column_c).is_none()); + assert!(pruning_stats.max_values(&column_c).is_none()); + assert!(pruning_stats.null_counts(&column_c).is_none()); + // Since row counts uses the first column that has row counts we get them back even + // if this columns does not have them set. + // This is debatable, personally I think `row_count` should not take a `Column` as an argument + // at all since all columns should have the same number of rows. + // But for now we just document the current behavior in this test. + let row_counts_c = as_uint64_array(&pruning_stats.row_counts(&column_c).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_row_counts_c = vec![Some(100), Some(200)]; + assert_eq!(row_counts_c, expected_row_counts_c); + assert!(pruning_stats.contained(&column_c, &values).is_none()); + + // Test with a column that doesn't exist + let column_d = Column::new_unqualified("d"); + assert!(pruning_stats.min_values(&column_d).is_none()); + assert!(pruning_stats.max_values(&column_d).is_none()); + assert!(pruning_stats.null_counts(&column_d).is_none()); + assert!(pruning_stats.row_counts(&column_d).is_none()); + assert!(pruning_stats.contained(&column_d, &values).is_none()); + } + + #[test] + fn test_statistics_pruning_statistics_empty() { + let statistics = vec![]; + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + Field::new("c", DataType::Int32, false), + ])); + let pruning_stats = PrunableStatistics::new(statistics, schema); + + let column_a = Column::new_unqualified("a"); + let column_b = Column::new_unqualified("b"); + + // Min/max values are all missing + assert!(pruning_stats.min_values(&column_a).is_none()); + assert!(pruning_stats.max_values(&column_a).is_none()); + assert!(pruning_stats.min_values(&column_b).is_none()); + assert!(pruning_stats.max_values(&column_b).is_none()); + + // Null counts are all missing + assert!(pruning_stats.null_counts(&column_a).is_none()); + assert!(pruning_stats.null_counts(&column_b).is_none()); + + // Row counts are all missing + assert!(pruning_stats.row_counts(&column_a).is_none()); + assert!(pruning_stats.row_counts(&column_b).is_none()); + + // Contained values are all empty + let values = HashSet::from([ScalarValue::from(1i32)]); + assert!(pruning_stats.contained(&column_a, &values).is_none()); + } + + #[test] + fn test_composite_pruning_statistics_partition_and_file() { + // Create partition statistics + let partition_values = vec![ + vec![ScalarValue::from(1i32), ScalarValue::from(10i32)], + vec![ScalarValue::from(2i32), ScalarValue::from(20i32)], + ]; + let partition_fields = vec![ + Arc::new(Field::new("part_a", DataType::Int32, false)), + Arc::new(Field::new("part_b", DataType::Int32, false)), + ]; + let partition_stats = + PartitionPruningStatistics::try_new(partition_values, partition_fields) + .unwrap(); + + // Create file statistics + let file_statistics = vec![ + Arc::new( + Statistics::default() + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_min_value(Precision::Exact(ScalarValue::from(100i32))) + .with_max_value(Precision::Exact(ScalarValue::from(200i32))) + .with_null_count(Precision::Exact(0)), + ) + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_min_value(Precision::Exact(ScalarValue::from(300i32))) + .with_max_value(Precision::Exact(ScalarValue::from(400i32))) + .with_null_count(Precision::Exact(5)), + ) + .with_num_rows(Precision::Exact(100)), + ), + Arc::new( + Statistics::default() + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_min_value(Precision::Exact(ScalarValue::from(500i32))) + .with_max_value(Precision::Exact(ScalarValue::from(600i32))) + .with_null_count(Precision::Exact(10)), + ) + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_min_value(Precision::Exact(ScalarValue::from(700i32))) + .with_max_value(Precision::Exact(ScalarValue::from(800i32))) + .with_null_count(Precision::Exact(0)), + ) + .with_num_rows(Precision::Exact(200)), + ), + ]; + + let file_schema = Arc::new(Schema::new(vec![ + Field::new("col_x", DataType::Int32, false), + Field::new("col_y", DataType::Int32, false), + ])); + let file_stats = PrunableStatistics::new(file_statistics, file_schema); + + // Create composite statistics + let composite_stats = CompositePruningStatistics::new(vec![ + Box::new(partition_stats), + Box::new(file_stats), + ]); + + // Test accessing columns that are only in partition statistics + let part_a = Column::new_unqualified("part_a"); + let part_b = Column::new_unqualified("part_b"); + + // Test accessing columns that are only in file statistics + let col_x = Column::new_unqualified("col_x"); + let col_y = Column::new_unqualified("col_y"); + + // For partition columns, should get values from partition statistics + let min_values_part_a = + as_int32_array(&composite_stats.min_values(&part_a).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_values_part_a = vec![Some(1), Some(2)]; + assert_eq!(min_values_part_a, expected_values_part_a); + + let max_values_part_a = + as_int32_array(&composite_stats.max_values(&part_a).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + // For partition values, min and max are the same + assert_eq!(max_values_part_a, expected_values_part_a); + + let min_values_part_b = + as_int32_array(&composite_stats.min_values(&part_b).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_values_part_b = vec![Some(10), Some(20)]; + assert_eq!(min_values_part_b, expected_values_part_b); + + // For file columns, should get values from file statistics + let min_values_col_x = + as_int32_array(&composite_stats.min_values(&col_x).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_values_col_x = vec![Some(100), Some(500)]; + assert_eq!(min_values_col_x, expected_values_col_x); + + let max_values_col_x = + as_int32_array(&composite_stats.max_values(&col_x).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_max_values_col_x = vec![Some(200), Some(600)]; + assert_eq!(max_values_col_x, expected_max_values_col_x); + + let min_values_col_y = + as_int32_array(&composite_stats.min_values(&col_y).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_values_col_y = vec![Some(300), Some(700)]; + assert_eq!(min_values_col_y, expected_values_col_y); + + // Test null counts - only available from file statistics + assert!(composite_stats.null_counts(&part_a).is_none()); + assert!(composite_stats.null_counts(&part_b).is_none()); + + let null_counts_col_x = + as_uint64_array(&composite_stats.null_counts(&col_x).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_null_counts_col_x = vec![Some(0), Some(10)]; + assert_eq!(null_counts_col_x, expected_null_counts_col_x); + + // Test row counts - only available from file statistics + assert!(composite_stats.row_counts(&part_a).is_none()); + let row_counts_col_x = + as_uint64_array(&composite_stats.row_counts(&col_x).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_row_counts = vec![Some(100), Some(200)]; + assert_eq!(row_counts_col_x, expected_row_counts); + + // Test contained values - only available from partition statistics + let values = HashSet::from([ScalarValue::from(1i32)]); + let contained_part_a = composite_stats.contained(&part_a, &values).unwrap(); + let expected_contained_part_a = BooleanArray::from(vec![true, false]); + assert_eq!(contained_part_a, expected_contained_part_a); + + // File statistics don't implement contained + assert!(composite_stats.contained(&col_x, &values).is_none()); + + // Non-existent column should return None for everything + let non_existent = Column::new_unqualified("non_existent"); + assert!(composite_stats.min_values(&non_existent).is_none()); + assert!(composite_stats.max_values(&non_existent).is_none()); + assert!(composite_stats.null_counts(&non_existent).is_none()); + assert!(composite_stats.row_counts(&non_existent).is_none()); + assert!(composite_stats.contained(&non_existent, &values).is_none()); + + // Verify num_containers matches + assert_eq!(composite_stats.num_containers(), 2); + } + + #[test] + fn test_composite_pruning_statistics_priority() { + // Create two sets of file statistics with the same column names + // but different values to test that the first one gets priority + + // First set of statistics + let first_statistics = vec![ + Arc::new( + Statistics::default() + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_min_value(Precision::Exact(ScalarValue::from(100i32))) + .with_max_value(Precision::Exact(ScalarValue::from(200i32))) + .with_null_count(Precision::Exact(0)), + ) + .with_num_rows(Precision::Exact(100)), + ), + Arc::new( + Statistics::default() + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_min_value(Precision::Exact(ScalarValue::from(300i32))) + .with_max_value(Precision::Exact(ScalarValue::from(400i32))) + .with_null_count(Precision::Exact(5)), + ) + .with_num_rows(Precision::Exact(200)), + ), + ]; + + let first_schema = Arc::new(Schema::new(vec![Field::new( + "col_a", + DataType::Int32, + false, + )])); + let first_stats = PrunableStatistics::new(first_statistics, first_schema); + + // Second set of statistics with the same column name but different values + let second_statistics = vec![ + Arc::new( + Statistics::default() + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_min_value(Precision::Exact(ScalarValue::from(1000i32))) + .with_max_value(Precision::Exact(ScalarValue::from(2000i32))) + .with_null_count(Precision::Exact(10)), + ) + .with_num_rows(Precision::Exact(1000)), + ), + Arc::new( + Statistics::default() + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_min_value(Precision::Exact(ScalarValue::from(3000i32))) + .with_max_value(Precision::Exact(ScalarValue::from(4000i32))) + .with_null_count(Precision::Exact(20)), + ) + .with_num_rows(Precision::Exact(2000)), + ), + ]; + + let second_schema = Arc::new(Schema::new(vec![Field::new( + "col_a", + DataType::Int32, + false, + )])); + let second_stats = PrunableStatistics::new(second_statistics, second_schema); + + // Create composite statistics with first stats having priority + let composite_stats = CompositePruningStatistics::new(vec![ + Box::new(first_stats.clone()), + Box::new(second_stats.clone()), + ]); + + let col_a = Column::new_unqualified("col_a"); + + // Should get values from first statistics since it has priority + let min_values = as_int32_array(&composite_stats.min_values(&col_a).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_min_values = vec![Some(100), Some(300)]; + assert_eq!(min_values, expected_min_values); + + let max_values = as_int32_array(&composite_stats.max_values(&col_a).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_max_values = vec![Some(200), Some(400)]; + assert_eq!(max_values, expected_max_values); + + let null_counts = as_uint64_array(&composite_stats.null_counts(&col_a).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_null_counts = vec![Some(0), Some(5)]; + assert_eq!(null_counts, expected_null_counts); + + let row_counts = as_uint64_array(&composite_stats.row_counts(&col_a).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_row_counts = vec![Some(100), Some(200)]; + assert_eq!(row_counts, expected_row_counts); + + // Create composite statistics with second stats having priority + // Now that we've added Clone trait to PrunableStatistics, we can just clone them + + let composite_stats_reversed = CompositePruningStatistics::new(vec![ + Box::new(second_stats.clone()), + Box::new(first_stats.clone()), + ]); + + // Should get values from second statistics since it now has priority + let min_values = + as_int32_array(&composite_stats_reversed.min_values(&col_a).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_min_values = vec![Some(1000), Some(3000)]; + assert_eq!(min_values, expected_min_values); + + let max_values = + as_int32_array(&composite_stats_reversed.max_values(&col_a).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_max_values = vec![Some(2000), Some(4000)]; + assert_eq!(max_values, expected_max_values); + + let null_counts = + as_uint64_array(&composite_stats_reversed.null_counts(&col_a).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_null_counts = vec![Some(10), Some(20)]; + assert_eq!(null_counts, expected_null_counts); + + let row_counts = + as_uint64_array(&composite_stats_reversed.row_counts(&col_a).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_row_counts = vec![Some(1000), Some(2000)]; + assert_eq!(row_counts, expected_row_counts); + } + + #[test] + fn test_composite_pruning_statistics_empty_and_mismatched_containers() { + // Test with empty statistics vector + // This should never happen, so we panic instead of returning a Result which would burned callers + let result = std::panic::catch_unwind(|| { + CompositePruningStatistics::new(vec![]); + }); + assert!(result.is_err()); + + // We should panic here because the number of containers is different + let result = std::panic::catch_unwind(|| { + // Create statistics with different number of containers + // Use partition stats for the test + let partition_values_1 = vec![ + vec![ScalarValue::from(1i32), ScalarValue::from(10i32)], + vec![ScalarValue::from(2i32), ScalarValue::from(20i32)], + ]; + let partition_fields_1 = vec![ + Arc::new(Field::new("part_a", DataType::Int32, false)), + Arc::new(Field::new("part_b", DataType::Int32, false)), + ]; + let partition_stats_1 = PartitionPruningStatistics::try_new( + partition_values_1, + partition_fields_1, + ) + .unwrap(); + let partition_values_2 = vec![ + vec![ScalarValue::from(3i32), ScalarValue::from(30i32)], + vec![ScalarValue::from(4i32), ScalarValue::from(40i32)], + vec![ScalarValue::from(5i32), ScalarValue::from(50i32)], + ]; + let partition_fields_2 = vec![ + Arc::new(Field::new("part_x", DataType::Int32, false)), + Arc::new(Field::new("part_y", DataType::Int32, false)), + ]; + let partition_stats_2 = PartitionPruningStatistics::try_new( + partition_values_2, + partition_fields_2, + ) + .unwrap(); + + CompositePruningStatistics::new(vec![ + Box::new(partition_stats_1), + Box::new(partition_stats_2), + ]); + }); + assert!(result.is_err()); + } +} diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index b8d9aea810f0..3d4aa78b6da6 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -506,7 +506,7 @@ impl PartialOrd for ScalarValue { } (List(_), _) | (LargeList(_), _) | (FixedSizeList(_), _) => None, (Struct(struct_arr1), Struct(struct_arr2)) => { - partial_cmp_struct(struct_arr1, struct_arr2) + partial_cmp_struct(struct_arr1.as_ref(), struct_arr2.as_ref()) } (Struct(_), _) => None, (Map(map_arr1), Map(map_arr2)) => partial_cmp_map(map_arr1, map_arr2), @@ -597,10 +597,28 @@ fn partial_cmp_list(arr1: &dyn Array, arr2: &dyn Array) -> Option { let arr1 = first_array_for_list(arr1); let arr2 = first_array_for_list(arr2); - let lt_res = arrow::compute::kernels::cmp::lt(&arr1, &arr2).ok()?; - let eq_res = arrow::compute::kernels::cmp::eq(&arr1, &arr2).ok()?; + let min_length = arr1.len().min(arr2.len()); + let arr1_trimmed = arr1.slice(0, min_length); + let arr2_trimmed = arr2.slice(0, min_length); + + let lt_res = arrow::compute::kernels::cmp::lt(&arr1_trimmed, &arr2_trimmed).ok()?; + let eq_res = arrow::compute::kernels::cmp::eq(&arr1_trimmed, &arr2_trimmed).ok()?; for j in 0..lt_res.len() { + // In Postgres, NULL values in lists are always considered to be greater than non-NULL values: + // + // $ SELECT ARRAY[NULL]::integer[] > ARRAY[1] + // true + // + // These next two if statements are introduced for replicating Postgres behavior, as + // arrow::compute does not account for this. + if arr1_trimmed.is_null(j) && !arr2_trimmed.is_null(j) { + return Some(Ordering::Greater); + } + if !arr1_trimmed.is_null(j) && arr2_trimmed.is_null(j) { + return Some(Ordering::Less); + } + if lt_res.is_valid(j) && lt_res.value(j) { return Some(Ordering::Less); } @@ -609,10 +627,23 @@ fn partial_cmp_list(arr1: &dyn Array, arr2: &dyn Array) -> Option { } } - Some(Ordering::Equal) + Some(arr1.len().cmp(&arr2.len())) +} + +fn flatten<'a>(array: &'a StructArray, columns: &mut Vec<&'a ArrayRef>) { + for i in 0..array.num_columns() { + let column = array.column(i); + if let Some(nested_struct) = column.as_any().downcast_ref::() { + // If it's a nested struct, recursively expand + flatten(nested_struct, columns); + } else { + // If it's a primitive type, add directly + columns.push(column); + } + } } -fn partial_cmp_struct(s1: &Arc, s2: &Arc) -> Option { +pub fn partial_cmp_struct(s1: &StructArray, s2: &StructArray) -> Option { if s1.len() != s2.len() { return None; } @@ -621,9 +652,15 @@ fn partial_cmp_struct(s1: &Arc, s2: &Arc) -> Option() } + + /// Compacts the allocation referenced by `self` to the minimum, copying the data if + /// necessary. + /// + /// This can be relevant when `self` is a list or contains a list as a nested value, as + /// a single list holds an Arc to its entire original array buffer. + pub fn compact(&mut self) { + match self { + ScalarValue::Null + | ScalarValue::Boolean(_) + | ScalarValue::Float16(_) + | ScalarValue::Float32(_) + | ScalarValue::Float64(_) + | ScalarValue::Decimal128(_, _, _) + | ScalarValue::Decimal256(_, _, _) + | ScalarValue::Int8(_) + | ScalarValue::Int16(_) + | ScalarValue::Int32(_) + | ScalarValue::Int64(_) + | ScalarValue::UInt8(_) + | ScalarValue::UInt16(_) + | ScalarValue::UInt32(_) + | ScalarValue::UInt64(_) + | ScalarValue::Date32(_) + | ScalarValue::Date64(_) + | ScalarValue::Time32Second(_) + | ScalarValue::Time32Millisecond(_) + | ScalarValue::Time64Microsecond(_) + | ScalarValue::Time64Nanosecond(_) + | ScalarValue::IntervalYearMonth(_) + | ScalarValue::IntervalDayTime(_) + | ScalarValue::IntervalMonthDayNano(_) + | ScalarValue::DurationSecond(_) + | ScalarValue::DurationMillisecond(_) + | ScalarValue::DurationMicrosecond(_) + | ScalarValue::DurationNanosecond(_) + | ScalarValue::Utf8(_) + | ScalarValue::LargeUtf8(_) + | ScalarValue::Utf8View(_) + | ScalarValue::TimestampSecond(_, _) + | ScalarValue::TimestampMillisecond(_, _) + | ScalarValue::TimestampMicrosecond(_, _) + | ScalarValue::TimestampNanosecond(_, _) + | ScalarValue::Binary(_) + | ScalarValue::FixedSizeBinary(_, _) + | ScalarValue::LargeBinary(_) + | ScalarValue::BinaryView(_) => (), + ScalarValue::FixedSizeList(arr) => { + let array = copy_array_data(&arr.to_data()); + *Arc::make_mut(arr) = FixedSizeListArray::from(array); + } + ScalarValue::List(arr) => { + let array = copy_array_data(&arr.to_data()); + *Arc::make_mut(arr) = ListArray::from(array); + } + ScalarValue::LargeList(arr) => { + let array = copy_array_data(&arr.to_data()); + *Arc::make_mut(arr) = LargeListArray::from(array) + } + ScalarValue::Struct(arr) => { + let array = copy_array_data(&arr.to_data()); + *Arc::make_mut(arr) = StructArray::from(array); + } + ScalarValue::Map(arr) => { + let array = copy_array_data(&arr.to_data()); + *Arc::make_mut(arr) = MapArray::from(array); + } + ScalarValue::Union(val, _, _) => { + if let Some((_, value)) = val.as_mut() { + value.compact(); + } + } + ScalarValue::Dictionary(_, value) => { + value.compact(); + } + } + } +} + +pub fn copy_array_data(data: &ArrayData) -> ArrayData { + let mut copy = MutableArrayData::new(vec![&data], true, data.len()); + copy.extend(0, 0, data.len()); + copy.freeze() } macro_rules! impl_scalar { @@ -3739,7 +3859,7 @@ impl fmt::Display for ScalarValue { array_value_to_string(arr.column(0), i).unwrap(); let value = array_value_to_string(arr.column(1), i).unwrap(); - buffer.push_back(format!("{}:{}", key, value)); + buffer.push_back(format!("{key}:{value}")); } format!( "{{{}}}", @@ -3758,7 +3878,7 @@ impl fmt::Display for ScalarValue { )? } ScalarValue::Union(val, _fields, _mode) => match val { - Some((id, val)) => write!(f, "{}:{}", id, val)?, + Some((id, val)) => write!(f, "{id}:{val}")?, None => write!(f, "NULL")?, }, ScalarValue::Dictionary(_k, v) => write!(f, "{v}")?, @@ -3935,7 +4055,7 @@ impl fmt::Debug for ScalarValue { write!(f, "DurationNanosecond(\"{self}\")") } ScalarValue::Union(val, _fields, _mode) => match val { - Some((id, val)) => write!(f, "Union {}:{}", id, val), + Some((id, val)) => write!(f, "Union {id}:{val}"), None => write!(f, "Union(NULL)"), }, ScalarValue::Dictionary(k, v) => write!(f, "Dictionary({k:?}, {v:?})"), @@ -4059,7 +4179,7 @@ mod tests { #[test] #[should_panic( - expected = "Error building ScalarValue::Struct. Expected array with exactly one element, found array with 4 elements" + expected = "InvalidArgumentError(\"Incorrect array length for StructArray field \\\"bool\\\", expected 1 got 4\")" )] fn test_scalar_value_from_for_struct_should_panic() { let _ = ScalarStructBuilder::new() @@ -4752,6 +4872,109 @@ mod tests { ])]), )); assert_eq!(a.partial_cmp(&b), Some(Ordering::Less)); + + let a = + ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + Some(2), + Some(3), + ])]), + )); + let b = + ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(vec![ + Some(2), + Some(3), + ])]), + )); + assert_eq!(a.partial_cmp(&b), Some(Ordering::Less)); + + let a = + ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(vec![ + Some(2), + Some(3), + Some(4), + ])]), + )); + let b = + ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + Some(2), + ])]), + )); + assert_eq!(a.partial_cmp(&b), Some(Ordering::Greater)); + + let a = + ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + Some(2), + Some(3), + ])]), + )); + let b = + ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + Some(2), + ])]), + )); + assert_eq!(a.partial_cmp(&b), Some(Ordering::Greater)); + + let a = + ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(vec![ + None, + Some(2), + Some(3), + ])]), + )); + let b = + ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + Some(2), + Some(3), + ])]), + )); + assert_eq!(a.partial_cmp(&b), Some(Ordering::Greater)); + + let a = ScalarValue::LargeList(Arc::new(LargeListArray::from_iter_primitive::< + Int64Type, + _, + _, + >(vec![Some(vec![ + None, + Some(2), + Some(3), + ])]))); + let b = ScalarValue::LargeList(Arc::new(LargeListArray::from_iter_primitive::< + Int64Type, + _, + _, + >(vec![Some(vec![ + Some(1), + Some(2), + Some(3), + ])]))); + assert_eq!(a.partial_cmp(&b), Some(Ordering::Greater)); + + let a = ScalarValue::FixedSizeList(Arc::new( + FixedSizeListArray::from_iter_primitive::( + vec![Some(vec![None, Some(2), Some(3)])], + 3, + ), + )); + let b = ScalarValue::FixedSizeList(Arc::new( + FixedSizeListArray::from_iter_primitive::( + vec![Some(vec![Some(1), Some(2), Some(3)])], + 3, + ), + )); + assert_eq!(a.partial_cmp(&b), Some(Ordering::Greater)); } #[test] @@ -7162,14 +7385,14 @@ mod tests { fn get_random_timestamps(sample_size: u64) -> Vec { let vector_size = sample_size; let mut timestamp = vec![]; - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); for i in 0..vector_size { - let year = rng.gen_range(1995..=2050); - let month = rng.gen_range(1..=12); - let day = rng.gen_range(1..=28); // to exclude invalid dates - let hour = rng.gen_range(0..=23); - let minute = rng.gen_range(0..=59); - let second = rng.gen_range(0..=59); + let year = rng.random_range(1995..=2050); + let month = rng.random_range(1..=12); + let day = rng.random_range(1..=28); // to exclude invalid dates + let hour = rng.random_range(0..=23); + let minute = rng.random_range(0..=59); + let second = rng.random_range(0..=59); if i % 4 == 0 { timestamp.push(ScalarValue::TimestampSecond( Some( @@ -7183,7 +7406,7 @@ mod tests { None, )) } else if i % 4 == 1 { - let millisec = rng.gen_range(0..=999); + let millisec = rng.random_range(0..=999); timestamp.push(ScalarValue::TimestampMillisecond( Some( NaiveDate::from_ymd_opt(year, month, day) @@ -7196,7 +7419,7 @@ mod tests { None, )) } else if i % 4 == 2 { - let microsec = rng.gen_range(0..=999_999); + let microsec = rng.random_range(0..=999_999); timestamp.push(ScalarValue::TimestampMicrosecond( Some( NaiveDate::from_ymd_opt(year, month, day) @@ -7209,7 +7432,7 @@ mod tests { None, )) } else if i % 4 == 3 { - let nanosec = rng.gen_range(0..=999_999_999); + let nanosec = rng.random_range(0..=999_999_999); timestamp.push(ScalarValue::TimestampNanosecond( Some( NaiveDate::from_ymd_opt(year, month, day) @@ -7233,27 +7456,27 @@ mod tests { let vector_size = sample_size; let mut intervals = vec![]; - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); const SECS_IN_ONE_DAY: i32 = 86_400; const MICROSECS_IN_ONE_DAY: i64 = 86_400_000_000; for i in 0..vector_size { if i % 4 == 0 { - let days = rng.gen_range(0..5000); + let days = rng.random_range(0..5000); // to not break second precision - let millis = rng.gen_range(0..SECS_IN_ONE_DAY) * 1000; + let millis = rng.random_range(0..SECS_IN_ONE_DAY) * 1000; intervals.push(ScalarValue::new_interval_dt(days, millis)); } else if i % 4 == 1 { - let days = rng.gen_range(0..5000); - let millisec = rng.gen_range(0..(MILLISECS_IN_ONE_DAY as i32)); + let days = rng.random_range(0..5000); + let millisec = rng.random_range(0..(MILLISECS_IN_ONE_DAY as i32)); intervals.push(ScalarValue::new_interval_dt(days, millisec)); } else if i % 4 == 2 { - let days = rng.gen_range(0..5000); + let days = rng.random_range(0..5000); // to not break microsec precision - let nanosec = rng.gen_range(0..MICROSECS_IN_ONE_DAY) * 1000; + let nanosec = rng.random_range(0..MICROSECS_IN_ONE_DAY) * 1000; intervals.push(ScalarValue::new_interval_mdn(0, days, nanosec)); } else { - let days = rng.gen_range(0..5000); - let nanosec = rng.gen_range(0..NANOSECS_IN_ONE_DAY); + let days = rng.random_range(0..5000); + let nanosec = rng.random_range(0..NANOSECS_IN_ONE_DAY); intervals.push(ScalarValue::new_interval_mdn(0, days, nanosec)); } } diff --git a/datafusion/common/src/scalar/struct_builder.rs b/datafusion/common/src/scalar/struct_builder.rs index 5ed464018401..fd19dccf8963 100644 --- a/datafusion/common/src/scalar/struct_builder.rs +++ b/datafusion/common/src/scalar/struct_builder.rs @@ -17,7 +17,6 @@ //! [`ScalarStructBuilder`] for building [`ScalarValue::Struct`] -use crate::error::_internal_err; use crate::{Result, ScalarValue}; use arrow::array::{ArrayRef, StructArray}; use arrow::datatypes::{DataType, Field, FieldRef, Fields}; @@ -109,17 +108,8 @@ impl ScalarStructBuilder { pub fn build(self) -> Result { let Self { fields, arrays } = self; - for array in &arrays { - if array.len() != 1 { - return _internal_err!( - "Error building ScalarValue::Struct. \ - Expected array with exactly one element, found array with {} elements", - array.len() - ); - } - } - - let struct_array = StructArray::try_new(Fields::from(fields), arrays, None)?; + let struct_array = + StructArray::try_new_with_length(Fields::from(fields), arrays, None, 1)?; Ok(ScalarValue::Struct(Arc::new(struct_array))) } } @@ -181,3 +171,15 @@ impl IntoFields for Vec { Fields::from(self) } } + +#[cfg(test)] +mod tests { + use super::*; + + // Other cases are tested by doc tests + #[test] + fn test_empty_struct() { + let sv = ScalarStructBuilder::new().build().unwrap(); + assert_eq!(format!("{sv}"), "{}"); + } +} diff --git a/datafusion/common/src/stats.rs b/datafusion/common/src/stats.rs index 807d885b3a4d..a6d132ef51f6 100644 --- a/datafusion/common/src/stats.rs +++ b/datafusion/common/src/stats.rs @@ -233,8 +233,8 @@ impl Precision { impl Debug for Precision { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - Precision::Exact(inner) => write!(f, "Exact({:?})", inner), - Precision::Inexact(inner) => write!(f, "Inexact({:?})", inner), + Precision::Exact(inner) => write!(f, "Exact({inner:?})"), + Precision::Inexact(inner) => write!(f, "Inexact({inner:?})"), Precision::Absent => write!(f, "Absent"), } } @@ -243,8 +243,8 @@ impl Debug for Precision { impl Display for Precision { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - Precision::Exact(inner) => write!(f, "Exact({:?})", inner), - Precision::Inexact(inner) => write!(f, "Inexact({:?})", inner), + Precision::Exact(inner) => write!(f, "Exact({inner:?})"), + Precision::Inexact(inner) => write!(f, "Inexact({inner:?})"), Precision::Absent => write!(f, "Absent"), } } @@ -352,6 +352,7 @@ impl Statistics { return self; }; + #[allow(clippy::large_enum_variant)] enum Slot { /// The column is taken and put into the specified statistics location Taken(usize), @@ -451,6 +452,9 @@ impl Statistics { /// Summarize zero or more statistics into a single `Statistics` instance. /// + /// The method assumes that all statistics are for the same schema. + /// If not, maybe you can call `SchemaMapper::map_column_statistics` to make them consistent. + /// /// Returns an error if the statistics do not match the specified schemas. pub fn try_merge_iter<'a, I>(items: I, schema: &Schema) -> Result where @@ -569,7 +573,7 @@ impl Display for Statistics { .iter() .enumerate() .map(|(i, cs)| { - let s = format!("(Col[{}]:", i); + let s = format!("(Col[{i}]:"); let s = if cs.min_value != Precision::Absent { format!("{} Min={}", s, cs.min_value) } else { diff --git a/datafusion/common/src/test_util.rs b/datafusion/common/src/test_util.rs index b801c452af2c..820a230bf6e1 100644 --- a/datafusion/common/src/test_util.rs +++ b/datafusion/common/src/test_util.rs @@ -18,10 +18,25 @@ //! Utility functions to make testing DataFusion based crates easier use crate::arrow::util::pretty::pretty_format_batches_with_options; -use crate::format::DEFAULT_FORMAT_OPTIONS; -use arrow::array::RecordBatch; +use arrow::array::{ArrayRef, RecordBatch}; +use arrow::error::ArrowError; +use std::fmt::Display; use std::{error::Error, path::PathBuf}; +/// Converts a vector or array into an ArrayRef. +pub trait IntoArrayRef { + fn into_array_ref(self) -> ArrayRef; +} + +pub fn format_batches(results: &[RecordBatch]) -> Result { + let datafusion_format_options = crate::config::FormatOptions::default(); + + let arrow_format_options: arrow::util::display::FormatOptions = + (&datafusion_format_options).try_into().unwrap(); + + pretty_format_batches_with_options(results, &arrow_format_options) +} + /// Compares formatted output of a record batch with an expected /// vector of strings, with the result of pretty formatting record /// batches. This is a macro so errors appear on the correct line @@ -59,12 +74,9 @@ macro_rules! assert_batches_eq { let expected_lines: Vec = $EXPECTED_LINES.iter().map(|&s| s.into()).collect(); - let formatted = $crate::arrow::util::pretty::pretty_format_batches_with_options( - $CHUNKS, - &$crate::format::DEFAULT_FORMAT_OPTIONS, - ) - .unwrap() - .to_string(); + let formatted = $crate::test_util::format_batches($CHUNKS) + .unwrap() + .to_string(); let actual_lines: Vec<&str> = formatted.trim().lines().collect(); @@ -77,18 +89,13 @@ macro_rules! assert_batches_eq { } pub fn batches_to_string(batches: &[RecordBatch]) -> String { - let actual = pretty_format_batches_with_options(batches, &DEFAULT_FORMAT_OPTIONS) - .unwrap() - .to_string(); + let actual = format_batches(batches).unwrap().to_string(); actual.trim().to_string() } pub fn batches_to_sort_string(batches: &[RecordBatch]) -> String { - let actual_lines = - pretty_format_batches_with_options(batches, &DEFAULT_FORMAT_OPTIONS) - .unwrap() - .to_string(); + let actual_lines = format_batches(batches).unwrap().to_string(); let mut actual_lines: Vec<&str> = actual_lines.trim().lines().collect(); @@ -122,12 +129,9 @@ macro_rules! assert_batches_sorted_eq { expected_lines.as_mut_slice()[2..num_lines - 1].sort_unstable() } - let formatted = $crate::arrow::util::pretty::pretty_format_batches_with_options( - $CHUNKS, - &$crate::format::DEFAULT_FORMAT_OPTIONS, - ) - .unwrap() - .to_string(); + let formatted = $crate::test_util::format_batches($CHUNKS) + .unwrap() + .to_string(); // fix for windows: \r\n --> let mut actual_lines: Vec<&str> = formatted.trim().lines().collect(); @@ -384,6 +388,326 @@ macro_rules! record_batch { } } +pub mod array_conversion { + use arrow::array::ArrayRef; + + use super::IntoArrayRef; + + impl IntoArrayRef for Vec { + fn into_array_ref(self) -> ArrayRef { + create_array!(Boolean, self) + } + } + + impl IntoArrayRef for Vec> { + fn into_array_ref(self) -> ArrayRef { + create_array!(Boolean, self) + } + } + + impl IntoArrayRef for &[bool] { + fn into_array_ref(self) -> ArrayRef { + create_array!(Boolean, self.to_vec()) + } + } + + impl IntoArrayRef for &[Option] { + fn into_array_ref(self) -> ArrayRef { + create_array!(Boolean, self.to_vec()) + } + } + + impl IntoArrayRef for Vec { + fn into_array_ref(self) -> ArrayRef { + create_array!(Int8, self) + } + } + + impl IntoArrayRef for Vec> { + fn into_array_ref(self) -> ArrayRef { + create_array!(Int8, self) + } + } + + impl IntoArrayRef for &[i8] { + fn into_array_ref(self) -> ArrayRef { + create_array!(Int8, self.to_vec()) + } + } + + impl IntoArrayRef for &[Option] { + fn into_array_ref(self) -> ArrayRef { + create_array!(Int8, self.to_vec()) + } + } + + impl IntoArrayRef for Vec { + fn into_array_ref(self) -> ArrayRef { + create_array!(Int16, self) + } + } + + impl IntoArrayRef for Vec> { + fn into_array_ref(self) -> ArrayRef { + create_array!(Int16, self) + } + } + + impl IntoArrayRef for &[i16] { + fn into_array_ref(self) -> ArrayRef { + create_array!(Int16, self.to_vec()) + } + } + + impl IntoArrayRef for &[Option] { + fn into_array_ref(self) -> ArrayRef { + create_array!(Int16, self.to_vec()) + } + } + + impl IntoArrayRef for Vec { + fn into_array_ref(self) -> ArrayRef { + create_array!(Int32, self) + } + } + + impl IntoArrayRef for Vec> { + fn into_array_ref(self) -> ArrayRef { + create_array!(Int32, self) + } + } + + impl IntoArrayRef for &[i32] { + fn into_array_ref(self) -> ArrayRef { + create_array!(Int32, self.to_vec()) + } + } + + impl IntoArrayRef for &[Option] { + fn into_array_ref(self) -> ArrayRef { + create_array!(Int32, self.to_vec()) + } + } + + impl IntoArrayRef for Vec { + fn into_array_ref(self) -> ArrayRef { + create_array!(Int64, self) + } + } + + impl IntoArrayRef for Vec> { + fn into_array_ref(self) -> ArrayRef { + create_array!(Int64, self) + } + } + + impl IntoArrayRef for &[i64] { + fn into_array_ref(self) -> ArrayRef { + create_array!(Int64, self.to_vec()) + } + } + + impl IntoArrayRef for &[Option] { + fn into_array_ref(self) -> ArrayRef { + create_array!(Int64, self.to_vec()) + } + } + + impl IntoArrayRef for Vec { + fn into_array_ref(self) -> ArrayRef { + create_array!(UInt8, self) + } + } + + impl IntoArrayRef for Vec> { + fn into_array_ref(self) -> ArrayRef { + create_array!(UInt8, self) + } + } + + impl IntoArrayRef for &[u8] { + fn into_array_ref(self) -> ArrayRef { + create_array!(UInt8, self.to_vec()) + } + } + + impl IntoArrayRef for &[Option] { + fn into_array_ref(self) -> ArrayRef { + create_array!(UInt8, self.to_vec()) + } + } + + impl IntoArrayRef for Vec { + fn into_array_ref(self) -> ArrayRef { + create_array!(UInt16, self) + } + } + + impl IntoArrayRef for Vec> { + fn into_array_ref(self) -> ArrayRef { + create_array!(UInt16, self) + } + } + + impl IntoArrayRef for &[u16] { + fn into_array_ref(self) -> ArrayRef { + create_array!(UInt16, self.to_vec()) + } + } + + impl IntoArrayRef for &[Option] { + fn into_array_ref(self) -> ArrayRef { + create_array!(UInt16, self.to_vec()) + } + } + + impl IntoArrayRef for Vec { + fn into_array_ref(self) -> ArrayRef { + create_array!(UInt32, self) + } + } + + impl IntoArrayRef for Vec> { + fn into_array_ref(self) -> ArrayRef { + create_array!(UInt32, self) + } + } + + impl IntoArrayRef for &[u32] { + fn into_array_ref(self) -> ArrayRef { + create_array!(UInt32, self.to_vec()) + } + } + + impl IntoArrayRef for &[Option] { + fn into_array_ref(self) -> ArrayRef { + create_array!(UInt32, self.to_vec()) + } + } + + impl IntoArrayRef for Vec { + fn into_array_ref(self) -> ArrayRef { + create_array!(UInt64, self) + } + } + + impl IntoArrayRef for Vec> { + fn into_array_ref(self) -> ArrayRef { + create_array!(UInt64, self) + } + } + + impl IntoArrayRef for &[u64] { + fn into_array_ref(self) -> ArrayRef { + create_array!(UInt64, self.to_vec()) + } + } + + impl IntoArrayRef for &[Option] { + fn into_array_ref(self) -> ArrayRef { + create_array!(UInt64, self.to_vec()) + } + } + + //#TODO add impl for f16 + + impl IntoArrayRef for Vec { + fn into_array_ref(self) -> ArrayRef { + create_array!(Float32, self) + } + } + + impl IntoArrayRef for Vec> { + fn into_array_ref(self) -> ArrayRef { + create_array!(Float32, self) + } + } + + impl IntoArrayRef for &[f32] { + fn into_array_ref(self) -> ArrayRef { + create_array!(Float32, self.to_vec()) + } + } + + impl IntoArrayRef for &[Option] { + fn into_array_ref(self) -> ArrayRef { + create_array!(Float32, self.to_vec()) + } + } + + impl IntoArrayRef for Vec { + fn into_array_ref(self) -> ArrayRef { + create_array!(Float64, self) + } + } + + impl IntoArrayRef for Vec> { + fn into_array_ref(self) -> ArrayRef { + create_array!(Float64, self) + } + } + + impl IntoArrayRef for &[f64] { + fn into_array_ref(self) -> ArrayRef { + create_array!(Float64, self.to_vec()) + } + } + + impl IntoArrayRef for &[Option] { + fn into_array_ref(self) -> ArrayRef { + create_array!(Float64, self.to_vec()) + } + } + + impl IntoArrayRef for Vec<&str> { + fn into_array_ref(self) -> ArrayRef { + create_array!(Utf8, self) + } + } + + impl IntoArrayRef for Vec> { + fn into_array_ref(self) -> ArrayRef { + create_array!(Utf8, self) + } + } + + impl IntoArrayRef for &[&str] { + fn into_array_ref(self) -> ArrayRef { + create_array!(Utf8, self.to_vec()) + } + } + + impl IntoArrayRef for &[Option<&str>] { + fn into_array_ref(self) -> ArrayRef { + create_array!(Utf8, self.to_vec()) + } + } + + impl IntoArrayRef for Vec { + fn into_array_ref(self) -> ArrayRef { + create_array!(Utf8, self) + } + } + + impl IntoArrayRef for Vec> { + fn into_array_ref(self) -> ArrayRef { + create_array!(Utf8, self) + } + } + + impl IntoArrayRef for &[String] { + fn into_array_ref(self) -> ArrayRef { + create_array!(Utf8, self.to_vec()) + } + } + + impl IntoArrayRef for &[Option] { + fn into_array_ref(self) -> ArrayRef { + create_array!(Utf8, self.to_vec()) + } + } +} + #[cfg(test)] mod tests { use crate::cast::{as_float64_array, as_int32_array, as_string_array}; diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index c70389b63177..cf51dadf6b4a 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -2354,7 +2354,7 @@ pub(crate) mod tests { fn test_large_tree() { let mut item = TestTreeNode::new_leaf("initial".to_string()); for i in 0..3000 { - item = TestTreeNode::new(vec![item], format!("parent-{}", i)); + item = TestTreeNode::new(vec![item], format!("parent-{i}")); } let mut visitor = diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index edc0d34b539a..61995d6707d4 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -95,7 +95,7 @@ extended_tests = [] [dependencies] arrow = { workspace = true } arrow-ipc = { workspace = true } -arrow-schema = { workspace = true } +arrow-schema = { workspace = true, features = ["canonical_extension_types"] } async-trait = { workspace = true } bytes = { workspace = true } bzip2 = { version = "0.5.2", optional = true } @@ -117,7 +117,6 @@ datafusion-functions-aggregate = { workspace = true } datafusion-functions-nested = { workspace = true, optional = true } datafusion-functions-table = { workspace = true } datafusion-functions-window = { workspace = true } -datafusion-macros = { workspace = true } datafusion-optimizer = { workspace = true } datafusion-physical-expr = { workspace = true } datafusion-physical-expr-common = { workspace = true } @@ -139,7 +138,7 @@ sqlparser = { workspace = true } tempfile = { workspace = true } tokio = { workspace = true } url = { workspace = true } -uuid = { version = "1.16", features = ["v4", "js"] } +uuid = { version = "1.17", features = ["v4", "js"] } xz2 = { version = "0.1", optional = true, features = ["static"] } zstd = { version = "0.13", optional = true, default-features = false } @@ -150,22 +149,23 @@ ctor = { workspace = true } dashmap = "6.1.0" datafusion-doc = { workspace = true } datafusion-functions-window-common = { workspace = true } +datafusion-macros = { workspace = true } datafusion-physical-optimizer = { workspace = true } doc-comment = { workspace = true } env_logger = { workspace = true } insta = { workspace = true } paste = "^1.0" rand = { workspace = true, features = ["small_rng"] } -rand_distr = "0.4.3" +rand_distr = "0.5" regex = { workspace = true } rstest = { workspace = true } serde_json = { workspace = true } -sysinfo = "0.34.2" +sysinfo = "0.35.2" test-utils = { path = "../../test-utils" } tokio = { workspace = true, features = ["rt-multi-thread", "parking_lot", "fs"] } [target.'cfg(not(target_os = "windows"))'.dev-dependencies] -nix = { version = "0.29.0", features = ["fs"] } +nix = { version = "0.30.1", features = ["fs"] } [[bench]] harness = false @@ -179,6 +179,10 @@ name = "csv_load" harness = false name = "distinct_query_sql" +[[bench]] +harness = false +name = "push_down_filter" + [[bench]] harness = false name = "sort_limit_query_sql" diff --git a/datafusion/core/benches/aggregate_query_sql.rs b/datafusion/core/benches/aggregate_query_sql.rs index b29bfc487340..057a0e1d1b54 100644 --- a/datafusion/core/benches/aggregate_query_sql.rs +++ b/datafusion/core/benches/aggregate_query_sql.rs @@ -158,7 +158,7 @@ fn criterion_benchmark(c: &mut Criterion) { query( ctx.clone(), &rt, - "SELECT utf8, approx_percentile_cont(u64_wide, 0.5, 2500) \ + "SELECT utf8, approx_percentile_cont(0.5, 2500) WITHIN GROUP (ORDER BY u64_wide) \ FROM t GROUP BY utf8", ) }) @@ -169,7 +169,7 @@ fn criterion_benchmark(c: &mut Criterion) { query( ctx.clone(), &rt, - "SELECT utf8, approx_percentile_cont(f32, 0.5, 2500) \ + "SELECT utf8, approx_percentile_cont(0.5, 2500) WITHIN GROUP (ORDER BY f32) \ FROM t GROUP BY utf8", ) }) diff --git a/datafusion/core/benches/data_utils/mod.rs b/datafusion/core/benches/data_utils/mod.rs index fc5f8945c439..c0477b1306f7 100644 --- a/datafusion/core/benches/data_utils/mod.rs +++ b/datafusion/core/benches/data_utils/mod.rs @@ -26,8 +26,8 @@ use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion::datasource::MemTable; use datafusion::error::Result; use datafusion_common::DataFusionError; +use rand::prelude::IndexedRandom; use rand::rngs::StdRng; -use rand::seq::SliceRandom; use rand::{Rng, SeedableRng}; use rand_distr::Distribution; use rand_distr::{Normal, Pareto}; @@ -49,11 +49,6 @@ pub fn create_table_provider( MemTable::try_new(schema, partitions).map(Arc::new) } -/// create a seedable [`StdRng`](rand::StdRng) -fn seedable_rng() -> StdRng { - StdRng::seed_from_u64(42) -} - /// Create test data schema pub fn create_schema() -> Schema { Schema::new(vec![ @@ -73,14 +68,14 @@ pub fn create_schema() -> Schema { fn create_data(size: usize, null_density: f64) -> Vec> { // use random numbers to avoid spurious compiler optimizations wrt to branching - let mut rng = seedable_rng(); + let mut rng = StdRng::seed_from_u64(42); (0..size) .map(|_| { - if rng.gen::() > null_density { + if rng.random::() > null_density { None } else { - Some(rng.gen::()) + Some(rng.random::()) } }) .collect() @@ -88,14 +83,14 @@ fn create_data(size: usize, null_density: f64) -> Vec> { fn create_integer_data(size: usize, value_density: f64) -> Vec> { // use random numbers to avoid spurious compiler optimizations wrt to branching - let mut rng = seedable_rng(); + let mut rng = StdRng::seed_from_u64(42); (0..size) .map(|_| { - if rng.gen::() > value_density { + if rng.random::() > value_density { None } else { - Some(rng.gen::()) + Some(rng.random::()) } }) .collect() @@ -125,7 +120,7 @@ fn create_record_batch( // Integer values between [0, 9]. let integer_values_narrow = (0..batch_size) - .map(|_| rng.gen_range(0_u64..10)) + .map(|_| rng.random_range(0_u64..10)) .collect::>(); RecordBatch::try_new( @@ -149,7 +144,7 @@ pub fn create_record_batches( partitions_len: usize, batch_size: usize, ) -> Vec> { - let mut rng = seedable_rng(); + let mut rng = StdRng::seed_from_u64(42); (0..partitions_len) .map(|_| { (0..array_len / batch_size / partitions_len) @@ -217,7 +212,7 @@ pub(crate) fn make_data( let mut ts_builder = Int64Builder::new(); let gen_id = |rng: &mut rand::rngs::SmallRng| { - rng.gen::<[u8; 16]>() + rng.random::<[u8; 16]>() .iter() .fold(String::new(), |mut output, b| { let _ = write!(output, "{b:02X}"); @@ -233,7 +228,7 @@ pub(crate) fn make_data( .map(|_| gen_sample_cnt(&mut rng)) .collect::>(); for _ in 0..sample_cnt { - let random_index = rng.gen_range(0..simultaneous_group_cnt); + let random_index = rng.random_range(0..simultaneous_group_cnt); let trace_id = &mut group_ids[random_index]; let sample_cnt = &mut group_sample_cnts[random_index]; *sample_cnt -= 1; diff --git a/datafusion/core/benches/dataframe.rs b/datafusion/core/benches/dataframe.rs index 832553ebed82..12eb34719e4b 100644 --- a/datafusion/core/benches/dataframe.rs +++ b/datafusion/core/benches/dataframe.rs @@ -32,7 +32,7 @@ use tokio::runtime::Runtime; fn create_context(field_count: u32) -> datafusion_common::Result> { let mut fields = vec![]; for i in 0..field_count { - fields.push(Field::new(format!("str{}", i), DataType::Utf8, true)) + fields.push(Field::new(format!("str{i}"), DataType::Utf8, true)) } let schema = Arc::new(Schema::new(fields)); @@ -49,8 +49,8 @@ fn run(column_count: u32, ctx: Arc, rt: &Runtime) { let mut data_frame = ctx.table("t").await.unwrap(); for i in 0..column_count { - let field_name = &format!("str{}", i); - let new_field_name = &format!("newstr{}", i); + let field_name = &format!("str{i}"); + let new_field_name = &format!("newstr{i}"); data_frame = data_frame .with_column_renamed(field_name, new_field_name) diff --git a/datafusion/core/benches/distinct_query_sql.rs b/datafusion/core/benches/distinct_query_sql.rs index c7056aab8689..c1ef55992689 100644 --- a/datafusion/core/benches/distinct_query_sql.rs +++ b/datafusion/core/benches/distinct_query_sql.rs @@ -154,7 +154,7 @@ fn criterion_benchmark_limited_distinct_sampled(c: &mut Criterion) { let sql = format!("select DISTINCT trace_id from traces group by trace_id limit {limit};"); c.bench_function( - format!("distinct query with {} partitions and {} samples per partition with limit {}", partitions, samples, limit).as_str(), + format!("distinct query with {partitions} partitions and {samples} samples per partition with limit {limit}").as_str(), |b| b.iter(|| { let (plan, ctx) = rt.block_on( create_context_sampled_data(sql.as_str(), partitions, samples) @@ -168,7 +168,7 @@ fn criterion_benchmark_limited_distinct_sampled(c: &mut Criterion) { let sql = format!("select DISTINCT trace_id from traces group by trace_id limit {limit};"); c.bench_function( - format!("distinct query with {} partitions and {} samples per partition with limit {}", partitions, samples, limit).as_str(), + format!("distinct query with {partitions} partitions and {samples} samples per partition with limit {limit}").as_str(), |b| b.iter(|| { let (plan, ctx) = rt.block_on( create_context_sampled_data(sql.as_str(), partitions, samples) @@ -182,7 +182,7 @@ fn criterion_benchmark_limited_distinct_sampled(c: &mut Criterion) { let sql = format!("select DISTINCT trace_id from traces group by trace_id limit {limit};"); c.bench_function( - format!("distinct query with {} partitions and {} samples per partition with limit {}", partitions, samples, limit).as_str(), + format!("distinct query with {partitions} partitions and {samples} samples per partition with limit {limit}").as_str(), |b| b.iter(|| { let (plan, ctx) = rt.block_on( create_context_sampled_data(sql.as_str(), partitions, samples) diff --git a/datafusion/core/benches/map_query_sql.rs b/datafusion/core/benches/map_query_sql.rs index 79229dfc2fbd..063b8e6c86bb 100644 --- a/datafusion/core/benches/map_query_sql.rs +++ b/datafusion/core/benches/map_query_sql.rs @@ -34,7 +34,7 @@ mod data_utils; fn build_keys(rng: &mut ThreadRng) -> Vec { let mut keys = vec![]; for _ in 0..1000 { - keys.push(rng.gen_range(0..9999).to_string()); + keys.push(rng.random_range(0..9999).to_string()); } keys } @@ -42,7 +42,7 @@ fn build_keys(rng: &mut ThreadRng) -> Vec { fn build_values(rng: &mut ThreadRng) -> Vec { let mut values = vec![]; for _ in 0..1000 { - values.push(rng.gen_range(0..9999)); + values.push(rng.random_range(0..9999)); } values } @@ -64,15 +64,18 @@ fn criterion_benchmark(c: &mut Criterion) { let rt = Runtime::new().unwrap(); let df = rt.block_on(ctx.lock().table("t")).unwrap(); - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let keys = build_keys(&mut rng); let values = build_values(&mut rng); let mut key_buffer = Vec::new(); let mut value_buffer = Vec::new(); for i in 0..1000 { - key_buffer.push(Expr::Literal(ScalarValue::Utf8(Some(keys[i].clone())))); - value_buffer.push(Expr::Literal(ScalarValue::Int32(Some(values[i])))); + key_buffer.push(Expr::Literal( + ScalarValue::Utf8(Some(keys[i].clone())), + None, + )); + value_buffer.push(Expr::Literal(ScalarValue::Int32(Some(values[i])), None)); } c.bench_function("map_1000_1", |b| { b.iter(|| { diff --git a/datafusion/core/benches/parquet_query_sql.rs b/datafusion/core/benches/parquet_query_sql.rs index f82a126c5652..14dcdf15f173 100644 --- a/datafusion/core/benches/parquet_query_sql.rs +++ b/datafusion/core/benches/parquet_query_sql.rs @@ -29,9 +29,10 @@ use datafusion_common::instant::Instant; use futures::stream::StreamExt; use parquet::arrow::ArrowWriter; use parquet::file::properties::{WriterProperties, WriterVersion}; -use rand::distributions::uniform::SampleUniform; -use rand::distributions::Alphanumeric; +use rand::distr::uniform::SampleUniform; +use rand::distr::Alphanumeric; use rand::prelude::*; +use rand::rng; use std::fs::File; use std::io::Read; use std::ops::Range; @@ -97,13 +98,13 @@ fn generate_string_dictionary( len: usize, valid_percent: f64, ) -> ArrayRef { - let mut rng = thread_rng(); + let mut rng = rng(); let strings: Vec<_> = (0..cardinality).map(|x| format!("{prefix}#{x}")).collect(); Arc::new(DictionaryArray::::from_iter((0..len).map( |_| { - rng.gen_bool(valid_percent) - .then(|| strings[rng.gen_range(0..cardinality)].as_str()) + rng.random_bool(valid_percent) + .then(|| strings[rng.random_range(0..cardinality)].as_str()) }, ))) } @@ -113,10 +114,10 @@ fn generate_strings( len: usize, valid_percent: f64, ) -> ArrayRef { - let mut rng = thread_rng(); + let mut rng = rng(); Arc::new(StringArray::from_iter((0..len).map(|_| { - rng.gen_bool(valid_percent).then(|| { - let string_len = rng.gen_range(string_length_range.clone()); + rng.random_bool(valid_percent).then(|| { + let string_len = rng.random_range(string_length_range.clone()); (0..string_len) .map(|_| char::from(rng.sample(Alphanumeric))) .collect::() @@ -133,10 +134,10 @@ where T: ArrowPrimitiveType, T::Native: SampleUniform, { - let mut rng = thread_rng(); + let mut rng = rng(); Arc::new(PrimitiveArray::::from_iter((0..len).map(|_| { - rng.gen_bool(valid_percent) - .then(|| rng.gen_range(range.clone())) + rng.random_bool(valid_percent) + .then(|| rng.random_range(range.clone())) }))) } diff --git a/datafusion/core/benches/push_down_filter.rs b/datafusion/core/benches/push_down_filter.rs new file mode 100644 index 000000000000..139fb12c3094 --- /dev/null +++ b/datafusion/core/benches/push_down_filter.rs @@ -0,0 +1,123 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::RecordBatch; +use arrow::datatypes::{DataType, Field, Schema}; +use bytes::{BufMut, BytesMut}; +use criterion::{criterion_group, criterion_main, Criterion}; +use datafusion::config::ConfigOptions; +use datafusion::prelude::{ParquetReadOptions, SessionContext}; +use datafusion_execution::object_store::ObjectStoreUrl; +use datafusion_physical_optimizer::filter_pushdown::FilterPushdown; +use datafusion_physical_optimizer::PhysicalOptimizerRule; +use datafusion_physical_plan::ExecutionPlan; +use object_store::memory::InMemory; +use object_store::path::Path; +use object_store::ObjectStore; +use parquet::arrow::ArrowWriter; +use std::sync::Arc; + +async fn create_plan() -> Arc { + let ctx = SessionContext::new(); + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, true), + Field::new("name", DataType::Utf8, true), + Field::new("age", DataType::UInt16, true), + Field::new("salary", DataType::Float64, true), + ])); + let batch = RecordBatch::new_empty(schema); + + let store = Arc::new(InMemory::new()) as Arc; + let mut out = BytesMut::new().writer(); + { + let mut writer = ArrowWriter::try_new(&mut out, batch.schema(), None).unwrap(); + writer.write(&batch).unwrap(); + writer.finish().unwrap(); + } + let data = out.into_inner().freeze(); + store + .put(&Path::from("test.parquet"), data.into()) + .await + .unwrap(); + ctx.register_object_store( + ObjectStoreUrl::parse("memory://").unwrap().as_ref(), + store, + ); + + ctx.register_parquet("t", "memory:///", ParquetReadOptions::default()) + .await + .unwrap(); + + let df = ctx + .sql( + r" + WITH brackets AS ( + SELECT age % 10 AS age_bracket + FROM t + GROUP BY age % 10 + HAVING COUNT(*) > 10 + ) + SELECT id, name, age, salary + FROM t + JOIN brackets ON t.age % 10 = brackets.age_bracket + WHERE age > 20 AND t.salary > 1000 + ORDER BY t.salary DESC + LIMIT 100 + ", + ) + .await + .unwrap(); + + df.create_physical_plan().await.unwrap() +} + +#[derive(Clone)] +struct BenchmarkPlan { + plan: Arc, + config: ConfigOptions, +} + +impl std::fmt::Display for BenchmarkPlan { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "BenchmarkPlan") + } +} + +fn bench_push_down_filter(c: &mut Criterion) { + // Create a relatively complex plan + let plan = tokio::runtime::Runtime::new() + .unwrap() + .block_on(create_plan()); + let mut config = ConfigOptions::default(); + config.execution.parquet.pushdown_filters = true; + let plan = BenchmarkPlan { plan, config }; + let optimizer = FilterPushdown::new(); + + c.bench_function("push_down_filter", |b| { + b.iter(|| { + optimizer + .optimize(Arc::clone(&plan.plan), &plan.config) + .unwrap(); + }); + }); +} + +// It's a bit absurd that it's this complicated but to generate a flamegraph you can run: +// `cargo flamegraph -p datafusion --bench push_down_filter --flamechart --root --profile profiling --freq 1000 -- --bench` +// See https://github.com/flamegraph-rs/flamegraph +criterion_group!(benches, bench_push_down_filter); +criterion_main!(benches); diff --git a/datafusion/core/benches/sort.rs b/datafusion/core/benches/sort.rs index 85f456ce5dc2..e1bc478b36f0 100644 --- a/datafusion/core/benches/sort.rs +++ b/datafusion/core/benches/sort.rs @@ -595,7 +595,7 @@ impl DataGenerator { /// Create an array of i64 sorted values (where approximately 1/3 values is repeated) fn i64_values(&mut self) -> Vec { let mut vec: Vec<_> = (0..INPUT_SIZE) - .map(|_| self.rng.gen_range(0..INPUT_SIZE as i64)) + .map(|_| self.rng.random_range(0..INPUT_SIZE as i64)) .collect(); vec.sort_unstable(); @@ -620,7 +620,7 @@ impl DataGenerator { // pick from the 100 strings randomly let mut input = (0..INPUT_SIZE) .map(|_| { - let idx = self.rng.gen_range(0..strings.len()); + let idx = self.rng.random_range(0..strings.len()); let s = Arc::clone(&strings[idx]); Some(s) }) @@ -643,7 +643,7 @@ impl DataGenerator { fn random_string(&mut self) -> String { let rng = &mut self.rng; - rng.sample_iter(rand::distributions::Alphanumeric) + rng.sample_iter(rand::distr::Alphanumeric) .filter(|c| c.is_ascii_alphabetic()) .take(20) .map(char::from) @@ -665,7 +665,7 @@ where let mut outputs: Vec>> = (0..NUM_STREAMS).map(|_| Vec::new()).collect(); for i in input { - let stream_idx = rng.gen_range(0..NUM_STREAMS); + let stream_idx = rng.random_range(0..NUM_STREAMS); let stream = &mut outputs[stream_idx]; match stream.last_mut() { Some(x) if x.len() < BATCH_SIZE => x.push(i), diff --git a/datafusion/core/benches/spm.rs b/datafusion/core/benches/spm.rs index 63b06f20cd86..8613525cb248 100644 --- a/datafusion/core/benches/spm.rs +++ b/datafusion/core/benches/spm.rs @@ -125,8 +125,7 @@ fn criterion_benchmark(c: &mut Criterion) { for &batch_count in &batch_counts { for &partition_count in &partition_counts { let description = format!( - "{}_batch_count_{}_partition_count_{}", - cardinality_label, batch_count, partition_count + "{cardinality_label}_batch_count_{batch_count}_partition_count_{partition_count}" ); run_bench( c, diff --git a/datafusion/core/benches/sql_planner.rs b/datafusion/core/benches/sql_planner.rs index 49cc830d58bc..6dc953f56b43 100644 --- a/datafusion/core/benches/sql_planner.rs +++ b/datafusion/core/benches/sql_planner.rs @@ -136,10 +136,10 @@ fn benchmark_with_param_values_many_columns( if i > 0 { aggregates.push_str(", "); } - aggregates.push_str(format!("MAX(a{})", i).as_str()); + aggregates.push_str(format!("MAX(a{i})").as_str()); } // SELECT max(attr0), ..., max(attrN) FROM t1. - let query = format!("SELECT {} FROM t1", aggregates); + let query = format!("SELECT {aggregates} FROM t1"); let statement = ctx.state().sql_to_statement(&query, "Generic").unwrap(); let plan = rt.block_on(async { ctx.state().statement_to_plan(statement).await.unwrap() }); @@ -164,7 +164,7 @@ fn register_union_order_table(ctx: &SessionContext, num_columns: usize, num_rows .map(|j| j as u64 * 100 + i) .collect::>(), )); - (format!("c{}", i), array) + (format!("c{i}"), array) }); let batch = RecordBatch::try_from_iter(iter).unwrap(); let schema = batch.schema(); @@ -172,7 +172,7 @@ fn register_union_order_table(ctx: &SessionContext, num_columns: usize, num_rows // tell DataFusion that the table is sorted by all columns let sort_order = (0..num_columns) - .map(|i| col(format!("c{}", i)).sort(true, true)) + .map(|i| col(format!("c{i}")).sort(true, true)) .collect::>(); // create the table @@ -208,12 +208,12 @@ fn union_orderby_query(n: usize) -> String { }) .collect::>() .join(", "); - query.push_str(&format!("(SELECT {} FROM t ORDER BY c{})", select_list, i)); + query.push_str(&format!("(SELECT {select_list} FROM t ORDER BY c{i})")); } query.push_str(&format!( "\nORDER BY {}", (0..n) - .map(|i| format!("c{}", i)) + .map(|i| format!("c{i}")) .collect::>() .join(", ") )); @@ -293,9 +293,9 @@ fn criterion_benchmark(c: &mut Criterion) { if i > 0 { aggregates.push_str(", "); } - aggregates.push_str(format!("MAX(a{})", i).as_str()); + aggregates.push_str(format!("MAX(a{i})").as_str()); } - let query = format!("SELECT {} FROM t1", aggregates); + let query = format!("SELECT {aggregates} FROM t1"); b.iter(|| { physical_plan(&ctx, &rt, &query); }); @@ -402,7 +402,7 @@ fn criterion_benchmark(c: &mut Criterion) { for q in tpch_queries { let sql = std::fs::read_to_string(format!("{benchmarks_path}queries/{q}.sql")).unwrap(); - c.bench_function(&format!("physical_plan_tpch_{}", q), |b| { + c.bench_function(&format!("physical_plan_tpch_{q}"), |b| { b.iter(|| physical_plan(&tpch_ctx, &rt, &sql)) }); } diff --git a/datafusion/core/benches/sql_query_with_io.rs b/datafusion/core/benches/sql_query_with_io.rs index 58d71ee5b2eb..58797dfed6b6 100644 --- a/datafusion/core/benches/sql_query_with_io.rs +++ b/datafusion/core/benches/sql_query_with_io.rs @@ -66,7 +66,7 @@ fn create_parquet_file(rng: &mut StdRng, id_offset: usize) -> Bytes { let mut payload_builder = Int64Builder::new(); for row in 0..FILE_ROWS { id_builder.append_value((row + id_offset) as u64); - payload_builder.append_value(rng.gen()); + payload_builder.append_value(rng.random()); } let batch = RecordBatch::try_new( Arc::clone(&schema), diff --git a/datafusion/core/src/bin/print_functions_docs.rs b/datafusion/core/src/bin/print_functions_docs.rs index 7afb90282a80..1044717aaffb 100644 --- a/datafusion/core/src/bin/print_functions_docs.rs +++ b/datafusion/core/src/bin/print_functions_docs.rs @@ -46,7 +46,7 @@ fn main() -> Result<()> { "scalar" => print_scalar_docs(), "window" => print_window_docs(), _ => { - panic!("Unknown function type: {}", function_type) + panic!("Unknown function type: {function_type}") } }?; @@ -92,7 +92,7 @@ fn print_window_docs() -> Result { fn save_doc_code_text(documentation: &Documentation, name: &str) { let attr_text = documentation.to_doc_attribute(); - let file_path = format!("{}.txt", name); + let file_path = format!("{name}.txt"); if std::path::Path::new(&file_path).exists() { std::fs::remove_file(&file_path).unwrap(); } @@ -215,16 +215,15 @@ fn print_docs( r#" #### Example -{} -"#, - example +{example} +"# ); } if let Some(alt_syntax) = &documentation.alternative_syntax { let _ = writeln!(docs, "#### Alternative Syntax\n"); for syntax in alt_syntax { - let _ = writeln!(docs, "```sql\n{}\n```", syntax); + let _ = writeln!(docs, "```sql\n{syntax}\n```"); } } diff --git a/datafusion/core/src/bin/print_runtime_config_docs.rs b/datafusion/core/src/bin/print_runtime_config_docs.rs new file mode 100644 index 000000000000..31425da73d35 --- /dev/null +++ b/datafusion/core/src/bin/print_runtime_config_docs.rs @@ -0,0 +1,23 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_execution::runtime_env::RuntimeEnvBuilder; + +fn main() { + let docs = RuntimeEnvBuilder::generate_config_markdown(); + println!("{docs}"); +} diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 9a70f8f43fb6..02a18f22c916 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -33,8 +33,8 @@ use crate::execution::context::{SessionState, TaskContext}; use crate::execution::FunctionRegistry; use crate::logical_expr::utils::find_window_exprs; use crate::logical_expr::{ - col, Expr, JoinType, LogicalPlan, LogicalPlanBuilder, LogicalPlanBuilderOptions, - Partitioning, TableType, + col, ident, Expr, JoinType, LogicalPlan, LogicalPlanBuilder, + LogicalPlanBuilderOptions, Partitioning, TableType, }; use crate::physical_plan::{ collect, collect_partitioned, execute_stream, execute_stream_partitioned, @@ -166,9 +166,12 @@ impl Default for DataFrameWriteOptions { /// /// # Example /// ``` +/// # use std::sync::Arc; /// # use datafusion::prelude::*; /// # use datafusion::error::Result; /// # use datafusion::functions_aggregate::expr_fn::min; +/// # use datafusion::arrow::array::{Int32Array, RecordBatch, StringArray}; +/// # use datafusion::arrow::datatypes::{DataType, Field, Schema}; /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); @@ -181,6 +184,28 @@ impl Default for DataFrameWriteOptions { /// .limit(0, Some(100))?; /// // Perform the actual computation /// let results = df.collect(); +/// +/// // Create a new dataframe with in-memory data +/// let schema = Schema::new(vec![ +/// Field::new("id", DataType::Int32, true), +/// Field::new("name", DataType::Utf8, true), +/// ]); +/// let batch = RecordBatch::try_new( +/// Arc::new(schema), +/// vec![ +/// Arc::new(Int32Array::from(vec![1, 2, 3])), +/// Arc::new(StringArray::from(vec!["foo", "bar", "baz"])), +/// ], +/// )?; +/// let df = ctx.read_batch(batch)?; +/// df.show().await?; +/// +/// // Create a new dataframe with in-memory data using macro +/// let df = dataframe!( +/// "id" => [1, 2, 3], +/// "name" => ["foo", "bar", "baz"] +/// )?; +/// df.show().await?; /// # Ok(()) /// # } /// ``` @@ -934,7 +959,7 @@ impl DataFrame { vec![], original_schema_fields .clone() - .map(|f| count(col(f.name())).alias(f.name())) + .map(|f| count(ident(f.name())).alias(f.name())) .collect::>(), ), // null_count aggregation @@ -943,7 +968,7 @@ impl DataFrame { original_schema_fields .clone() .map(|f| { - sum(case(is_null(col(f.name()))) + sum(case(is_null(ident(f.name()))) .when(lit(true), lit(1)) .otherwise(lit(0)) .unwrap()) @@ -957,7 +982,7 @@ impl DataFrame { original_schema_fields .clone() .filter(|f| f.data_type().is_numeric()) - .map(|f| avg(col(f.name())).alias(f.name())) + .map(|f| avg(ident(f.name())).alias(f.name())) .collect::>(), ), // std aggregation @@ -966,7 +991,7 @@ impl DataFrame { original_schema_fields .clone() .filter(|f| f.data_type().is_numeric()) - .map(|f| stddev(col(f.name())).alias(f.name())) + .map(|f| stddev(ident(f.name())).alias(f.name())) .collect::>(), ), // min aggregation @@ -977,7 +1002,7 @@ impl DataFrame { .filter(|f| { !matches!(f.data_type(), DataType::Binary | DataType::Boolean) }) - .map(|f| min(col(f.name())).alias(f.name())) + .map(|f| min(ident(f.name())).alias(f.name())) .collect::>(), ), // max aggregation @@ -988,7 +1013,7 @@ impl DataFrame { .filter(|f| { !matches!(f.data_type(), DataType::Binary | DataType::Boolean) }) - .map(|f| max(col(f.name())).alias(f.name())) + .map(|f| max(ident(f.name())).alias(f.name())) .collect::>(), ), // median aggregation @@ -997,7 +1022,7 @@ impl DataFrame { original_schema_fields .clone() .filter(|f| f.data_type().is_numeric()) - .map(|f| median(col(f.name())).alias(f.name())) + .map(|f| median(ident(f.name())).alias(f.name())) .collect::>(), ), ]; @@ -1312,7 +1337,10 @@ impl DataFrame { /// ``` pub async fn count(self) -> Result { let rows = self - .aggregate(vec![], vec![count(Expr::Literal(COUNT_STAR_EXPANSION))])? + .aggregate( + vec![], + vec![count(Expr::Literal(COUNT_STAR_EXPANSION, None))], + )? .collect() .await?; let len = *rows @@ -1366,8 +1394,47 @@ impl DataFrame { /// # } /// ``` pub async fn show(self) -> Result<()> { + println!("{}", self.to_string().await?); + Ok(()) + } + + /// Execute the `DataFrame` and return a string representation of the results. + /// + /// # Example + /// ``` + /// # use datafusion::prelude::*; + /// # use datafusion::error::Result; + /// # use datafusion::execution::SessionStateBuilder; + /// + /// # #[tokio::main] + /// # async fn main() -> Result<()> { + /// let cfg = SessionConfig::new() + /// .set_str("datafusion.format.null", "no-value"); + /// let session_state = SessionStateBuilder::new() + /// .with_config(cfg) + /// .with_default_features() + /// .build(); + /// let ctx = SessionContext::new_with_state(session_state); + /// let df = ctx.sql("select null as 'null-column'").await?; + /// let result = df.to_string().await?; + /// assert_eq!(result, + /// "+-------------+ + /// | null-column | + /// +-------------+ + /// | no-value | + /// +-------------+" + /// ); + /// # Ok(()) + /// # } + pub async fn to_string(self) -> Result { + let options = self.session_state.config().options().format.clone(); + let arrow_options: arrow::util::display::FormatOptions = (&options).try_into()?; + let results = self.collect().await?; - Ok(pretty::print_batches(&results)?) + Ok( + pretty::pretty_format_batches_with_options(&results, &arrow_options)? + .to_string(), + ) } /// Execute the `DataFrame` and print only the first `num` rows of the @@ -2160,6 +2227,94 @@ impl DataFrame { }) .collect() } + + /// Helper for creating DataFrame. + /// # Example + /// ``` + /// use std::sync::Arc; + /// use arrow::array::{ArrayRef, Int32Array, StringArray}; + /// use datafusion::prelude::DataFrame; + /// let id: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3])); + /// let name: ArrayRef = Arc::new(StringArray::from(vec!["foo", "bar", "baz"])); + /// let df = DataFrame::from_columns(vec![("id", id), ("name", name)]).unwrap(); + /// // +----+------+, + /// // | id | name |, + /// // +----+------+, + /// // | 1 | foo |, + /// // | 2 | bar |, + /// // | 3 | baz |, + /// // +----+------+, + /// ``` + pub fn from_columns(columns: Vec<(&str, ArrayRef)>) -> Result { + let fields = columns + .iter() + .map(|(name, array)| Field::new(*name, array.data_type().clone(), true)) + .collect::>(); + + let arrays = columns + .into_iter() + .map(|(_, array)| array) + .collect::>(); + + let schema = Arc::new(Schema::new(fields)); + let batch = RecordBatch::try_new(schema, arrays)?; + let ctx = SessionContext::new(); + let df = ctx.read_batch(batch)?; + Ok(df) + } +} + +/// Macro for creating DataFrame. +/// # Example +/// ``` +/// use datafusion::prelude::dataframe; +/// # use datafusion::error::Result; +/// # #[tokio::main] +/// # async fn main() -> Result<()> { +/// let df = dataframe!( +/// "id" => [1, 2, 3], +/// "name" => ["foo", "bar", "baz"] +/// )?; +/// df.show().await?; +/// // +----+------+, +/// // | id | name |, +/// // +----+------+, +/// // | 1 | foo |, +/// // | 2 | bar |, +/// // | 3 | baz |, +/// // +----+------+, +/// let df_empty = dataframe!()?; // empty DataFrame +/// assert_eq!(df_empty.schema().fields().len(), 0); +/// assert_eq!(df_empty.count().await?, 0); +/// # Ok(()) +/// # } +/// ``` +#[macro_export] +macro_rules! dataframe { + () => {{ + use std::sync::Arc; + + use datafusion::prelude::SessionContext; + use datafusion::arrow::array::RecordBatch; + use datafusion::arrow::datatypes::Schema; + + let ctx = SessionContext::new(); + let batch = RecordBatch::new_empty(Arc::new(Schema::empty())); + ctx.read_batch(batch) + }}; + + ($($name:expr => $data:expr),+ $(,)?) => {{ + use datafusion::prelude::DataFrame; + use datafusion::common::test_util::IntoArrayRef; + + let columns = vec![ + $( + ($name, $data.into_array_ref()), + )+ + ]; + + DataFrame::from_columns(columns) + }}; } #[derive(Debug)] diff --git a/datafusion/core/src/datasource/file_format/arrow.rs b/datafusion/core/src/datasource/file_format/arrow.rs index 7fc27453d1ad..fc63591c5b69 100644 --- a/datafusion/core/src/datasource/file_format/arrow.rs +++ b/datafusion/core/src/datasource/file_format/arrow.rs @@ -27,7 +27,7 @@ use std::sync::Arc; use super::file_compression_type::FileCompressionType; use super::write::demux::DemuxedStreamReceiver; -use super::write::{create_writer, SharedBuffer}; +use super::write::SharedBuffer; use super::FileFormatFactory; use crate::datasource::file_format::write::get_writer_schema; use crate::datasource::file_format::FileFormat; @@ -51,9 +51,9 @@ use datafusion_datasource::display::FileGroupDisplay; use datafusion_datasource::file::FileSource; use datafusion_datasource::file_scan_config::{FileScanConfig, FileScanConfigBuilder}; use datafusion_datasource::sink::{DataSink, DataSinkExec}; +use datafusion_datasource::write::ObjectWriterBuilder; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_expr::dml::InsertOp; -use datafusion_physical_expr::PhysicalExpr; use datafusion_physical_expr_common::sort_expr::LexRequirement; use async_trait::async_trait; @@ -63,6 +63,7 @@ use futures::stream::BoxStream; use futures::StreamExt; use object_store::{GetResultPayload, ObjectMeta, ObjectStore}; use tokio::io::AsyncWriteExt; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; /// Initial writing buffer size. Note this is just a size hint for efficiency. It /// will grow beyond the set value if needed. @@ -223,7 +224,7 @@ impl FileSink for ArrowFileSink { async fn spawn_writer_tasks_and_join( &self, - _context: &Arc, + context: &Arc, demux_task: SpawnedTask>, mut file_stream_rx: DemuxedStreamReceiver, object_store: Arc, @@ -241,12 +242,19 @@ impl FileSink for ArrowFileSink { &get_writer_schema(&self.config), ipc_options.clone(), )?; - let mut object_store_writer = create_writer( + let mut object_store_writer = ObjectWriterBuilder::new( FileCompressionType::UNCOMPRESSED, &path, Arc::clone(&object_store), ) - .await?; + .with_buffer_size(Some( + context + .session_config() + .options() + .execution + .objectstore_writer_buffer_size, + )) + .build()?; file_write_tasks.spawn(async move { let mut row_count = 0; while let Some(batch) = rx.recv().await { diff --git a/datafusion/core/src/datasource/file_format/csv.rs b/datafusion/core/src/datasource/file_format/csv.rs index 323bc28057d4..efec07abbca0 100644 --- a/datafusion/core/src/datasource/file_format/csv.rs +++ b/datafusion/core/src/datasource/file_format/csv.rs @@ -217,8 +217,11 @@ mod tests { assert_eq!(tt_batches, 50 /* 100/2 */); // test metadata - assert_eq!(exec.statistics()?.num_rows, Precision::Absent); - assert_eq!(exec.statistics()?.total_byte_size, Precision::Absent); + assert_eq!(exec.partition_statistics(None)?.num_rows, Precision::Absent); + assert_eq!( + exec.partition_statistics(None)?.total_byte_size, + Precision::Absent + ); Ok(()) } @@ -1023,7 +1026,7 @@ mod tests { for _ in 0..batch_count { let output = deserializer.next()?; let DeserializerOutput::RecordBatch(batch) = output else { - panic!("Expected RecordBatch, got {:?}", output); + panic!("Expected RecordBatch, got {output:?}"); }; all_batches = concat_batches(&schema, &[all_batches, batch])?; } @@ -1061,7 +1064,7 @@ mod tests { for _ in 0..batch_count { let output = deserializer.next()?; let DeserializerOutput::RecordBatch(batch) = output else { - panic!("Expected RecordBatch, got {:?}", output); + panic!("Expected RecordBatch, got {output:?}"); }; all_batches = concat_batches(&schema, &[all_batches, batch])?; } @@ -1142,18 +1145,14 @@ mod tests { fn csv_line(line_number: usize) -> Bytes { let (int_value, float_value, bool_value, char_value) = csv_values(line_number); - format!( - "{},{},{},{}\n", - int_value, float_value, bool_value, char_value - ) - .into() + format!("{int_value},{float_value},{bool_value},{char_value}\n").into() } fn csv_values(line_number: usize) -> (i32, f64, bool, String) { let int_value = line_number as i32; let float_value = line_number as f64; let bool_value = line_number % 2 == 0; - let char_value = format!("{}-string", line_number); + let char_value = format!("{line_number}-string"); (int_value, float_value, bool_value, char_value) } diff --git a/datafusion/core/src/datasource/file_format/json.rs b/datafusion/core/src/datasource/file_format/json.rs index a70a0f51d330..34d3d64f07fb 100644 --- a/datafusion/core/src/datasource/file_format/json.rs +++ b/datafusion/core/src/datasource/file_format/json.rs @@ -75,8 +75,11 @@ mod tests { assert_eq!(tt_batches, 6 /* 12/2 */); // test metadata - assert_eq!(exec.statistics()?.num_rows, Precision::Absent); - assert_eq!(exec.statistics()?.total_byte_size, Precision::Absent); + assert_eq!(exec.partition_statistics(None)?.num_rows, Precision::Absent); + assert_eq!( + exec.partition_statistics(None)?.total_byte_size, + Precision::Absent + ); Ok(()) } @@ -275,7 +278,7 @@ mod tests { for _ in 0..3 { let output = deserializer.next()?; let DeserializerOutput::RecordBatch(batch) = output else { - panic!("Expected RecordBatch, got {:?}", output); + panic!("Expected RecordBatch, got {output:?}"); }; all_batches = concat_batches(&schema, &[all_batches, batch])? } @@ -315,7 +318,7 @@ mod tests { for _ in 0..2 { let output = deserializer.next()?; let DeserializerOutput::RecordBatch(batch) = output else { - panic!("Expected RecordBatch, got {:?}", output); + panic!("Expected RecordBatch, got {output:?}"); }; all_batches = concat_batches(&schema, &[all_batches, batch])? } diff --git a/datafusion/core/src/datasource/file_format/mod.rs b/datafusion/core/src/datasource/file_format/mod.rs index 3a098301f14e..015a9512a968 100644 --- a/datafusion/core/src/datasource/file_format/mod.rs +++ b/datafusion/core/src/datasource/file_format/mod.rs @@ -93,7 +93,7 @@ pub(crate) mod test_util { .with_projection(projection) .with_limit(limit) .build(), - None, + None ) .await?; Ok(exec) diff --git a/datafusion/core/src/datasource/file_format/options.rs b/datafusion/core/src/datasource/file_format/options.rs index 08e9a628dd61..9aaf1cf59811 100644 --- a/datafusion/core/src/datasource/file_format/options.rs +++ b/datafusion/core/src/datasource/file_format/options.rs @@ -550,7 +550,7 @@ impl ReadOptions<'_> for CsvReadOptions<'_> { ListingOptions::new(Arc::new(file_format)) .with_file_extension(self.file_extension) - .with_target_partitions(config.target_partitions()) + .with_session_config_options(config) .with_table_partition_cols(self.table_partition_cols.clone()) .with_file_sort_order(self.file_sort_order.clone()) } @@ -585,9 +585,9 @@ impl ReadOptions<'_> for ParquetReadOptions<'_> { ListingOptions::new(Arc::new(file_format)) .with_file_extension(self.file_extension) - .with_target_partitions(config.target_partitions()) .with_table_partition_cols(self.table_partition_cols.clone()) .with_file_sort_order(self.file_sort_order.clone()) + .with_session_config_options(config) } async fn get_resolved_schema( @@ -615,7 +615,7 @@ impl ReadOptions<'_> for NdJsonReadOptions<'_> { ListingOptions::new(Arc::new(file_format)) .with_file_extension(self.file_extension) - .with_target_partitions(config.target_partitions()) + .with_session_config_options(config) .with_table_partition_cols(self.table_partition_cols.clone()) .with_file_sort_order(self.file_sort_order.clone()) } @@ -643,7 +643,7 @@ impl ReadOptions<'_> for AvroReadOptions<'_> { ListingOptions::new(Arc::new(file_format)) .with_file_extension(self.file_extension) - .with_target_partitions(config.target_partitions()) + .with_session_config_options(config) .with_table_partition_cols(self.table_partition_cols.clone()) } @@ -669,7 +669,7 @@ impl ReadOptions<'_> for ArrowReadOptions<'_> { ListingOptions::new(Arc::new(file_format)) .with_file_extension(self.file_extension) - .with_target_partitions(config.target_partitions()) + .with_session_config_options(config) .with_table_partition_cols(self.table_partition_cols.clone()) } diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index 7b8b99273f4e..6a5c19829c1c 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -27,7 +27,10 @@ pub(crate) mod test_util { use crate::test::object_store::local_unpartitioned_file; - /// Writes `batches` to a temporary parquet file + /// Writes each `batch` to at least one temporary parquet file + /// + /// For example, if `batches` contains 2 batches, the function will create + /// 2 temporary files, each containing the contents of one batch /// /// If multi_page is set to `true`, the parquet file(s) are written /// with 2 rows per data page (used to test page filtering and @@ -52,7 +55,7 @@ pub(crate) mod test_util { } } - // we need the tmp files to be sorted as some tests rely on the how the returning files are ordered + // we need the tmp files to be sorted as some tests rely on the returned file ordering // https://github.com/apache/datafusion/pull/6629 let tmp_files = { let mut tmp_files: Vec<_> = (0..batches.len()) @@ -104,10 +107,8 @@ pub(crate) mod test_util { mod tests { use std::fmt::{self, Display, Formatter}; - use std::pin::Pin; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; - use std::task::{Context, Poll}; use std::time::Duration; use crate::datasource::file_format::parquet::test_util::store_parquet; @@ -117,7 +118,7 @@ mod tests { use crate::prelude::{ParquetReadOptions, SessionConfig, SessionContext}; use arrow::array::RecordBatch; - use arrow_schema::{Schema, SchemaRef}; + use arrow_schema::Schema; use datafusion_catalog::Session; use datafusion_common::cast::{ as_binary_array, as_binary_view_array, as_boolean_array, as_float32_array, @@ -137,7 +138,7 @@ mod tests { }; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_execution::runtime_env::RuntimeEnv; - use datafusion_execution::{RecordBatchStream, TaskContext}; + use datafusion_execution::TaskContext; use datafusion_expr::dml::InsertOp; use datafusion_physical_plan::stream::RecordBatchStreamAdapter; use datafusion_physical_plan::{collect, ExecutionPlan}; @@ -150,7 +151,7 @@ mod tests { use async_trait::async_trait; use datafusion_datasource::file_groups::FileGroup; use futures::stream::BoxStream; - use futures::{Stream, StreamExt}; + use futures::StreamExt; use insta::assert_snapshot; use log::error; use object_store::local::LocalFileSystem; @@ -166,6 +167,8 @@ mod tests { use parquet::format::FileMetaData; use tokio::fs::File; + use crate::test_util::bounded_stream; + enum ForceViews { Yes, No, @@ -616,9 +619,15 @@ mod tests { assert_eq!(tt_batches, 4 /* 8/2 */); // test metadata - assert_eq!(exec.statistics()?.num_rows, Precision::Exact(8)); + assert_eq!( + exec.partition_statistics(None)?.num_rows, + Precision::Exact(8) + ); // TODO correct byte size: https://github.com/apache/datafusion/issues/14936 - assert_eq!(exec.statistics()?.total_byte_size, Precision::Exact(671)); + assert_eq!( + exec.partition_statistics(None)?.total_byte_size, + Precision::Exact(671) + ); Ok(()) } @@ -659,9 +668,15 @@ mod tests { get_exec(&state, "alltypes_plain.parquet", projection, Some(1)).await?; // note: even if the limit is set, the executor rounds up to the batch size - assert_eq!(exec.statistics()?.num_rows, Precision::Exact(8)); + assert_eq!( + exec.partition_statistics(None)?.num_rows, + Precision::Exact(8) + ); // TODO correct byte size: https://github.com/apache/datafusion/issues/14936 - assert_eq!(exec.statistics()?.total_byte_size, Precision::Exact(671)); + assert_eq!( + exec.partition_statistics(None)?.total_byte_size, + Precision::Exact(671) + ); let batches = collect(exec, task_ctx).await?; assert_eq!(1, batches.len()); assert_eq!(11, batches[0].num_columns()); @@ -1073,7 +1088,7 @@ mod tests { let format = state .get_file_format_factory("parquet") .map(|factory| factory.create(state, &Default::default()).unwrap()) - .unwrap_or(Arc::new(ParquetFormat::new())); + .unwrap_or_else(|| Arc::new(ParquetFormat::new())); scan_format( state, &*format, None, &testdata, file_name, projection, limit, @@ -1308,7 +1323,7 @@ mod tests { #[tokio::test] async fn parquet_sink_write_with_extension() -> Result<()> { let filename = "test_file.custom_ext"; - let file_path = format!("file:///path/to/{}", filename); + let file_path = format!("file:///path/to/{filename}"); let parquet_sink = create_written_parquet_sink(file_path.as_str()).await?; // assert written to proper path @@ -1523,8 +1538,7 @@ mod tests { let prefix = path_parts[0].as_ref(); assert!( expected_partitions.contains(prefix), - "expected path prefix to match partition, instead found {:?}", - prefix + "expected path prefix to match partition, instead found {prefix:?}" ); expected_partitions.remove(prefix); @@ -1648,43 +1662,4 @@ mod tests { Ok(()) } - - /// Creates an bounded stream for testing purposes. - fn bounded_stream( - batch: RecordBatch, - limit: usize, - ) -> datafusion_execution::SendableRecordBatchStream { - Box::pin(BoundedStream { - count: 0, - limit, - batch, - }) - } - - struct BoundedStream { - limit: usize, - count: usize, - batch: RecordBatch, - } - - impl Stream for BoundedStream { - type Item = Result; - - fn poll_next( - mut self: Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll> { - if self.count >= self.limit { - return Poll::Ready(None); - } - self.count += 1; - Poll::Ready(Some(Ok(self.batch.clone()))) - } - } - - impl RecordBatchStream for BoundedStream { - fn schema(&self) -> SchemaRef { - self.batch.schema() - } - } } diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs index a490315f0d5d..a91480fda430 100644 --- a/datafusion/core/src/datasource/listing/table.rs +++ b/datafusion/core/src/datasource/listing/table.rs @@ -24,34 +24,31 @@ use std::{any::Any, str::FromStr, sync::Arc}; use crate::datasource::{ create_ordering, - file_format::{ - file_compression_type::FileCompressionType, FileFormat, FilePushdownSupport, - }, + file_format::{file_compression_type::FileCompressionType, FileFormat}, physical_plan::FileSinkConfig, }; use crate::execution::context::SessionState; use datafusion_catalog::TableProvider; use datafusion_catalog_listing::metadata::apply_metadata_filters; -use datafusion_common::{config_err, DataFusionError, Result}; +use datafusion_common::{config_err, DataFusionError, Result, ToDFSchema}; use datafusion_datasource::file_scan_config::{FileScanConfig, FileScanConfigBuilder}; use datafusion_datasource::metadata::MetadataColumn; +use datafusion_datasource::schema_adapter::DefaultSchemaAdapterFactory; +use datafusion_execution::config::SessionConfig; use datafusion_expr::dml::InsertOp; -use datafusion_expr::{utils::conjunction, Expr, TableProviderFilterPushDown}; +use datafusion_expr::{Expr, TableProviderFilterPushDown}; use datafusion_expr::{SortExpr, TableType}; use datafusion_physical_plan::empty::EmptyExec; use datafusion_physical_plan::{ExecutionPlan, Statistics}; use arrow::datatypes::{DataType, Field, Schema, SchemaBuilder, SchemaRef}; use datafusion_common::{ - config_datafusion_err, internal_err, plan_err, project_schema, Constraints, - SchemaExt, ToDFSchema, + config_datafusion_err, internal_err, plan_err, project_schema, Constraints, SchemaExt, }; use datafusion_execution::cache::{ cache_manager::FileStatisticsCache, cache_unit::DefaultFileStatisticsCache, }; -use datafusion_physical_expr::{ - create_physical_expr, LexOrdering, PhysicalSortRequirement, -}; +use datafusion_physical_expr::{create_physical_expr, LexOrdering, PhysicalSortRequirement}; use async_trait::async_trait; use datafusion_catalog::Session; @@ -62,24 +59,28 @@ use datafusion_physical_expr_common::sort_expr::LexRequirement; use futures::{future, stream, Stream, StreamExt, TryStreamExt}; use itertools::Itertools; use object_store::ObjectStore; +use datafusion_expr::utils::conjunction; /// Configuration for creating a [`ListingTable`] +/// +/// #[derive(Debug, Clone)] pub struct ListingTableConfig { /// Paths on the `ObjectStore` for creating `ListingTable`. /// They should share the same schema and object store. pub table_paths: Vec, /// Optional `SchemaRef` for the to be created `ListingTable`. + /// + /// See details on [`ListingTableConfig::with_schema`] pub file_schema: Option, - /// Optional `ListingOptions` for the to be created `ListingTable`. + /// Optional [`ListingOptions`] for the to be created [`ListingTable`]. + /// + /// See details on [`ListingTableConfig::with_listing_options`] pub options: Option, } impl ListingTableConfig { - /// Creates new [`ListingTableConfig`]. - /// - /// The [`SchemaRef`] and [`ListingOptions`] are inferred based on - /// the suffix of the provided `table_paths` first element. + /// Creates new [`ListingTableConfig`] for reading the specified URL pub fn new(table_path: ListingTableUrl) -> Self { let table_paths = vec![table_path]; Self { @@ -91,8 +92,7 @@ impl ListingTableConfig { /// Creates new [`ListingTableConfig`] with multiple table paths. /// - /// The [`SchemaRef`] and [`ListingOptions`] are inferred based on - /// the suffix of the provided `table_paths` first element. + /// See [`Self::infer_options`] for details on what happens with multiple paths pub fn new_with_multi_paths(table_paths: Vec) -> Self { Self { table_paths, @@ -100,7 +100,16 @@ impl ListingTableConfig { options: None, } } - /// Add `schema` to [`ListingTableConfig`] + /// Set the `schema` for the overall [`ListingTable`] + /// + /// [`ListingTable`] will automatically coerce, when possible, the schema + /// for individual files to match this schema. + /// + /// If a schema is not provided, it is inferred using + /// [`Self::infer_schema`]. + /// + /// If the schema is provided, it must contain only the fields in the file + /// without the table partitioning columns. pub fn with_schema(self, schema: SchemaRef) -> Self { Self { table_paths: self.table_paths, @@ -110,6 +119,9 @@ impl ListingTableConfig { } /// Add `listing_options` to [`ListingTableConfig`] + /// + /// If not provided, format and other options are inferred via + /// [`Self::infer_options`]. pub fn with_listing_options(self, listing_options: ListingOptions) -> Self { Self { table_paths: self.table_paths, @@ -118,7 +130,7 @@ impl ListingTableConfig { } } - ///Returns a tupe of (file_extension, optional compression_extension) + /// Returns a tuple of `(file_extension, optional compression_extension)` /// /// For example a path ending with blah.test.csv.gz returns `("csv", Some("gz"))` /// For example a path ending with blah.test.csv returns `("csv", None)` @@ -140,7 +152,9 @@ impl ListingTableConfig { } } - /// Infer `ListingOptions` based on `table_path` suffix. + /// Infer `ListingOptions` based on `table_path` and file suffix. + /// + /// The format is inferred based on the first `table_path`. pub async fn infer_options(self, state: &dyn Session) -> Result { let store = if let Some(url) = self.table_paths.first() { state.runtime_env().object_store(url)? @@ -185,7 +199,8 @@ impl ListingTableConfig { let listing_options = ListingOptions::new(file_format) .with_file_extension(listing_file_extension) - .with_target_partitions(state.config().target_partitions()); + .with_target_partitions(state.config().target_partitions()) + .with_collect_stat(state.config().collect_statistics()); Ok(Self { table_paths: self.table_paths, @@ -194,7 +209,13 @@ impl ListingTableConfig { }) } - /// Infer the [`SchemaRef`] based on `table_path` suffix. Requires `self.options` to be set prior to using. + /// Infer the [`SchemaRef`] based on `table_path`s. + /// + /// This method infers the table schema using the first `table_path`. + /// See [`ListingOptions::infer_schema`] for more details + /// + /// # Errors + /// * if `self.options` is not set. See [`Self::with_listing_options`] pub async fn infer_schema(self, state: &dyn Session) -> Result { match self.options { Some(options) => { @@ -214,12 +235,15 @@ impl ListingTableConfig { } } - /// Convenience wrapper for calling `infer_options` and `infer_schema` + /// Convenience method to call both [`Self::infer_options`] and [`Self::infer_schema`] pub async fn infer(self, state: &dyn Session) -> Result { self.infer_options(state).await?.infer_schema(state).await } - /// Infer the partition columns from the path. Requires `self.options` to be set prior to using. + /// Infer the partition columns from `table_paths`. + /// + /// # Errors + /// * if `self.options` is not set. See [`Self::with_listing_options`] pub async fn infer_partitions_from_path(self, state: &dyn Session) -> Result { match self.options { Some(options) => { @@ -279,6 +303,7 @@ pub struct ListingOptions { /// parquet metadata. /// /// See + /// /// NOTE: This attribute stores all equivalent orderings (the outer `Vec`) /// where each ordering consists of an individual lexicographic /// ordering (encapsulated by a `Vec`). If there aren't @@ -296,19 +321,30 @@ impl ListingOptions { /// - use default file extension filter /// - no input partition to discover /// - one target partition - /// - stat collection + /// - do not collect statistics pub fn new(format: Arc) -> Self { Self { file_extension: format.get_ext(), format, table_partition_cols: vec![], - collect_stat: true, + collect_stat: false, target_partitions: 1, file_sort_order: vec![], metadata_cols: vec![], } } + /// Set options from [`SessionConfig`] and returns self. + /// + /// Currently this sets `target_partitions` and `collect_stat` + /// but if more options are added in the future that need to be coordinated + /// they will be synchronized thorugh this method. + pub fn with_session_config_options(mut self, config: &SessionConfig) -> Self { + self = self.with_target_partitions(config.target_partitions()); + self = self.with_collect_stat(config.collect_statistics()); + self + } + /// Set file extension on [`ListingOptions`] and returns self. /// /// # Example @@ -530,11 +566,13 @@ impl ListingOptions { } /// Infer the schema of the files at the given path on the provided object store. - /// The inferred schema does not include the partitioning columns. /// - /// This method will not be called by the table itself but before creating it. - /// This way when creating the logical plan we can decide to resolve the schema - /// locally or ask a remote service to do it (e.g a scheduler). + /// If the table_path contains one or more files (i.e. it is a directory / + /// prefix of files) their schema is merged by calling [`FileFormat::infer_schema`] + /// + /// Note: The inferred schema does not include any partitioning columns. + /// + /// This method is called as part of creating a [`ListingTable`]. pub async fn infer_schema<'a>( &'a self, state: &dyn Session, @@ -727,16 +765,14 @@ impl ListingOptions { /// `ListingTable` also supports limit, filter and projection pushdown for formats that /// support it as such as Parquet. /// -/// # Implementation +/// # See Also /// -/// `ListingTable` Uses [`DataSourceExec`] to execute the data. See that struct -/// for more details. +/// 1. [`ListingTableConfig`]: Configuration options +/// 1. [`DataSourceExec`]: `ExecutionPlan` used by `ListingTable` /// /// [`DataSourceExec`]: crate::datasource::source::DataSourceExec /// -/// # Example -/// -/// To read a directory of parquet files using a [`ListingTable`]: +/// # Example: Read a directory of parquet files using a [`ListingTable`] /// /// ```no_run /// # use datafusion::prelude::SessionContext; @@ -783,7 +819,7 @@ impl ListingOptions { /// # Ok(()) /// # } /// ``` -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct ListingTable { table_paths: Vec, /// `file_schema` contains only the columns physically stored in the data files themselves. @@ -802,8 +838,7 @@ pub struct ListingTable { } impl ListingTable { - /// Create new [`ListingTable`] that lists the FS to get the files - /// to scan. See [`ListingTable`] for and example. + /// Create new [`ListingTable`] /// /// Takes a `ListingTableConfig` as input which requires an `ObjectStore` and `table_path`. /// `ListingOptions` and `SchemaRef` are optional. If they are not @@ -811,6 +846,7 @@ impl ListingTable { /// If the schema is provided then it must be resolved before creating the table /// and should contain the fields of the file without the extended table columns, /// i.e. the partitioning and metadata columns. + /// See documentation and example on [`ListingTable`] and [`ListingTableConfig`] pub fn try_new(config: ListingTableConfig) -> Result { let file_schema = config .file_schema @@ -875,7 +911,7 @@ impl ListingTable { /// If `None`, creates a new [`DefaultFileStatisticsCache`] scoped to this query. pub fn with_cache(mut self, cache: Option) -> Self { self.collected_statistics = - cache.unwrap_or(Arc::new(DefaultFileStatisticsCache::default())); + cache.unwrap_or_else(|| Arc::new(DefaultFileStatisticsCache::default())); self } @@ -1069,7 +1105,7 @@ impl TableProvider for ListingTable { .with_table_partition_cols(table_partition_cols) .with_metadata_cols(metadata_cols) .build(), - filters.as_ref(), + filters.as_ref() ) .await } @@ -1094,18 +1130,6 @@ impl TableProvider for ListingTable { return Ok(TableProviderFilterPushDown::Exact); } - // if we can't push it down completely with only the filename-based/path-based - // column names, then we should check if we can do parquet predicate pushdown - let supports_pushdown = self.options.format.supports_filters_pushdown( - &self.file_schema, - &self.table_schema, - &[filter], - )?; - - if supports_pushdown == FilePushdownSupport::Supported { - return Ok(TableProviderFilterPushDown::Exact); - } - Ok(TableProviderFilterPushDown::Inexact) }) .collect() @@ -1258,12 +1282,24 @@ impl ListingTable { get_files_with_limit(files, limit, self.options.collect_stat).await?; let file_groups = file_group.split_files(self.options.target_partitions); - compute_all_files_statistics( + let (mut file_groups, mut stats) = compute_all_files_statistics( file_groups, self.schema(), self.options.collect_stat, inexact_stats, - ) + )?; + let (schema_mapper, _) = DefaultSchemaAdapterFactory::from_schema(self.schema()) + .map_schema(self.file_schema.as_ref())?; + stats.column_statistics = + schema_mapper.map_column_statistics(&stats.column_statistics)?; + file_groups.iter_mut().try_for_each(|file_group| { + if let Some(stat) = file_group.statistics_mut() { + stat.column_statistics = + schema_mapper.map_column_statistics(&stat.column_statistics)?; + } + Ok::<_, DataFusionError>(()) + })?; + Ok((file_groups, stats)) } /// Collects statistics for a given partitioned file. @@ -1409,7 +1445,9 @@ mod tests { #[tokio::test] async fn read_single_file() -> Result<()> { - let ctx = SessionContext::new(); + let ctx = SessionContext::new_with_config( + SessionConfig::new().with_collect_statistics(true), + ); let table = load_table(&ctx, "alltypes_plain.parquet").await?; let projection = None; @@ -1422,15 +1460,21 @@ mod tests { assert_eq!(exec.output_partitioning().partition_count(), 1); // test metadata - assert_eq!(exec.statistics()?.num_rows, Precision::Exact(8)); - assert_eq!(exec.statistics()?.total_byte_size, Precision::Exact(671)); + assert_eq!( + exec.partition_statistics(None)?.num_rows, + Precision::Exact(8) + ); + assert_eq!( + exec.partition_statistics(None)?.total_byte_size, + Precision::Exact(671) + ); Ok(()) } #[cfg(feature = "parquet")] #[tokio::test] - async fn load_table_stats_by_default() -> Result<()> { + async fn do_not_load_table_stats_by_default() -> Result<()> { use crate::datasource::file_format::parquet::ParquetFormat; let testdata = crate::test_util::parquet_test_data(); @@ -1442,15 +1486,37 @@ mod tests { let opt = ListingOptions::new(Arc::new(ParquetFormat::default())); let schema = opt.infer_schema(&state, &table_path).await?; + let config = ListingTableConfig::new(table_path.clone()) + .with_listing_options(opt) + .with_schema(schema); + let table = ListingTable::try_new(config)?; + + let exec = table.scan(&state, None, &[], None).await?; + assert_eq!(exec.partition_statistics(None)?.num_rows, Precision::Absent); + // TODO correct byte size: https://github.com/apache/datafusion/issues/14936 + assert_eq!( + exec.partition_statistics(None)?.total_byte_size, + Precision::Absent + ); + + let opt = ListingOptions::new(Arc::new(ParquetFormat::default())) + .with_collect_stat(true); + let schema = opt.infer_schema(&state, &table_path).await?; let config = ListingTableConfig::new(table_path) .with_listing_options(opt) .with_schema(schema); let table = ListingTable::try_new(config)?; let exec = table.scan(&state, None, &[], None).await?; - assert_eq!(exec.statistics()?.num_rows, Precision::Exact(8)); + assert_eq!( + exec.partition_statistics(None)?.num_rows, + Precision::Exact(8) + ); // TODO correct byte size: https://github.com/apache/datafusion/issues/14936 - assert_eq!(exec.statistics()?.total_byte_size, Precision::Exact(671)); + assert_eq!( + exec.partition_statistics(None)?.total_byte_size, + Precision::Exact(671) + ); Ok(()) } @@ -1476,8 +1542,11 @@ mod tests { let table = ListingTable::try_new(config)?; let exec = table.scan(&state, None, &[], None).await?; - assert_eq!(exec.statistics()?.num_rows, Precision::Absent); - assert_eq!(exec.statistics()?.total_byte_size, Precision::Absent); + assert_eq!(exec.partition_statistics(None)?.num_rows, Precision::Absent); + assert_eq!( + exec.partition_statistics(None)?.total_byte_size, + Precision::Absent + ); Ok(()) } @@ -1817,7 +1886,7 @@ mod tests { // more files than meta_fetch_concurrency (32) let files: Vec = - (0..64).map(|i| format!("bucket/key1/file{}", i)).collect(); + (0..64).map(|i| format!("bucket/key1/file{i}")).collect(); // Collect references to each string let file_refs: Vec<&str> = files.iter().map(|s| s.as_str()).collect(); assert_list_files_for_exact_paths(file_refs.as_slice(), 5, 5, Some("")).await?; @@ -1994,7 +2063,7 @@ mod tests { let table_paths = files .iter() - .map(|t| ListingTableUrl::parse(format!("test:///{}", t)).unwrap()) + .map(|t| ListingTableUrl::parse(format!("test:///{t}")).unwrap()) .collect(); let config = ListingTableConfig::new_with_multi_paths(table_paths) .with_listing_options(opt) @@ -2315,7 +2384,7 @@ mod tests { let filter_predicate = Expr::BinaryExpr(BinaryExpr::new( Box::new(Expr::Column("column1".into())), Operator::GtEq, - Box::new(Expr::Literal(ScalarValue::Int32(Some(0)))), + Box::new(Expr::Literal(ScalarValue::Int32(Some(0)), None)), )); // Create a new batch of data to insert into the table @@ -2499,8 +2568,10 @@ mod tests { // create table let tmp_dir = TempDir::new()?; - let tmp_path = tmp_dir.into_path(); - let str_path = tmp_path.to_str().expect("Temp path should convert to &str"); + let str_path = tmp_dir + .path() + .to_str() + .expect("Temp path should convert to &str"); session_ctx .sql(&format!( "create external table foo(a varchar, b varchar, c int) \ @@ -2541,7 +2612,7 @@ mod tests { #[tokio::test] async fn test_infer_options_compressed_csv() -> Result<()> { let testdata = crate::test_util::arrow_test_data(); - let filename = format!("{}/csv/aggregate_test_100.csv.gz", testdata); + let filename = format!("{testdata}/csv/aggregate_test_100.csv.gz"); let table_path = ListingTableUrl::parse(filename).unwrap(); let ctx = SessionContext::new(); diff --git a/datafusion/core/src/datasource/listing_table_factory.rs b/datafusion/core/src/datasource/listing_table_factory.rs index 636d1623c5e9..71686c61a8f7 100644 --- a/datafusion/core/src/datasource/listing_table_factory.rs +++ b/datafusion/core/src/datasource/listing_table_factory.rs @@ -111,9 +111,8 @@ impl TableProviderFactory for ListingTableFactory { let table_path = ListingTableUrl::parse(&cmd.location)?; let options = ListingOptions::new(file_format) - .with_collect_stat(state.config().collect_statistics()) .with_file_extension(file_extension) - .with_target_partitions(state.config().target_partitions()) + .with_session_config_options(session_state.config()) .with_table_partition_cols(table_partition_cols); options diff --git a/datafusion/core/src/datasource/mod.rs b/datafusion/core/src/datasource/mod.rs index 25a89644cd2a..b3d69064ff15 100644 --- a/datafusion/core/src/datasource/mod.rs +++ b/datafusion/core/src/datasource/mod.rs @@ -51,28 +51,26 @@ pub use datafusion_physical_expr::create_ordering; #[cfg(all(test, feature = "parquet"))] mod tests { - use crate::prelude::SessionContext; - - use std::fs; - use std::sync::Arc; - - use arrow::array::{Int32Array, StringArray}; - use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; - use arrow::record_batch::RecordBatch; - use datafusion_common::test_util::batches_to_sort_string; - use datafusion_datasource::file_scan_config::FileScanConfigBuilder; use datafusion_datasource::schema_adapter::{ DefaultSchemaAdapterFactory, SchemaAdapter, SchemaAdapterFactory, SchemaMapper, }; - use datafusion_datasource::PartitionedFile; - use datafusion_datasource_parquet::source::ParquetSource; - - use datafusion_common::record_batch; - use ::object_store::path::Path; - use ::object_store::ObjectMeta; - use datafusion_datasource::source::DataSourceExec; + use crate::prelude::SessionContext; + use arrow::{ + array::{Int32Array, StringArray}, + datatypes::{DataType, Field, Schema, SchemaRef}, + record_batch::RecordBatch, + }; + use datafusion_common::{record_batch, test_util::batches_to_sort_string}; + use datafusion_datasource::{ + file::FileSource, file_scan_config::FileScanConfigBuilder, + source::DataSourceExec, PartitionedFile, + }; + use datafusion_datasource_parquet::source::ParquetSource; + use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_physical_plan::collect; + use object_store::{path::Path, ObjectMeta}; + use std::{fs, sync::Arc}; use tempfile::TempDir; #[tokio::test] @@ -81,7 +79,6 @@ mod tests { // record batches returned from parquet. This can be useful for schema evolution // where older files may not have all columns. - use datafusion_execution::object_store::ObjectStoreUrl; let tmp_dir = TempDir::new().unwrap(); let table_dir = tmp_dir.path().join("parquet_test"); fs::DirBuilder::new().create(table_dir.as_path()).unwrap(); @@ -124,10 +121,9 @@ mod tests { let f2 = Field::new("extra_column", DataType::Utf8, true); let schema = Arc::new(Schema::new(vec![f1.clone(), f2.clone()])); - let source = Arc::new( - ParquetSource::default() - .with_schema_adapter_factory(Arc::new(TestSchemaAdapterFactory {})), - ); + let source = ParquetSource::default() + .with_schema_adapter_factory(Arc::new(TestSchemaAdapterFactory {})) + .unwrap(); let base_conf = FileScanConfigBuilder::new( ObjectStoreUrl::local_filesystem(), schema, @@ -264,5 +260,12 @@ mod tests { Ok(RecordBatch::try_new(schema, new_columns).unwrap()) } + + fn map_column_statistics( + &self, + _file_col_statistics: &[datafusion_common::ColumnStatistics], + ) -> datafusion_common::Result> { + unimplemented!() + } } } diff --git a/datafusion/core/src/datasource/physical_plan/arrow_file.rs b/datafusion/core/src/datasource/physical_plan/arrow_file.rs index f0a1f94d87e1..5728746e904b 100644 --- a/datafusion/core/src/datasource/physical_plan/arrow_file.rs +++ b/datafusion/core/src/datasource/physical_plan/arrow_file.rs @@ -15,196 +15,39 @@ // specific language governing permissions and limitations // under the License. -//! Execution plan for reading Arrow files - use std::any::Any; use std::sync::Arc; use crate::datasource::physical_plan::{FileMeta, FileOpenFuture, FileOpener}; use crate::error::Result; +use datafusion_datasource::as_file_source; +use datafusion_datasource::schema_adapter::SchemaAdapterFactory; use arrow::buffer::Buffer; use arrow::datatypes::SchemaRef; use arrow_ipc::reader::FileDecoder; -use datafusion_common::config::ConfigOptions; -use datafusion_common::{Constraints, Statistics}; +use datafusion_common::Statistics; use datafusion_datasource::file::FileSource; use datafusion_datasource::file_scan_config::FileScanConfig; -use datafusion_datasource::source::DataSourceExec; -use datafusion_datasource_json::source::JsonSource; -use datafusion_execution::{SendableRecordBatchStream, TaskContext}; -use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; -use datafusion_physical_expr_common::sort_expr::LexOrdering; -use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; -use datafusion_physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet}; -use datafusion_physical_plan::{ - DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, -}; +use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet; -use datafusion_datasource::file_groups::FileGroup; use futures::StreamExt; use itertools::Itertools; use object_store::{GetOptions, GetRange, GetResultPayload, ObjectStore}; -/// Execution plan for scanning Arrow data source -#[derive(Debug, Clone)] -#[deprecated(since = "46.0.0", note = "use DataSourceExec instead")] -pub struct ArrowExec { - inner: DataSourceExec, - base_config: FileScanConfig, -} - -#[allow(unused, deprecated)] -impl ArrowExec { - /// Create a new Arrow reader execution plan provided base configurations - pub fn new(base_config: FileScanConfig) -> Self { - let ( - projected_schema, - projected_constraints, - projected_statistics, - projected_output_ordering, - ) = base_config.project(); - let cache = Self::compute_properties( - Arc::clone(&projected_schema), - &projected_output_ordering, - projected_constraints, - &base_config, - ); - let arrow = ArrowSource::default(); - let base_config = base_config.with_source(Arc::new(arrow)); - Self { - inner: DataSourceExec::new(Arc::new(base_config.clone())), - base_config, - } - } - /// Ref to the base configs - pub fn base_config(&self) -> &FileScanConfig { - &self.base_config - } - - fn file_scan_config(&self) -> FileScanConfig { - self.inner - .data_source() - .as_any() - .downcast_ref::() - .unwrap() - .clone() - } - - fn json_source(&self) -> JsonSource { - self.file_scan_config() - .file_source() - .as_any() - .downcast_ref::() - .unwrap() - .clone() - } - - fn output_partitioning_helper(file_scan_config: &FileScanConfig) -> Partitioning { - Partitioning::UnknownPartitioning(file_scan_config.file_groups.len()) - } - - /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. - fn compute_properties( - schema: SchemaRef, - output_ordering: &[LexOrdering], - constraints: Constraints, - file_scan_config: &FileScanConfig, - ) -> PlanProperties { - // Equivalence Properties - let eq_properties = - EquivalenceProperties::new_with_orderings(schema, output_ordering) - .with_constraints(constraints); - - PlanProperties::new( - eq_properties, - Self::output_partitioning_helper(file_scan_config), // Output Partitioning - EmissionType::Incremental, - Boundedness::Bounded, - ) - } - - fn with_file_groups(mut self, file_groups: Vec) -> Self { - self.base_config.file_groups = file_groups.clone(); - let mut file_source = self.file_scan_config(); - file_source = file_source.with_file_groups(file_groups); - self.inner = self.inner.with_data_source(Arc::new(file_source)); - self - } -} - -#[allow(unused, deprecated)] -impl DisplayAs for ArrowExec { - fn fmt_as( - &self, - t: DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { - self.inner.fmt_as(t, f) - } -} - -#[allow(unused, deprecated)] -impl ExecutionPlan for ArrowExec { - fn name(&self) -> &'static str { - "ArrowExec" - } - - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { - self.inner.properties() - } - fn children(&self) -> Vec<&Arc> { - Vec::new() - } - - fn with_new_children( - self: Arc, - _: Vec>, - ) -> Result> { - Ok(self) - } - - /// Redistribute files across partitions according to their size - /// See comments on `FileGroupPartitioner` for more detail. - fn repartitioned( - &self, - target_partitions: usize, - config: &ConfigOptions, - ) -> Result>> { - self.inner.repartitioned(target_partitions, config) - } - fn execute( - &self, - partition: usize, - context: Arc, - ) -> Result { - self.inner.execute(partition, context) - } - fn metrics(&self) -> Option { - self.inner.metrics() - } - fn statistics(&self) -> Result { - self.inner.statistics() - } - fn fetch(&self) -> Option { - self.inner.fetch() - } - - fn with_fetch(&self, limit: Option) -> Option> { - self.inner.with_fetch(limit) - } -} - /// Arrow configuration struct that is given to DataSourceExec /// Does not hold anything special, since [`FileScanConfig`] is sufficient for arrow #[derive(Clone, Default)] pub struct ArrowSource { metrics: ExecutionPlanMetricsSet, projected_statistics: Option, + schema_adapter_factory: Option>, +} + +impl From for Arc { + fn from(source: ArrowSource) -> Self { + as_file_source(source) + } } impl FileSource for ArrowSource { @@ -255,6 +98,20 @@ impl FileSource for ArrowSource { fn file_type(&self) -> &str { "arrow" } + + fn with_schema_adapter_factory( + &self, + schema_adapter_factory: Arc, + ) -> Result> { + Ok(Arc::new(Self { + schema_adapter_factory: Some(schema_adapter_factory), + ..self.clone() + })) + } + + fn schema_adapter_factory(&self) -> Option> { + self.schema_adapter_factory.clone() + } } /// The struct arrow that implements `[FileOpener]` trait diff --git a/datafusion/core/src/datasource/physical_plan/json.rs b/datafusion/core/src/datasource/physical_plan/json.rs index 736248fbd95d..0d45711c76fb 100644 --- a/datafusion/core/src/datasource/physical_plan/json.rs +++ b/datafusion/core/src/datasource/physical_plan/json.rs @@ -19,7 +19,6 @@ //! //! [`FileSource`]: datafusion_datasource::file::FileSource -#[allow(deprecated)] pub use datafusion_datasource_json::source::*; #[cfg(test)] diff --git a/datafusion/core/src/datasource/physical_plan/mod.rs b/datafusion/core/src/datasource/physical_plan/mod.rs index e3f237803b34..3f71b253d969 100644 --- a/datafusion/core/src/datasource/physical_plan/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/mod.rs @@ -27,30 +27,18 @@ pub mod parquet; #[cfg(feature = "avro")] pub mod avro; -#[allow(deprecated)] #[cfg(feature = "avro")] -pub use avro::{AvroExec, AvroSource}; +pub use avro::AvroSource; #[cfg(feature = "parquet")] pub use datafusion_datasource_parquet::source::ParquetSource; #[cfg(feature = "parquet")] -#[allow(deprecated)] -pub use datafusion_datasource_parquet::{ - ParquetExec, ParquetExecBuilder, ParquetFileMetrics, ParquetFileReaderFactory, -}; +pub use datafusion_datasource_parquet::{ParquetFileMetrics, ParquetFileReaderFactory}; -#[allow(deprecated)] -pub use arrow_file::ArrowExec; pub use arrow_file::ArrowSource; -#[allow(deprecated)] -pub use json::NdJsonExec; - pub use json::{JsonOpener, JsonSource}; -#[allow(deprecated)] -pub use csv::{CsvExec, CsvExecBuilder}; - pub use csv::{CsvOpener, CsvSource}; pub use datafusion_datasource::file::FileSource; pub use datafusion_datasource::file_groups::FileGroup; diff --git a/datafusion/core/src/datasource/physical_plan/parquet.rs b/datafusion/core/src/datasource/physical_plan/parquet.rs index a44060b16999..61e44e5b45bb 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet.rs @@ -44,7 +44,7 @@ mod tests { use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaBuilder}; use arrow::record_batch::RecordBatch; use arrow::util::pretty::pretty_format_batches; - use arrow_schema::SchemaRef; + use arrow_schema::{SchemaRef, TimeUnit}; use bytes::{BufMut, BytesMut}; use datafusion_common::config::TableParquetOptions; use datafusion_common::test_util::{batches_to_sort_string, batches_to_string}; @@ -54,6 +54,7 @@ mod tests { use datafusion_datasource::file_scan_config::FileScanConfigBuilder; use datafusion_datasource::source::DataSourceExec; + use datafusion_datasource::file::FileSource; use datafusion_datasource::{FileRange, PartitionedFile}; use datafusion_datasource_parquet::source::ParquetSource; use datafusion_datasource_parquet::{ @@ -95,7 +96,11 @@ mod tests { #[derive(Debug, Default)] struct RoundTrip { projection: Option>, - schema: Option, + /// Optional logical table schema to use when reading the parquet files + /// + /// If None, the logical schema to use will be inferred from the + /// original data via [`Schema::try_merge`] + table_schema: Option, predicate: Option, pushdown_predicate: bool, page_index_predicate: bool, @@ -112,8 +117,11 @@ mod tests { self } - fn with_schema(mut self, schema: SchemaRef) -> Self { - self.schema = Some(schema); + /// Specify table schema. + /// + ///See [`Self::table_schema`] for more details + fn with_table_schema(mut self, schema: SchemaRef) -> Self { + self.table_schema = Some(schema); self } @@ -145,16 +153,16 @@ mod tests { self.round_trip(batches).await.batches } - fn build_file_source(&self, file_schema: SchemaRef) -> Arc { + fn build_file_source(&self, table_schema: SchemaRef) -> Arc { // set up predicate (this is normally done by a layer higher up) let predicate = self .predicate .as_ref() - .map(|p| logical2physical(p, &file_schema)); + .map(|p| logical2physical(p, &table_schema)); let mut source = ParquetSource::default(); if let Some(predicate) = predicate { - source = source.with_predicate(Arc::clone(&file_schema), predicate); + source = source.with_predicate(Arc::clone(&table_schema), predicate); } if self.pushdown_predicate { @@ -177,14 +185,14 @@ mod tests { source = source.with_bloom_filter_on_read(false); } - Arc::new(source) + source.with_schema(Arc::clone(&table_schema)) } fn build_parquet_exec( &self, file_schema: SchemaRef, file_group: FileGroup, - source: Arc, + source: Arc, ) -> Arc { let base_config = FileScanConfigBuilder::new( ObjectStoreUrl::local_filesystem(), @@ -198,8 +206,14 @@ mod tests { } /// run the test, returning the `RoundTripResult` + /// + /// Each input batch is written into one or more parquet files (and thus + /// they could potentially have different schemas). The resulting + /// parquet files are then read back and filters are applied to the async fn round_trip(&self, batches: Vec) -> RoundTripResult { - let file_schema = match &self.schema { + // If table_schema is not set, we need to merge the schema of the + // input batches to get a unified schema. + let table_schema = match &self.table_schema { Some(schema) => schema, None => &Arc::new( Schema::try_merge( @@ -208,7 +222,6 @@ mod tests { .unwrap(), ), }; - let file_schema = Arc::clone(file_schema); // If testing with page_index_predicate, write parquet // files with multiple pages let multi_page = self.page_index_predicate; @@ -216,9 +229,9 @@ mod tests { let file_group: FileGroup = meta.into_iter().map(Into::into).collect(); // build a ParquetExec to return the results - let parquet_source = self.build_file_source(file_schema.clone()); + let parquet_source = self.build_file_source(Arc::clone(table_schema)); let parquet_exec = self.build_parquet_exec( - file_schema.clone(), + Arc::clone(table_schema), file_group.clone(), Arc::clone(&parquet_source), ); @@ -228,9 +241,9 @@ mod tests { false, // use a new ParquetSource to avoid sharing execution metrics self.build_parquet_exec( - file_schema.clone(), + Arc::clone(table_schema), file_group.clone(), - self.build_file_source(file_schema.clone()), + self.build_file_source(Arc::clone(table_schema)), ), Arc::new(Schema::new(vec![ Field::new("plan_type", DataType::Utf8, true), @@ -303,7 +316,7 @@ mod tests { // Thus this predicate will come back as false. let filter = col("c2").eq(lit(1_i32)); let rt = RoundTrip::new() - .with_schema(table_schema.clone()) + .with_table_schema(table_schema.clone()) .with_predicate(filter.clone()) .with_pushdown_predicate() .round_trip(vec![batch.clone()]) @@ -322,7 +335,7 @@ mod tests { // If we excplicitly allow nulls the rest of the predicate should work let filter = col("c2").is_null().and(col("c1").eq(lit(1_i32))); let rt = RoundTrip::new() - .with_schema(table_schema.clone()) + .with_table_schema(table_schema.clone()) .with_predicate(filter.clone()) .with_pushdown_predicate() .round_trip(vec![batch.clone()]) @@ -361,7 +374,7 @@ mod tests { // Thus this predicate will come back as false. let filter = col("c2").eq(lit("abc")); let rt = RoundTrip::new() - .with_schema(table_schema.clone()) + .with_table_schema(table_schema.clone()) .with_predicate(filter.clone()) .with_pushdown_predicate() .round_trip(vec![batch.clone()]) @@ -380,7 +393,7 @@ mod tests { // If we excplicitly allow nulls the rest of the predicate should work let filter = col("c2").is_null().and(col("c1").eq(lit(1_i32))); let rt = RoundTrip::new() - .with_schema(table_schema.clone()) + .with_table_schema(table_schema.clone()) .with_predicate(filter.clone()) .with_pushdown_predicate() .round_trip(vec![batch.clone()]) @@ -423,7 +436,7 @@ mod tests { // Thus this predicate will come back as false. let filter = col("c2").eq(lit("abc")); let rt = RoundTrip::new() - .with_schema(table_schema.clone()) + .with_table_schema(table_schema.clone()) .with_predicate(filter.clone()) .with_pushdown_predicate() .round_trip(vec![batch.clone()]) @@ -442,7 +455,7 @@ mod tests { // If we excplicitly allow nulls the rest of the predicate should work let filter = col("c2").is_null().and(col("c1").eq(lit(1_i32))); let rt = RoundTrip::new() - .with_schema(table_schema.clone()) + .with_table_schema(table_schema.clone()) .with_predicate(filter.clone()) .with_pushdown_predicate() .round_trip(vec![batch.clone()]) @@ -485,7 +498,7 @@ mod tests { // Thus this predicate will come back as false. let filter = col("c2").eq(lit("abc")); let rt = RoundTrip::new() - .with_schema(table_schema.clone()) + .with_table_schema(table_schema.clone()) .with_predicate(filter.clone()) .with_pushdown_predicate() .round_trip(vec![batch.clone()]) @@ -504,7 +517,7 @@ mod tests { // If we excplicitly allow nulls the rest of the predicate should work let filter = col("c2").is_null().and(col("c3").eq(lit(7_i32))); let rt = RoundTrip::new() - .with_schema(table_schema.clone()) + .with_table_schema(table_schema.clone()) .with_predicate(filter.clone()) .with_pushdown_predicate() .round_trip(vec![batch.clone()]) @@ -552,7 +565,7 @@ mod tests { .and(col("c3").eq(lit(10_i32)).or(col("c2").is_null())); let rt = RoundTrip::new() - .with_schema(table_schema.clone()) + .with_table_schema(table_schema.clone()) .with_predicate(filter.clone()) .with_pushdown_predicate() .round_trip(vec![batch.clone()]) @@ -582,7 +595,7 @@ mod tests { .or(col("c3").gt(lit(20_i32)).and(col("c2").is_null())); let rt = RoundTrip::new() - .with_schema(table_schema) + .with_table_schema(table_schema) .with_predicate(filter.clone()) .with_pushdown_predicate() .round_trip(vec![batch]) @@ -870,13 +883,15 @@ mod tests { Arc::new(StringViewArray::from(vec![Some("foo"), Some("bar")])); let batch = create_batch(vec![("c1", c1.clone())]); - let schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::Utf8, false)])); + // Table schema is Utf8 but file schema is StringView + let table_schema = + Arc::new(Schema::new(vec![Field::new("c1", DataType::Utf8, false)])); // Predicate should prune all row groups let filter = col("c1").eq(lit(ScalarValue::Utf8(Some("aaa".to_string())))); let rt = RoundTrip::new() .with_predicate(filter) - .with_schema(schema.clone()) + .with_table_schema(table_schema.clone()) .round_trip(vec![batch.clone()]) .await; // There should be no predicate evaluation errors @@ -889,7 +904,7 @@ mod tests { let filter = col("c1").eq(lit(ScalarValue::Utf8(Some("foo".to_string())))); let rt = RoundTrip::new() .with_predicate(filter) - .with_schema(schema) + .with_table_schema(table_schema) .round_trip(vec![batch]) .await; // There should be no predicate evaluation errors @@ -911,14 +926,14 @@ mod tests { let c1: ArrayRef = Arc::new(Int8Array::from(vec![Some(1), Some(2)])); let batch = create_batch(vec![("c1", c1.clone())]); - let schema = + let table_schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::UInt64, false)])); // Predicate should prune all row groups let filter = col("c1").eq(lit(ScalarValue::UInt64(Some(5)))); let rt = RoundTrip::new() .with_predicate(filter) - .with_schema(schema.clone()) + .with_table_schema(table_schema.clone()) .round_trip(vec![batch.clone()]) .await; // There should be no predicate evaluation errors @@ -930,7 +945,7 @@ mod tests { let filter = col("c1").eq(lit(ScalarValue::UInt64(Some(1)))); let rt = RoundTrip::new() .with_predicate(filter) - .with_schema(schema) + .with_table_schema(table_schema) .round_trip(vec![batch]) .await; // There should be no predicate evaluation errors @@ -1182,7 +1197,7 @@ mod tests { // batch2: c3(int8), c2(int64), c1(string), c4(string) let batch2 = create_batch(vec![("c3", c4), ("c2", c2), ("c1", c1)]); - let schema = Schema::new(vec![ + let table_schema = Schema::new(vec![ Field::new("c1", DataType::Utf8, true), Field::new("c2", DataType::Int64, true), Field::new("c3", DataType::Int8, true), @@ -1190,7 +1205,7 @@ mod tests { // read/write them files: let read = RoundTrip::new() - .with_schema(Arc::new(schema)) + .with_table_schema(Arc::new(table_schema)) .round_trip_to_batches(vec![batch1, batch2]) .await; assert_contains!(read.unwrap_err().to_string(), @@ -1326,6 +1341,124 @@ mod tests { Ok(()) } + #[tokio::test] + async fn parquet_exec_with_int96_nested() -> Result<()> { + // This test ensures that we maintain compatibility with coercing int96 to the desired + // resolution when they're within a nested type (e.g., struct, map, list). This file + // originates from a modified CometFuzzTestSuite ParquetGenerator to generate combinations + // of primitive and complex columns using int96. Other tests cover reading the data + // correctly with this coercion. Here we're only checking the coerced schema is correct. + let testdata = "../../datafusion/core/tests/data"; + let filename = "int96_nested.parquet"; + let session_ctx = SessionContext::new(); + let state = session_ctx.state(); + let task_ctx = state.task_ctx(); + + let parquet_exec = scan_format( + &state, + &ParquetFormat::default().with_coerce_int96(Some("us".to_string())), + None, + testdata, + filename, + None, + None, + ) + .await + .unwrap(); + assert_eq!(parquet_exec.output_partitioning().partition_count(), 1); + + let mut results = parquet_exec.execute(0, task_ctx.clone())?; + let batch = results.next().await.unwrap()?; + + let expected_schema = Arc::new(Schema::new(vec![ + Field::new("c0", DataType::Timestamp(TimeUnit::Microsecond, None), true), + Field::new_struct( + "c1", + vec![Field::new( + "c0", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + )], + true, + ), + Field::new_struct( + "c2", + vec![Field::new_list( + "c0", + Field::new( + "element", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + ), + true, + )], + true, + ), + Field::new_map( + "c3", + "key_value", + Field::new( + "key", + DataType::Timestamp(TimeUnit::Microsecond, None), + false, + ), + Field::new( + "value", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + ), + false, + true, + ), + Field::new_list( + "c4", + Field::new( + "element", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + ), + true, + ), + Field::new_list( + "c5", + Field::new_struct( + "element", + vec![Field::new( + "c0", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + )], + true, + ), + true, + ), + Field::new_list( + "c6", + Field::new_map( + "element", + "key_value", + Field::new( + "key", + DataType::Timestamp(TimeUnit::Microsecond, None), + false, + ), + Field::new( + "value", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + ), + false, + true, + ), + true, + ), + ])); + + assert_eq!(batch.schema(), expected_schema); + + Ok(()) + } + #[tokio::test] async fn parquet_exec_with_range() -> Result<()> { fn file_range(meta: &ObjectMeta, start: i64, end: i64) -> PartitionedFile { diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index fc110a0699df..dbe5c2c00f17 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -35,7 +35,11 @@ use crate::{ }, datasource::{provider_as_source, MemTable, ViewTable}, error::{DataFusionError, Result}, - execution::{options::ArrowReadOptions, runtime_env::RuntimeEnv, FunctionRegistry}, + execution::{ + options::ArrowReadOptions, + runtime_env::{RuntimeEnv, RuntimeEnvBuilder}, + FunctionRegistry, + }, logical_expr::AggregateUDF, logical_expr::ScalarUDF, logical_expr::{ @@ -1036,13 +1040,70 @@ impl SessionContext { variable, value, .. } = stmt; - let mut state = self.state.write(); - state.config_mut().options_mut().set(&variable, &value)?; - drop(state); + // Check if this is a runtime configuration + if variable.starts_with("datafusion.runtime.") { + self.set_runtime_variable(&variable, &value)?; + } else { + let mut state = self.state.write(); + state.config_mut().options_mut().set(&variable, &value)?; + drop(state); + } self.return_empty_dataframe() } + fn set_runtime_variable(&self, variable: &str, value: &str) -> Result<()> { + let key = variable.strip_prefix("datafusion.runtime.").unwrap(); + + match key { + "memory_limit" => { + let memory_limit = Self::parse_memory_limit(value)?; + + let mut state = self.state.write(); + let mut builder = + RuntimeEnvBuilder::from_runtime_env(state.runtime_env()); + builder = builder.with_memory_limit(memory_limit, 1.0); + *state = SessionStateBuilder::from(state.clone()) + .with_runtime_env(Arc::new(builder.build()?)) + .build(); + } + _ => { + return Err(DataFusionError::Plan(format!( + "Unknown runtime configuration: {variable}" + ))) + } + } + Ok(()) + } + + /// Parse memory limit from string to number of bytes + /// Supports formats like '1.5G', '100M', '512K' + /// + /// # Examples + /// ``` + /// use datafusion::execution::context::SessionContext; + /// + /// assert_eq!(SessionContext::parse_memory_limit("1M").unwrap(), 1024 * 1024); + /// assert_eq!(SessionContext::parse_memory_limit("1.5G").unwrap(), (1.5 * 1024.0 * 1024.0 * 1024.0) as usize); + /// ``` + pub fn parse_memory_limit(limit: &str) -> Result { + let (number, unit) = limit.split_at(limit.len() - 1); + let number: f64 = number.parse().map_err(|_| { + DataFusionError::Plan(format!( + "Failed to parse number from memory limit '{limit}'" + )) + })?; + + match unit { + "K" => Ok((number * 1024.0) as usize), + "M" => Ok((number * 1024.0 * 1024.0) as usize), + "G" => Ok((number * 1024.0 * 1024.0 * 1024.0) as usize), + _ => Err(DataFusionError::Plan(format!( + "Unsupported unit '{unit}' in memory limit '{limit}'" + ))), + } + } + async fn create_custom_table( &self, cmd: &CreateExternalTable, @@ -1153,7 +1214,7 @@ impl SessionContext { let mut params: Vec = parameters .into_iter() .map(|e| match e { - Expr::Literal(scalar) => Ok(scalar), + Expr::Literal(scalar, _) => Ok(scalar), _ => not_impl_err!("Unsupported parameter type: {}", e), }) .collect::>()?; @@ -1647,7 +1708,7 @@ impl FunctionRegistry for SessionContext { } fn expr_planners(&self) -> Vec> { - self.state.read().expr_planners() + self.state.read().expr_planners().to_vec() } fn register_expr_planner( @@ -1833,7 +1894,6 @@ mod tests { use crate::test; use crate::test_util::{plan_and_collect, populate_csv_partitions}; use arrow::datatypes::{DataType, TimeUnit}; - use std::env; use std::error::Error; use std::path::PathBuf; diff --git a/datafusion/core/src/execution/context/parquet.rs b/datafusion/core/src/execution/context/parquet.rs index 6ec9796fe90d..2fb763bee495 100644 --- a/datafusion/core/src/execution/context/parquet.rs +++ b/datafusion/core/src/execution/context/parquet.rs @@ -31,6 +31,21 @@ impl SessionContext { /// [`read_table`](Self::read_table) with a [`super::ListingTable`]. /// /// For an example, see [`read_csv`](Self::read_csv) + /// + /// # Note: Statistics + /// + /// NOTE: by default, statistics are collected when reading the Parquet + /// files This can slow down the initial DataFrame creation while + /// greatly accelerating queries with certain filters. + /// + /// To disable statistics collection, set the [config option] + /// `datafusion.execution.collect_statistics` to `false`. See + /// [`ConfigOptions`] and [`ExecutionOptions::collect_statistics`] for more + /// details. + /// + /// [config option]: https://datafusion.apache.org/user-guide/configs.html + /// [`ConfigOptions`]: crate::config::ConfigOptions + /// [`ExecutionOptions::collect_statistics`]: crate::config::ExecutionOptions::collect_statistics pub async fn read_parquet( &self, table_paths: P, @@ -41,6 +56,13 @@ impl SessionContext { /// Registers a Parquet file as a table that can be referenced from SQL /// statements executed against this context. + /// + /// # Note: Statistics + /// + /// Statistics are not collected by default. See [`read_parquet`] for more + /// details and how to enable them. + /// + /// [`read_parquet`]: Self::read_parquet pub async fn register_parquet( &self, table_ref: impl Into, @@ -84,6 +106,8 @@ mod tests { use crate::parquet::basic::Compression; use crate::test_util::parquet_test_data; + use arrow::util::pretty::pretty_format_batches; + use datafusion_common::assert_contains; use datafusion_common::config::TableParquetOptions; use datafusion_execution::config::SessionConfig; @@ -129,6 +153,49 @@ mod tests { Ok(()) } + async fn explain_query_all_with_config(config: SessionConfig) -> Result { + let ctx = SessionContext::new_with_config(config); + + ctx.register_parquet( + "test", + &format!("{}/alltypes_plain*.parquet", parquet_test_data()), + ParquetReadOptions::default(), + ) + .await?; + let df = ctx.sql("EXPLAIN SELECT * FROM test").await?; + let results = df.collect().await?; + let content = pretty_format_batches(&results).unwrap().to_string(); + Ok(content) + } + + #[tokio::test] + async fn register_parquet_respects_collect_statistics_config() -> Result<()> { + // The default is true + let mut config = SessionConfig::new(); + config.options_mut().explain.physical_plan_only = true; + config.options_mut().explain.show_statistics = true; + let content = explain_query_all_with_config(config).await?; + assert_contains!(content, "statistics=[Rows=Exact("); + + // Explicitly set to true + let mut config = SessionConfig::new(); + config.options_mut().explain.physical_plan_only = true; + config.options_mut().explain.show_statistics = true; + config.options_mut().execution.collect_statistics = true; + let content = explain_query_all_with_config(config).await?; + assert_contains!(content, "statistics=[Rows=Exact("); + + // Explicitly set to false + let mut config = SessionConfig::new(); + config.options_mut().explain.physical_plan_only = true; + config.options_mut().explain.show_statistics = true; + config.options_mut().execution.collect_statistics = false; + let content = explain_query_all_with_config(config).await?; + assert_contains!(content, "statistics=[Rows=Absent,"); + + Ok(()) + } + #[tokio::test] async fn read_from_registered_table_with_glob_path() -> Result<()> { let ctx = SessionContext::new(); @@ -286,7 +353,7 @@ mod tests { let expected_path = binding[0].as_str(); assert_eq!( read_df.unwrap_err().strip_backtrace(), - format!("Execution error: File path '{}' does not match the expected extension '.parquet'", expected_path) + format!("Execution error: File path '{expected_path}' does not match the expected extension '.parquet'") ); // Read the dataframe from 'output3.parquet.snappy.parquet' with the correct file extension. diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index 28f599304f8c..8aa812cc5258 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -552,6 +552,11 @@ impl SessionState { &self.optimizer } + /// Returns the [`ExprPlanner`]s for this session + pub fn expr_planners(&self) -> &[Arc] { + &self.expr_planners + } + /// Returns the [`QueryPlanner`] for this session pub fn query_planner(&self) -> &Arc { &self.query_planner @@ -1348,28 +1353,30 @@ impl SessionStateBuilder { } = self; let config = config.unwrap_or_default(); - let runtime_env = runtime_env.unwrap_or(Arc::new(RuntimeEnv::default())); + let runtime_env = runtime_env.unwrap_or_else(|| Arc::new(RuntimeEnv::default())); let mut state = SessionState { - session_id: session_id.unwrap_or(Uuid::new_v4().to_string()), + session_id: session_id.unwrap_or_else(|| Uuid::new_v4().to_string()), analyzer: analyzer.unwrap_or_default(), expr_planners: expr_planners.unwrap_or_default(), type_planner, optimizer: optimizer.unwrap_or_default(), physical_optimizers: physical_optimizers.unwrap_or_default(), - query_planner: query_planner.unwrap_or(Arc::new(DefaultQueryPlanner {})), - catalog_list: catalog_list - .unwrap_or(Arc::new(MemoryCatalogProviderList::new()) - as Arc), + query_planner: query_planner + .unwrap_or_else(|| Arc::new(DefaultQueryPlanner {})), + catalog_list: catalog_list.unwrap_or_else(|| { + Arc::new(MemoryCatalogProviderList::new()) as Arc + }), table_functions: table_functions.unwrap_or_default(), scalar_functions: HashMap::new(), aggregate_functions: HashMap::new(), window_functions: HashMap::new(), serializer_registry: serializer_registry - .unwrap_or(Arc::new(EmptySerializerRegistry)), + .unwrap_or_else(|| Arc::new(EmptySerializerRegistry)), file_formats: HashMap::new(), - table_options: table_options - .unwrap_or(TableOptions::default_from_session_config(config.options())), + table_options: table_options.unwrap_or_else(|| { + TableOptions::default_from_session_config(config.options()) + }), config, execution_props: execution_props.unwrap_or_default(), table_factories: table_factories.unwrap_or_default(), @@ -1635,7 +1642,7 @@ struct SessionContextProvider<'a> { impl ContextProvider for SessionContextProvider<'_> { fn get_expr_planners(&self) -> &[Arc] { - &self.state.expr_planners + self.state.expr_planners() } fn get_type_planner(&self) -> Option> { @@ -1751,7 +1758,7 @@ impl FunctionRegistry for SessionState { let result = self.scalar_functions.get(name); result.cloned().ok_or_else(|| { - plan_datafusion_err!("There is no UDF named \"{name}\" in the registry") + plan_datafusion_err!("There is no UDF named \"{name}\" in the registry. Use session context `register_udf` function to register a custom UDF") }) } @@ -1759,7 +1766,7 @@ impl FunctionRegistry for SessionState { let result = self.aggregate_functions.get(name); result.cloned().ok_or_else(|| { - plan_datafusion_err!("There is no UDAF named \"{name}\" in the registry") + plan_datafusion_err!("There is no UDAF named \"{name}\" in the registry. Use session context `register_udaf` function to register a custom UDAF") }) } @@ -1767,7 +1774,7 @@ impl FunctionRegistry for SessionState { let result = self.window_functions.get(name); result.cloned().ok_or_else(|| { - plan_datafusion_err!("There is no UDWF named \"{name}\" in the registry") + plan_datafusion_err!("There is no UDWF named \"{name}\" in the registry. Use session context `register_udwf` function to register a custom UDWF") }) } @@ -1957,8 +1964,17 @@ pub(crate) struct PreparedPlan { #[cfg(test)] mod tests { use super::{SessionContextProvider, SessionStateBuilder}; + use crate::common::assert_contains; + use crate::config::ConfigOptions; + use crate::datasource::empty::EmptyTable; + use crate::datasource::provider_as_source; use crate::datasource::MemTable; use crate::execution::context::SessionState; + use crate::logical_expr::planner::ExprPlanner; + use crate::logical_expr::{AggregateUDF, ScalarUDF, TableSource, WindowUDF}; + use crate::physical_plan::ExecutionPlan; + use crate::sql::planner::ContextProvider; + use crate::sql::{ResolvedTableReference, TableReference}; use arrow::array::{ArrayRef, Int32Array, RecordBatch, StringArray}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_catalog::MemoryCatalogProviderList; @@ -1968,6 +1984,7 @@ mod tests { use datafusion_expr::Expr; use datafusion_optimizer::optimizer::OptimizerRule; use datafusion_optimizer::Optimizer; + use datafusion_physical_plan::display::DisplayableExecutionPlan; use datafusion_sql::planner::{PlannerContext, SqlToRel}; use std::collections::HashMap; use std::sync::Arc; @@ -2125,4 +2142,148 @@ mod tests { Ok(()) } + + /// This test demonstrates why it's more convenient and somewhat necessary to provide + /// an `expr_planners` method for `SessionState`. + #[tokio::test] + async fn test_with_expr_planners() -> Result<()> { + // A helper method for planning count wildcard with or without expr planners. + async fn plan_count_wildcard( + with_expr_planners: bool, + ) -> Result> { + let mut context_provider = MyContextProvider::new().with_table( + "t", + provider_as_source(Arc::new(EmptyTable::new(Schema::empty().into()))), + ); + if with_expr_planners { + context_provider = context_provider.with_expr_planners(); + } + + let state = &context_provider.state; + let statement = state.sql_to_statement("select count(*) from t", "mysql")?; + let plan = SqlToRel::new(&context_provider).statement_to_plan(statement)?; + state.create_physical_plan(&plan).await + } + + // Planning count wildcard without expr planners should fail. + let got = plan_count_wildcard(false).await; + assert_contains!( + got.unwrap_err().to_string(), + "Physical plan does not support logical expression Wildcard" + ); + + // Planning count wildcard with expr planners should succeed. + let got = plan_count_wildcard(true).await?; + let displayable = DisplayableExecutionPlan::new(got.as_ref()); + assert_eq!( + displayable.indent(false).to_string(), + "ProjectionExec: expr=[0 as count(*)]\n PlaceholderRowExec\n" + ); + + Ok(()) + } + + /// A `ContextProvider` based on `SessionState`. + /// + /// Almost all planning context are retrieved from the `SessionState`. + struct MyContextProvider { + /// The session state. + state: SessionState, + /// Registered tables. + tables: HashMap>, + /// Controls whether to return expression planners when called `ContextProvider::expr_planners`. + return_expr_planners: bool, + } + + impl MyContextProvider { + /// Creates a new `SessionContextProvider`. + pub fn new() -> Self { + Self { + state: SessionStateBuilder::default() + .with_default_features() + .build(), + tables: HashMap::new(), + return_expr_planners: false, + } + } + + /// Registers a table. + /// + /// The catalog and schema are provided by default. + pub fn with_table(mut self, table: &str, source: Arc) -> Self { + self.tables.insert( + ResolvedTableReference { + catalog: "default".to_string().into(), + schema: "public".to_string().into(), + table: table.to_string().into(), + }, + source, + ); + self + } + + /// Sets the `return_expr_planners` flag to true. + pub fn with_expr_planners(self) -> Self { + Self { + return_expr_planners: true, + ..self + } + } + } + + impl ContextProvider for MyContextProvider { + fn get_table_source(&self, name: TableReference) -> Result> { + let resolved_table_ref = ResolvedTableReference { + catalog: "default".to_string().into(), + schema: "public".to_string().into(), + table: name.table().to_string().into(), + }; + let source = self.tables.get(&resolved_table_ref).cloned().unwrap(); + Ok(source) + } + + /// We use a `return_expr_planners` flag to demonstrate why it's necessary to + /// return the expression planners in the `SessionState`. + /// + /// Note, the default implementation returns an empty slice. + fn get_expr_planners(&self) -> &[Arc] { + if self.return_expr_planners { + self.state.expr_planners() + } else { + &[] + } + } + + fn get_function_meta(&self, name: &str) -> Option> { + self.state.scalar_functions().get(name).cloned() + } + + fn get_aggregate_meta(&self, name: &str) -> Option> { + self.state.aggregate_functions().get(name).cloned() + } + + fn get_window_meta(&self, name: &str) -> Option> { + self.state.window_functions().get(name).cloned() + } + + fn get_variable_type(&self, _variable_names: &[String]) -> Option { + None + } + + fn options(&self) -> &ConfigOptions { + self.state.config_options() + } + + fn udf_names(&self) -> Vec { + self.state.scalar_functions().keys().cloned().collect() + } + + fn udaf_names(&self) -> Vec { + self.state.aggregate_functions().keys().cloned().collect() + } + + fn udwf_names(&self) -> Vec { + self.state.window_functions().keys().cloned().collect() + } + } } diff --git a/datafusion/core/src/lib.rs b/datafusion/core/src/lib.rs index cc510bc81f1a..6956108e2df3 100644 --- a/datafusion/core/src/lib.rs +++ b/datafusion/core/src/lib.rs @@ -22,7 +22,18 @@ #![cfg_attr(docsrs, feature(doc_auto_cfg))] // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 -#![cfg_attr(not(test), deny(clippy::clone_on_ref_ptr))] +// +// Eliminate unnecessary function calls(some may be not cheap) due to `xxx_or` +// for performance. Also avoid abusing `xxx_or_else` for readability: +// https://github.com/apache/datafusion/issues/15802 +#![cfg_attr( + not(test), + deny( + clippy::clone_on_ref_ptr, + clippy::or_fun_call, + clippy::unnecessary_lazy_evaluations + ) +)] #![warn(missing_docs, clippy::needless_borrow)] //! [DataFusion] is an extensible query engine written in Rust that @@ -300,14 +311,17 @@ //! ``` //! //! A [`TableProvider`] provides information for planning and -//! an [`ExecutionPlan`]s for execution. DataFusion includes [`ListingTable`] -//! which supports reading several common file formats, and you can support any -//! new file format by implementing the [`TableProvider`] trait. See also: +//! an [`ExecutionPlan`] for execution. DataFusion includes [`ListingTable`], +//! a [`TableProvider`] which reads individual files or directories of files +//! ("partitioned datasets") of the same file format. Users can add +//! support for new file formats by implementing the [`TableProvider`] +//! trait. //! -//! 1. [`ListingTable`]: Reads data from Parquet, JSON, CSV, or AVRO -//! files. Supports single files or multiple files with HIVE style -//! partitioning, optional compression, directly reading from remote -//! object store and more. +//! See also: +//! +//! 1. [`ListingTable`]: Reads data from one or more Parquet, JSON, CSV, or AVRO +//! files supporting HIVE style partitioning, optional compression, directly +//! reading from remote object store and more. //! //! 2. [`MemTable`]: Reads data from in memory [`RecordBatch`]es. //! @@ -326,11 +340,11 @@ //! A [`LogicalPlan`] is a Directed Acyclic Graph (DAG) of other //! [`LogicalPlan`]s, each potentially containing embedded [`Expr`]s. //! -//! `LogicalPlan`s can be rewritten with [`TreeNode`] API, see the +//! [`LogicalPlan`]s can be rewritten with [`TreeNode`] API, see the //! [`tree_node module`] for more details. //! //! [`Expr`]s can also be rewritten with [`TreeNode`] API and simplified using -//! [`ExprSimplifier`]. Examples of working with and executing `Expr`s can be +//! [`ExprSimplifier`]. Examples of working with and executing [`Expr`]s can be //! found in the [`expr_api`.rs] example //! //! [`TreeNode`]: datafusion_common::tree_node::TreeNode @@ -415,17 +429,17 @@ //! //! ## Streaming Execution //! -//! DataFusion is a "streaming" query engine which means `ExecutionPlan`s incrementally +//! DataFusion is a "streaming" query engine which means [`ExecutionPlan`]s incrementally //! read from their input(s) and compute output one [`RecordBatch`] at a time //! by continually polling [`SendableRecordBatchStream`]s. Output and -//! intermediate `RecordBatch`s each have approximately `batch_size` rows, +//! intermediate [`RecordBatch`]s each have approximately `batch_size` rows, //! which amortizes per-batch overhead of execution. //! //! Note that certain operations, sometimes called "pipeline breakers", //! (for example full sorts or hash aggregations) are fundamentally non streaming and //! must read their input fully before producing **any** output. As much as possible, //! other operators read a single [`RecordBatch`] from their input to produce a -//! single `RecordBatch` as output. +//! single [`RecordBatch`] as output. //! //! For example, given this SQL query: //! @@ -434,9 +448,9 @@ //! ``` //! //! The diagram below shows the call sequence when a consumer calls [`next()`] to -//! get the next `RecordBatch` of output. While it is possible that some +//! get the next [`RecordBatch`] of output. While it is possible that some //! steps run on different threads, typically tokio will use the same thread -//! that called `next()` to read from the input, apply the filter, and +//! that called [`next()`] to read from the input, apply the filter, and //! return the results without interleaving any other operations. This results //! in excellent cache locality as the same CPU core that produces the data often //! consumes it immediately as well. @@ -474,35 +488,35 @@ //! DataFusion automatically runs each plan with multiple CPU cores using //! a [Tokio] [`Runtime`] as a thread pool. While tokio is most commonly used //! for asynchronous network I/O, the combination of an efficient, work-stealing -//! scheduler and first class compiler support for automatic continuation -//! generation (`async`), also makes it a compelling choice for CPU intensive +//! scheduler, and first class compiler support for automatic continuation +//! generation (`async`) also makes it a compelling choice for CPU intensive //! applications as explained in the [Using Rustlang’s Async Tokio //! Runtime for CPU-Bound Tasks] blog. //! //! The number of cores used is determined by the `target_partitions` //! configuration setting, which defaults to the number of CPU cores. //! While preparing for execution, DataFusion tries to create this many distinct -//! `async` [`Stream`]s for each `ExecutionPlan`. -//! The `Stream`s for certain `ExecutionPlans`, such as as [`RepartitionExec`] +//! `async` [`Stream`]s for each [`ExecutionPlan`]. +//! The [`Stream`]s for certain [`ExecutionPlan`]s, such as [`RepartitionExec`] //! and [`CoalescePartitionsExec`], spawn [Tokio] [`task`]s, that are run by -//! threads managed by the `Runtime`. -//! Many DataFusion `Stream`s perform CPU intensive processing. +//! threads managed by the [`Runtime`]. +//! Many DataFusion [`Stream`]s perform CPU intensive processing. //! //! Using `async` for CPU intensive tasks makes it easy for [`TableProvider`]s //! to perform network I/O using standard Rust `async` during execution. //! However, this design also makes it very easy to mix CPU intensive and latency //! sensitive I/O work on the same thread pool ([`Runtime`]). -//! Using the same (default) `Runtime` is convenient, and often works well for +//! Using the same (default) [`Runtime`] is convenient, and often works well for //! initial development and processing local files, but it can lead to problems //! under load and/or when reading from network sources such as AWS S3. //! //! If your system does not fully utilize either the CPU or network bandwidth //! during execution, or you see significantly higher tail (e.g. p99) latencies //! responding to network requests, **it is likely you need to use a different -//! `Runtime` for CPU intensive DataFusion plans**. This effect can be especially +//! [`Runtime`] for CPU intensive DataFusion plans**. This effect can be especially //! pronounced when running several queries concurrently. //! -//! As shown in the following figure, using the same `Runtime` for both CPU +//! As shown in the following figure, using the same [`Runtime`] for both CPU //! intensive processing and network requests can introduce significant //! delays in responding to those network requests. Delays in processing network //! requests can and does lead network flow control to throttle the available @@ -603,8 +617,8 @@ //! The state required to execute queries is managed by the following //! structures: //! -//! 1. [`SessionContext`]: State needed for create [`LogicalPlan`]s such -//! as the table definitions, and the function registries. +//! 1. [`SessionContext`]: State needed to create [`LogicalPlan`]s such +//! as the table definitions and the function registries. //! //! 2. [`TaskContext`]: State needed for execution such as the //! [`MemoryPool`], [`DiskManager`], and [`ObjectStoreRegistry`]. @@ -872,6 +886,12 @@ doc_comment::doctest!( user_guide_configs ); +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/user-guide/runtime_configs.md", + user_guide_runtime_configs +); + #[cfg(doctest)] doc_comment::doctest!( "../../../docs/source/user-guide/crate-configuration.md", @@ -1021,8 +1041,8 @@ doc_comment::doctest!( #[cfg(doctest)] doc_comment::doctest!( - "../../../docs/source/user-guide/sql/write_options.md", - user_guide_sql_write_options + "../../../docs/source/user-guide/sql/format_options.md", + user_guide_sql_format_options ); #[cfg(doctest)] diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index be24206c676c..c65fcb4c4c93 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -62,7 +62,9 @@ use arrow::array::{builder::StringBuilder, RecordBatch}; use arrow::compute::SortOptions; use arrow::datatypes::{Schema, SchemaRef}; use datafusion_common::display::ToStringifiedPlan; -use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor}; +use datafusion_common::tree_node::{ + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeVisitor, +}; use datafusion_common::{ exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_err, DFSchema, ScalarValue, @@ -522,7 +524,7 @@ impl DefaultPhysicalPlanner { Some("true") => true, Some("false") => false, Some(value) => - return Err(DataFusionError::Configuration(format!("provided value for 'execution.keep_partition_by_columns' was not recognized: \"{}\"", value))), + return Err(DataFusionError::Configuration(format!("provided value for 'execution.keep_partition_by_columns' was not recognized: \"{value}\""))), }; let sink_format = file_type_to_format(file_type)? @@ -572,27 +574,25 @@ impl DefaultPhysicalPlanner { let input_exec = children.one()?; let get_sort_keys = |expr: &Expr| match expr { - Expr::WindowFunction(WindowFunction { - params: - WindowFunctionParams { - ref partition_by, - ref order_by, - .. - }, - .. - }) => generate_sort_key(partition_by, order_by), + Expr::WindowFunction(window_fun) => { + let WindowFunctionParams { + ref partition_by, + ref order_by, + .. + } = &window_fun.as_ref().params; + generate_sort_key(partition_by, order_by) + } Expr::Alias(Alias { expr, .. }) => { // Convert &Box to &T match &**expr { - Expr::WindowFunction(WindowFunction { - params: - WindowFunctionParams { - ref partition_by, - ref order_by, - .. - }, - .. - }) => generate_sort_key(partition_by, order_by), + Expr::WindowFunction(window_fun) => { + let WindowFunctionParams { + ref partition_by, + ref order_by, + .. + } = &window_fun.as_ref().params; + generate_sort_key(partition_by, order_by) + } _ => unreachable!(), } } @@ -694,7 +694,7 @@ impl DefaultPhysicalPlanner { } return internal_err!("Physical input schema should be the same as the one converted from logical input schema. Differences: {}", differences .iter() - .map(|s| format!("\n\t- {}", s)) + .map(|s| format!("\n\t- {s}")) .join("")); } @@ -1113,8 +1113,6 @@ impl DefaultPhysicalPlanner { && !prefer_hash_join { // Use SortMergeJoin if hash join is not preferred - // Sort-Merge join support currently is experimental - let join_on_len = join_on.len(); Arc::new(SortMergeJoinExec::try_new( physical_left, @@ -1506,17 +1504,18 @@ pub fn create_window_expr_with_name( let name = name.into(); let physical_schema: &Schema = &logical_schema.into(); match e { - Expr::WindowFunction(WindowFunction { - fun, - params: - WindowFunctionParams { - args, - partition_by, - order_by, - window_frame, - null_treatment, - }, - }) => { + Expr::WindowFunction(window_fun) => { + let WindowFunction { + fun, + params: + WindowFunctionParams { + args, + partition_by, + order_by, + window_frame, + null_treatment, + }, + } = window_fun.as_ref(); let physical_args = create_physical_exprs(args, logical_schema, execution_props)?; let partition_by = @@ -1714,6 +1713,14 @@ impl DefaultPhysicalPlanner { let config = &session_state.config_options().explain; let explain_format = &e.explain_format; + if !e.logical_optimization_succeeded { + return Ok(Arc::new(ExplainExec::new( + Arc::clone(e.schema.inner()), + e.stringified_plans.clone(), + true, + ))); + } + match explain_format { ExplainFormat::Indent => { /* fall through */ } ExplainFormat::Tree => { @@ -1952,7 +1959,7 @@ impl DefaultPhysicalPlanner { "Optimized physical plan:\n{}\n", displayable(new_plan.as_ref()).indent(false) ); - trace!("Detailed optimized physical plan:\n{:?}", new_plan); + trace!("Detailed optimized physical plan:\n{new_plan:?}"); Ok(new_plan) } @@ -2069,29 +2076,36 @@ fn maybe_fix_physical_column_name( expr: Result>, input_physical_schema: &SchemaRef, ) -> Result> { - if let Ok(e) = &expr { - if let Some(column) = e.as_any().downcast_ref::() { - let physical_field = input_physical_schema.field(column.index()); + let Ok(expr) = expr else { return expr }; + expr.transform_down(|node| { + if let Some(column) = node.as_any().downcast_ref::() { + let idx = column.index(); + let physical_field = input_physical_schema.field(idx); let expr_col_name = column.name(); let physical_name = physical_field.name(); - if physical_name != expr_col_name { + if expr_col_name != physical_name { // handle edge cases where the physical_name contains ':'. let colon_count = physical_name.matches(':').count(); let mut splits = expr_col_name.match_indices(':'); let split_pos = splits.nth(colon_count); - if let Some((idx, _)) = split_pos { - let base_name = &expr_col_name[..idx]; + if let Some((i, _)) = split_pos { + let base_name = &expr_col_name[..i]; if base_name == physical_name { - let updated_column = Column::new(physical_name, column.index()); - return Ok(Arc::new(updated_column)); + let updated_column = Column::new(physical_name, idx); + return Ok(Transformed::yes(Arc::new(updated_column))); } } } + + // If names already match or fix is not possible, just leave it as it is + Ok(Transformed::no(node)) + } else { + Ok(Transformed::no(node)) } - } - expr + }) + .data() } struct OptimizationInvariantChecker<'a> { @@ -2192,11 +2206,16 @@ mod tests { use arrow::array::{ArrayRef, DictionaryArray, Int32Array}; use arrow::datatypes::{DataType, Field, Int32Type}; use datafusion_common::config::ConfigOptions; - use datafusion_common::{assert_contains, DFSchemaRef, TableReference}; + use datafusion_common::{ + assert_contains, DFSchemaRef, TableReference, ToDFSchema as _, + }; use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_execution::TaskContext; - use datafusion_expr::{col, lit, LogicalPlanBuilder, UserDefinedLogicalNodeCore}; + use datafusion_expr::{ + col, lit, LogicalPlanBuilder, Operator, UserDefinedLogicalNodeCore, + }; use datafusion_functions_aggregate::expr_fn::sum; + use datafusion_physical_expr::expressions::{BinaryExpr, IsNotNullExpr}; use datafusion_physical_expr::EquivalenceProperties; use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; @@ -2238,7 +2257,8 @@ mod tests { // verify that the plan correctly casts u8 to i64 // the cast from u8 to i64 for literal will be simplified, and get lit(int64(5)) // the cast here is implicit so has CastOptions with safe=true - let expected = "BinaryExpr { left: Column { name: \"c7\", index: 2 }, op: Lt, right: Literal { value: Int64(5) }, fail_on_overflow: false }"; + let expected = r#"BinaryExpr { left: Column { name: "c7", index: 2 }, op: Lt, right: Literal { value: Int64(5), field: Field { name: "5", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, fail_on_overflow: false }"#; + assert!(format!("{exec_plan:?}").contains(expected)); Ok(()) } @@ -2263,7 +2283,7 @@ mod tests { &session_state, ); - let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { value: Utf8(NULL) }, "c1"), (Literal { value: Int64(NULL) }, "c2"), (Literal { value: Int64(NULL) }, "c3")], groups: [[false, false, false], [true, false, false], [false, true, false], [false, false, true], [true, true, false], [true, false, true], [false, true, true], [true, true, true]] })"#; + let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { value: Utf8(NULL), field: Field { name: "NULL", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} } }, "c1"), (Literal { value: Int64(NULL), field: Field { name: "NULL", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} } }, "c2"), (Literal { value: Int64(NULL), field: Field { name: "NULL", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} } }, "c3")], groups: [[false, false, false], [true, false, false], [false, true, false], [false, false, true], [true, true, false], [true, false, true], [false, true, true], [true, true, true]] })"#; assert_eq!(format!("{cube:?}"), expected); @@ -2290,7 +2310,7 @@ mod tests { &session_state, ); - let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { value: Utf8(NULL) }, "c1"), (Literal { value: Int64(NULL) }, "c2"), (Literal { value: Int64(NULL) }, "c3")], groups: [[true, true, true], [false, true, true], [false, false, true], [false, false, false]] })"#; + let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { value: Utf8(NULL), field: Field { name: "NULL", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} } }, "c1"), (Literal { value: Int64(NULL), field: Field { name: "NULL", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} } }, "c2"), (Literal { value: Int64(NULL), field: Field { name: "NULL", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} } }, "c3")], groups: [[true, true, true], [false, true, true], [false, false, true], [false, false, false]] })"#; assert_eq!(format!("{rollup:?}"), expected); @@ -2474,7 +2494,7 @@ mod tests { let execution_plan = plan(&logical_plan).await?; // verify that the plan correctly adds cast from Int64(1) to Utf8, and the const will be evaluated. - let expected = "expr: [(BinaryExpr { left: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"a\") }, fail_on_overflow: false }, op: Or, right: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"1\") }, fail_on_overflow: false }, fail_on_overflow: false }"; + let expected = "expr: [(BinaryExpr { left: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"a\"), field: Field { name: \"a\", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, fail_on_overflow: false }, op: Or, right: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"1\"), field: Field { name: \"1\", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, fail_on_overflow: false }, fail_on_overflow: false }"; let actual = format!("{execution_plan:?}"); assert!(actual.contains(expected), "{}", actual); @@ -2689,6 +2709,54 @@ mod tests { } } + #[tokio::test] + async fn test_explain_indent_err() { + let planner = DefaultPhysicalPlanner::default(); + let ctx = SessionContext::new(); + let schema = Schema::new(vec![Field::new("id", DataType::Int32, false)]); + let plan = Arc::new( + scan_empty(Some("employee"), &schema, None) + .unwrap() + .explain(true, false) + .unwrap() + .build() + .unwrap(), + ); + + // Create a schema + let schema = Arc::new(Schema::new(vec![ + Field::new("plan_type", DataType::Utf8, false), + Field::new("plan", DataType::Utf8, false), + ])); + + // Create invalid indentation in the plan + let stringified_plans = + vec![StringifiedPlan::new(PlanType::FinalLogicalPlan, "Test Err")]; + + let explain = Explain { + verbose: false, + explain_format: ExplainFormat::Indent, + plan, + stringified_plans, + schema: schema.to_dfschema_ref().unwrap(), + logical_optimization_succeeded: false, + }; + let plan = planner + .handle_explain(&explain, &ctx.state()) + .await + .unwrap(); + if let Some(plan) = plan.as_any().downcast_ref::() { + let stringified_plans = plan.stringified_plans(); + assert_eq!(stringified_plans.len(), 1); + assert_eq!(stringified_plans[0].plan.as_str(), "Test Err"); + } else { + panic!( + "Plan was not an explain plan: {}", + displayable(plan.as_ref()).indent(true) + ); + } + } + #[tokio::test] async fn test_maybe_fix_colon_in_physical_name() { // The physical schema has a field name with a colon @@ -2713,6 +2781,47 @@ mod tests { assert_eq!(col.name(), "metric:avg"); } + + #[tokio::test] + async fn test_maybe_fix_nested_column_name_with_colon() { + let schema = Schema::new(vec![Field::new("column", DataType::Int32, false)]); + let schema_ref: SchemaRef = Arc::new(schema); + + // Construct the nested expr + let col_expr = Arc::new(Column::new("column:1", 0)) as Arc; + let is_not_null_expr = Arc::new(IsNotNullExpr::new(col_expr.clone())); + + // Create a binary expression and put the column inside + let binary_expr = Arc::new(BinaryExpr::new( + is_not_null_expr.clone(), + Operator::Or, + is_not_null_expr.clone(), + )) as Arc; + + let fixed_expr = + maybe_fix_physical_column_name(Ok(binary_expr), &schema_ref).unwrap(); + + let bin = fixed_expr + .as_any() + .downcast_ref::() + .expect("Expected BinaryExpr"); + + // Check that both sides where renamed + for expr in &[bin.left(), bin.right()] { + let is_not_null = expr + .as_any() + .downcast_ref::() + .expect("Expected IsNotNull"); + + let col = is_not_null + .arg() + .as_any() + .downcast_ref::() + .expect("Expected Column"); + + assert_eq!(col.name(), "column"); + } + } struct ErrorExtensionPlanner {} #[async_trait] diff --git a/datafusion/core/src/prelude.rs b/datafusion/core/src/prelude.rs index 9c9fcd04bf09..d723620d3232 100644 --- a/datafusion/core/src/prelude.rs +++ b/datafusion/core/src/prelude.rs @@ -25,6 +25,7 @@ //! use datafusion::prelude::*; //! ``` +pub use crate::dataframe; pub use crate::dataframe::DataFrame; pub use crate::execution::context::{SQLOptions, SessionConfig, SessionContext}; pub use crate::execution::options::{ diff --git a/datafusion/core/src/test/object_store.rs b/datafusion/core/src/test/object_store.rs index 8b19658bb147..ed8474bbfc81 100644 --- a/datafusion/core/src/test/object_store.rs +++ b/datafusion/core/src/test/object_store.rs @@ -148,7 +148,7 @@ impl ObjectStore for BlockingObjectStore { "{} barrier wait timed out for {location}", BlockingObjectStore::NAME ); - log::error!("{}", error_message); + log::error!("{error_message}"); return Err(Error::Generic { store: BlockingObjectStore::NAME, source: error_message.into(), diff --git a/datafusion/core/src/test_util/mod.rs b/datafusion/core/src/test_util/mod.rs index d6865ca3d532..2f8e66a2bbfb 100644 --- a/datafusion/core/src/test_util/mod.rs +++ b/datafusion/core/src/test_util/mod.rs @@ -22,12 +22,14 @@ pub mod parquet; pub mod csv; +use futures::Stream; use std::any::Any; use std::collections::HashMap; use std::fs::File; use std::io::Write; use std::path::Path; use std::sync::Arc; +use std::task::{Context, Poll}; use crate::catalog::{TableProvider, TableProviderFactory}; use crate::dataframe::DataFrame; @@ -38,11 +40,13 @@ use crate::logical_expr::{LogicalPlanBuilder, UNNAMED_TABLE}; use crate::physical_plan::ExecutionPlan; use crate::prelude::{CsvReadOptions, SessionContext}; +use crate::execution::SendableRecordBatchStream; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use datafusion_catalog::Session; use datafusion_common::TableReference; use datafusion_expr::{CreateExternalTable, Expr, SortExpr, TableType}; +use std::pin::Pin; use async_trait::async_trait; @@ -52,6 +56,8 @@ use tempfile::TempDir; pub use datafusion_common::test_util::parquet_test_data; pub use datafusion_common::test_util::{arrow_test_data, get_data_dir}; +use crate::execution::RecordBatchStream; + /// Scan an empty data source, mainly used in tests pub fn scan_empty( name: Option<&str>, @@ -234,3 +240,44 @@ pub fn register_unbounded_file_with_ordering( ctx.register_table(table_name, Arc::new(StreamTable::new(Arc::new(config))))?; Ok(()) } + +/// Creates a bounded stream that emits the same record batch a specified number of times. +/// This is useful for testing purposes. +pub fn bounded_stream( + record_batch: RecordBatch, + limit: usize, +) -> SendableRecordBatchStream { + Box::pin(BoundedStream { + record_batch, + count: 0, + limit, + }) +} + +struct BoundedStream { + record_batch: RecordBatch, + count: usize, + limit: usize, +} + +impl Stream for BoundedStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + if self.count >= self.limit { + Poll::Ready(None) + } else { + self.count += 1; + Poll::Ready(Some(Ok(self.record_batch.clone()))) + } + } +} + +impl RecordBatchStream for BoundedStream { + fn schema(&self) -> SchemaRef { + self.record_batch.schema() + } +} diff --git a/datafusion/core/src/test_util/parquet.rs b/datafusion/core/src/test_util/parquet.rs index f5753af64d93..ddd18e9ae2c2 100644 --- a/datafusion/core/src/test_util/parquet.rs +++ b/datafusion/core/src/test_util/parquet.rs @@ -37,6 +37,7 @@ use crate::physical_plan::metrics::MetricsSet; use crate::physical_plan::ExecutionPlan; use crate::prelude::{Expr, SessionConfig, SessionContext}; +use datafusion_datasource::file::FileSource; use datafusion_datasource::file_scan_config::FileScanConfigBuilder; use datafusion_datasource::source::DataSourceExec; use object_store::path::Path; @@ -82,23 +83,22 @@ impl TestParquetFile { props: WriterProperties, batches: impl IntoIterator, ) -> Result { - let file = File::create(&path).unwrap(); + let file = File::create(&path)?; let mut batches = batches.into_iter(); let first_batch = batches.next().expect("need at least one record batch"); let schema = first_batch.schema(); - let mut writer = - ArrowWriter::try_new(file, Arc::clone(&schema), Some(props)).unwrap(); + let mut writer = ArrowWriter::try_new(file, Arc::clone(&schema), Some(props))?; - writer.write(&first_batch).unwrap(); + writer.write(&first_batch)?; let mut num_rows = first_batch.num_rows(); for batch in batches { - writer.write(&batch).unwrap(); + writer.write(&batch)?; num_rows += batch.num_rows(); } - writer.close().unwrap(); + writer.close()?; println!("Generated test dataset with {num_rows} rows"); @@ -182,10 +182,14 @@ impl TestParquetFile { let physical_filter_expr = create_physical_expr(&filter, &df_schema, &ExecutionProps::default())?; - let source = Arc::new(ParquetSource::new(parquet_options).with_predicate( - Arc::clone(&self.schema), - Arc::clone(&physical_filter_expr), - )); + let source = Arc::new( + ParquetSource::new(parquet_options) + .with_predicate( + Arc::clone(&self.schema), + Arc::clone(&physical_filter_expr) + ), + ) + .with_schema(Arc::clone(&self.schema)); let config = scan_config_builder.with_source(source).build(); let parquet_exec = DataSourceExec::from_data_source(config); diff --git a/datafusion/core/tests/custom_sources_cases/mod.rs b/datafusion/core/tests/custom_sources_cases/mod.rs index eb930b9a60bc..cbdc4a448ea4 100644 --- a/datafusion/core/tests/custom_sources_cases/mod.rs +++ b/datafusion/core/tests/custom_sources_cases/mod.rs @@ -180,6 +180,13 @@ impl ExecutionPlan for CustomExecutionPlan { } fn statistics(&self) -> Result { + self.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + if partition.is_some() { + return Ok(Statistics::new_unknown(&self.schema())); + } let batch = TEST_CUSTOM_RECORD_BATCH!().unwrap(); Ok(Statistics { num_rows: Precision::Exact(batch.num_rows()), diff --git a/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs b/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs index f68bcfaf1550..c80c0b4bf54b 100644 --- a/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs +++ b/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs @@ -179,12 +179,12 @@ impl TableProvider for CustomProvider { match &filters[0] { Expr::BinaryExpr(BinaryExpr { right, .. }) => { let int_value = match &**right { - Expr::Literal(ScalarValue::Int8(Some(i))) => *i as i64, - Expr::Literal(ScalarValue::Int16(Some(i))) => *i as i64, - Expr::Literal(ScalarValue::Int32(Some(i))) => *i as i64, - Expr::Literal(ScalarValue::Int64(Some(i))) => *i, + Expr::Literal(ScalarValue::Int8(Some(i)), _) => *i as i64, + Expr::Literal(ScalarValue::Int16(Some(i)), _) => *i as i64, + Expr::Literal(ScalarValue::Int32(Some(i)), _) => *i as i64, + Expr::Literal(ScalarValue::Int64(Some(i)), _) => *i, Expr::Cast(Cast { expr, data_type: _ }) => match expr.deref() { - Expr::Literal(lit_value) => match lit_value { + Expr::Literal(lit_value, _) => match lit_value { ScalarValue::Int8(Some(v)) => *v as i64, ScalarValue::Int16(Some(v)) => *v as i64, ScalarValue::Int32(Some(v)) => *v as i64, diff --git a/datafusion/core/tests/custom_sources_cases/statistics.rs b/datafusion/core/tests/custom_sources_cases/statistics.rs index 66c886510e96..f9b0db0e808c 100644 --- a/datafusion/core/tests/custom_sources_cases/statistics.rs +++ b/datafusion/core/tests/custom_sources_cases/statistics.rs @@ -184,6 +184,14 @@ impl ExecutionPlan for StatisticsValidation { fn statistics(&self) -> Result { Ok(self.stats.clone()) } + + fn partition_statistics(&self, partition: Option) -> Result { + if partition.is_some() { + Ok(Statistics::new_unknown(&self.schema)) + } else { + Ok(self.stats.clone()) + } + } } fn init_ctx(stats: Statistics, schema: Schema) -> Result { @@ -232,7 +240,7 @@ async fn sql_basic() -> Result<()> { let physical_plan = df.create_physical_plan().await.unwrap(); // the statistics should be those of the source - assert_eq!(stats, physical_plan.statistics()?); + assert_eq!(stats, physical_plan.partition_statistics(None)?); Ok(()) } @@ -248,7 +256,7 @@ async fn sql_filter() -> Result<()> { .unwrap(); let physical_plan = df.create_physical_plan().await.unwrap(); - let stats = physical_plan.statistics()?; + let stats = physical_plan.partition_statistics(None)?; assert_eq!(stats.num_rows, Precision::Inexact(1)); Ok(()) @@ -270,7 +278,7 @@ async fn sql_limit() -> Result<()> { column_statistics: col_stats, total_byte_size: Precision::Absent }, - physical_plan.statistics()? + physical_plan.partition_statistics(None)? ); let df = ctx @@ -279,7 +287,7 @@ async fn sql_limit() -> Result<()> { .unwrap(); let physical_plan = df.create_physical_plan().await.unwrap(); // when the limit is larger than the original number of lines, statistics remain unchanged - assert_eq!(stats, physical_plan.statistics()?); + assert_eq!(stats, physical_plan.partition_statistics(None)?); Ok(()) } @@ -296,7 +304,7 @@ async fn sql_window() -> Result<()> { let physical_plan = df.create_physical_plan().await.unwrap(); - let result = physical_plan.statistics()?; + let result = physical_plan.partition_statistics(None)?; assert_eq!(stats.num_rows, result.num_rows); let col_stats = result.column_statistics; diff --git a/datafusion/core/tests/data/filter_pushdown/single_file.gz.parquet b/datafusion/core/tests/data/filter_pushdown/single_file.gz.parquet new file mode 100644 index 000000000000..ed700576a5af Binary files /dev/null and b/datafusion/core/tests/data/filter_pushdown/single_file.gz.parquet differ diff --git a/datafusion/core/tests/data/filter_pushdown/single_file_small_pages.gz.parquet b/datafusion/core/tests/data/filter_pushdown/single_file_small_pages.gz.parquet new file mode 100644 index 000000000000..29282cfbb622 Binary files /dev/null and b/datafusion/core/tests/data/filter_pushdown/single_file_small_pages.gz.parquet differ diff --git a/datafusion/core/tests/data/int96_nested.parquet b/datafusion/core/tests/data/int96_nested.parquet new file mode 100644 index 000000000000..708823ded6fa Binary files /dev/null and b/datafusion/core/tests/data/int96_nested.parquet differ diff --git a/datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-01/j5fUeSDQo22oPyPU.parquet b/datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-01/j5fUeSDQo22oPyPU.parquet new file mode 100644 index 000000000000..ec164c6df7b5 Binary files /dev/null and b/datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-01/j5fUeSDQo22oPyPU.parquet differ diff --git a/datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-02/j5fUeSDQo22oPyPU.parquet b/datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-02/j5fUeSDQo22oPyPU.parquet new file mode 100644 index 000000000000..4b78cf963c11 Binary files /dev/null and b/datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-02/j5fUeSDQo22oPyPU.parquet differ diff --git a/datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-03/j5fUeSDQo22oPyPU.parquet b/datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-03/j5fUeSDQo22oPyPU.parquet new file mode 100644 index 000000000000..09a01771d503 Binary files /dev/null and b/datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-03/j5fUeSDQo22oPyPU.parquet differ diff --git a/datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-04/j5fUeSDQo22oPyPU.parquet b/datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-04/j5fUeSDQo22oPyPU.parquet new file mode 100644 index 000000000000..6398cc43a2f5 Binary files /dev/null and b/datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-04/j5fUeSDQo22oPyPU.parquet differ diff --git a/datafusion/core/tests/dataframe/dataframe_functions.rs b/datafusion/core/tests/dataframe/dataframe_functions.rs index c763d4c8de2d..40590d74ad91 100644 --- a/datafusion/core/tests/dataframe/dataframe_functions.rs +++ b/datafusion/core/tests/dataframe/dataframe_functions.rs @@ -384,7 +384,7 @@ async fn test_fn_approx_median() -> Result<()> { #[tokio::test] async fn test_fn_approx_percentile_cont() -> Result<()> { - let expr = approx_percentile_cont(col("b"), lit(0.5), None); + let expr = approx_percentile_cont(col("b").sort(true, false), lit(0.5), None); let df = create_test_table().await?; let batches = df.aggregate(vec![], vec![expr]).unwrap().collect().await?; @@ -392,11 +392,26 @@ async fn test_fn_approx_percentile_cont() -> Result<()> { assert_snapshot!( batches_to_string(&batches), @r" - +---------------------------------------------+ - | approx_percentile_cont(test.b,Float64(0.5)) | - +---------------------------------------------+ - | 10 | - +---------------------------------------------+ + +---------------------------------------------------------------------------+ + | approx_percentile_cont(Float64(0.5)) WITHIN GROUP [test.b ASC NULLS LAST] | + +---------------------------------------------------------------------------+ + | 10 | + +---------------------------------------------------------------------------+ + "); + + let expr = approx_percentile_cont(col("b").sort(false, false), lit(0.1), None); + + let df = create_test_table().await?; + let batches = df.aggregate(vec![], vec![expr]).unwrap().collect().await?; + + assert_snapshot!( + batches_to_string(&batches), + @r" + +----------------------------------------------------------------------------+ + | approx_percentile_cont(Float64(0.1)) WITHIN GROUP [test.b DESC NULLS LAST] | + +----------------------------------------------------------------------------+ + | 100 | + +----------------------------------------------------------------------------+ "); // the arg2 parameter is a complex expr, but it can be evaluated to the literal value @@ -405,23 +420,59 @@ async fn test_fn_approx_percentile_cont() -> Result<()> { None::<&str>, "arg_2".to_string(), )); - let expr = approx_percentile_cont(col("b"), alias_expr, None); + let expr = approx_percentile_cont(col("b").sort(true, false), alias_expr, None); let df = create_test_table().await?; let batches = df.aggregate(vec![], vec![expr]).unwrap().collect().await?; assert_snapshot!( batches_to_string(&batches), @r" - +--------------------------------------+ - | approx_percentile_cont(test.b,arg_2) | - +--------------------------------------+ - | 10 | - +--------------------------------------+ + +--------------------------------------------------------------------+ + | approx_percentile_cont(arg_2) WITHIN GROUP [test.b ASC NULLS LAST] | + +--------------------------------------------------------------------+ + | 10 | + +--------------------------------------------------------------------+ + " + ); + + let alias_expr = Expr::Alias(Alias::new( + cast(lit(0.1), DataType::Float32), + None::<&str>, + "arg_2".to_string(), + )); + let expr = approx_percentile_cont(col("b").sort(false, false), alias_expr, None); + let df = create_test_table().await?; + let batches = df.aggregate(vec![], vec![expr]).unwrap().collect().await?; + + assert_snapshot!( + batches_to_string(&batches), + @r" + +---------------------------------------------------------------------+ + | approx_percentile_cont(arg_2) WITHIN GROUP [test.b DESC NULLS LAST] | + +---------------------------------------------------------------------+ + | 100 | + +---------------------------------------------------------------------+ " ); // with number of centroids set - let expr = approx_percentile_cont(col("b"), lit(0.5), Some(lit(2))); + let expr = approx_percentile_cont(col("b").sort(true, false), lit(0.5), Some(lit(2))); + + let df = create_test_table().await?; + let batches = df.aggregate(vec![], vec![expr]).unwrap().collect().await?; + + assert_snapshot!( + batches_to_string(&batches), + @r" + +------------------------------------------------------------------------------------+ + | approx_percentile_cont(Float64(0.5),Int32(2)) WITHIN GROUP [test.b ASC NULLS LAST] | + +------------------------------------------------------------------------------------+ + | 30 | + +------------------------------------------------------------------------------------+ + "); + + let expr = + approx_percentile_cont(col("b").sort(false, false), lit(0.1), Some(lit(2))); let df = create_test_table().await?; let batches = df.aggregate(vec![], vec![expr]).unwrap().collect().await?; @@ -429,11 +480,11 @@ async fn test_fn_approx_percentile_cont() -> Result<()> { assert_snapshot!( batches_to_string(&batches), @r" - +------------------------------------------------------+ - | approx_percentile_cont(test.b,Float64(0.5),Int32(2)) | - +------------------------------------------------------+ - | 30 | - +------------------------------------------------------+ + +-------------------------------------------------------------------------------------+ + | approx_percentile_cont(Float64(0.1),Int32(2)) WITHIN GROUP [test.b DESC NULLS LAST] | + +-------------------------------------------------------------------------------------+ + | 69 | + +-------------------------------------------------------------------------------------+ "); Ok(()) diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 1855a512048d..f91bc6eed23c 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -32,6 +32,7 @@ use arrow::datatypes::{ }; use arrow::error::ArrowError; use arrow::util::pretty::pretty_format_batches; +use datafusion::{assert_batches_eq, dataframe}; use datafusion_functions_aggregate::count::{count_all, count_all_window}; use datafusion_functions_aggregate::expr_fn::{ array_agg, avg, count, count_distinct, max, median, min, sum, @@ -906,7 +907,7 @@ async fn window_using_aggregates() -> Result<()> { vec![col("c3")], ); - Expr::WindowFunction(w) + Expr::from(w) .null_treatment(NullTreatment::IgnoreNulls) .order_by(vec![col("c2").sort(true, true), col("c3").sort(true, true)]) .window_frame(WindowFrame::new_bounds( @@ -1209,7 +1210,7 @@ async fn join_on_filter_datatype() -> Result<()> { let join = left.clone().join_on( right.clone(), JoinType::Inner, - Some(Expr::Literal(ScalarValue::Null)), + Some(Expr::Literal(ScalarValue::Null, None)), )?; assert_snapshot!(join.into_optimized_plan().unwrap(), @"EmptyRelation"); @@ -1852,6 +1853,56 @@ async fn with_column_renamed_case_sensitive() -> Result<()> { Ok(()) } +#[tokio::test] +async fn describe_lookup_via_quoted_identifier() -> Result<()> { + let ctx = SessionContext::new(); + let name = "aggregate_test_100"; + register_aggregate_csv(&ctx, name).await?; + let df = ctx.table(name); + + let df = df + .await? + .filter(col("c2").eq(lit(3)).and(col("c1").eq(lit("a"))))? + .limit(0, Some(1))? + .sort(vec![ + // make the test deterministic + col("c1").sort(true, true), + col("c2").sort(true, true), + col("c3").sort(true, true), + ])? + .select_columns(&["c1"])?; + + let df_renamed = df.clone().with_column_renamed("c1", "CoLu.Mn[\"1\"]")?; + + let describe_result = df_renamed.describe().await?; + describe_result + .clone() + .sort(vec![ + col("describe").sort(true, true), + col("CoLu.Mn[\"1\"]").sort(true, true), + ])? + .show() + .await?; + assert_snapshot!( + batches_to_sort_string(&describe_result.clone().collect().await?), + @r###" + +------------+--------------+ + | describe | CoLu.Mn["1"] | + +------------+--------------+ + | count | 1 | + | max | a | + | mean | null | + | median | null | + | min | a | + | null_count | 0 | + | std | null | + +------------+--------------+ + "### + ); + + Ok(()) +} + #[tokio::test] async fn cast_expr_test() -> Result<()> { let df = test_table() @@ -2454,6 +2505,11 @@ async fn write_table_with_order() -> Result<()> { write_df = write_df .with_column_renamed("column1", "tablecol1") .unwrap(); + + // Ensure the column type matches the target table + write_df = + write_df.with_column("tablecol1", cast(col("tablecol1"), DataType::Utf8View))?; + let sql_str = "create external table data(tablecol1 varchar) stored as parquet location '" .to_owned() @@ -2514,28 +2570,26 @@ async fn test_count_wildcard_on_sort() -> Result<()> { assert_snapshot!( pretty_format_batches(&sql_results).unwrap(), - @r###" - +---------------+------------------------------------------------------------------------------------------------------------+ - | plan_type | plan | - +---------------+------------------------------------------------------------------------------------------------------------+ - | logical_plan | Projection: t1.b, count(*) | - | | Sort: count(Int64(1)) AS count(*) AS count(*) ASC NULLS LAST | - | | Projection: t1.b, count(Int64(1)) AS count(*), count(Int64(1)) | - | | Aggregate: groupBy=[[t1.b]], aggr=[[count(Int64(1))]] | - | | TableScan: t1 projection=[b] | - | physical_plan | ProjectionExec: expr=[b@0 as b, count(*)@1 as count(*)] | - | | SortPreservingMergeExec: [count(Int64(1))@2 ASC NULLS LAST] | - | | SortExec: expr=[count(Int64(1))@2 ASC NULLS LAST], preserve_partitioning=[true] | - | | ProjectionExec: expr=[b@0 as b, count(Int64(1))@1 as count(*), count(Int64(1))@1 as count(Int64(1))] | - | | AggregateExec: mode=FinalPartitioned, gby=[b@0 as b], aggr=[count(Int64(1))] | - | | CoalesceBatchesExec: target_batch_size=8192 | - | | RepartitionExec: partitioning=Hash([b@0], 4), input_partitions=4 | - | | RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 | - | | AggregateExec: mode=Partial, gby=[b@0 as b], aggr=[count(Int64(1))] | - | | DataSourceExec: partitions=1, partition_sizes=[1] | - | | | - +---------------+------------------------------------------------------------------------------------------------------------+ - "### + @r" + +---------------+------------------------------------------------------------------------------------+ + | plan_type | plan | + +---------------+------------------------------------------------------------------------------------+ + | logical_plan | Sort: count(*) ASC NULLS LAST | + | | Projection: t1.b, count(Int64(1)) AS count(*) | + | | Aggregate: groupBy=[[t1.b]], aggr=[[count(Int64(1))]] | + | | TableScan: t1 projection=[b] | + | physical_plan | SortPreservingMergeExec: [count(*)@1 ASC NULLS LAST] | + | | SortExec: expr=[count(*)@1 ASC NULLS LAST], preserve_partitioning=[true] | + | | ProjectionExec: expr=[b@0 as b, count(Int64(1))@1 as count(*)] | + | | AggregateExec: mode=FinalPartitioned, gby=[b@0 as b], aggr=[count(Int64(1))] | + | | CoalesceBatchesExec: target_batch_size=8192 | + | | RepartitionExec: partitioning=Hash([b@0], 4), input_partitions=4 | + | | RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 | + | | AggregateExec: mode=Partial, gby=[b@0 as b], aggr=[count(Int64(1))] | + | | DataSourceExec: partitions=1, partition_sizes=[1] | + | | | + +---------------+------------------------------------------------------------------------------------+ + " ); assert_snapshot!( @@ -3570,16 +3624,15 @@ async fn unnest_columns() -> Result<()> { assert_snapshot!( batches_to_sort_string(&results), @r###" - +----------+------------------------------------------------+--------------------+ - | shape_id | points | tags | - +----------+------------------------------------------------+--------------------+ - | 1 | [{x: -3, y: -4}, {x: -3, y: 6}, {x: 2, y: -2}] | [tag1] | - | 2 | | [tag1, tag2] | - | 3 | [{x: -9, y: 2}, {x: -10, y: -4}] | | - | 4 | [{x: -3, y: 5}, {x: 2, y: -1}] | [tag1, tag2, tag3] | - +----------+------------------------------------------------+--------------------+ - "### - ); + +----------+---------------------------------+--------------------------+ + | shape_id | points | tags | + +----------+---------------------------------+--------------------------+ + | 1 | [{x: 5, y: -8}, {x: -3, y: -4}] | [tag1] | + | 2 | [{x: 6, y: 2}, {x: -2, y: -8}] | [tag1] | + | 3 | [{x: -9, y: -7}, {x: -2, y: 5}] | [tag1, tag2, tag3, tag4] | + | 4 | | [tag1, tag2, tag3] | + +----------+---------------------------------+--------------------------+ + "###); // Unnest tags let df = table_with_nested_types(NUM_ROWS).await?; @@ -3587,19 +3640,20 @@ async fn unnest_columns() -> Result<()> { assert_snapshot!( batches_to_sort_string(&results), @r###" - +----------+------------------------------------------------+------+ - | shape_id | points | tags | - +----------+------------------------------------------------+------+ - | 1 | [{x: -3, y: -4}, {x: -3, y: 6}, {x: 2, y: -2}] | tag1 | - | 2 | | tag1 | - | 2 | | tag2 | - | 3 | [{x: -9, y: 2}, {x: -10, y: -4}] | | - | 4 | [{x: -3, y: 5}, {x: 2, y: -1}] | tag1 | - | 4 | [{x: -3, y: 5}, {x: 2, y: -1}] | tag2 | - | 4 | [{x: -3, y: 5}, {x: 2, y: -1}] | tag3 | - +----------+------------------------------------------------+------+ - "### - ); + +----------+---------------------------------+------+ + | shape_id | points | tags | + +----------+---------------------------------+------+ + | 1 | [{x: 5, y: -8}, {x: -3, y: -4}] | tag1 | + | 2 | [{x: 6, y: 2}, {x: -2, y: -8}] | tag1 | + | 3 | [{x: -9, y: -7}, {x: -2, y: 5}] | tag1 | + | 3 | [{x: -9, y: -7}, {x: -2, y: 5}] | tag2 | + | 3 | [{x: -9, y: -7}, {x: -2, y: 5}] | tag3 | + | 3 | [{x: -9, y: -7}, {x: -2, y: 5}] | tag4 | + | 4 | | tag1 | + | 4 | | tag2 | + | 4 | | tag3 | + +----------+---------------------------------+------+ + "###); // Test aggregate results for tags. let df = table_with_nested_types(NUM_ROWS).await?; @@ -3612,20 +3666,18 @@ async fn unnest_columns() -> Result<()> { assert_snapshot!( batches_to_sort_string(&results), @r###" - +----------+-----------------+--------------------+ - | shape_id | points | tags | - +----------+-----------------+--------------------+ - | 1 | {x: -3, y: -4} | [tag1] | - | 1 | {x: -3, y: 6} | [tag1] | - | 1 | {x: 2, y: -2} | [tag1] | - | 2 | | [tag1, tag2] | - | 3 | {x: -10, y: -4} | | - | 3 | {x: -9, y: 2} | | - | 4 | {x: -3, y: 5} | [tag1, tag2, tag3] | - | 4 | {x: 2, y: -1} | [tag1, tag2, tag3] | - +----------+-----------------+--------------------+ - "### - ); + +----------+----------------+--------------------------+ + | shape_id | points | tags | + +----------+----------------+--------------------------+ + | 1 | {x: -3, y: -4} | [tag1] | + | 1 | {x: 5, y: -8} | [tag1] | + | 2 | {x: -2, y: -8} | [tag1] | + | 2 | {x: 6, y: 2} | [tag1] | + | 3 | {x: -2, y: 5} | [tag1, tag2, tag3, tag4] | + | 3 | {x: -9, y: -7} | [tag1, tag2, tag3, tag4] | + | 4 | | [tag1, tag2, tag3] | + +----------+----------------+--------------------------+ + "###); // Test aggregate results for points. let df = table_with_nested_types(NUM_ROWS).await?; @@ -3642,25 +3694,26 @@ async fn unnest_columns() -> Result<()> { assert_snapshot!( batches_to_sort_string(&results), @r###" - +----------+-----------------+------+ - | shape_id | points | tags | - +----------+-----------------+------+ - | 1 | {x: -3, y: -4} | tag1 | - | 1 | {x: -3, y: 6} | tag1 | - | 1 | {x: 2, y: -2} | tag1 | - | 2 | | tag1 | - | 2 | | tag2 | - | 3 | {x: -10, y: -4} | | - | 3 | {x: -9, y: 2} | | - | 4 | {x: -3, y: 5} | tag1 | - | 4 | {x: -3, y: 5} | tag2 | - | 4 | {x: -3, y: 5} | tag3 | - | 4 | {x: 2, y: -1} | tag1 | - | 4 | {x: 2, y: -1} | tag2 | - | 4 | {x: 2, y: -1} | tag3 | - +----------+-----------------+------+ - "### - ); + +----------+----------------+------+ + | shape_id | points | tags | + +----------+----------------+------+ + | 1 | {x: -3, y: -4} | tag1 | + | 1 | {x: 5, y: -8} | tag1 | + | 2 | {x: -2, y: -8} | tag1 | + | 2 | {x: 6, y: 2} | tag1 | + | 3 | {x: -2, y: 5} | tag1 | + | 3 | {x: -2, y: 5} | tag2 | + | 3 | {x: -2, y: 5} | tag3 | + | 3 | {x: -2, y: 5} | tag4 | + | 3 | {x: -9, y: -7} | tag1 | + | 3 | {x: -9, y: -7} | tag2 | + | 3 | {x: -9, y: -7} | tag3 | + | 3 | {x: -9, y: -7} | tag4 | + | 4 | | tag1 | + | 4 | | tag2 | + | 4 | | tag3 | + +----------+----------------+------+ + "###); // Test aggregate results for points and tags. let df = table_with_nested_types(NUM_ROWS).await?; @@ -3994,15 +4047,15 @@ async fn unnest_aggregate_columns() -> Result<()> { assert_snapshot!( batches_to_sort_string(&results), @r###" - +--------------------+ - | tags | - +--------------------+ - | | - | [tag1, tag2, tag3] | - | [tag1, tag2, tag3] | - | [tag1, tag2] | - | [tag1] | - +--------------------+ + +--------------------------+ + | tags | + +--------------------------+ + | [tag1, tag2, tag3, tag4] | + | [tag1, tag2, tag3] | + | [tag1, tag2] | + | [tag1] | + | [tag1] | + +--------------------------+ "### ); @@ -4018,7 +4071,7 @@ async fn unnest_aggregate_columns() -> Result<()> { +-------------+ | count(tags) | +-------------+ - | 9 | + | 11 | +-------------+ "### ); @@ -4267,7 +4320,7 @@ async fn unnest_analyze_metrics() -> Result<()> { assert_contains!(&formatted, "elapsed_compute="); assert_contains!(&formatted, "input_batches=1"); assert_contains!(&formatted, "input_rows=5"); - assert_contains!(&formatted, "output_rows=10"); + assert_contains!(&formatted, "output_rows=11"); assert_contains!(&formatted, "output_batches=1"); Ok(()) @@ -4472,7 +4525,10 @@ async fn consecutive_projection_same_schema() -> Result<()> { // Add `t` column full of nulls let df = df - .with_column("t", cast(Expr::Literal(ScalarValue::Null), DataType::Int32)) + .with_column( + "t", + cast(Expr::Literal(ScalarValue::Null, None), DataType::Int32), + ) .unwrap(); df.clone().show().await.unwrap(); @@ -4614,7 +4670,7 @@ async fn table_with_nested_types(n: usize) -> Result { shape_id_builder.append_value(idx as u32 + 1); // Add a random number of points - let num_points: usize = rng.gen_range(0..4); + let num_points: usize = rng.random_range(0..4); if num_points > 0 { for _ in 0..num_points.max(2) { // Add x value @@ -4622,13 +4678,13 @@ async fn table_with_nested_types(n: usize) -> Result { .values() .field_builder::(0) .unwrap() - .append_value(rng.gen_range(-10..10)); + .append_value(rng.random_range(-10..10)); // Add y value points_builder .values() .field_builder::(1) .unwrap() - .append_value(rng.gen_range(-10..10)); + .append_value(rng.random_range(-10..10)); points_builder.values().append(true); } } @@ -4637,7 +4693,7 @@ async fn table_with_nested_types(n: usize) -> Result { points_builder.append(num_points > 0); // Append tags. - let num_tags: usize = rng.gen_range(0..5); + let num_tags: usize = rng.random_range(0..5); for id in 0..num_tags { tags_builder.values().append_value(format!("tag{}", id + 1)); } @@ -5079,7 +5135,7 @@ async fn write_partitioned_parquet_results() -> Result<()> { .await?; // Explicitly read the parquet file at c2=123 to verify the physical files are partitioned - let partitioned_file = format!("{out_dir}/c2=123", out_dir = out_dir); + let partitioned_file = format!("{out_dir}/c2=123"); let filter_df = ctx .read_parquet(&partitioned_file, ParquetReadOptions::default()) .await?; @@ -6017,3 +6073,62 @@ async fn test_insert_into_casting_support() -> Result<()> { ); Ok(()) } + +#[tokio::test] +async fn test_dataframe_from_columns() -> Result<()> { + let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3])); + let b: ArrayRef = Arc::new(BooleanArray::from(vec![true, true, false])); + let c: ArrayRef = Arc::new(StringArray::from(vec![Some("foo"), Some("bar"), None])); + let df = DataFrame::from_columns(vec![("a", a), ("b", b), ("c", c)])?; + + assert_eq!(df.schema().fields().len(), 3); + assert_eq!(df.clone().count().await?, 3); + + let rows = df.sort(vec![col("a").sort(true, true)])?; + assert_batches_eq!( + &[ + "+---+-------+-----+", + "| a | b | c |", + "+---+-------+-----+", + "| 1 | true | foo |", + "| 2 | true | bar |", + "| 3 | false | |", + "+---+-------+-----+", + ], + &rows.collect().await? + ); + + Ok(()) +} + +#[tokio::test] +async fn test_dataframe_macro() -> Result<()> { + let df = dataframe!( + "a" => [1, 2, 3], + "b" => [true, true, false], + "c" => [Some("foo"), Some("bar"), None] + )?; + + assert_eq!(df.schema().fields().len(), 3); + assert_eq!(df.clone().count().await?, 3); + + let rows = df.sort(vec![col("a").sort(true, true)])?; + assert_batches_eq!( + &[ + "+---+-------+-----+", + "| a | b | c |", + "+---+-------+-----+", + "| 1 | true | foo |", + "| 2 | true | bar |", + "| 3 | false | |", + "+---+-------+-----+", + ], + &rows.collect().await? + ); + + let df_empty = dataframe!()?; + assert_eq!(df_empty.schema().fields().len(), 0); + assert_eq!(df_empty.count().await?, 0); + + Ok(()) +} diff --git a/datafusion/core/tests/execution/logical_plan.rs b/datafusion/core/tests/execution/logical_plan.rs index b30636ddf6a8..f5a8a30e0130 100644 --- a/datafusion/core/tests/execution/logical_plan.rs +++ b/datafusion/core/tests/execution/logical_plan.rs @@ -19,15 +19,19 @@ //! create them and depend on them. Test executable semantics of logical plans. use arrow::array::Int64Array; -use arrow::datatypes::{DataType, Field}; +use arrow::datatypes::{DataType, Field, Schema}; +use datafusion::datasource::{provider_as_source, ViewTable}; use datafusion::execution::session_state::SessionStateBuilder; -use datafusion_common::{Column, DFSchema, Result, ScalarValue, Spans}; +use datafusion_common::{Column, DFSchema, DFSchemaRef, Result, ScalarValue, Spans}; use datafusion_execution::TaskContext; use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams}; use datafusion_expr::logical_plan::{LogicalPlan, Values}; -use datafusion_expr::{Aggregate, AggregateUDF, Expr}; +use datafusion_expr::{ + Aggregate, AggregateUDF, EmptyRelation, Expr, LogicalPlanBuilder, UNNAMED_TABLE, +}; use datafusion_functions_aggregate::count::Count; use datafusion_physical_plan::collect; +use insta::assert_snapshot; use std::collections::HashMap; use std::fmt::Debug; use std::ops::Deref; @@ -43,9 +47,9 @@ async fn count_only_nulls() -> Result<()> { let input = Arc::new(LogicalPlan::Values(Values { schema: input_schema, values: vec![ - vec![Expr::Literal(ScalarValue::Null)], - vec![Expr::Literal(ScalarValue::Null)], - vec![Expr::Literal(ScalarValue::Null)], + vec![Expr::Literal(ScalarValue::Null, None)], + vec![Expr::Literal(ScalarValue::Null, None)], + vec![Expr::Literal(ScalarValue::Null, None)], ], })); let input_col_ref = Expr::Column(Column { @@ -92,7 +96,41 @@ where T: Debug, { let [element] = elements else { - panic!("Expected exactly one element, got {:?}", elements); + panic!("Expected exactly one element, got {elements:?}"); }; element } + +#[test] +fn inline_scan_projection_test() -> Result<()> { + let name = UNNAMED_TABLE; + let column = "a"; + + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ]); + let projection = vec![schema.index_of(column)?]; + + let provider = ViewTable::new( + LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: DFSchemaRef::new(DFSchema::try_from(schema)?), + }), + None, + ); + let source = provider_as_source(Arc::new(provider)); + + let plan = LogicalPlanBuilder::scan(name, source, Some(projection))?.build()?; + + assert_snapshot!( + format!("{plan}"), + @r" + SubqueryAlias: ?table? + Projection: a + EmptyRelation + " + ); + + Ok(()) +} diff --git a/datafusion/core/tests/expr_api/mod.rs b/datafusion/core/tests/expr_api/mod.rs index aef10379da07..a9cf7f04bb3a 100644 --- a/datafusion/core/tests/expr_api/mod.rs +++ b/datafusion/core/tests/expr_api/mod.rs @@ -358,8 +358,7 @@ async fn evaluate_agg_test(expr: Expr, expected_lines: Vec<&str>) { assert_eq!( expected_lines, actual_lines, - "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", - expected_lines, actual_lines + "\n\nexpected:\n\n{expected_lines:#?}\nactual:\n\n{actual_lines:#?}\n\n" ); } @@ -379,8 +378,7 @@ fn evaluate_expr_test(expr: Expr, expected_lines: Vec<&str>) { assert_eq!( expected_lines, actual_lines, - "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", - expected_lines, actual_lines + "\n\nexpected:\n\n{expected_lines:#?}\nactual:\n\n{actual_lines:#?}\n\n" ); } diff --git a/datafusion/core/tests/expr_api/simplification.rs b/datafusion/core/tests/expr_api/simplification.rs index 7bb21725ef40..91a507bdf7f0 100644 --- a/datafusion/core/tests/expr_api/simplification.rs +++ b/datafusion/core/tests/expr_api/simplification.rs @@ -282,10 +282,13 @@ fn select_date_plus_interval() -> Result<()> { let date_plus_interval_expr = to_timestamp_expr(ts_string) .cast_to(&DataType::Date32, schema)? - + Expr::Literal(ScalarValue::IntervalDayTime(Some(IntervalDayTime { - days: 123, - milliseconds: 0, - }))); + + Expr::Literal( + ScalarValue::IntervalDayTime(Some(IntervalDayTime { + days: 123, + milliseconds: 0, + })), + None, + ); let plan = LogicalPlanBuilder::from(table_scan.clone()) .project(vec![date_plus_interval_expr])? @@ -547,9 +550,9 @@ fn test_simplify_with_cycle_count( }; let simplifier = ExprSimplifier::new(info); let (simplified_expr, count) = simplifier - .simplify_with_cycle_count(input_expr.clone()) + .simplify_with_cycle_count_transformed(input_expr.clone()) .expect("successfully evaluated"); - + let simplified_expr = simplified_expr.data; assert_eq!( simplified_expr, expected_expr, "Mismatch evaluating {input_expr}\n Expected:{expected_expr}\n Got:{simplified_expr}" diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index ff3b66986ced..940b85bb96cf 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -17,8 +17,9 @@ use std::sync::Arc; +use crate::fuzz_cases::aggregation_fuzzer::query_builder::QueryBuilder; use crate::fuzz_cases::aggregation_fuzzer::{ - AggregationFuzzerBuilder, DatasetGeneratorConfig, QueryBuilder, + AggregationFuzzerBuilder, DatasetGeneratorConfig, }; use arrow::array::{ @@ -54,7 +55,7 @@ use datafusion_execution::runtime_env::RuntimeEnvBuilder; use datafusion_execution::TaskContext; use datafusion_physical_plan::metrics::MetricValue; use rand::rngs::StdRng; -use rand::{random, thread_rng, Rng, SeedableRng}; +use rand::{random, rng, Rng, SeedableRng}; use super::record_batch_generator::get_supported_types_columns; @@ -85,6 +86,7 @@ async fn test_min() { .with_aggregate_function("min") // min works on all column types .with_aggregate_arguments(data_gen_config.all_columns()) + .with_dataset_sort_keys(data_gen_config.sort_keys_set.clone()) .set_group_by_columns(data_gen_config.all_columns()); AggregationFuzzerBuilder::from(data_gen_config) @@ -111,6 +113,7 @@ async fn test_first_val() { .with_table_name("fuzz_table") .with_aggregate_function("first_value") .with_aggregate_arguments(data_gen_config.all_columns()) + .with_dataset_sort_keys(data_gen_config.sort_keys_set.clone()) .set_group_by_columns(data_gen_config.all_columns()); AggregationFuzzerBuilder::from(data_gen_config) @@ -137,6 +140,7 @@ async fn test_last_val() { .with_table_name("fuzz_table") .with_aggregate_function("last_value") .with_aggregate_arguments(data_gen_config.all_columns()) + .with_dataset_sort_keys(data_gen_config.sort_keys_set.clone()) .set_group_by_columns(data_gen_config.all_columns()); AggregationFuzzerBuilder::from(data_gen_config) @@ -156,6 +160,7 @@ async fn test_max() { .with_aggregate_function("max") // max works on all column types .with_aggregate_arguments(data_gen_config.all_columns()) + .with_dataset_sort_keys(data_gen_config.sort_keys_set.clone()) .set_group_by_columns(data_gen_config.all_columns()); AggregationFuzzerBuilder::from(data_gen_config) @@ -176,6 +181,7 @@ async fn test_sum() { .with_distinct_aggregate_function("sum") // sum only works on numeric columns .with_aggregate_arguments(data_gen_config.numeric_columns()) + .with_dataset_sort_keys(data_gen_config.sort_keys_set.clone()) .set_group_by_columns(data_gen_config.all_columns()); AggregationFuzzerBuilder::from(data_gen_config) @@ -196,6 +202,7 @@ async fn test_count() { .with_distinct_aggregate_function("count") // count work for all arguments .with_aggregate_arguments(data_gen_config.all_columns()) + .with_dataset_sort_keys(data_gen_config.sort_keys_set.clone()) .set_group_by_columns(data_gen_config.all_columns()); AggregationFuzzerBuilder::from(data_gen_config) @@ -216,6 +223,7 @@ async fn test_median() { .with_distinct_aggregate_function("median") // median only works on numeric columns .with_aggregate_arguments(data_gen_config.numeric_columns()) + .with_dataset_sort_keys(data_gen_config.sort_keys_set.clone()) .set_group_by_columns(data_gen_config.all_columns()); AggregationFuzzerBuilder::from(data_gen_config) @@ -233,8 +241,8 @@ async fn test_median() { /// 1. Floating point numbers /// 1. structured types fn baseline_config() -> DatasetGeneratorConfig { - let mut rng = thread_rng(); - let columns = get_supported_types_columns(rng.gen()); + let mut rng = rng(); + let columns = get_supported_types_columns(rng.random()); let min_num_rows = 512; let max_num_rows = 1024; @@ -246,6 +254,12 @@ fn baseline_config() -> DatasetGeneratorConfig { // low cardinality to try and get many repeated runs vec![String::from("u8_low")], vec![String::from("utf8_low"), String::from("u8_low")], + vec![String::from("dictionary_utf8_low")], + vec![ + String::from("dictionary_utf8_low"), + String::from("utf8_low"), + String::from("u8_low"), + ], ], } } @@ -423,13 +437,13 @@ pub(crate) fn make_staggered_batches( let mut input4: Vec = vec![0; len]; input123.iter_mut().for_each(|v| { *v = ( - rng.gen_range(0..n_distinct) as i64, - rng.gen_range(0..n_distinct) as i64, - rng.gen_range(0..n_distinct) as i64, + rng.random_range(0..n_distinct) as i64, + rng.random_range(0..n_distinct) as i64, + rng.random_range(0..n_distinct) as i64, ) }); input4.iter_mut().for_each(|v| { - *v = rng.gen_range(0..n_distinct) as i64; + *v = rng.random_range(0..n_distinct) as i64; }); input123.sort(); let input1 = Int64Array::from_iter_values(input123.clone().into_iter().map(|k| k.0)); @@ -449,7 +463,7 @@ pub(crate) fn make_staggered_batches( let mut batches = vec![]; if STREAM { while remainder.num_rows() > 0 { - let batch_size = rng.gen_range(0..50); + let batch_size = rng.random_range(0..50); if remainder.num_rows() < batch_size { break; } @@ -458,7 +472,7 @@ pub(crate) fn make_staggered_batches( } } else { while remainder.num_rows() > 0 { - let batch_size = rng.gen_range(0..remainder.num_rows() + 1); + let batch_size = rng.random_range(0..remainder.num_rows() + 1); batches.push(remainder.slice(0, batch_size)); remainder = remainder.slice(batch_size, remainder.num_rows() - batch_size); } @@ -504,7 +518,9 @@ async fn group_by_string_test( let expected = compute_counts(&input, column_name); let schema = input[0].schema(); - let session_config = SessionConfig::new().with_batch_size(50); + let session_config = SessionConfig::new() + .with_batch_size(50) + .with_repartition_file_scans(false); let ctx = SessionContext::new_with_config(session_config); let provider = MemTable::try_new(schema.clone(), vec![input]).unwrap(); @@ -682,8 +698,8 @@ async fn test_single_mode_aggregate_with_spill() -> Result<()> { Arc::new(StringArray::from( (0..1024) .map(|_| -> String { - thread_rng() - .sample_iter::(rand::distributions::Standard) + rng() + .sample_iter::(rand::distr::StandardUniform) .take(5) .collect() }) diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/context_generator.rs b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/context_generator.rs index 3c9fe2917251..2abfcd8417cb 100644 --- a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/context_generator.rs +++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/context_generator.rs @@ -25,7 +25,7 @@ use datafusion_catalog::TableProvider; use datafusion_common::ScalarValue; use datafusion_common::{error::Result, utils::get_available_parallelism}; use datafusion_expr::col; -use rand::{thread_rng, Rng}; +use rand::{rng, Rng}; use crate::fuzz_cases::aggregation_fuzzer::data_generator::Dataset; @@ -112,7 +112,7 @@ impl SessionContextGenerator { /// Randomly generate session context pub fn generate(&self) -> Result { - let mut rng = thread_rng(); + let mut rng = rng(); let schema = self.dataset.batches[0].schema(); let batches = self.dataset.batches.clone(); let provider = MemTable::try_new(schema, vec![batches])?; @@ -123,17 +123,17 @@ impl SessionContextGenerator { // - `skip_partial`, trigger or not trigger currently for simplicity // - `sorted`, if found a sorted dataset, will or will not push down this information // - `spilling`(TODO) - let batch_size = rng.gen_range(1..=self.max_batch_size); + let batch_size = rng.random_range(1..=self.max_batch_size); - let target_partitions = rng.gen_range(1..=self.max_target_partitions); + let target_partitions = rng.random_range(1..=self.max_target_partitions); let skip_partial_params_idx = - rng.gen_range(0..self.candidate_skip_partial_params.len()); + rng.random_range(0..self.candidate_skip_partial_params.len()); let skip_partial_params = self.candidate_skip_partial_params[skip_partial_params_idx]; let (provider, sort_hint) = - if rng.gen_bool(0.5) && !self.dataset.sort_keys.is_empty() { + if rng.random_bool(0.5) && !self.dataset.sort_keys.is_empty() { // Sort keys exist and random to push down let sort_exprs = self .dataset diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs index 53e9288ab4af..cfb3c1c6a1b9 100644 --- a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs +++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs @@ -16,15 +16,14 @@ // under the License. use std::sync::Arc; -use std::{collections::HashSet, str::FromStr}; use arrow::array::RecordBatch; use arrow::util::pretty::pretty_format_batches; use datafusion_common::{DataFusionError, Result}; use datafusion_common_runtime::JoinSet; -use rand::seq::SliceRandom; -use rand::{thread_rng, Rng}; +use rand::{rng, Rng}; +use crate::fuzz_cases::aggregation_fuzzer::query_builder::QueryBuilder; use crate::fuzz_cases::aggregation_fuzzer::{ check_equality_of_batches, context_generator::{SessionContextGenerator, SessionContextWithParams}, @@ -69,30 +68,16 @@ impl AggregationFuzzerBuilder { /// - 3 random queries /// - 3 random queries for each group by selected from the sort keys /// - 1 random query with no grouping - pub fn add_query_builder(mut self, mut query_builder: QueryBuilder) -> Self { - const NUM_QUERIES: usize = 3; - for _ in 0..NUM_QUERIES { - let sql = query_builder.generate_query(); - self.candidate_sqls.push(Arc::from(sql)); - } - // also add several queries limited to grouping on the group by columns only, if any - // So if the data is sorted on `a,b` only group by `a,b` or`a` or `b` - if let Some(data_gen_config) = &self.data_gen_config { - for sort_keys in &data_gen_config.sort_keys_set { - let group_by_columns = sort_keys.iter().map(|s| s.as_str()); - query_builder = query_builder.set_group_by_columns(group_by_columns); - for _ in 0..NUM_QUERIES { - let sql = query_builder.generate_query(); - self.candidate_sqls.push(Arc::from(sql)); - } - } - } - // also add a query with no grouping - query_builder = query_builder.set_group_by_columns(vec![]); - let sql = query_builder.generate_query(); - self.candidate_sqls.push(Arc::from(sql)); + pub fn add_query_builder(mut self, query_builder: QueryBuilder) -> Self { + self = self.table_name(query_builder.table_name()); + + let sqls = query_builder + .generate_queries() + .into_iter() + .map(|sql| Arc::from(sql.as_str())); + self.candidate_sqls.extend(sqls); - self.table_name(query_builder.table_name()) + self } pub fn table_name(mut self, table_name: &str) -> Self { @@ -178,7 +163,7 @@ impl AggregationFuzzer { async fn run_inner(&mut self) -> Result<()> { let mut join_set = JoinSet::new(); - let mut rng = thread_rng(); + let mut rng = rng(); // Loop to generate datasets and its query for _ in 0..self.data_gen_rounds { @@ -192,7 +177,7 @@ impl AggregationFuzzer { let query_groups = datasets .into_iter() .map(|dataset| { - let sql_idx = rng.gen_range(0..self.candidate_sqls.len()); + let sql_idx = rng.random_range(0..self.candidate_sqls.len()); let sql = self.candidate_sqls[sql_idx].clone(); QueryGroup { dataset, sql } @@ -212,10 +197,7 @@ impl AggregationFuzzer { while let Some(join_handle) = join_set.join_next().await { // propagate errors join_handle.map_err(|e| { - DataFusionError::Internal(format!( - "AggregationFuzzer task error: {:?}", - e - )) + DataFusionError::Internal(format!("AggregationFuzzer task error: {e:?}")) })??; } Ok(()) @@ -371,217 +353,3 @@ fn format_batches_with_limit(batches: &[RecordBatch]) -> impl std::fmt::Display pretty_format_batches(&to_print).unwrap() } - -/// Random aggregate query builder -/// -/// Creates queries like -/// ```sql -/// SELECT AGG(..) FROM table_name GROUP BY -///``` -#[derive(Debug, Default, Clone)] -pub struct QueryBuilder { - /// The name of the table to query - table_name: String, - /// Aggregate functions to be used in the query - /// (function_name, is_distinct) - aggregate_functions: Vec<(String, bool)>, - /// Columns to be used in group by - group_by_columns: Vec, - /// Possible columns for arguments in the aggregate functions - /// - /// Assumes each - arguments: Vec, -} -impl QueryBuilder { - pub fn new() -> Self { - Default::default() - } - - /// return the table name if any - pub fn table_name(&self) -> &str { - &self.table_name - } - - /// Set the table name for the query builder - pub fn with_table_name(mut self, table_name: impl Into) -> Self { - self.table_name = table_name.into(); - self - } - - /// Add a new possible aggregate function to the query builder - pub fn with_aggregate_function( - mut self, - aggregate_function: impl Into, - ) -> Self { - self.aggregate_functions - .push((aggregate_function.into(), false)); - self - } - - /// Add a new possible `DISTINCT` aggregate function to the query - /// - /// This is different than `with_aggregate_function` because only certain - /// aggregates support `DISTINCT` - pub fn with_distinct_aggregate_function( - mut self, - aggregate_function: impl Into, - ) -> Self { - self.aggregate_functions - .push((aggregate_function.into(), true)); - self - } - - /// Set the columns to be used in the group bys clauses - pub fn set_group_by_columns<'a>( - mut self, - group_by: impl IntoIterator, - ) -> Self { - self.group_by_columns = group_by.into_iter().map(String::from).collect(); - self - } - - /// Add one or more columns to be used as an argument in the aggregate functions - pub fn with_aggregate_arguments<'a>( - mut self, - arguments: impl IntoIterator, - ) -> Self { - let arguments = arguments.into_iter().map(String::from); - self.arguments.extend(arguments); - self - } - - pub fn generate_query(&self) -> String { - let group_by = self.random_group_by(); - let mut query = String::from("SELECT "); - query.push_str(&group_by.join(", ")); - if !group_by.is_empty() { - query.push_str(", "); - } - query.push_str(&self.random_aggregate_functions(&group_by).join(", ")); - query.push_str(" FROM "); - query.push_str(&self.table_name); - if !group_by.is_empty() { - query.push_str(" GROUP BY "); - query.push_str(&group_by.join(", ")); - } - query - } - - /// Generate a some random aggregate function invocations (potentially repeating). - /// - /// Each aggregate function invocation is of the form - /// - /// ```sql - /// function_name( argument) as alias - /// ``` - /// - /// where - /// * `function_names` are randomly selected from [`Self::aggregate_functions`] - /// * ` argument` is randomly selected from [`Self::arguments`] - /// * `alias` is a unique alias `colN` for the column (to avoid duplicate column names) - fn random_aggregate_functions(&self, group_by_cols: &[String]) -> Vec { - const MAX_NUM_FUNCTIONS: usize = 5; - let mut rng = thread_rng(); - let num_aggregate_functions = rng.gen_range(1..MAX_NUM_FUNCTIONS); - - let mut alias_gen = 1; - - let mut aggregate_functions = vec![]; - - let mut order_by_black_list: HashSet = - group_by_cols.iter().cloned().collect(); - // remove one random col - if let Some(first) = order_by_black_list.iter().next().cloned() { - order_by_black_list.remove(&first); - } - - while aggregate_functions.len() < num_aggregate_functions { - let idx = rng.gen_range(0..self.aggregate_functions.len()); - let (function_name, is_distinct) = &self.aggregate_functions[idx]; - let argument = self.random_argument(); - let alias = format!("col{}", alias_gen); - let distinct = if *is_distinct { "DISTINCT " } else { "" }; - alias_gen += 1; - - let (order_by, null_opt) = if function_name.eq("first_value") - || function_name.eq("last_value") - { - ( - self.order_by(&order_by_black_list), /* Among the order by columns, at most one group by column can be included to avoid all order by column values being identical */ - self.null_opt(), - ) - } else { - ("".to_string(), "".to_string()) - }; - - let function = format!( - "{function_name}({distinct}{argument}{order_by}) {null_opt} as {alias}" - ); - aggregate_functions.push(function); - } - aggregate_functions - } - - /// Pick a random aggregate function argument - fn random_argument(&self) -> String { - let mut rng = thread_rng(); - let idx = rng.gen_range(0..self.arguments.len()); - self.arguments[idx].clone() - } - - fn order_by(&self, black_list: &HashSet) -> String { - let mut available_columns: Vec = self - .arguments - .iter() - .filter(|col| !black_list.contains(*col)) - .cloned() - .collect(); - - available_columns.shuffle(&mut thread_rng()); - - let num_of_order_by_col = 12; - let column_count = std::cmp::min(num_of_order_by_col, available_columns.len()); - - let selected_columns = &available_columns[0..column_count]; - - let mut rng = thread_rng(); - let mut result = String::from_str(" order by ").unwrap(); - for col in selected_columns { - let order = if rng.gen_bool(0.5) { "ASC" } else { "DESC" }; - result.push_str(&format!("{} {},", col, order)); - } - - result.strip_suffix(",").unwrap().to_string() - } - - fn null_opt(&self) -> String { - if thread_rng().gen_bool(0.5) { - "RESPECT NULLS".to_string() - } else { - "IGNORE NULLS".to_string() - } - } - - /// Pick a random number of fields to group by (non-repeating) - /// - /// Limited to 3 group by columns to ensure coverage for large groups. With - /// larger numbers of columns, each group has many fewer values. - fn random_group_by(&self) -> Vec { - let mut rng = thread_rng(); - const MAX_GROUPS: usize = 3; - let max_groups = self.group_by_columns.len().max(MAX_GROUPS); - let num_group_by = rng.gen_range(1..max_groups); - - let mut already_used = HashSet::new(); - let mut group_by = vec![]; - while group_by.len() < num_group_by - && already_used.len() != self.group_by_columns.len() - { - let idx = rng.gen_range(0..self.group_by_columns.len()); - if already_used.insert(idx) { - group_by.push(self.group_by_columns[idx].clone()); - } - } - group_by - } -} diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/mod.rs b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/mod.rs index bfb3bb096326..04b764e46a96 100644 --- a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/mod.rs +++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/mod.rs @@ -43,6 +43,7 @@ use datafusion_common::error::Result; mod context_generator; mod data_generator; mod fuzzer; +pub mod query_builder; pub use crate::fuzz_cases::record_batch_generator::ColumnDescr; pub use data_generator::DatasetGeneratorConfig; diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/query_builder.rs b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/query_builder.rs new file mode 100644 index 000000000000..209278385b7b --- /dev/null +++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/query_builder.rs @@ -0,0 +1,384 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{collections::HashSet, str::FromStr}; + +use rand::{rng, seq::SliceRandom, Rng}; + +/// Random aggregate query builder +/// +/// Creates queries like +/// ```sql +/// SELECT AGG(..) FROM table_name GROUP BY +///``` +#[derive(Debug, Default, Clone)] +pub struct QueryBuilder { + // =================================== + // Table settings + // =================================== + /// The name of the table to query + table_name: String, + + // =================================== + // Grouping settings + // =================================== + /// Columns to be used in randomly generate `groupings` + /// + /// # Example + /// + /// Columns: + /// + /// ```text + /// [a,b,c,d] + /// ``` + /// + /// And randomly generated `groupings` (at least 1 column) + /// can be: + /// + /// ```text + /// [a] + /// [a,b] + /// [a,b,d] + /// ... + /// ``` + /// + /// So the finally generated sqls will be: + /// + /// ```text + /// SELECT aggr FROM t GROUP BY a; + /// SELECT aggr FROM t GROUP BY a,b; + /// SELECT aggr FROM t GROUP BY a,b,d; + /// ... + /// ``` + group_by_columns: Vec, + + /// Max columns num in randomly generated `groupings` + max_group_by_columns: usize, + + /// Min columns num in randomly generated `groupings` + min_group_by_columns: usize, + + /// The sort keys of dataset + /// + /// Due to optimizations will be triggered when all or some + /// grouping columns are the sort keys of dataset. + /// So it is necessary to randomly generate some `groupings` basing on + /// dataset sort keys for test coverage. + /// + /// # Example + /// + /// Dataset including columns [a,b,c], and sorted by [a,b] + /// + /// And we may generate sqls to try covering the sort-optimization cases like: + /// + /// ```text + /// SELECT aggr FROM t GROUP BY b; // no permutation case + /// SELECT aggr FROM t GROUP BY a,c; // partial permutation case + /// SELECT aggr FROM t GROUP BY a,b,c; // full permutation case + /// ... + /// ``` + /// + /// More details can see [`GroupOrdering`]. + /// + /// [`GroupOrdering`]: datafusion_physical_plan::aggregates::order::GroupOrdering + /// + dataset_sort_keys: Vec>, + + /// If we will also test the no grouping case like: + /// + /// ```text + /// SELECT aggr FROM t; + /// ``` + /// + no_grouping: bool, + + // ==================================== + // Aggregation function settings + // ==================================== + /// Aggregate functions to be used in the query + /// (function_name, is_distinct) + aggregate_functions: Vec<(String, bool)>, + + /// Possible columns for arguments in the aggregate functions + /// + /// Assumes each + arguments: Vec, +} + +impl QueryBuilder { + pub fn new() -> Self { + Self { + no_grouping: true, + max_group_by_columns: 5, + min_group_by_columns: 1, + ..Default::default() + } + } + + /// return the table name if any + pub fn table_name(&self) -> &str { + &self.table_name + } + + /// Set the table name for the query builder + pub fn with_table_name(mut self, table_name: impl Into) -> Self { + self.table_name = table_name.into(); + self + } + + /// Add a new possible aggregate function to the query builder + pub fn with_aggregate_function( + mut self, + aggregate_function: impl Into, + ) -> Self { + self.aggregate_functions + .push((aggregate_function.into(), false)); + self + } + + /// Add a new possible `DISTINCT` aggregate function to the query + /// + /// This is different than `with_aggregate_function` because only certain + /// aggregates support `DISTINCT` + pub fn with_distinct_aggregate_function( + mut self, + aggregate_function: impl Into, + ) -> Self { + self.aggregate_functions + .push((aggregate_function.into(), true)); + self + } + + /// Set the columns to be used in the group bys clauses + pub fn set_group_by_columns<'a>( + mut self, + group_by: impl IntoIterator, + ) -> Self { + self.group_by_columns = group_by.into_iter().map(String::from).collect(); + self + } + + /// Add one or more columns to be used as an argument in the aggregate functions + pub fn with_aggregate_arguments<'a>( + mut self, + arguments: impl IntoIterator, + ) -> Self { + let arguments = arguments.into_iter().map(String::from); + self.arguments.extend(arguments); + self + } + + /// Add max columns num in group by(default: 3), for example if it is set to 1, + /// the generated sql will group by at most 1 column + #[allow(dead_code)] + pub fn with_max_group_by_columns(mut self, max_group_by_columns: usize) -> Self { + self.max_group_by_columns = max_group_by_columns; + self + } + + #[allow(dead_code)] + pub fn with_min_group_by_columns(mut self, min_group_by_columns: usize) -> Self { + self.min_group_by_columns = min_group_by_columns; + self + } + + /// Add sort keys of dataset if any, then the builder will generate queries basing on it + /// to cover the sort-optimization cases + pub fn with_dataset_sort_keys(mut self, dataset_sort_keys: Vec>) -> Self { + self.dataset_sort_keys = dataset_sort_keys; + self + } + + /// Add if also test the no grouping aggregation case(default: true) + #[allow(dead_code)] + pub fn with_no_grouping(mut self, no_grouping: bool) -> Self { + self.no_grouping = no_grouping; + self + } + + pub fn generate_queries(mut self) -> Vec { + const NUM_QUERIES: usize = 3; + let mut sqls = Vec::new(); + + // Add several queries group on randomly picked columns + for _ in 0..NUM_QUERIES { + let sql = self.generate_query(); + sqls.push(sql); + } + + // Also add several queries limited to grouping on the group by + // dataset sorted columns only, if any. + // So if the data is sorted on `a,b` only group by `a,b` or`a` or `b`. + if !self.dataset_sort_keys.is_empty() { + let dataset_sort_keys = self.dataset_sort_keys.clone(); + for sort_keys in dataset_sort_keys { + let group_by_columns = sort_keys.iter().map(|s| s.as_str()); + self = self.set_group_by_columns(group_by_columns); + for _ in 0..NUM_QUERIES { + let sql = self.generate_query(); + sqls.push(sql); + } + } + } + + // Also add a query with no grouping + if self.no_grouping { + self = self.set_group_by_columns(vec![]); + let sql = self.generate_query(); + sqls.push(sql); + } + + sqls + } + + fn generate_query(&self) -> String { + let group_by = self.random_group_by(); + dbg!(&group_by); + let mut query = String::from("SELECT "); + query.push_str(&group_by.join(", ")); + if !group_by.is_empty() { + query.push_str(", "); + } + query.push_str(&self.random_aggregate_functions(&group_by).join(", ")); + query.push_str(" FROM "); + query.push_str(&self.table_name); + if !group_by.is_empty() { + query.push_str(" GROUP BY "); + query.push_str(&group_by.join(", ")); + } + query + } + + /// Generate a some random aggregate function invocations (potentially repeating). + /// + /// Each aggregate function invocation is of the form + /// + /// ```sql + /// function_name( argument) as alias + /// ``` + /// + /// where + /// * `function_names` are randomly selected from [`Self::aggregate_functions`] + /// * ` argument` is randomly selected from [`Self::arguments`] + /// * `alias` is a unique alias `colN` for the column (to avoid duplicate column names) + fn random_aggregate_functions(&self, group_by_cols: &[String]) -> Vec { + const MAX_NUM_FUNCTIONS: usize = 5; + let mut rng = rng(); + let num_aggregate_functions = rng.random_range(1..=MAX_NUM_FUNCTIONS); + + let mut alias_gen = 1; + + let mut aggregate_functions = vec![]; + + let mut order_by_black_list: HashSet = + group_by_cols.iter().cloned().collect(); + // remove one random col + if let Some(first) = order_by_black_list.iter().next().cloned() { + order_by_black_list.remove(&first); + } + + while aggregate_functions.len() < num_aggregate_functions { + let idx = rng.random_range(0..self.aggregate_functions.len()); + let (function_name, is_distinct) = &self.aggregate_functions[idx]; + let argument = self.random_argument(); + let alias = format!("col{alias_gen}"); + let distinct = if *is_distinct { "DISTINCT " } else { "" }; + alias_gen += 1; + + let (order_by, null_opt) = if function_name.eq("first_value") + || function_name.eq("last_value") + { + ( + self.order_by(&order_by_black_list), /* Among the order by columns, at most one group by column can be included to avoid all order by column values being identical */ + self.null_opt(), + ) + } else { + ("".to_string(), "".to_string()) + }; + + let function = format!( + "{function_name}({distinct}{argument}{order_by}) {null_opt} as {alias}" + ); + aggregate_functions.push(function); + } + aggregate_functions + } + + /// Pick a random aggregate function argument + fn random_argument(&self) -> String { + let mut rng = rng(); + let idx = rng.random_range(0..self.arguments.len()); + self.arguments[idx].clone() + } + + fn order_by(&self, black_list: &HashSet) -> String { + let mut available_columns: Vec = self + .arguments + .iter() + .filter(|col| !black_list.contains(*col)) + .cloned() + .collect(); + + available_columns.shuffle(&mut rng()); + + let num_of_order_by_col = 12; + let column_count = std::cmp::min(num_of_order_by_col, available_columns.len()); + + let selected_columns = &available_columns[0..column_count]; + + let mut rng = rng(); + let mut result = String::from_str(" order by ").unwrap(); + for col in selected_columns { + let order = if rng.random_bool(0.5) { "ASC" } else { "DESC" }; + result.push_str(&format!("{col} {order},")); + } + + result.strip_suffix(",").unwrap().to_string() + } + + fn null_opt(&self) -> String { + if rng().random_bool(0.5) { + "RESPECT NULLS".to_string() + } else { + "IGNORE NULLS".to_string() + } + } + + /// Pick a random number of fields to group by (non-repeating) + /// + /// Limited to `max_group_by_columns` group by columns to ensure coverage for large groups. + /// With larger numbers of columns, each group has many fewer values. + fn random_group_by(&self) -> Vec { + let mut rng = rng(); + let min_groups = self.min_group_by_columns; + let max_groups = self.max_group_by_columns; + assert!(min_groups <= max_groups); + let num_group_by = rng.random_range(min_groups..=max_groups); + + let mut already_used = HashSet::new(); + let mut group_by = vec![]; + while group_by.len() < num_group_by + && already_used.len() != self.group_by_columns.len() + { + let idx = rng.random_range(0..self.group_by_columns.len()); + if already_used.insert(idx) { + group_by.push(self.group_by_columns[idx].clone()); + } + } + group_by + } +} diff --git a/datafusion/core/tests/fuzz_cases/equivalence/ordering.rs b/datafusion/core/tests/fuzz_cases/equivalence/ordering.rs index 769deef1187d..d12d0a130c0c 100644 --- a/datafusion/core/tests/fuzz_cases/equivalence/ordering.rs +++ b/datafusion/core/tests/fuzz_cases/equivalence/ordering.rs @@ -69,16 +69,14 @@ fn test_ordering_satisfy_with_equivalence_random() -> Result<()> { table_data_with_properties.clone(), )?; let err_msg = format!( - "Error in test case requirement:{:?}, expected: {:?}, eq_properties {}", - requirement, expected, eq_properties + "Error in test case requirement:{requirement:?}, expected: {expected:?}, eq_properties {eq_properties}" ); // Check whether ordering_satisfy API result and // experimental result matches. assert_eq!( eq_properties.ordering_satisfy(requirement.as_ref()), expected, - "{}", - err_msg + "{err_msg}" ); } } @@ -141,8 +139,7 @@ fn test_ordering_satisfy_with_equivalence_complex_random() -> Result<()> { table_data_with_properties.clone(), )?; let err_msg = format!( - "Error in test case requirement:{:?}, expected: {:?}, eq_properties: {}", - requirement, expected, eq_properties, + "Error in test case requirement:{requirement:?}, expected: {expected:?}, eq_properties: {eq_properties}", ); // Check whether ordering_satisfy API result and // experimental result matches. @@ -150,8 +147,7 @@ fn test_ordering_satisfy_with_equivalence_complex_random() -> Result<()> { assert_eq!( eq_properties.ordering_satisfy(requirement.as_ref()), (expected | false), - "{}", - err_msg + "{err_msg}" ); } } diff --git a/datafusion/core/tests/fuzz_cases/equivalence/projection.rs b/datafusion/core/tests/fuzz_cases/equivalence/projection.rs index a3fa1157b38f..38e66387a02c 100644 --- a/datafusion/core/tests/fuzz_cases/equivalence/projection.rs +++ b/datafusion/core/tests/fuzz_cases/equivalence/projection.rs @@ -82,8 +82,7 @@ fn project_orderings_random() -> Result<()> { // Make sure each ordering after projection is valid. for ordering in projected_eq.oeq_class().iter() { let err_msg = format!( - "Error in test case ordering:{:?}, eq_properties {}, proj_exprs: {:?}", - ordering, eq_properties, proj_exprs, + "Error in test case ordering:{ordering:?}, eq_properties {eq_properties}, proj_exprs: {proj_exprs:?}", ); // Since ordered section satisfies schema, we expect // that result will be same after sort (e.g sort was unnecessary). @@ -178,16 +177,14 @@ fn ordering_satisfy_after_projection_random() -> Result<()> { projected_batch.clone(), )?; let err_msg = format!( - "Error in test case requirement:{:?}, expected: {:?}, eq_properties: {}, projected_eq: {}, projection_mapping: {:?}", - requirement, expected, eq_properties, projected_eq, projection_mapping + "Error in test case requirement:{requirement:?}, expected: {expected:?}, eq_properties: {eq_properties}, projected_eq: {projected_eq}, projection_mapping: {projection_mapping:?}" ); // Check whether ordering_satisfy API result and // experimental result matches. assert_eq!( projected_eq.ordering_satisfy(requirement.as_ref()), expected, - "{}", - err_msg + "{err_msg}" ); } } diff --git a/datafusion/core/tests/fuzz_cases/equivalence/properties.rs b/datafusion/core/tests/fuzz_cases/equivalence/properties.rs index 593e1c6c2dca..9a2146415749 100644 --- a/datafusion/core/tests/fuzz_cases/equivalence/properties.rs +++ b/datafusion/core/tests/fuzz_cases/equivalence/properties.rs @@ -84,10 +84,9 @@ fn test_find_longest_permutation_random() -> Result<()> { ); let err_msg = format!( - "Error in test case ordering:{:?}, eq_properties: {}", - ordering, eq_properties + "Error in test case ordering:{ordering:?}, eq_properties: {eq_properties}" ); - assert_eq!(ordering.len(), indices.len(), "{}", err_msg); + assert_eq!(ordering.len(), indices.len(), "{err_msg}"); // Since ordered section satisfies schema, we expect // that result will be same after sort (e.g sort was unnecessary). assert!( diff --git a/datafusion/core/tests/fuzz_cases/equivalence/utils.rs b/datafusion/core/tests/fuzz_cases/equivalence/utils.rs index d4b41b686631..a906648f872d 100644 --- a/datafusion/core/tests/fuzz_cases/equivalence/utils.rs +++ b/datafusion/core/tests/fuzz_cases/equivalence/utils.rs @@ -114,7 +114,7 @@ pub fn create_random_schema(seed: u64) -> Result<(SchemaRef, EquivalenceProperti }; while !remaining_exprs.is_empty() { - let n_sort_expr = rng.gen_range(0..remaining_exprs.len() + 1); + let n_sort_expr = rng.random_range(0..remaining_exprs.len() + 1); remaining_exprs.shuffle(&mut rng); let ordering = remaining_exprs @@ -369,7 +369,7 @@ pub fn generate_table_for_eq_properties( // Utility closure to generate random array let mut generate_random_array = |num_elems: usize, max_val: usize| -> ArrayRef { let values: Vec = (0..num_elems) - .map(|_| rng.gen_range(0..max_val) as f64 / 2.0) + .map(|_| rng.random_range(0..max_val) as f64 / 2.0) .collect(); Arc::new(Float64Array::from_iter_values(values)) }; @@ -524,7 +524,7 @@ fn generate_random_f64_array( rng: &mut StdRng, ) -> ArrayRef { let values: Vec = (0..n_elems) - .map(|_| rng.gen_range(0..n_distinct) as f64 / 2.0) + .map(|_| rng.random_range(0..n_distinct) as f64 / 2.0) .collect(); Arc::new(Float64Array::from_iter_values(values)) } diff --git a/datafusion/core/tests/fuzz_cases/join_fuzz.rs b/datafusion/core/tests/fuzz_cases/join_fuzz.rs index da93dd5edf29..82ee73b525cb 100644 --- a/datafusion/core/tests/fuzz_cases/join_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/join_fuzz.rs @@ -186,7 +186,7 @@ async fn test_full_join_1k_filtered() { } #[tokio::test] -async fn test_semi_join_1k() { +async fn test_left_semi_join_1k() { JoinFuzzTestCase::new( make_staggered_batches(1000), make_staggered_batches(1000), @@ -198,7 +198,7 @@ async fn test_semi_join_1k() { } #[tokio::test] -async fn test_semi_join_1k_filtered() { +async fn test_left_semi_join_1k_filtered() { JoinFuzzTestCase::new( make_staggered_batches(1000), make_staggered_batches(1000), @@ -209,6 +209,30 @@ async fn test_semi_join_1k_filtered() { .await } +#[tokio::test] +async fn test_right_semi_join_1k() { + JoinFuzzTestCase::new( + make_staggered_batches(1000), + make_staggered_batches(1000), + JoinType::RightSemi, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await +} + +#[tokio::test] +async fn test_right_semi_join_1k_filtered() { + JoinFuzzTestCase::new( + make_staggered_batches(1000), + make_staggered_batches(1000), + JoinType::RightSemi, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await +} + #[tokio::test] async fn test_left_anti_join_1k() { JoinFuzzTestCase::new( @@ -545,7 +569,7 @@ impl JoinFuzzTestCase { std::fs::remove_dir_all(fuzz_debug).unwrap_or(()); std::fs::create_dir_all(fuzz_debug).unwrap(); let out_dir_name = &format!("{fuzz_debug}/batch_size_{batch_size}"); - println!("Test result data mismatch found. HJ rows {}, SMJ rows {}, NLJ rows {}", hj_rows, smj_rows, nlj_rows); + println!("Test result data mismatch found. HJ rows {hj_rows}, SMJ rows {smj_rows}, NLJ rows {nlj_rows}"); println!("The debug is ON. Input data will be saved to {out_dir_name}"); Self::save_partitioned_batches_as_parquet( @@ -561,9 +585,9 @@ impl JoinFuzzTestCase { if join_tests.contains(&NljHj) && nlj_rows != hj_rows { println!("=============== HashJoinExec =================="); - hj_formatted_sorted.iter().for_each(|s| println!("{}", s)); + hj_formatted_sorted.iter().for_each(|s| println!("{s}")); println!("=============== NestedLoopJoinExec =================="); - nlj_formatted_sorted.iter().for_each(|s| println!("{}", s)); + nlj_formatted_sorted.iter().for_each(|s| println!("{s}")); Self::save_partitioned_batches_as_parquet( &nlj_collected, @@ -579,9 +603,9 @@ impl JoinFuzzTestCase { if join_tests.contains(&HjSmj) && smj_rows != hj_rows { println!("=============== HashJoinExec =================="); - hj_formatted_sorted.iter().for_each(|s| println!("{}", s)); + hj_formatted_sorted.iter().for_each(|s| println!("{s}")); println!("=============== SortMergeJoinExec =================="); - smj_formatted_sorted.iter().for_each(|s| println!("{}", s)); + smj_formatted_sorted.iter().for_each(|s| println!("{s}")); Self::save_partitioned_batches_as_parquet( &hj_collected, @@ -597,10 +621,10 @@ impl JoinFuzzTestCase { } if join_tests.contains(&NljHj) { - let err_msg_rowcnt = format!("NestedLoopJoinExec and HashJoinExec produced different row counts, batch_size: {}", batch_size); + let err_msg_rowcnt = format!("NestedLoopJoinExec and HashJoinExec produced different row counts, batch_size: {batch_size}"); assert_eq!(nlj_rows, hj_rows, "{}", err_msg_rowcnt.as_str()); - let err_msg_contents = format!("NestedLoopJoinExec and HashJoinExec produced different results, batch_size: {}", batch_size); + let err_msg_contents = format!("NestedLoopJoinExec and HashJoinExec produced different results, batch_size: {batch_size}"); // row level compare if any of joins returns the result // the reason is different formatting when there is no rows for (i, (nlj_line, hj_line)) in nlj_formatted_sorted @@ -671,7 +695,7 @@ impl JoinFuzzTestCase { std::fs::create_dir_all(out_path).unwrap(); input.iter().enumerate().for_each(|(idx, batch)| { - let file_path = format!("{out_path}/file_{}.parquet", idx); + let file_path = format!("{out_path}/file_{idx}.parquet"); let mut file = std::fs::File::create(&file_path).unwrap(); println!( "{}: Saving batch idx {} rows {} to parquet {}", @@ -722,11 +746,9 @@ impl JoinFuzzTestCase { path.to_str().unwrap(), datafusion::prelude::ParquetReadOptions::default(), ) - .await - .unwrap() + .await? .collect() - .await - .unwrap(); + .await?; batches.append(&mut batch); } @@ -739,13 +761,13 @@ impl JoinFuzzTestCase { /// two sorted int32 columns 'a', 'b' ranged from 0..99 as join columns /// two random int32 columns 'x', 'y' as other columns fn make_staggered_batches(len: usize) -> Vec { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let mut input12: Vec<(i32, i32)> = vec![(0, 0); len]; let mut input3: Vec = vec![0; len]; let mut input4: Vec = vec![0; len]; input12 .iter_mut() - .for_each(|v| *v = (rng.gen_range(0..100), rng.gen_range(0..100))); + .for_each(|v| *v = (rng.random_range(0..100), rng.random_range(0..100))); rng.fill(&mut input3[..]); rng.fill(&mut input4[..]); input12.sort_unstable(); diff --git a/datafusion/core/tests/fuzz_cases/limit_fuzz.rs b/datafusion/core/tests/fuzz_cases/limit_fuzz.rs index 987a732eb294..4c5ebf040241 100644 --- a/datafusion/core/tests/fuzz_cases/limit_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/limit_fuzz.rs @@ -24,7 +24,7 @@ use arrow::util::pretty::pretty_format_batches; use datafusion::datasource::MemTable; use datafusion::prelude::SessionContext; use datafusion_common::assert_contains; -use rand::{thread_rng, Rng}; +use rand::{rng, Rng}; use std::sync::Arc; use test_utils::stagger_batch; @@ -54,11 +54,11 @@ async fn run_limit_fuzz_test(make_data: F) where F: Fn(usize) -> SortedData, { - let mut rng = thread_rng(); + let mut rng = rng(); for size in [10, 1_0000, 10_000, 100_000] { let data = make_data(size); // test various limits including some random ones - for limit in [1, 3, 7, 17, 10000, rng.gen_range(1..size * 2)] { + for limit in [1, 3, 7, 17, 10000, rng.random_range(1..size * 2)] { // limit can be larger than the number of rows in the input run_limit_test(limit, &data).await; } @@ -97,13 +97,13 @@ impl SortedData { /// Create an i32 column of random values, with the specified number of /// rows, sorted the default fn new_i32(size: usize) -> Self { - let mut rng = thread_rng(); + let mut rng = rng(); // have some repeats (approximately 1/3 of the values are the same) let max = size as i32 / 3; let data: Vec> = (0..size) .map(|_| { // no nulls for now - Some(rng.gen_range(0..max)) + Some(rng.random_range(0..max)) }) .collect(); @@ -118,17 +118,17 @@ impl SortedData { /// Create an f64 column of random values, with the specified number of /// rows, sorted the default fn new_f64(size: usize) -> Self { - let mut rng = thread_rng(); + let mut rng = rng(); let mut data: Vec> = (0..size / 3) .map(|_| { // no nulls for now - Some(rng.gen_range(0.0..1.0f64)) + Some(rng.random_range(0.0..1.0f64)) }) .collect(); // have some repeats (approximately 1/3 of the values are the same) while data.len() < size { - data.push(data[rng.gen_range(0..data.len())]); + data.push(data[rng.random_range(0..data.len())]); } let batches = stagger_batch(f64_batch(data.iter().cloned())); @@ -142,7 +142,7 @@ impl SortedData { /// Create an string column of random values, with the specified number of /// rows, sorted the default fn new_str(size: usize) -> Self { - let mut rng = thread_rng(); + let mut rng = rng(); let mut data: Vec> = (0..size / 3) .map(|_| { // no nulls for now @@ -152,7 +152,7 @@ impl SortedData { // have some repeats (approximately 1/3 of the values are the same) while data.len() < size { - data.push(data[rng.gen_range(0..data.len())].clone()); + data.push(data[rng.random_range(0..data.len())].clone()); } let batches = stagger_batch(string_batch(data.iter())); @@ -166,7 +166,7 @@ impl SortedData { /// Create two columns of random values (int64, string), with the specified number of /// rows, sorted the default fn new_i64str(size: usize) -> Self { - let mut rng = thread_rng(); + let mut rng = rng(); // 100 distinct values let strings: Vec> = (0..100) @@ -180,8 +180,8 @@ impl SortedData { let data = (0..size) .map(|_| { ( - Some(rng.gen_range(0..10)), - strings[rng.gen_range(0..strings.len())].clone(), + Some(rng.random_range(0..10)), + strings[rng.random_range(0..strings.len())].clone(), ) }) .collect::>(); @@ -340,8 +340,8 @@ async fn run_limit_test(fetch: usize, data: &SortedData) { /// Return random ASCII String with len fn get_random_string(len: usize) -> String { - thread_rng() - .sample_iter(rand::distributions::Alphanumeric) + rng() + .sample_iter(rand::distr::Alphanumeric) .take(len) .map(char::from) .collect() diff --git a/datafusion/core/tests/fuzz_cases/pruning.rs b/datafusion/core/tests/fuzz_cases/pruning.rs index 11dd961a54ee..eef4e8d8856f 100644 --- a/datafusion/core/tests/fuzz_cases/pruning.rs +++ b/datafusion/core/tests/fuzz_cases/pruning.rs @@ -90,42 +90,42 @@ async fn test_utf8_not_like() { #[tokio::test] async fn test_utf8_like_prefix() { - Utf8Test::new(|value| col("a").like(lit(format!("%{}", value)))) + Utf8Test::new(|value| col("a").like(lit(format!("%{value}")))) .run() .await; } #[tokio::test] async fn test_utf8_like_suffix() { - Utf8Test::new(|value| col("a").like(lit(format!("{}%", value)))) + Utf8Test::new(|value| col("a").like(lit(format!("{value}%")))) .run() .await; } #[tokio::test] async fn test_utf8_not_like_prefix() { - Utf8Test::new(|value| col("a").not_like(lit(format!("%{}", value)))) + Utf8Test::new(|value| col("a").not_like(lit(format!("%{value}")))) .run() .await; } #[tokio::test] async fn test_utf8_not_like_ecsape() { - Utf8Test::new(|value| col("a").not_like(lit(format!("\\%{}%", value)))) + Utf8Test::new(|value| col("a").not_like(lit(format!("\\%{value}%")))) .run() .await; } #[tokio::test] async fn test_utf8_not_like_suffix() { - Utf8Test::new(|value| col("a").not_like(lit(format!("{}%", value)))) + Utf8Test::new(|value| col("a").not_like(lit(format!("{value}%")))) .run() .await; } #[tokio::test] async fn test_utf8_not_like_suffix_one() { - Utf8Test::new(|value| col("a").not_like(lit(format!("{}_", value)))) + Utf8Test::new(|value| col("a").not_like(lit(format!("{value}_")))) .run() .await; } @@ -226,7 +226,7 @@ impl Utf8Test { return (*files).clone(); } - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let values = Self::values(); let mut row_groups = vec![]; @@ -345,7 +345,7 @@ async fn write_parquet_file( /// The string values for [Utf8Test::values] static VALUES: LazyLock> = LazyLock::new(|| { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let characters = [ "z", diff --git a/datafusion/core/tests/fuzz_cases/record_batch_generator.rs b/datafusion/core/tests/fuzz_cases/record_batch_generator.rs index 9a62a6397d82..4eac1482ad3f 100644 --- a/datafusion/core/tests/fuzz_cases/record_batch_generator.rs +++ b/datafusion/core/tests/fuzz_cases/record_batch_generator.rs @@ -17,13 +17,15 @@ use std::sync::Arc; -use arrow::array::{ArrayRef, RecordBatch}; +use arrow::array::{ArrayRef, DictionaryArray, PrimitiveArray, RecordBatch}; use arrow::datatypes::{ - BooleanType, DataType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, Field, - Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, - IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType, - Schema, Time32MillisecondType, Time32SecondType, Time64MicrosecondType, - Time64NanosecondType, TimeUnit, TimestampMicrosecondType, TimestampMillisecondType, + ArrowPrimitiveType, BooleanType, DataType, Date32Type, Date64Type, Decimal128Type, + Decimal256Type, DurationMicrosecondType, DurationMillisecondType, + DurationNanosecondType, DurationSecondType, Field, Float32Type, Float64Type, + Int16Type, Int32Type, Int64Type, Int8Type, IntervalDayTimeType, + IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType, Schema, + Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType, + TimeUnit, TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, }; @@ -32,7 +34,7 @@ use arrow_schema::{ DECIMAL256_MAX_SCALE, }; use datafusion_common::{arrow_datafusion_err, DataFusionError, Result}; -use rand::{rngs::StdRng, thread_rng, Rng, SeedableRng}; +use rand::{rng, rngs::StdRng, Rng, SeedableRng}; use test_utils::array_gen::{ BinaryArrayGenerator, BooleanArrayGenerator, DecimalArrayGenerator, PrimitiveArrayGenerator, StringArrayGenerator, @@ -85,16 +87,33 @@ pub fn get_supported_types_columns(rng_seed: u64) -> Vec { "interval_month_day_nano", DataType::Interval(IntervalUnit::MonthDayNano), ), + // Internal error: AggregationFuzzer task error: JoinError::Panic(Id(29108), "called `Option::unwrap()` on a `None` value", ...). + // ColumnDescr::new( + // "duration_seconds", + // DataType::Duration(TimeUnit::Second), + // ), + ColumnDescr::new( + "duration_milliseconds", + DataType::Duration(TimeUnit::Millisecond), + ), + ColumnDescr::new( + "duration_microsecond", + DataType::Duration(TimeUnit::Microsecond), + ), + ColumnDescr::new( + "duration_nanosecond", + DataType::Duration(TimeUnit::Nanosecond), + ), ColumnDescr::new("decimal128", { - let precision: u8 = rng.gen_range(1..=DECIMAL128_MAX_PRECISION); - let scale: i8 = rng.gen_range( + let precision: u8 = rng.random_range(1..=DECIMAL128_MAX_PRECISION); + let scale: i8 = rng.random_range( i8::MIN..=std::cmp::min(precision as i8, DECIMAL128_MAX_SCALE), ); DataType::Decimal128(precision, scale) }), ColumnDescr::new("decimal256", { - let precision: u8 = rng.gen_range(1..=DECIMAL256_MAX_PRECISION); - let scale: i8 = rng.gen_range( + let precision: u8 = rng.random_range(1..=DECIMAL256_MAX_PRECISION); + let scale: i8 = rng.random_range( i8::MIN..=std::cmp::min(precision as i8, DECIMAL256_MAX_SCALE), ); DataType::Decimal256(precision, scale) @@ -108,6 +127,11 @@ pub fn get_supported_types_columns(rng_seed: u64) -> Vec { ColumnDescr::new("binary", DataType::Binary), ColumnDescr::new("large_binary", DataType::LargeBinary), ColumnDescr::new("binaryview", DataType::BinaryView), + ColumnDescr::new( + "dictionary_utf8_low", + DataType::Dictionary(Box::new(DataType::UInt64), Box::new(DataType::Utf8)), + ) + .with_max_num_distinct(10), ] } @@ -161,22 +185,19 @@ pub struct RecordBatchGenerator { /// If a seed is provided when constructing the generator, it will be used to /// create `rng` and the pseudo-randomly generated batches will be deterministic. - /// Otherwise, `rng` will be initialized using `thread_rng()` and the batches + /// Otherwise, `rng` will be initialized using `rng()` and the batches /// generated will be different each time. rng: StdRng, } macro_rules! generate_decimal_array { - ($SELF:ident, $NUM_ROWS:ident, $MAX_NUM_DISTINCT: expr, $BATCH_GEN_RNG:ident, $ARRAY_GEN_RNG:ident, $PRECISION: ident, $SCALE: ident, $ARROW_TYPE: ident) => {{ - let null_pct_idx = $BATCH_GEN_RNG.gen_range(0..$SELF.candidate_null_pcts.len()); - let null_pct = $SELF.candidate_null_pcts[null_pct_idx]; - + ($SELF:ident, $NUM_ROWS:ident, $MAX_NUM_DISTINCT: expr, $NULL_PCT:ident, $BATCH_GEN_RNG:ident, $ARRAY_GEN_RNG:ident, $PRECISION: ident, $SCALE: ident, $ARROW_TYPE: ident) => {{ let mut generator = DecimalArrayGenerator { precision: $PRECISION, scale: $SCALE, num_decimals: $NUM_ROWS, num_distinct_decimals: $MAX_NUM_DISTINCT, - null_pct, + null_pct: $NULL_PCT, rng: $ARRAY_GEN_RNG, }; @@ -186,17 +207,13 @@ macro_rules! generate_decimal_array { // Generating `BooleanArray` due to it being a special type in Arrow (bit-packed) macro_rules! generate_boolean_array { - ($SELF:ident, $NUM_ROWS:ident, $MAX_NUM_DISTINCT:expr, $BATCH_GEN_RNG:ident, $ARRAY_GEN_RNG:ident, $ARROW_TYPE: ident) => {{ - // Select a null percentage from the candidate percentages - let null_pct_idx = $BATCH_GEN_RNG.gen_range(0..$SELF.candidate_null_pcts.len()); - let null_pct = $SELF.candidate_null_pcts[null_pct_idx]; - + ($SELF:ident, $NUM_ROWS:ident, $MAX_NUM_DISTINCT:expr, $NULL_PCT:ident, $BATCH_GEN_RNG:ident, $ARRAY_GEN_RNG:ident, $ARROW_TYPE: ident) => {{ let num_distinct_booleans = if $MAX_NUM_DISTINCT >= 2 { 2 } else { 1 }; let mut generator = BooleanArrayGenerator { num_booleans: $NUM_ROWS, num_distinct_booleans, - null_pct, + null_pct: $NULL_PCT, rng: $ARRAY_GEN_RNG, }; @@ -205,14 +222,11 @@ macro_rules! generate_boolean_array { } macro_rules! generate_primitive_array { - ($SELF:ident, $NUM_ROWS:ident, $MAX_NUM_DISTINCT:expr, $BATCH_GEN_RNG:ident, $ARRAY_GEN_RNG:ident, $ARROW_TYPE:ident) => {{ - let null_pct_idx = $BATCH_GEN_RNG.gen_range(0..$SELF.candidate_null_pcts.len()); - let null_pct = $SELF.candidate_null_pcts[null_pct_idx]; - + ($SELF:ident, $NUM_ROWS:ident, $MAX_NUM_DISTINCT:expr, $NULL_PCT:ident, $BATCH_GEN_RNG:ident, $ARRAY_GEN_RNG:ident, $ARROW_TYPE:ident) => {{ let mut generator = PrimitiveArrayGenerator { num_primitives: $NUM_ROWS, num_distinct_primitives: $MAX_NUM_DISTINCT, - null_pct, + null_pct: $NULL_PCT, rng: $ARRAY_GEN_RNG, }; @@ -220,6 +234,28 @@ macro_rules! generate_primitive_array { }}; } +macro_rules! generate_dict { + ($SELF:ident, $NUM_ROWS:ident, $MAX_NUM_DISTINCT:expr, $NULL_PCT:ident, $BATCH_GEN_RNG:ident, $ARRAY_GEN_RNG:ident, $ARROW_TYPE:ident, $VALUES: ident) => {{ + debug_assert_eq!($VALUES.len(), $MAX_NUM_DISTINCT); + let keys: PrimitiveArray<$ARROW_TYPE> = (0..$NUM_ROWS) + .map(|_| { + if $BATCH_GEN_RNG.random::() < $NULL_PCT { + None + } else if $MAX_NUM_DISTINCT > 1 { + let range = 0..($MAX_NUM_DISTINCT + as <$ARROW_TYPE as ArrowPrimitiveType>::Native); + Some($ARRAY_GEN_RNG.random_range(range)) + } else { + Some(0) + } + }) + .collect(); + + let dict = DictionaryArray::new(keys, $VALUES); + Arc::new(dict) as ArrayRef + }}; +} + impl RecordBatchGenerator { /// Create a new `RecordBatchGenerator` with a random seed. The generated /// batches will be different each time. @@ -235,7 +271,7 @@ impl RecordBatchGenerator { max_rows_num, columns, candidate_null_pcts, - rng: StdRng::from_rng(thread_rng()).unwrap(), + rng: StdRng::from_rng(&mut rng()), } } @@ -247,9 +283,9 @@ impl RecordBatchGenerator { } pub fn generate(&mut self) -> Result { - let num_rows = self.rng.gen_range(self.min_rows_num..=self.max_rows_num); - let array_gen_rng = StdRng::from_seed(self.rng.gen()); - let mut batch_gen_rng = StdRng::from_seed(self.rng.gen()); + let num_rows = self.rng.random_range(self.min_rows_num..=self.max_rows_num); + let array_gen_rng = StdRng::from_seed(self.rng.random()); + let mut batch_gen_rng = StdRng::from_seed(self.rng.random()); let columns = self.columns.clone(); // Build arrays @@ -281,9 +317,28 @@ impl RecordBatchGenerator { num_rows: usize, batch_gen_rng: &mut StdRng, array_gen_rng: StdRng, + ) -> ArrayRef { + let null_pct_idx = batch_gen_rng.random_range(0..self.candidate_null_pcts.len()); + let null_pct = self.candidate_null_pcts[null_pct_idx]; + + Self::generate_array_of_type_inner( + col, + num_rows, + batch_gen_rng, + array_gen_rng, + null_pct, + ) + } + + fn generate_array_of_type_inner( + col: &ColumnDescr, + num_rows: usize, + batch_gen_rng: &mut StdRng, + array_gen_rng: StdRng, + null_pct: f64, ) -> ArrayRef { let num_distinct = if num_rows > 1 { - batch_gen_rng.gen_range(1..num_rows) + batch_gen_rng.random_range(1..num_rows) } else { num_rows }; @@ -299,6 +354,7 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, Int8Type @@ -309,6 +365,7 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, Int16Type @@ -319,6 +376,7 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, Int32Type @@ -329,6 +387,7 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, Int64Type @@ -339,6 +398,7 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, UInt8Type @@ -349,6 +409,7 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, UInt16Type @@ -359,6 +420,7 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, UInt32Type @@ -369,6 +431,7 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, UInt64Type @@ -379,6 +442,7 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, Float32Type @@ -389,6 +453,7 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, Float64Type @@ -399,6 +464,7 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, Date32Type @@ -409,6 +475,7 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, Date64Type @@ -419,6 +486,7 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, Time32SecondType @@ -429,6 +497,7 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, Time32MillisecondType @@ -439,6 +508,7 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, Time64MicrosecondType @@ -449,6 +519,7 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, Time64NanosecondType @@ -459,6 +530,7 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, IntervalYearMonthType @@ -469,6 +541,7 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, IntervalDayTimeType @@ -479,16 +552,62 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, IntervalMonthDayNanoType ) } + DataType::Duration(TimeUnit::Second) => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + null_pct, + batch_gen_rng, + array_gen_rng, + DurationSecondType + ) + } + DataType::Duration(TimeUnit::Millisecond) => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + null_pct, + batch_gen_rng, + array_gen_rng, + DurationMillisecondType + ) + } + DataType::Duration(TimeUnit::Microsecond) => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + null_pct, + batch_gen_rng, + array_gen_rng, + DurationMicrosecondType + ) + } + DataType::Duration(TimeUnit::Nanosecond) => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + null_pct, + batch_gen_rng, + array_gen_rng, + DurationNanosecondType + ) + } DataType::Timestamp(TimeUnit::Second, None) => { generate_primitive_array!( self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, TimestampSecondType @@ -499,6 +618,7 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, TimestampMillisecondType @@ -509,6 +629,7 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, TimestampMicrosecondType @@ -519,16 +640,14 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, TimestampNanosecondType ) } DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => { - let null_pct_idx = - batch_gen_rng.gen_range(0..self.candidate_null_pcts.len()); - let null_pct = self.candidate_null_pcts[null_pct_idx]; - let max_len = batch_gen_rng.gen_range(1..50); + let max_len = batch_gen_rng.random_range(1..50); let mut generator = StringArrayGenerator { max_len, @@ -546,10 +665,7 @@ impl RecordBatchGenerator { } } DataType::Binary | DataType::LargeBinary | DataType::BinaryView => { - let null_pct_idx = - batch_gen_rng.gen_range(0..self.candidate_null_pcts.len()); - let null_pct = self.candidate_null_pcts[null_pct_idx]; - let max_len = batch_gen_rng.gen_range(1..100); + let max_len = batch_gen_rng.random_range(1..100); let mut generator = BinaryArrayGenerator { max_len, @@ -571,6 +687,7 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, precision, @@ -583,6 +700,7 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, precision, @@ -595,11 +713,43 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, BooleanType } } + DataType::Dictionary(ref key_type, ref value_type) + if key_type.is_dictionary_key_type() => + { + // We generate just num_distinct values because they will be reused by different keys + let mut array_gen_rng = array_gen_rng; + + let values = Self::generate_array_of_type_inner( + &ColumnDescr::new("values", *value_type.clone()), + num_distinct, + batch_gen_rng, + array_gen_rng.clone(), + // Once https://github.com/apache/datafusion/issues/16228 is fixed + // we can also generate nulls in values + 0.0, // null values are generated on the key level + ); + + match key_type.as_ref() { + // new key types can be added here + DataType::UInt64 => generate_dict!( + self, + num_rows, + num_distinct, + null_pct, + batch_gen_rng, + array_gen_rng, + UInt64Type, + values + ), + _ => panic!("Invalid dictionary keys type: {key_type}"), + } + } _ => { panic!("Unsupported data generator type: {}", col.column_type) } @@ -636,8 +786,8 @@ mod tests { let batch1 = gen1.generate().unwrap(); let batch2 = gen2.generate().unwrap(); - let batch1_formatted = format!("{:?}", batch1); - let batch2_formatted = format!("{:?}", batch2); + let batch1_formatted = format!("{batch1:?}"); + let batch2_formatted = format!("{batch2:?}"); assert_eq!(batch1_formatted, batch2_formatted); } diff --git a/datafusion/core/tests/fuzz_cases/sort_fuzz.rs b/datafusion/core/tests/fuzz_cases/sort_fuzz.rs index 0b0f0aa2f105..703b8715821a 100644 --- a/datafusion/core/tests/fuzz_cases/sort_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/sort_fuzz.rs @@ -248,7 +248,7 @@ impl SortTest { let exec = MemorySourceConfig::try_new_exec(&input, schema, None).unwrap(); let sort = Arc::new(SortExec::new(sort_ordering, exec)); - let session_config = SessionConfig::new(); + let session_config = SessionConfig::new().with_repartition_file_scans(false); let session_ctx = if let Some(pool_size) = self.pool_size { // Make sure there is enough space for the initial spill // reservation @@ -298,20 +298,20 @@ impl SortTest { /// Return randomly sized record batches in a field named 'x' of type `Int32` /// with randomized i32 content fn make_staggered_i32_batches(len: usize) -> Vec { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let max_batch = 1024; let mut batches = vec![]; let mut remaining = len; while remaining != 0 { - let to_read = rng.gen_range(0..=remaining.min(max_batch)); + let to_read = rng.random_range(0..=remaining.min(max_batch)); remaining -= to_read; batches.push( RecordBatch::try_from_iter(vec![( "x", Arc::new(Int32Array::from_iter_values( - (0..to_read).map(|_| rng.gen()), + (0..to_read).map(|_| rng.random()), )) as ArrayRef, )]) .unwrap(), @@ -323,20 +323,20 @@ fn make_staggered_i32_batches(len: usize) -> Vec { /// Return randomly sized record batches in a field named 'x' of type `Utf8` /// with randomized content fn make_staggered_utf8_batches(len: usize) -> Vec { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let max_batch = 1024; let mut batches = vec![]; let mut remaining = len; while remaining != 0 { - let to_read = rng.gen_range(0..=remaining.min(max_batch)); + let to_read = rng.random_range(0..=remaining.min(max_batch)); remaining -= to_read; batches.push( RecordBatch::try_from_iter(vec![( "x", Arc::new(StringArray::from_iter_values( - (0..to_read).map(|_| format!("test_string_{}", rng.gen::())), + (0..to_read).map(|_| format!("test_string_{}", rng.random::())), )) as ArrayRef, )]) .unwrap(), @@ -349,13 +349,13 @@ fn make_staggered_utf8_batches(len: usize) -> Vec { /// with randomized i32 content and a field named 'y' of type `Utf8` /// with randomized content fn make_staggered_i32_utf8_batches(len: usize) -> Vec { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let max_batch = 1024; let mut batches = vec![]; let mut remaining = len; while remaining != 0 { - let to_read = rng.gen_range(0..=remaining.min(max_batch)); + let to_read = rng.random_range(0..=remaining.min(max_batch)); remaining -= to_read; batches.push( @@ -363,13 +363,14 @@ fn make_staggered_i32_utf8_batches(len: usize) -> Vec { ( "x", Arc::new(Int32Array::from_iter_values( - (0..to_read).map(|_| rng.gen()), + (0..to_read).map(|_| rng.random()), )) as ArrayRef, ), ( "y", Arc::new(StringArray::from_iter_values( - (0..to_read).map(|_| format!("test_string_{}", rng.gen::())), + (0..to_read) + .map(|_| format!("test_string_{}", rng.random::())), )) as ArrayRef, ), ]) diff --git a/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs b/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs index 06b93d41af36..cf6867758edc 100644 --- a/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs @@ -94,7 +94,7 @@ mod sp_repartition_fuzz_tests { }; while !remaining_exprs.is_empty() { - let n_sort_expr = rng.gen_range(0..remaining_exprs.len() + 1); + let n_sort_expr = rng.random_range(0..remaining_exprs.len() + 1); remaining_exprs.shuffle(&mut rng); let ordering = remaining_exprs @@ -144,7 +144,7 @@ mod sp_repartition_fuzz_tests { // Utility closure to generate random array let mut generate_random_array = |num_elems: usize, max_val: usize| -> ArrayRef { let values: Vec = (0..num_elems) - .map(|_| rng.gen_range(0..max_val) as u64) + .map(|_| rng.random_range(0..max_val) as u64) .collect(); Arc::new(UInt64Array::from_iter_values(values)) }; @@ -261,7 +261,7 @@ mod sp_repartition_fuzz_tests { let res = concat_batches(&res[0].schema(), &res)?; for ordering in eq_properties.oeq_class().iter() { - let err_msg = format!("error in eq properties: {:?}", eq_properties); + let err_msg = format!("error in eq properties: {eq_properties:?}"); let sort_columns = ordering .iter() .map(|sort_expr| sort_expr.evaluate_to_sort_column(&res)) @@ -273,7 +273,7 @@ mod sp_repartition_fuzz_tests { let sorted_columns = lexsort(&sort_columns, None)?; // Make sure after merging ordering is still valid. - assert_eq!(orig_columns.len(), sorted_columns.len(), "{}", err_msg); + assert_eq!(orig_columns.len(), sorted_columns.len(), "{err_msg}"); assert!( izip!(orig_columns.into_iter(), sorted_columns.into_iter()) .all(|(lhs, rhs)| { lhs == rhs }), @@ -447,9 +447,9 @@ mod sp_repartition_fuzz_tests { let mut input123: Vec<(i64, i64, i64)> = vec![(0, 0, 0); len]; input123.iter_mut().for_each(|v| { *v = ( - rng.gen_range(0..n_distinct) as i64, - rng.gen_range(0..n_distinct) as i64, - rng.gen_range(0..n_distinct) as i64, + rng.random_range(0..n_distinct) as i64, + rng.random_range(0..n_distinct) as i64, + rng.random_range(0..n_distinct) as i64, ) }); input123.sort(); @@ -471,7 +471,7 @@ mod sp_repartition_fuzz_tests { let mut batches = vec![]; if STREAM { while remainder.num_rows() > 0 { - let batch_size = rng.gen_range(0..50); + let batch_size = rng.random_range(0..50); if remainder.num_rows() < batch_size { break; } @@ -481,7 +481,7 @@ mod sp_repartition_fuzz_tests { } } else { while remainder.num_rows() > 0 { - let batch_size = rng.gen_range(0..remainder.num_rows() + 1); + let batch_size = rng.random_range(0..remainder.num_rows() + 1); batches.push(remainder.slice(0, batch_size)); remainder = remainder.slice(batch_size, remainder.num_rows() - batch_size); diff --git a/datafusion/core/tests/fuzz_cases/sort_query_fuzz.rs b/datafusion/core/tests/fuzz_cases/sort_query_fuzz.rs index 1319d4817326..d2d3a5e0c22f 100644 --- a/datafusion/core/tests/fuzz_cases/sort_query_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/sort_query_fuzz.rs @@ -25,18 +25,16 @@ use arrow_schema::SchemaRef; use datafusion::datasource::MemTable; use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_common::{instant::Instant, Result}; +use datafusion_execution::disk_manager::DiskManagerBuilder; use datafusion_execution::memory_pool::{ human_readable_size, MemoryPool, UnboundedMemoryPool, }; use datafusion_expr::display_schema; use datafusion_physical_plan::spill::get_record_batch_memory_size; -use rand::seq::SliceRandom; use std::time::Duration; -use datafusion_execution::{ - disk_manager::DiskManagerConfig, memory_pool::FairSpillPool, - runtime_env::RuntimeEnvBuilder, -}; +use datafusion_execution::{memory_pool::FairSpillPool, runtime_env::RuntimeEnvBuilder}; +use rand::prelude::IndexedRandom; use rand::Rng; use rand::{rngs::StdRng, SeedableRng}; @@ -199,25 +197,24 @@ impl SortQueryFuzzer { // Execute until either`max_rounds` or `time_limit` is reached let max_rounds = self.max_rounds.unwrap_or(usize::MAX); for round in 0..max_rounds { - let init_seed = self.runner_rng.gen(); + let init_seed = self.runner_rng.random(); for query_i in 0..self.queries_per_round { - let query_seed = self.runner_rng.gen(); + let query_seed = self.runner_rng.random(); let mut expected_results: Option> = None; // use first config's result as the expected result for config_i in 0..self.config_variations_per_query { if self.should_stop_due_to_time_limit(start_time, round, query_i) { return Ok(()); } - let config_seed = self.runner_rng.gen(); + let config_seed = self.runner_rng.random(); println!( - "[SortQueryFuzzer] Round {}, Query {} (Config {})", - round, query_i, config_i + "[SortQueryFuzzer] Round {round}, Query {query_i} (Config {config_i})" ); println!(" Seeds:"); - println!(" init_seed = {}", init_seed); - println!(" query_seed = {}", query_seed); - println!(" config_seed = {}", config_seed); + println!(" init_seed = {init_seed}"); + println!(" query_seed = {query_seed}"); + println!(" config_seed = {config_seed}"); let results = self .test_gen @@ -300,7 +297,7 @@ impl SortFuzzerTestGenerator { let mut rng = StdRng::seed_from_u64(rng_seed); let min_ncol = min(candidate_columns.len(), 5); let max_ncol = min(candidate_columns.len(), 10); - let amount = rng.gen_range(min_ncol..=max_ncol); + let amount = rng.random_range(min_ncol..=max_ncol); let selected_columns = candidate_columns .choose_multiple(&mut rng, amount) .cloned() @@ -327,7 +324,7 @@ impl SortFuzzerTestGenerator { /// memory table should be generated with more partitions, due to https://github.com/apache/datafusion/issues/15088 fn init_partitioned_staggered_batches(&mut self, rng_seed: u64) { let mut rng = StdRng::seed_from_u64(rng_seed); - let num_partitions = rng.gen_range(1..=self.max_partitions); + let num_partitions = rng.random_range(1..=self.max_partitions); let max_batch_size = self.num_rows / num_partitions / 50; let target_partition_size = self.num_rows / num_partitions; @@ -344,7 +341,7 @@ impl SortFuzzerTestGenerator { // Generate a random batch of size between 1 and max_batch_size // Let edge case (1-row batch) more common - let (min_nrow, max_nrow) = if rng.gen_bool(0.1) { + let (min_nrow, max_nrow) = if rng.random_bool(0.1) { (1, 3) } else { (1, max_batch_size) @@ -355,7 +352,7 @@ impl SortFuzzerTestGenerator { max_nrow, self.selected_columns.clone(), ) - .with_seed(rng.gen()); + .with_seed(rng.random()); let record_batch = record_batch_generator.generate().unwrap(); num_rows += record_batch.num_rows(); @@ -373,9 +370,9 @@ impl SortFuzzerTestGenerator { } // After all partitions are created, optionally make one partition have 0/1 batch - if num_partitions > 2 && rng.gen_bool(0.1) { - let partition_index = rng.gen_range(0..num_partitions); - if rng.gen_bool(0.5) { + if num_partitions > 2 && rng.random_bool(0.1) { + let partition_index = rng.random_range(0..num_partitions); + if rng.random_bool(0.5) { // 0 batch partitions[partition_index] = Vec::new(); } else { @@ -424,7 +421,7 @@ impl SortFuzzerTestGenerator { pub fn generate_random_query(&self, rng_seed: u64) -> (String, Option) { let mut rng = StdRng::seed_from_u64(rng_seed); - let num_columns = rng.gen_range(1..=3).min(self.selected_columns.len()); + let num_columns = rng.random_range(1..=3).min(self.selected_columns.len()); let selected_columns: Vec<_> = self .selected_columns .choose_multiple(&mut rng, num_columns) @@ -433,37 +430,37 @@ impl SortFuzzerTestGenerator { let mut order_by_clauses = Vec::new(); for col in selected_columns { let mut clause = col.name.clone(); - if rng.gen_bool(0.5) { - let order = if rng.gen_bool(0.5) { "ASC" } else { "DESC" }; - clause.push_str(&format!(" {}", order)); + if rng.random_bool(0.5) { + let order = if rng.random_bool(0.5) { "ASC" } else { "DESC" }; + clause.push_str(&format!(" {order}")); } - if rng.gen_bool(0.5) { - let nulls = if rng.gen_bool(0.5) { + if rng.random_bool(0.5) { + let nulls = if rng.random_bool(0.5) { "NULLS FIRST" } else { "NULLS LAST" }; - clause.push_str(&format!(" {}", nulls)); + clause.push_str(&format!(" {nulls}")); } order_by_clauses.push(clause); } let dataset_size = self.dataset_state.as_ref().unwrap().dataset_size; - let limit = if rng.gen_bool(0.2) { + let limit = if rng.random_bool(0.2) { // Prefer edge cases for k like 1, dataset_size, etc. - Some(if rng.gen_bool(0.5) { + Some(if rng.random_bool(0.5) { let edge_cases = [1, 2, 3, dataset_size - 1, dataset_size, dataset_size + 1]; *edge_cases.choose(&mut rng).unwrap() } else { - rng.gen_range(1..=dataset_size) + rng.random_range(1..=dataset_size) }) } else { None }; - let limit_clause = limit.map_or(String::new(), |l| format!(" LIMIT {}", l)); + let limit_clause = limit.map_or(String::new(), |l| format!(" LIMIT {l}")); let query = format!( "SELECT * FROM {} ORDER BY {}{}", @@ -487,12 +484,12 @@ impl SortFuzzerTestGenerator { // 30% to 200% of the dataset size (if `with_memory_limit` is false, config // will use the default unbounded pool to override it later) - let memory_limit = rng.gen_range( + let memory_limit = rng.random_range( (dataset_size as f64 * 0.5) as usize..=(dataset_size as f64 * 2.0) as usize, ); // 10% to 20% of the per-partition memory limit size let per_partition_mem_limit = memory_limit / num_partitions; - let sort_spill_reservation_bytes = rng.gen_range( + let sort_spill_reservation_bytes = rng.random_range( (per_partition_mem_limit as f64 * 0.2) as usize ..=(per_partition_mem_limit as f64 * 0.3) as usize, ); @@ -505,7 +502,7 @@ impl SortFuzzerTestGenerator { 0 } else { let dataset_size = self.dataset_state.as_ref().unwrap().dataset_size; - rng.gen_range(0..=dataset_size * 2_usize) + rng.random_range(0..=dataset_size * 2_usize) }; // Set up strings for printing @@ -522,13 +519,10 @@ impl SortFuzzerTestGenerator { println!(" Config: "); println!(" Dataset size: {}", human_readable_size(dataset_size)); - println!(" Number of partitions: {}", num_partitions); + println!(" Number of partitions: {num_partitions}"); println!(" Batch size: {}", init_state.approx_batch_num_rows / 2); - println!(" Memory limit: {}", memory_limit_str); - println!( - " Per partition memory limit: {}", - per_partition_limit_str - ); + println!(" Memory limit: {memory_limit_str}"); + println!(" Per partition memory limit: {per_partition_limit_str}"); println!( " Sort spill reservation bytes: {}", human_readable_size(sort_spill_reservation_bytes) @@ -552,7 +546,7 @@ impl SortFuzzerTestGenerator { let runtime = RuntimeEnvBuilder::new() .with_memory_pool(memory_pool) - .with_disk_manager(DiskManagerConfig::NewOs) + .with_disk_manager_builder(DiskManagerBuilder::default()) .build_arc()?; let ctx = SessionContext::new_with_config_rt(config, runtime); @@ -575,7 +569,7 @@ impl SortFuzzerTestGenerator { self.init_partitioned_staggered_batches(dataset_seed); let (query_str, limit) = self.generate_random_query(query_seed); println!(" Query:"); - println!(" {}", query_str); + println!(" {query_str}"); // ==== Execute the query ==== diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs b/datafusion/core/tests/fuzz_cases/window_fuzz.rs index 6b166dd32782..5bd2e457b42a 100644 --- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs @@ -35,7 +35,7 @@ use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_common::HashMap; use datafusion_common::{Result, ScalarValue}; use datafusion_common_runtime::SpawnedTask; -use datafusion_expr::type_coercion::functions::data_types_with_aggregate_udf; +use datafusion_expr::type_coercion::functions::fields_with_aggregate_udf; use datafusion_expr::{ WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; @@ -51,7 +51,7 @@ use datafusion_physical_expr::expressions::{cast, col, lit}; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; use datafusion_physical_expr_common::sort_expr::LexOrdering; -use rand::distributions::Alphanumeric; +use rand::distr::Alphanumeric; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; use test_utils::add_empty_batches; @@ -398,8 +398,8 @@ fn get_random_function( WindowFunctionDefinition::WindowUDF(lead_udwf()), vec![ arg.clone(), - lit(ScalarValue::Int64(Some(rng.gen_range(1..10)))), - lit(ScalarValue::Int64(Some(rng.gen_range(1..1000)))), + lit(ScalarValue::Int64(Some(rng.random_range(1..10)))), + lit(ScalarValue::Int64(Some(rng.random_range(1..1000)))), ], ), ); @@ -409,8 +409,8 @@ fn get_random_function( WindowFunctionDefinition::WindowUDF(lag_udwf()), vec![ arg.clone(), - lit(ScalarValue::Int64(Some(rng.gen_range(1..10)))), - lit(ScalarValue::Int64(Some(rng.gen_range(1..1000)))), + lit(ScalarValue::Int64(Some(rng.random_range(1..10)))), + lit(ScalarValue::Int64(Some(rng.random_range(1..1000)))), ], ), ); @@ -435,12 +435,12 @@ fn get_random_function( WindowFunctionDefinition::WindowUDF(nth_value_udwf()), vec![ arg.clone(), - lit(ScalarValue::Int64(Some(rng.gen_range(1..10)))), + lit(ScalarValue::Int64(Some(rng.random_range(1..10)))), ], ), ); - let rand_fn_idx = rng.gen_range(0..window_fn_map.len()); + let rand_fn_idx = rng.random_range(0..window_fn_map.len()); let fn_name = window_fn_map.keys().collect::>()[rand_fn_idx]; let (window_fn, args) = window_fn_map.values().collect::>()[rand_fn_idx]; let mut args = args.clone(); @@ -448,9 +448,9 @@ fn get_random_function( if !args.is_empty() { // Do type coercion first argument let a = args[0].clone(); - let dt = a.data_type(schema.as_ref()).unwrap(); - let coerced = data_types_with_aggregate_udf(&[dt], udf).unwrap(); - args[0] = cast(a, schema, coerced[0].clone()).unwrap(); + let dt = a.return_field(schema.as_ref()).unwrap(); + let coerced = fields_with_aggregate_udf(&[dt], udf).unwrap(); + args[0] = cast(a, schema, coerced[0].data_type().clone()).unwrap(); } } @@ -463,12 +463,12 @@ fn get_random_window_frame(rng: &mut StdRng, is_linear: bool) -> WindowFrame { is_preceding: bool, } let first_bound = Utils { - val: rng.gen_range(0..10), - is_preceding: rng.gen_range(0..2) == 0, + val: rng.random_range(0..10), + is_preceding: rng.random_range(0..2) == 0, }; let second_bound = Utils { - val: rng.gen_range(0..10), - is_preceding: rng.gen_range(0..2) == 0, + val: rng.random_range(0..10), + is_preceding: rng.random_range(0..2) == 0, }; let (start_bound, end_bound) = if first_bound.is_preceding == second_bound.is_preceding { @@ -485,7 +485,7 @@ fn get_random_window_frame(rng: &mut StdRng, is_linear: bool) -> WindowFrame { (second_bound, first_bound) }; // 0 means Range, 1 means Rows, 2 means GROUPS - let rand_num = rng.gen_range(0..3); + let rand_num = rng.random_range(0..3); let units = if rand_num < 1 { WindowFrameUnits::Range } else if rand_num < 2 { @@ -517,7 +517,7 @@ fn get_random_window_frame(rng: &mut StdRng, is_linear: bool) -> WindowFrame { }; let mut window_frame = WindowFrame::new_bounds(units, start_bound, end_bound); // with 10% use unbounded preceding in tests - if rng.gen_range(0..10) == 0 { + if rng.random_range(0..10) == 0 { window_frame.start_bound = WindowFrameBound::Preceding(ScalarValue::Int32(None)); } @@ -545,7 +545,7 @@ fn get_random_window_frame(rng: &mut StdRng, is_linear: bool) -> WindowFrame { }; let mut window_frame = WindowFrame::new_bounds(units, start_bound, end_bound); // with 10% use unbounded preceding in tests - if rng.gen_range(0..10) == 0 { + if rng.random_range(0..10) == 0 { window_frame.start_bound = WindowFrameBound::Preceding(ScalarValue::UInt64(None)); } @@ -569,7 +569,7 @@ fn convert_bound_to_current_row_if_applicable( match bound { WindowFrameBound::Preceding(value) | WindowFrameBound::Following(value) => { if let Ok(zero) = ScalarValue::new_zero(&value.data_type()) { - if value == &zero && rng.gen_range(0..2) == 0 { + if value == &zero && rng.random_range(0..2) == 0 { *bound = WindowFrameBound::CurrentRow; } } @@ -728,7 +728,7 @@ async fn run_window_test( for (line1, line2) in usual_formatted_sorted.iter().zip(running_formatted_sorted) { - println!("{:?} --- {:?}", line1, line2); + println!("{line1:?} --- {line2:?}"); } unreachable!(); } @@ -758,9 +758,9 @@ pub(crate) fn make_staggered_batches( let mut input5: Vec = vec!["".to_string(); len]; input123.iter_mut().for_each(|v| { *v = ( - rng.gen_range(0..n_distinct) as i32, - rng.gen_range(0..n_distinct) as i32, - rng.gen_range(0..n_distinct) as i32, + rng.random_range(0..n_distinct) as i32, + rng.random_range(0..n_distinct) as i32, + rng.random_range(0..n_distinct) as i32, ) }); input123.sort(); @@ -788,7 +788,7 @@ pub(crate) fn make_staggered_batches( let mut batches = vec![]; if STREAM { while remainder.num_rows() > 0 { - let batch_size = rng.gen_range(0..50); + let batch_size = rng.random_range(0..50); if remainder.num_rows() < batch_size { batches.push(remainder); break; @@ -798,7 +798,7 @@ pub(crate) fn make_staggered_batches( } } else { while remainder.num_rows() > 0 { - let batch_size = rng.gen_range(0..remainder.num_rows() + 1); + let batch_size = rng.random_range(0..remainder.num_rows() + 1); batches.push(remainder.slice(0, batch_size)); remainder = remainder.slice(batch_size, remainder.num_rows() - batch_size); } diff --git a/datafusion/core/tests/integration_tests/schema_adapter_integration_tests.rs b/datafusion/core/tests/integration_tests/schema_adapter_integration_tests.rs new file mode 100644 index 000000000000..833af04680db --- /dev/null +++ b/datafusion/core/tests/integration_tests/schema_adapter_integration_tests.rs @@ -0,0 +1,445 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Integration test for schema adapter factory functionality + +use std::any::Any; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use arrow::record_batch::RecordBatch; +use datafusion::datasource::object_store::ObjectStoreUrl; +use datafusion::datasource::physical_plan::arrow_file::ArrowSource; +use datafusion::prelude::*; +use datafusion_common::Result; +use datafusion_datasource::file::FileSource; +use datafusion_datasource::file_scan_config::FileScanConfigBuilder; +use datafusion_datasource::schema_adapter::{SchemaAdapter, SchemaAdapterFactory}; +use datafusion_datasource::source::DataSourceExec; +use datafusion_datasource::PartitionedFile; +use std::sync::Arc; +use tempfile::TempDir; + +#[cfg(feature = "parquet")] +use datafusion_datasource_parquet::ParquetSource; +#[cfg(feature = "parquet")] +use parquet::arrow::ArrowWriter; +#[cfg(feature = "parquet")] +use parquet::file::properties::WriterProperties; + +#[cfg(feature = "csv")] +use datafusion_datasource_csv::CsvSource; + +/// A schema adapter factory that transforms column names to uppercase +#[derive(Debug)] +struct UppercaseAdapterFactory {} + +impl SchemaAdapterFactory for UppercaseAdapterFactory { + fn create(&self, schema: &Schema) -> Result> { + Ok(Box::new(UppercaseAdapter { + input_schema: Arc::new(schema.clone()), + })) + } +} + +/// Schema adapter that transforms column names to uppercase +#[derive(Debug)] +struct UppercaseAdapter { + input_schema: SchemaRef, +} + +impl SchemaAdapter for UppercaseAdapter { + fn adapt(&self, record_batch: RecordBatch) -> Result { + // In a real adapter, we might transform the data too + // For this test, we're just passing through the batch + Ok(record_batch) + } + + fn output_schema(&self) -> SchemaRef { + let fields = self + .input_schema + .fields() + .iter() + .map(|f| { + Field::new( + f.name().to_uppercase().as_str(), + f.data_type().clone(), + f.is_nullable(), + ) + }) + .collect(); + + Arc::new(Schema::new(fields)) + } +} + +#[cfg(feature = "parquet")] +#[tokio::test] +async fn test_parquet_integration_with_schema_adapter() -> Result<()> { + // Create a temporary directory for our test file + let tmp_dir = TempDir::new()?; + let file_path = tmp_dir.path().join("test.parquet"); + let file_path_str = file_path.to_str().unwrap(); + + // Create test data + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, true), + ])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(arrow::array::Int32Array::from(vec![1, 2, 3])), + Arc::new(arrow::array::StringArray::from(vec!["a", "b", "c"])), + ], + )?; + + // Write test parquet file + let file = std::fs::File::create(file_path_str)?; + let props = WriterProperties::builder().build(); + let mut writer = ArrowWriter::try_new(file, schema.clone(), Some(props))?; + writer.write(&batch)?; + writer.close()?; + + // Create a session context + let ctx = SessionContext::new(); + + // Create a ParquetSource with the adapter factory + let source = ParquetSource::default() + .with_schema_adapter_factory(Arc::new(UppercaseAdapterFactory {})); + + // Create a scan config + let config = FileScanConfigBuilder::new( + ObjectStoreUrl::parse(&format!("file://{}", file_path_str))?, + schema.clone(), + ) + .with_source(source) + .build(); + + // Create a data source executor + let exec = DataSourceExec::from_data_source(config); + + // Collect results + let task_ctx = ctx.task_ctx(); + let stream = exec.execute(0, task_ctx)?; + let batches = datafusion::physical_plan::common::collect(stream).await?; + + // There should be one batch + assert_eq!(batches.len(), 1); + + // Verify the schema has uppercase column names + let result_schema = batches[0].schema(); + assert_eq!(result_schema.field(0).name(), "ID"); + assert_eq!(result_schema.field(1).name(), "NAME"); + + Ok(()) +} + +#[tokio::test] +async fn test_multi_source_schema_adapter_reuse() -> Result<()> { + // This test verifies that the same schema adapter factory can be reused + // across different file source types. This is important for ensuring that: + // 1. The schema adapter factory interface works uniformly across all source types + // 2. The factory can be shared and cloned efficiently using Arc + // 3. Various data source implementations correctly implement the schema adapter factory pattern + + // Create a test factory + let factory = Arc::new(UppercaseAdapterFactory {}); + + // Apply the same adapter to different source types + let arrow_source = + ArrowSource::default().with_schema_adapter_factory(factory.clone()); + + #[cfg(feature = "parquet")] + let parquet_source = + ParquetSource::default().with_schema_adapter_factory(factory.clone()); + + #[cfg(feature = "csv")] + let csv_source = CsvSource::default().with_schema_adapter_factory(factory.clone()); + + // Verify adapters were properly set + assert!(arrow_source.schema_adapter_factory().is_some()); + + #[cfg(feature = "parquet")] + assert!(parquet_source.schema_adapter_factory().is_some()); + + #[cfg(feature = "csv")] + assert!(csv_source.schema_adapter_factory().is_some()); + + Ok(()) +} + +// Helper function to test From for Arc implementations +fn test_from_impl> + Default>(expected_file_type: &str) { + let source = T::default(); + let file_source: Arc = source.into(); + assert_eq!(file_source.file_type(), expected_file_type); +} + +#[test] +fn test_from_implementations() { + // Test From implementation for various sources + test_from_impl::("arrow"); + + #[cfg(feature = "parquet")] + test_from_impl::("parquet"); + + #[cfg(feature = "csv")] + test_from_impl::("csv"); + + #[cfg(feature = "json")] + test_from_impl::("json"); +} + +/// A simple test schema adapter factory that doesn't modify the schema +#[derive(Debug)] +struct TestSchemaAdapterFactory {} + +impl SchemaAdapterFactory for TestSchemaAdapterFactory { + fn create(&self, schema: &Schema) -> Result> { + Ok(Box::new(TestSchemaAdapter { + input_schema: Arc::new(schema.clone()), + })) + } +} + +/// A test schema adapter that passes through data unmodified +#[derive(Debug)] +struct TestSchemaAdapter { + input_schema: SchemaRef, +} + +impl SchemaAdapter for TestSchemaAdapter { + fn adapt(&self, record_batch: RecordBatch) -> Result { + // Just pass through the batch unmodified + Ok(record_batch) + } + + fn output_schema(&self) -> SchemaRef { + self.input_schema.clone() + } +} + +#[cfg(feature = "parquet")] +#[test] +fn test_schema_adapter_preservation() { + // Create a test schema + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, true), + ])); + + // Create source with schema adapter factory + let source = ParquetSource::default(); + let factory = Arc::new(TestSchemaAdapterFactory {}); + let file_source = source.with_schema_adapter_factory(factory); + + // Create a FileScanConfig with the source + let config_builder = + FileScanConfigBuilder::new(ObjectStoreUrl::local_filesystem(), schema.clone()) + .with_source(file_source.clone()) + // Add a file to make it valid + .with_file(PartitionedFile::new("test.parquet", 100)); + + let config = config_builder.build(); + + // Verify the schema adapter factory is present in the file source + assert!(config.source().schema_adapter_factory().is_some()); +} + + +/// A test source for testing schema adapters +#[derive(Debug, Clone)] +struct TestSource { + schema_adapter_factory: Option>, +} + +impl TestSource { + fn new() -> Self { + Self { + schema_adapter_factory: None, + } + } +} + +impl FileSource for TestSource { + fn file_type(&self) -> &str { + "test" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn create_file_opener( + &self, + _store: Arc, + _conf: &FileScanConfig, + _index: usize, + ) -> Arc { + unimplemented!("Not needed for this test") + } + + fn with_batch_size(&self, _batch_size: usize) -> Arc { + Arc::new(self.clone()) + } + + fn with_schema(&self, _schema: SchemaRef) -> Arc { + Arc::new(self.clone()) + } + + fn with_projection(&self, _projection: &FileScanConfig) -> Arc { + Arc::new(self.clone()) + } + + fn with_statistics(&self, _statistics: Statistics) -> Arc { + Arc::new(self.clone()) + } + + fn metrics(&self) -> &ExecutionPlanMetricsSet { + unimplemented!("Not needed for this test") + } + + fn statistics(&self) -> Result { + Ok(Statistics::default()) + } + + fn with_schema_adapter_factory( + &self, + schema_adapter_factory: Arc, + ) -> Result> { + Ok(Arc::new(Self { + schema_adapter_factory: Some(schema_adapter_factory), + })) + } + + fn schema_adapter_factory(&self) -> Option> { + self.schema_adapter_factory.clone() + } +} + +/// A test schema adapter factory +#[derive(Debug)] +struct TestSchemaAdapterFactory {} + +impl SchemaAdapterFactory for TestSchemaAdapterFactory { + fn create( + &self, + projected_table_schema: SchemaRef, + _table_schema: SchemaRef, + ) -> Box { + Box::new(TestSchemaAdapter { + table_schema: projected_table_schema, + }) + } +} + +/// A test schema adapter implementation +#[derive(Debug)] +struct TestSchemaAdapter { + table_schema: SchemaRef, +} + +impl SchemaAdapter for TestSchemaAdapter { + fn map_column_index(&self, index: usize, file_schema: &Schema) -> Option { + let field = self.table_schema.field(index); + file_schema.fields.find(field.name()).map(|(i, _)| i) + } + + fn map_schema( + &self, + file_schema: &Schema, + ) -> Result<(Arc, Vec)> { + let mut projection = Vec::with_capacity(file_schema.fields().len()); + for (file_idx, file_field) in file_schema.fields().iter().enumerate() { + if self.table_schema.fields().find(file_field.name()).is_some() { + projection.push(file_idx); + } + } + + Ok((Arc::new(TestSchemaMapping {}), projection)) + } +} + +/// A test schema mapper implementation +#[derive(Debug)] +struct TestSchemaMapping {} + +impl SchemaMapper for TestSchemaMapping { + fn map_batch(&self, batch: RecordBatch) -> Result { + // For testing, just return the original batch + Ok(batch) + } + + fn map_column_statistics( + &self, + stats: &[ColumnStatistics], + ) -> Result> { + // For testing, just return the input statistics + Ok(stats.to_vec()) + } +} + +#[test] +fn test_schema_adapter() { + // This test verifies the functionality of the SchemaAdapter and SchemaAdapterFactory + // components used in DataFusion's file sources. + // + // The test specifically checks: + // 1. Creating and attaching a schema adapter factory to a file source + // 2. Creating a schema adapter using the factory + // 3. The schema adapter's ability to map column indices between a table schema and a file schema + // 4. The schema adapter's ability to create a projection that selects only the columns + // from the file schema that are present in the table schema + // + // Schema adapters are used when the schema of data in files doesn't exactly match + // the schema expected by the query engine, allowing for field mapping and data transformation. + + // Create a test schema + let table_schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, true), + ])); + + // Create a file schema + let file_schema = Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, true), + Field::new("extra", DataType::Int64, true), + ]); + + // Create a TestSource + let source = TestSource::new(); + assert!(source.schema_adapter_factory().is_none()); + + // Add a schema adapter factory + let factory = Arc::new(TestSchemaAdapterFactory {}); + let source_with_adapter = source.with_schema_adapter_factory(factory).unwrap(); + assert!(source_with_adapter.schema_adapter_factory().is_some()); + + // Create a schema adapter + let adapter_factory = source_with_adapter.schema_adapter_factory().unwrap(); + let adapter = + adapter_factory.create(Arc::clone(&table_schema), Arc::clone(&table_schema)); + + // Test mapping column index + assert_eq!(adapter.map_column_index(0, &file_schema), Some(0)); + assert_eq!(adapter.map_column_index(1, &file_schema), Some(1)); + + // Test creating schema mapper + let (_mapper, projection) = adapter.map_schema(&file_schema).unwrap(); + assert_eq!(projection, vec![0, 1]); +} diff --git a/datafusion/core/tests/memory_limit/mod.rs b/datafusion/core/tests/memory_limit/mod.rs index 01342d1604fc..7695cc0969d8 100644 --- a/datafusion/core/tests/memory_limit/mod.rs +++ b/datafusion/core/tests/memory_limit/mod.rs @@ -31,7 +31,6 @@ use datafusion::assert_batches_eq; use datafusion::datasource::memory::MemorySourceConfig; use datafusion::datasource::source::DataSourceExec; use datafusion::datasource::{MemTable, TableProvider}; -use datafusion::execution::disk_manager::DiskManagerConfig; use datafusion::execution::runtime_env::RuntimeEnvBuilder; use datafusion::execution::session_state::SessionStateBuilder; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; @@ -41,11 +40,12 @@ use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_catalog::streaming::StreamingTable; use datafusion_catalog::Session; use datafusion_common::{assert_contains, Result}; +use datafusion_execution::disk_manager::{DiskManagerBuilder, DiskManagerMode}; use datafusion_execution::memory_pool::{ FairSpillPool, GreedyMemoryPool, MemoryPool, TrackConsumersPool, }; use datafusion_execution::runtime_env::RuntimeEnv; -use datafusion_execution::{DiskManager, TaskContext}; +use datafusion_execution::TaskContext; use datafusion_expr::{Expr, TableType}; use datafusion_physical_expr::{LexOrdering, PhysicalSortExpr}; use datafusion_physical_optimizer::join_selection::JoinSelection; @@ -84,7 +84,7 @@ async fn group_by_none() { TestCase::new() .with_query("select median(request_bytes) from t") .with_expected_errors(vec![ - "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: AggregateStream" + "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as:\n AggregateStream" ]) .with_memory_limit(2_000) .run() @@ -96,7 +96,7 @@ async fn group_by_row_hash() { TestCase::new() .with_query("select count(*) from t GROUP BY response_bytes") .with_expected_errors(vec![ - "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: GroupedHashAggregateStream" + "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as:\n GroupedHashAggregateStream" ]) .with_memory_limit(2_000) .run() @@ -109,7 +109,7 @@ async fn group_by_hash() { // group by dict column .with_query("select count(*) from t GROUP BY service, host, pod, container") .with_expected_errors(vec![ - "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: GroupedHashAggregateStream" + "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as:\n GroupedHashAggregateStream" ]) .with_memory_limit(1_000) .run() @@ -122,7 +122,7 @@ async fn join_by_key_multiple_partitions() { TestCase::new() .with_query("select t1.* from t t1 JOIN t t2 ON t1.service = t2.service") .with_expected_errors(vec![ - "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: HashJoinInput", + "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as:\n HashJoinInput", ]) .with_memory_limit(1_000) .with_config(config) @@ -136,7 +136,7 @@ async fn join_by_key_single_partition() { TestCase::new() .with_query("select t1.* from t t1 JOIN t t2 ON t1.service = t2.service") .with_expected_errors(vec![ - "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: HashJoinInput", + "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as:\n HashJoinInput", ]) .with_memory_limit(1_000) .with_config(config) @@ -149,7 +149,7 @@ async fn join_by_expression() { TestCase::new() .with_query("select t1.* from t t1 JOIN t t2 ON t1.service != t2.service") .with_expected_errors(vec![ - "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: NestedLoopJoinLoad[0]", + "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as:\n NestedLoopJoinLoad[0]", ]) .with_memory_limit(1_000) .run() @@ -161,7 +161,7 @@ async fn cross_join() { TestCase::new() .with_query("select t1.*, t2.* from t t1 CROSS JOIN t t2") .with_expected_errors(vec![ - "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: CrossJoinExec", + "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as:\n CrossJoinExec", ]) .with_memory_limit(1_000) .run() @@ -204,7 +204,7 @@ async fn sort_merge_join_spill() { ) .with_memory_limit(1_000) .with_config(config) - .with_disk_manager_config(DiskManagerConfig::NewOs) + .with_disk_manager_builder(DiskManagerBuilder::default()) .with_scenario(Scenario::AccessLogStreaming) .run() .await @@ -217,7 +217,7 @@ async fn symmetric_hash_join() { "select t1.* from t t1 JOIN t t2 ON t1.pod = t2.pod AND t1.time = t2.time", ) .with_expected_errors(vec![ - "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: SymmetricHashJoinStream", + "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as:\n SymmetricHashJoinStream", ]) .with_memory_limit(1_000) .with_scenario(Scenario::AccessLogStreaming) @@ -235,7 +235,7 @@ async fn sort_preserving_merge() { // so only a merge is needed .with_query("select * from t ORDER BY a ASC NULLS LAST, b ASC NULLS LAST LIMIT 10") .with_expected_errors(vec![ - "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: SortPreservingMergeExec", + "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as:\n SortPreservingMergeExec", ]) // provide insufficient memory to merge .with_memory_limit(partition_size / 2) @@ -288,7 +288,7 @@ async fn sort_spill_reservation() { .with_memory_limit(mem_limit) // use a single partition so only a sort is needed .with_scenario(scenario) - .with_disk_manager_config(DiskManagerConfig::NewOs) + .with_disk_manager_builder(DiskManagerBuilder::default()) .with_expected_plan( // It is important that this plan only has a SortExec, not // also merge, so we can ensure the sort could finish @@ -315,7 +315,7 @@ async fn sort_spill_reservation() { test.clone() .with_expected_errors(vec![ "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as:", - "bytes for ExternalSorterMerge", + "B for ExternalSorterMerge", ]) .with_config(config) .run() @@ -344,7 +344,7 @@ async fn oom_recursive_cte() { SELECT * FROM nodes;", ) .with_expected_errors(vec![ - "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: RecursiveQuery", + "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as:\n RecursiveQuery", ]) .with_memory_limit(2_000) .run() @@ -354,7 +354,7 @@ async fn oom_recursive_cte() { #[tokio::test] async fn oom_parquet_sink() { let dir = tempfile::tempdir().unwrap(); - let path = dir.into_path().join("test.parquet"); + let path = dir.path().join("test.parquet"); let _ = File::create(path.clone()).await.unwrap(); TestCase::new() @@ -378,7 +378,7 @@ async fn oom_parquet_sink() { #[tokio::test] async fn oom_with_tracked_consumer_pool() { let dir = tempfile::tempdir().unwrap(); - let path = dir.into_path().join("test.parquet"); + let path = dir.path().join("test.parquet"); let _ = File::create(path.clone()).await.unwrap(); TestCase::new() @@ -396,7 +396,7 @@ async fn oom_with_tracked_consumer_pool() { .with_expected_errors(vec![ "Failed to allocate additional", "for ParquetSink(ArrowColumnWriter)", - "Additional allocation failed with top memory consumers (across reservations) as: ParquetSink(ArrowColumnWriter)" + "Additional allocation failed with top memory consumers (across reservations) as:\n ParquetSink(ArrowColumnWriter)" ]) .with_memory_pool(Arc::new( TrackConsumersPool::new( @@ -408,6 +408,19 @@ async fn oom_with_tracked_consumer_pool() { .await } +#[tokio::test] +async fn oom_grouped_hash_aggregate() { + TestCase::new() + .with_query("SELECT COUNT(*), SUM(request_bytes) FROM t GROUP BY host") + .with_expected_errors(vec![ + "Failed to allocate additional", + "GroupedHashAggregateStream[0] (count(1), sum(t.request_bytes))", + ]) + .with_memory_limit(1_000) + .run() + .await +} + /// For regression case: if spilled `StringViewArray`'s buffer will be referenced by /// other batches which are also need to be spilled, then the spill writer will /// repeatedly write out the same buffer, and after reading back, each batch's size @@ -417,7 +430,7 @@ async fn oom_with_tracked_consumer_pool() { /// If there is memory explosion for spilled record batch, this test will fail. #[tokio::test] async fn test_stringview_external_sort() { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let array_length = 1000; let num_batches = 200; // Batches contain two columns: random 100-byte string, and random i32 @@ -427,7 +440,7 @@ async fn test_stringview_external_sort() { let strings: Vec = (0..array_length) .map(|_| { (0..100) - .map(|_| rng.gen_range(0..=u8::MAX) as char) + .map(|_| rng.random_range(0..=u8::MAX) as char) .collect() }) .collect(); @@ -435,8 +448,9 @@ async fn test_stringview_external_sort() { let string_array = StringViewArray::from(strings); let array_ref: ArrayRef = Arc::new(string_array); - let random_numbers: Vec = - (0..array_length).map(|_| rng.gen_range(0..=1000)).collect(); + let random_numbers: Vec = (0..array_length) + .map(|_| rng.random_range(0..=1000)) + .collect(); let int_array = Int32Array::from(random_numbers); let int_array_ref: ArrayRef = Arc::new(int_array); @@ -458,7 +472,9 @@ async fn test_stringview_external_sort() { .with_memory_pool(Arc::new(FairSpillPool::new(60 * 1024 * 1024))); let runtime = builder.build_arc().unwrap(); - let config = SessionConfig::new().with_sort_spill_reservation_bytes(40 * 1024 * 1024); + let config = SessionConfig::new() + .with_sort_spill_reservation_bytes(40 * 1024 * 1024) + .with_repartition_file_scans(false); let ctx = SessionContext::new_with_config_rt(config, runtime); ctx.register_table("t", Arc::new(table)).unwrap(); @@ -534,11 +550,10 @@ async fn setup_context( disk_limit: u64, memory_pool_limit: usize, ) -> Result { - let disk_manager = DiskManager::try_new(DiskManagerConfig::NewOs)?; - - let disk_manager = Arc::try_unwrap(disk_manager) - .expect("DiskManager should be a single instance") - .with_max_temp_directory_size(disk_limit)?; + let disk_manager = DiskManagerBuilder::default() + .with_mode(DiskManagerMode::OsTmpDirectory) + .with_max_temp_directory_size(disk_limit) + .build()?; let runtime = RuntimeEnvBuilder::new() .with_memory_pool(Arc::new(FairSpillPool::new(memory_pool_limit))) @@ -603,7 +618,7 @@ async fn test_disk_spill_limit_not_reached() -> Result<()> { let spill_count = plan.metrics().unwrap().spill_count().unwrap(); let spilled_bytes = plan.metrics().unwrap().spilled_bytes().unwrap(); - println!("spill count {}, spill bytes {}", spill_count, spilled_bytes); + println!("spill count {spill_count}, spill bytes {spilled_bytes}"); assert!(spill_count > 0); assert!((spilled_bytes as u64) < disk_spill_limit); @@ -627,7 +642,7 @@ struct TestCase { scenario: Scenario, /// How should the disk manager (that allows spilling) be /// configured? Defaults to `Disabled` - disk_manager_config: DiskManagerConfig, + disk_manager_builder: DiskManagerBuilder, /// Expected explain plan, if non-empty expected_plan: Vec, /// Is the plan expected to pass? Defaults to false @@ -643,7 +658,8 @@ impl TestCase { config: SessionConfig::new(), memory_pool: None, scenario: Scenario::AccessLog, - disk_manager_config: DiskManagerConfig::Disabled, + disk_manager_builder: DiskManagerBuilder::default() + .with_mode(DiskManagerMode::Disabled), expected_plan: vec![], expected_success: false, } @@ -700,11 +716,11 @@ impl TestCase { /// Specify if the disk manager should be enabled. If true, /// operators that support it can spill - pub fn with_disk_manager_config( + pub fn with_disk_manager_builder( mut self, - disk_manager_config: DiskManagerConfig, + disk_manager_builder: DiskManagerBuilder, ) -> Self { - self.disk_manager_config = disk_manager_config; + self.disk_manager_builder = disk_manager_builder; self } @@ -723,7 +739,7 @@ impl TestCase { memory_pool, config, scenario, - disk_manager_config, + disk_manager_builder, expected_plan, expected_success, } = self; @@ -732,7 +748,7 @@ impl TestCase { let mut builder = RuntimeEnvBuilder::new() // disk manager setting controls the spilling - .with_disk_manager(disk_manager_config) + .with_disk_manager_builder(disk_manager_builder) .with_memory_limit(memory_limit, MEMORY_FRACTION); if let Some(pool) = memory_pool { diff --git a/datafusion/core/tests/optimizer/mod.rs b/datafusion/core/tests/optimizer/mod.rs index 585540bd5875..2daed4fe36bb 100644 --- a/datafusion/core/tests/optimizer/mod.rs +++ b/datafusion/core/tests/optimizer/mod.rs @@ -116,7 +116,7 @@ fn concat_ws_literals() -> Result<()> { fn quick_test(sql: &str, expected_plan: &str) { let plan = test_sql(sql).unwrap(); - assert_eq!(expected_plan, format!("{}", plan)); + assert_eq!(expected_plan, format!("{plan}")); } fn test_sql(sql: &str) -> Result { @@ -342,8 +342,7 @@ where let expected = lit(ScalarValue::from(expected_value.clone())); assert_eq!( output, expected, - "{} simplified to {}, but expected {}", - expr, output, expected + "{expr} simplified to {output}, but expected {expected}" ); } } @@ -352,8 +351,7 @@ fn validate_unchanged_cases(rewriter: &mut GuaranteeRewriter, cases: &[Expr]) { let output = expr.clone().rewrite(rewriter).data().unwrap(); assert_eq!( &output, expr, - "{} was simplified to {}, but expected it to be unchanged", - expr, output + "{expr} was simplified to {output}, but expected it to be unchanged" ); } } diff --git a/datafusion/core/tests/parquet/file_statistics.rs b/datafusion/core/tests/parquet/file_statistics.rs index 7e98ebed6c9a..a60beaf665e5 100644 --- a/datafusion/core/tests/parquet/file_statistics.rs +++ b/datafusion/core/tests/parquet/file_statistics.rs @@ -28,6 +28,7 @@ use datafusion::execution::context::SessionState; use datafusion::execution::session_state::SessionStateBuilder; use datafusion::prelude::SessionContext; use datafusion_common::stats::Precision; +use datafusion_common::DFSchema; use datafusion_execution::cache::cache_manager::CacheManagerConfig; use datafusion_execution::cache::cache_unit::{ DefaultFileStatisticsCache, DefaultListFilesCache, @@ -37,6 +38,10 @@ use datafusion_execution::runtime_env::RuntimeEnvBuilder; use datafusion_expr::{col, lit, Expr}; use datafusion::datasource::physical_plan::FileScanConfig; +use datafusion_physical_optimizer::filter_pushdown::FilterPushdown; +use datafusion_physical_optimizer::PhysicalOptimizerRule; +use datafusion_physical_plan::filter::FilterExec; +use datafusion_physical_plan::ExecutionPlan; use tempfile::tempdir; #[tokio::test] @@ -45,18 +50,53 @@ async fn check_stats_precision_with_filter_pushdown() { let filename = format!("{}/{}", testdata, "alltypes_plain.parquet"); let table_path = ListingTableUrl::parse(filename).unwrap(); - let opt = ListingOptions::new(Arc::new(ParquetFormat::default())); + let opt = + ListingOptions::new(Arc::new(ParquetFormat::default())).with_collect_stat(true); let table = get_listing_table(&table_path, None, &opt).await; + let (_, _, state) = get_cache_runtime_state(); + let mut options = state.config().options().clone(); + options.execution.parquet.pushdown_filters = true; + // Scan without filter, stats are exact let exec = table.scan(&state, None, &[], None).await.unwrap(); - assert_eq!(exec.statistics().unwrap().num_rows, Precision::Exact(8)); + assert_eq!( + exec.partition_statistics(None).unwrap().num_rows, + Precision::Exact(8), + "Stats without filter should be exact" + ); - // Scan with filter pushdown, stats are inexact - let filter = Expr::gt(col("id"), lit(1)); + // This is a filter that cannot be evaluated by the table provider scanning + // (it is not a partition filter). Therefore; it will be pushed down to the + // source operator after the appropriate optimizer pass. + let filter_expr = Expr::gt(col("id"), lit(1)); + let exec_with_filter = table + .scan(&state, None, &[filter_expr.clone()], None) + .await + .unwrap(); + + let ctx = SessionContext::new(); + let df_schema = DFSchema::try_from(table.schema()).unwrap(); + let physical_filter = ctx.create_physical_expr(filter_expr, &df_schema).unwrap(); - let exec = table.scan(&state, None, &[filter], None).await.unwrap(); - assert_eq!(exec.statistics().unwrap().num_rows, Precision::Inexact(8)); + let filtered_exec = + Arc::new(FilterExec::try_new(physical_filter, exec_with_filter).unwrap()) + as Arc; + + let optimized_exec = FilterPushdown::new() + .optimize(filtered_exec, &options) + .unwrap(); + + assert!( + optimized_exec.as_any().is::(), + "Sanity check that the pushdown did what we expected" + ); + // Scan with filter pushdown, stats are inexact + assert_eq!( + optimized_exec.partition_statistics(None).unwrap().num_rows, + Precision::Inexact(8), + "Stats after filter pushdown should be inexact" + ); } #[tokio::test] @@ -70,7 +110,8 @@ async fn load_table_stats_with_session_level_cache() { // Create a separate DefaultFileStatisticsCache let (cache2, _, state2) = get_cache_runtime_state(); - let opt = ListingOptions::new(Arc::new(ParquetFormat::default())); + let opt = + ListingOptions::new(Arc::new(ParquetFormat::default())).with_collect_stat(true); let table1 = get_listing_table(&table_path, Some(cache1), &opt).await; let table2 = get_listing_table(&table_path, Some(cache2), &opt).await; @@ -79,9 +120,12 @@ async fn load_table_stats_with_session_level_cache() { assert_eq!(get_static_cache_size(&state1), 0); let exec1 = table1.scan(&state1, None, &[], None).await.unwrap(); - assert_eq!(exec1.statistics().unwrap().num_rows, Precision::Exact(8)); assert_eq!( - exec1.statistics().unwrap().total_byte_size, + exec1.partition_statistics(None).unwrap().num_rows, + Precision::Exact(8) + ); + assert_eq!( + exec1.partition_statistics(None).unwrap().total_byte_size, // TODO correct byte size: https://github.com/apache/datafusion/issues/14936 Precision::Exact(671), ); @@ -91,9 +135,12 @@ async fn load_table_stats_with_session_level_cache() { //check session 1 cache result not show in session 2 assert_eq!(get_static_cache_size(&state2), 0); let exec2 = table2.scan(&state2, None, &[], None).await.unwrap(); - assert_eq!(exec2.statistics().unwrap().num_rows, Precision::Exact(8)); assert_eq!( - exec2.statistics().unwrap().total_byte_size, + exec2.partition_statistics(None).unwrap().num_rows, + Precision::Exact(8) + ); + assert_eq!( + exec2.partition_statistics(None).unwrap().total_byte_size, // TODO correct byte size: https://github.com/apache/datafusion/issues/14936 Precision::Exact(671), ); @@ -103,9 +150,12 @@ async fn load_table_stats_with_session_level_cache() { //check session 1 cache result not show in session 2 assert_eq!(get_static_cache_size(&state1), 1); let exec3 = table1.scan(&state1, None, &[], None).await.unwrap(); - assert_eq!(exec3.statistics().unwrap().num_rows, Precision::Exact(8)); assert_eq!( - exec3.statistics().unwrap().total_byte_size, + exec3.partition_statistics(None).unwrap().num_rows, + Precision::Exact(8) + ); + assert_eq!( + exec3.partition_statistics(None).unwrap().total_byte_size, // TODO correct byte size: https://github.com/apache/datafusion/issues/14936 Precision::Exact(671), ); @@ -117,23 +167,15 @@ async fn load_table_stats_with_session_level_cache() { async fn list_files_with_session_level_cache() { let p_name = "alltypes_plain.parquet"; let testdata = datafusion::test_util::parquet_test_data(); - let filename = format!("{}/{}", testdata, p_name); + let filename = format!("{testdata}/{p_name}"); - let temp_path1 = tempdir() - .unwrap() - .into_path() - .into_os_string() - .into_string() - .unwrap(); - let temp_filename1 = format!("{}/{}", temp_path1, p_name); + let temp_dir1 = tempdir().unwrap(); + let temp_path1 = temp_dir1.path().to_str().unwrap(); + let temp_filename1 = format!("{temp_path1}/{p_name}"); - let temp_path2 = tempdir() - .unwrap() - .into_path() - .into_os_string() - .into_string() - .unwrap(); - let temp_filename2 = format!("{}/{}", temp_path2, p_name); + let temp_dir2 = tempdir().unwrap(); + let temp_path2 = temp_dir2.path().to_str().unwrap(); + let temp_filename2 = format!("{temp_path2}/{p_name}"); fs::copy(filename.clone(), temp_filename1).expect("panic"); fs::copy(filename, temp_filename2).expect("panic"); diff --git a/datafusion/core/tests/parquet/filter_pushdown.rs b/datafusion/core/tests/parquet/filter_pushdown.rs index 02fb59740493..b8d570916c7c 100644 --- a/datafusion/core/tests/parquet/filter_pushdown.rs +++ b/datafusion/core/tests/parquet/filter_pushdown.rs @@ -32,50 +32,45 @@ use arrow::compute::concat_batches; use arrow::record_batch::RecordBatch; use datafusion::physical_plan::collect; use datafusion::physical_plan::metrics::MetricsSet; -use datafusion::prelude::{col, lit, lit_timestamp_nano, Expr, SessionContext}; +use datafusion::prelude::{ + col, lit, lit_timestamp_nano, Expr, ParquetReadOptions, SessionContext, +}; use datafusion::test_util::parquet::{ParquetScanOptions, TestParquetFile}; -use datafusion_common::instant::Instant; use datafusion_expr::utils::{conjunction, disjunction, split_conjunction}; use itertools::Itertools; use parquet::file::properties::WriterProperties; use tempfile::TempDir; -use test_utils::AccessLogGenerator; /// how many rows of generated data to write to our parquet file (arbitrary) const NUM_ROWS: usize = 4096; -fn generate_file(tempdir: &TempDir, props: WriterProperties) -> TestParquetFile { - // Tune down the generator for smaller files - let generator = AccessLogGenerator::new() - .with_row_limit(NUM_ROWS) - .with_pods_per_host(1..4) - .with_containers_per_pod(1..2) - .with_entries_per_container(128..256); - - let file = tempdir.path().join("data.parquet"); - - let start = Instant::now(); - println!("Writing test data to {file:?}"); - let test_parquet_file = TestParquetFile::try_new(file, props, generator).unwrap(); - println!( - "Completed generating test data in {:?}", - Instant::now() - start - ); - test_parquet_file +async fn read_parquet_test_data>(path: T) -> Vec { + let ctx: SessionContext = SessionContext::new(); + ctx.read_parquet(path.into(), ParquetReadOptions::default()) + .await + .unwrap() + .collect() + .await + .unwrap() } #[tokio::test] async fn single_file() { - // Only create the parquet file once as it is fairly large + let batches = + read_parquet_test_data("tests/data/filter_pushdown/single_file.gz.parquet").await; - let tempdir = TempDir::new_in(Path::new(".")).unwrap(); - // Set row group size smaller so can test with fewer rows + // Set the row group size smaller so can test with fewer rows let props = WriterProperties::builder() .set_max_row_group_size(1024) .build(); - let test_parquet_file = generate_file(&tempdir, props); + // Only create the parquet file once as it is fairly large + let tempdir = TempDir::new_in(Path::new(".")).unwrap(); + + let test_parquet_file = + TestParquetFile::try_new(tempdir.path().join("data.parquet"), props, batches) + .unwrap(); let case = TestCase::new(&test_parquet_file) .with_name("selective") // request_method = 'GET' @@ -224,16 +219,25 @@ async fn single_file() { } #[tokio::test] +#[allow(dead_code)] async fn single_file_small_data_pages() { + let batches = read_parquet_test_data( + "tests/data/filter_pushdown/single_file_small_pages.gz.parquet", + ) + .await; + let tempdir = TempDir::new_in(Path::new(".")).unwrap(); - // Set low row count limit to improve page filtering + // Set a low row count limit to improve page filtering let props = WriterProperties::builder() .set_max_row_group_size(2048) .set_data_page_row_count_limit(512) .set_write_batch_size(512) .build(); - let test_parquet_file = generate_file(&tempdir, props); + + let test_parquet_file = + TestParquetFile::try_new(tempdir.path().join("data.parquet"), props, batches) + .unwrap(); // The statistics on the 'pod' column are as follows: // diff --git a/datafusion/core/tests/physical_optimizer/enforce_distribution.rs b/datafusion/core/tests/physical_optimizer/enforce_distribution.rs index 9898f6204e88..4034800c30cb 100644 --- a/datafusion/core/tests/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/tests/physical_optimizer/enforce_distribution.rs @@ -19,19 +19,25 @@ use std::fmt::Debug; use std::ops::Deref; use std::sync::Arc; -use crate::physical_optimizer::test_utils::parquet_exec_with_sort; use crate::physical_optimizer::test_utils::{ check_integrity, coalesce_partitions_exec, repartition_exec, schema, sort_merge_join_exec, sort_preserving_merge_exec, }; +use crate::physical_optimizer::test_utils::{ + parquet_exec_with_sort, parquet_exec_with_stats, +}; +use arrow::array::{RecordBatch, UInt64Array, UInt8Array}; use arrow::compute::SortOptions; +use arrow_schema::{DataType, Field, Schema, SchemaRef}; use datafusion::config::ConfigOptions; use datafusion::datasource::file_format::file_compression_type::FileCompressionType; use datafusion::datasource::listing::PartitionedFile; use datafusion::datasource::object_store::ObjectStoreUrl; use datafusion::datasource::physical_plan::{CsvSource, ParquetSource}; use datafusion::datasource::source::DataSourceExec; +use datafusion::datasource::MemTable; +use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_common::error::Result; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::ScalarValue; @@ -52,6 +58,7 @@ use datafusion_physical_plan::aggregates::{ AggregateExec, AggregateMode, PhysicalGroupBy, }; use datafusion_physical_plan::coalesce_batches::CoalesceBatchesExec; +use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion_physical_plan::execution_plan::ExecutionPlan; use datafusion_physical_plan::expressions::col; use datafusion_physical_plan::filter::FilterExec; @@ -169,7 +176,7 @@ impl ExecutionPlan for SortRequiredExec { } fn statistics(&self) -> Result { - self.input.statistics() + self.input.partition_statistics(None) } } @@ -524,8 +531,7 @@ impl TestConfig { assert_eq!( &expected_lines, &actual_lines, - "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", - expected_lines, actual_lines + "\n\nexpected:\n\n{expected_lines:#?}\nactual:\n\n{actual_lines:#?}\n\n" ); Ok(optimized) @@ -3471,3 +3477,147 @@ fn optimize_away_unnecessary_repartition2() -> Result<()> { Ok(()) } + +/// Ensures that `DataSourceExec` has been repartitioned into `target_partitions` file groups +#[tokio::test] +async fn test_distribute_sort_parquet() -> Result<()> { + let test_config: TestConfig = + TestConfig::default().with_prefer_repartition_file_scans(1000); + assert!( + test_config.config.optimizer.repartition_file_scans, + "should enable scans to be repartitioned" + ); + + let schema = schema(); + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { + expr: col("c", &schema).unwrap(), + options: SortOptions::default(), + }]); + let physical_plan = sort_exec(sort_key, parquet_exec_with_stats(10000 * 8192), false); + + // prior to optimization, this is the starting plan + let starting = &[ + "SortExec: expr=[c@2 ASC], preserve_partitioning=[false]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + ]; + plans_matches_expected!(starting, physical_plan.clone()); + + // what the enforce distribution run does. + let expected = &[ + "SortExec: expr=[c@2 ASC], preserve_partitioning=[false]", + " CoalescePartitionsExec", + " DataSourceExec: file_groups={10 groups: [[x:0..8192000], [x:8192000..16384000], [x:16384000..24576000], [x:24576000..32768000], [x:32768000..40960000], [x:40960000..49152000], [x:49152000..57344000], [x:57344000..65536000], [x:65536000..73728000], [x:73728000..81920000]]}, projection=[a, b, c, d, e], file_type=parquet", + ]; + test_config.run(expected, physical_plan.clone(), &[Run::Distribution])?; + + // what the sort parallelization (in enforce sorting), does after the enforce distribution changes + let expected = &[ + "SortPreservingMergeExec: [c@2 ASC]", + " SortExec: expr=[c@2 ASC], preserve_partitioning=[true]", + " DataSourceExec: file_groups={10 groups: [[x:0..8192000], [x:8192000..16384000], [x:16384000..24576000], [x:24576000..32768000], [x:32768000..40960000], [x:40960000..49152000], [x:49152000..57344000], [x:57344000..65536000], [x:65536000..73728000], [x:73728000..81920000]]}, projection=[a, b, c, d, e], file_type=parquet", + ]; + test_config.run(expected, physical_plan, &[Run::Distribution, Run::Sorting])?; + Ok(()) +} + +/// Ensures that `DataSourceExec` has been repartitioned into `target_partitions` memtable groups +#[tokio::test] +async fn test_distribute_sort_memtable() -> Result<()> { + let test_config: TestConfig = + TestConfig::default().with_prefer_repartition_file_scans(1000); + assert!( + test_config.config.optimizer.repartition_file_scans, + "should enable scans to be repartitioned" + ); + + let mem_table = create_memtable()?; + let session_config = SessionConfig::new() + .with_repartition_file_min_size(1000) + .with_target_partitions(3); + let ctx = SessionContext::new_with_config(session_config); + ctx.register_table("users", Arc::new(mem_table))?; + + let dataframe = ctx.sql("SELECT * FROM users order by id;").await?; + let physical_plan = dataframe.create_physical_plan().await?; + + // this is the final, optimized plan + let expected = &[ + "SortPreservingMergeExec: [id@0 ASC NULLS LAST]", + " SortExec: expr=[id@0 ASC NULLS LAST], preserve_partitioning=[true]", + " DataSourceExec: partitions=3, partition_sizes=[34, 33, 33]", + ]; + plans_matches_expected!(expected, physical_plan); + + Ok(()) +} + +/// Create a [`MemTable`] with 100 batches of 8192 rows each, in 1 partition +fn create_memtable() -> Result { + let mut batches = Vec::with_capacity(100); + for _ in 0..100 { + batches.push(create_record_batch()?); + } + let partitions = vec![batches]; + MemTable::try_new(get_schema(), partitions) +} + +fn create_record_batch() -> Result { + let id_array = UInt8Array::from(vec![1; 8192]); + let account_array = UInt64Array::from(vec![9000; 8192]); + + Ok(RecordBatch::try_new( + get_schema(), + vec![Arc::new(id_array), Arc::new(account_array)], + ) + .unwrap()) +} + +fn get_schema() -> SchemaRef { + SchemaRef::new(Schema::new(vec![ + Field::new("id", DataType::UInt8, false), + Field::new("bank_account", DataType::UInt64, true), + ])) +} +#[test] +fn test_replace_order_preserving_variants_with_fetch() -> Result<()> { + // Create a base plan + let parquet_exec = parquet_exec(); + + let sort_expr = PhysicalSortExpr { + expr: Arc::new(Column::new("id", 0)), + options: SortOptions::default(), + }; + + let ordering = LexOrdering::new(vec![sort_expr]); + + // Create a SortPreservingMergeExec with fetch=5 + let spm_exec = Arc::new( + SortPreservingMergeExec::new(ordering, parquet_exec.clone()).with_fetch(Some(5)), + ); + + // Create distribution context + let dist_context = DistributionContext::new( + spm_exec, + true, + vec![DistributionContext::new(parquet_exec, false, vec![])], + ); + + // Apply the function + let result = replace_order_preserving_variants(dist_context)?; + + // Verify the plan was transformed to CoalescePartitionsExec + result + .plan + .as_any() + .downcast_ref::() + .expect("Expected CoalescePartitionsExec"); + + // Verify fetch was preserved + assert_eq!( + result.plan.fetch(), + Some(5), + "Fetch value was not preserved after transformation" + ); + + Ok(()) +} diff --git a/datafusion/core/tests/physical_optimizer/enforce_sorting.rs b/datafusion/core/tests/physical_optimizer/enforce_sorting.rs index d4b84a52f401..f7668c8aab11 100644 --- a/datafusion/core/tests/physical_optimizer/enforce_sorting.rs +++ b/datafusion/core/tests/physical_optimizer/enforce_sorting.rs @@ -3440,3 +3440,37 @@ fn test_handles_multiple_orthogonal_sorts() -> Result<()> { Ok(()) } + +#[test] +fn test_parallelize_sort_preserves_fetch() -> Result<()> { + // Create a schema + let schema = create_test_schema3()?; + let parquet_exec = parquet_exec(&schema); + let coalesced = Arc::new(CoalescePartitionsExec::new(parquet_exec.clone())); + let top_coalesced = + Arc::new(CoalescePartitionsExec::new(coalesced.clone()).with_fetch(Some(10))); + + let requirements = PlanWithCorrespondingCoalescePartitions::new( + top_coalesced.clone(), + true, + vec![PlanWithCorrespondingCoalescePartitions::new( + coalesced, + true, + vec![PlanWithCorrespondingCoalescePartitions::new( + parquet_exec, + false, + vec![], + )], + )], + ); + + let res = parallelize_sorts(requirements)?; + + // Verify fetch was preserved + assert_eq!( + res.data.plan.fetch(), + Some(10), + "Fetch value was not preserved after transformation" + ); + Ok(()) +} diff --git a/datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs b/datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs new file mode 100644 index 000000000000..a28933d97bcd --- /dev/null +++ b/datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs @@ -0,0 +1,378 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::{Arc, LazyLock}; + +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use datafusion::{ + logical_expr::Operator, + physical_plan::{ + expressions::{BinaryExpr, Column, Literal}, + PhysicalExpr, + }, + scalar::ScalarValue, +}; +use datafusion_common::config::ConfigOptions; +use datafusion_functions_aggregate::count::count_udaf; +use datafusion_physical_expr::expressions::col; +use datafusion_physical_expr::{aggregate::AggregateExprBuilder, Partitioning}; +use datafusion_physical_optimizer::filter_pushdown::FilterPushdown; +use datafusion_physical_plan::{ + aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}, + coalesce_batches::CoalesceBatchesExec, + filter::FilterExec, + repartition::RepartitionExec, +}; + +use util::{OptimizationTest, TestNode, TestScanBuilder}; + +mod util; + +#[test] +fn test_pushdown_into_scan() { + let scan = TestScanBuilder::new(schema()).with_support(true).build(); + let predicate = col_lit_predicate("a", "foo", &schema()); + let plan = Arc::new(FilterExec::try_new(predicate, scan).unwrap()); + + // expect the predicate to be pushed down into the DataSource + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown{}, true), + @r" + OptimizationTest: + input: + - FilterExec: a@0 = foo + - DataSourceExec: file_groups={1 group: [[test.paqruet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - DataSourceExec: file_groups={1 group: [[test.paqruet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo + " + ); +} + +/// Show that we can use config options to determine how to do pushdown. +#[test] +fn test_pushdown_into_scan_with_config_options() { + let scan = TestScanBuilder::new(schema()).with_support(true).build(); + let predicate = col_lit_predicate("a", "foo", &schema()); + let plan = Arc::new(FilterExec::try_new(predicate, scan).unwrap()) as _; + + let mut cfg = ConfigOptions::default(); + insta::assert_snapshot!( + OptimizationTest::new( + Arc::clone(&plan), + FilterPushdown {}, + false + ), + @r" + OptimizationTest: + input: + - FilterExec: a@0 = foo + - DataSourceExec: file_groups={1 group: [[test.paqruet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - FilterExec: a@0 = foo + - DataSourceExec: file_groups={1 group: [[test.paqruet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + " + ); + + cfg.execution.parquet.pushdown_filters = true; + insta::assert_snapshot!( + OptimizationTest::new( + plan, + FilterPushdown {}, + true + ), + @r" + OptimizationTest: + input: + - FilterExec: a@0 = foo + - DataSourceExec: file_groups={1 group: [[test.paqruet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - DataSourceExec: file_groups={1 group: [[test.paqruet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo + " + ); +} + +#[test] +fn test_filter_collapse() { + // filter should be pushed down into the parquet scan with two filters + let scan = TestScanBuilder::new(schema()).with_support(true).build(); + let predicate1 = col_lit_predicate("a", "foo", &schema()); + let filter1 = Arc::new(FilterExec::try_new(predicate1, scan).unwrap()); + let predicate2 = col_lit_predicate("b", "bar", &schema()); + let plan = Arc::new(FilterExec::try_new(predicate2, filter1).unwrap()); + + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown{}, true), + @r" + OptimizationTest: + input: + - FilterExec: b@1 = bar + - FilterExec: a@0 = foo + - DataSourceExec: file_groups={1 group: [[test.paqruet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - DataSourceExec: file_groups={1 group: [[test.paqruet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo AND b@1 = bar + " + ); +} + +#[test] +fn test_filter_with_projection() { + let scan = TestScanBuilder::new(schema()).with_support(true).build(); + let projection = vec![1, 0]; + let predicate = col_lit_predicate("a", "foo", &schema()); + let plan = Arc::new( + FilterExec::try_new(predicate, Arc::clone(&scan)) + .unwrap() + .with_projection(Some(projection)) + .unwrap(), + ); + + // expect the predicate to be pushed down into the DataSource but the FilterExec to be converted to ProjectionExec + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown{}, true), + @r" + OptimizationTest: + input: + - FilterExec: a@0 = foo, projection=[b@1, a@0] + - DataSourceExec: file_groups={1 group: [[test.paqruet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - ProjectionExec: expr=[b@1 as b, a@0 as a] + - DataSourceExec: file_groups={1 group: [[test.paqruet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo + ", + ); + + // add a test where the filter is on a column that isn't included in the output + let projection = vec![1]; + let predicate = col_lit_predicate("a", "foo", &schema()); + let plan = Arc::new( + FilterExec::try_new(predicate, scan) + .unwrap() + .with_projection(Some(projection)) + .unwrap(), + ); + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown{},true), + @r" + OptimizationTest: + input: + - FilterExec: a@0 = foo, projection=[b@1] + - DataSourceExec: file_groups={1 group: [[test.paqruet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - ProjectionExec: expr=[b@1 as b] + - DataSourceExec: file_groups={1 group: [[test.paqruet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo + " + ); +} + +#[test] +fn test_push_down_through_transparent_nodes() { + // expect the predicate to be pushed down into the DataSource + let scan = TestScanBuilder::new(schema()).with_support(true).build(); + let coalesce = Arc::new(CoalesceBatchesExec::new(scan, 1)); + let predicate = col_lit_predicate("a", "foo", &schema()); + let filter = Arc::new(FilterExec::try_new(predicate, coalesce).unwrap()); + let repartition = Arc::new( + RepartitionExec::try_new(filter, Partitioning::RoundRobinBatch(1)).unwrap(), + ); + let predicate = col_lit_predicate("b", "bar", &schema()); + let plan = Arc::new(FilterExec::try_new(predicate, repartition).unwrap()); + + // expect the predicate to be pushed down into the DataSource + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown{},true), + @r" + OptimizationTest: + input: + - FilterExec: b@1 = bar + - RepartitionExec: partitioning=RoundRobinBatch(1), input_partitions=1 + - FilterExec: a@0 = foo + - CoalesceBatchesExec: target_batch_size=1 + - DataSourceExec: file_groups={1 group: [[test.paqruet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - RepartitionExec: partitioning=RoundRobinBatch(1), input_partitions=1 + - CoalesceBatchesExec: target_batch_size=1 + - DataSourceExec: file_groups={1 group: [[test.paqruet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo AND b@1 = bar + " + ); +} + +#[test] +fn test_no_pushdown_through_aggregates() { + // There are 2 important points here: + // 1. The outer filter **is not** pushed down at all because we haven't implemented pushdown support + // yet for AggregateExec. + // 2. The inner filter **is** pushed down into the DataSource. + let scan = TestScanBuilder::new(schema()).with_support(true).build(); + + let coalesce = Arc::new(CoalesceBatchesExec::new(scan, 10)); + + let filter = Arc::new( + FilterExec::try_new(col_lit_predicate("a", "foo", &schema()), coalesce).unwrap(), + ); + + let aggregate_expr = + vec![ + AggregateExprBuilder::new(count_udaf(), vec![col("a", &schema()).unwrap()]) + .schema(schema()) + .alias("cnt") + .build() + .map(Arc::new) + .unwrap(), + ]; + let group_by = PhysicalGroupBy::new_single(vec![ + (col("a", &schema()).unwrap(), "a".to_string()), + (col("b", &schema()).unwrap(), "b".to_string()), + ]); + let aggregate = Arc::new( + AggregateExec::try_new( + AggregateMode::Final, + group_by, + aggregate_expr.clone(), + vec![None], + filter, + schema(), + ) + .unwrap(), + ); + + let coalesce = Arc::new(CoalesceBatchesExec::new(aggregate, 100)); + + let predicate = col_lit_predicate("b", "bar", &schema()); + let plan = Arc::new(FilterExec::try_new(predicate, coalesce).unwrap()); + + // expect the predicate to be pushed down into the DataSource + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown{}, true), + @r" + OptimizationTest: + input: + - FilterExec: b@1 = bar + - CoalesceBatchesExec: target_batch_size=100 + - AggregateExec: mode=Final, gby=[a@0 as a, b@1 as b], aggr=[cnt], ordering_mode=PartiallySorted([0]) + - FilterExec: a@0 = foo + - CoalesceBatchesExec: target_batch_size=10 + - DataSourceExec: file_groups={1 group: [[test.paqruet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - FilterExec: b@1 = bar + - CoalesceBatchesExec: target_batch_size=100 + - AggregateExec: mode=Final, gby=[a@0 as a, b@1 as b], aggr=[cnt] + - CoalesceBatchesExec: target_batch_size=10 + - DataSourceExec: file_groups={1 group: [[test.paqruet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo + " + ); +} + +/// Test various combinations of handling of child pushdown results +/// in an ExectionPlan in combination with support/not support in a DataSource. +#[test] +fn test_node_handles_child_pushdown_result() { + // If we set `with_support(true)` + `inject_filter = true` then the filter is pushed down to the DataSource + // and no FilterExec is created. + let scan = TestScanBuilder::new(schema()).with_support(true).build(); + let predicate = col_lit_predicate("a", "foo", &schema()); + let plan = Arc::new(TestNode::new(true, Arc::clone(&scan), predicate)); + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown{}, true), + @r" + OptimizationTest: + input: + - TestInsertExec { inject_filter: true } + - DataSourceExec: file_groups={1 group: [[test.paqruet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - TestInsertExec { inject_filter: true } + - DataSourceExec: file_groups={1 group: [[test.paqruet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo + ", + ); + + // If we set `with_support(false)` + `inject_filter = true` then the filter is not pushed down to the DataSource + // and a FilterExec is created. + let scan = TestScanBuilder::new(schema()).with_support(false).build(); + let predicate = col_lit_predicate("a", "foo", &schema()); + let plan = Arc::new(TestNode::new(true, Arc::clone(&scan), predicate)); + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown{}, true), + @r" + OptimizationTest: + input: + - TestInsertExec { inject_filter: true } + - DataSourceExec: file_groups={1 group: [[test.paqruet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false + output: + Ok: + - TestInsertExec { inject_filter: false } + - FilterExec: a@0 = foo + - DataSourceExec: file_groups={1 group: [[test.paqruet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false + ", + ); + + // If we set `with_support(false)` + `inject_filter = false` then the filter is not pushed down to the DataSource + // and no FilterExec is created. + let scan = TestScanBuilder::new(schema()).with_support(false).build(); + let predicate = col_lit_predicate("a", "foo", &schema()); + let plan = Arc::new(TestNode::new(false, Arc::clone(&scan), predicate)); + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown{}, true), + @r" + OptimizationTest: + input: + - TestInsertExec { inject_filter: false } + - DataSourceExec: file_groups={1 group: [[test.paqruet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false + output: + Ok: + - TestInsertExec { inject_filter: false } + - DataSourceExec: file_groups={1 group: [[test.paqruet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false + ", + ); +} + +/// Schema: +/// a: String +/// b: String +/// c: f64 +static TEST_SCHEMA: LazyLock = LazyLock::new(|| { + let fields = vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + Field::new("c", DataType::Float64, false), + ]; + Arc::new(Schema::new(fields)) +}); + +fn schema() -> SchemaRef { + Arc::clone(&TEST_SCHEMA) +} + +/// Returns a predicate that is a binary expression col = lit +fn col_lit_predicate( + column_name: &str, + scalar_value: impl Into, + schema: &Schema, +) -> Arc { + let scalar_value = scalar_value.into(); + Arc::new(BinaryExpr::new( + Arc::new(Column::new_with_schema(column_name, schema).unwrap()), + Operator::Eq, + Arc::new(Literal::new(scalar_value)), + )) +} diff --git a/datafusion/core/tests/physical_optimizer/filter_pushdown/util.rs b/datafusion/core/tests/physical_optimizer/filter_pushdown/util.rs new file mode 100644 index 000000000000..87fa70c07a69 --- /dev/null +++ b/datafusion/core/tests/physical_optimizer/filter_pushdown/util.rs @@ -0,0 +1,541 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::SchemaRef; +use arrow::error::ArrowError; +use arrow::{array::RecordBatch, compute::concat_batches}; +use datafusion::{datasource::object_store::ObjectStoreUrl, physical_plan::PhysicalExpr}; +use datafusion_common::{config::ConfigOptions, internal_err, Result, Statistics}; +use datafusion_datasource::{ + file::FileSource, file_meta::FileMeta, file_scan_config::FileScanConfig, + file_scan_config::FileScanConfigBuilder, file_stream::FileOpenFuture, + file_stream::FileOpener, schema_adapter::DefaultSchemaAdapterFactory, + schema_adapter::SchemaAdapterFactory, source::DataSourceExec, PartitionedFile, +}; +use datafusion_physical_expr::conjunction; +use datafusion_physical_expr_common::physical_expr::fmt_sql; +use datafusion_physical_optimizer::PhysicalOptimizerRule; +use datafusion_physical_plan::{ + displayable, + filter::FilterExec, + filter_pushdown::{ + ChildPushdownResult, FilterDescription, FilterPushdownPropagation, + PredicateSupport, PredicateSupports, + }, + metrics::ExecutionPlanMetricsSet, + DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, +}; +use futures::stream::BoxStream; +use futures::{FutureExt, Stream}; +use object_store::ObjectStore; +use std::{ + any::Any, + fmt::{Display, Formatter}, + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; +pub struct TestOpener { + batches: Vec, + batch_size: Option, + schema: Option, + projection: Option>, +} + +impl FileOpener for TestOpener { + fn open(&self, _file_meta: FileMeta) -> Result { + let mut batches = self.batches.clone(); + if let Some(batch_size) = self.batch_size { + let batch = concat_batches(&batches[0].schema(), &batches)?; + let mut new_batches = Vec::new(); + for i in (0..batch.num_rows()).step_by(batch_size) { + let end = std::cmp::min(i + batch_size, batch.num_rows()); + let batch = batch.slice(i, end - i); + new_batches.push(batch); + } + batches = new_batches.into_iter().collect(); + } + if let Some(schema) = &self.schema { + let factory = DefaultSchemaAdapterFactory::from_schema(Arc::clone(schema)); + let (mapper, projection) = factory.map_schema(&batches[0].schema()).unwrap(); + let mut new_batches = Vec::new(); + for batch in batches { + let batch = batch.project(&projection).unwrap(); + let batch = mapper.map_batch(batch).unwrap(); + new_batches.push(batch); + } + batches = new_batches; + } + if let Some(projection) = &self.projection { + batches = batches + .into_iter() + .map(|batch| batch.project(projection).unwrap()) + .collect(); + } + + let stream = TestStream::new(batches); + + Ok((async { + let stream: BoxStream<'static, Result> = + Box::pin(stream); + Ok(stream) + }) + .boxed()) + } +} + +/// A placeholder data source that accepts filter pushdown +#[derive(Clone, Default)] +pub struct TestSource { + support: bool, + predicate: Option>, + statistics: Option, + batch_size: Option, + batches: Vec, + schema: Option, + metrics: ExecutionPlanMetricsSet, + projection: Option>, + schema_adapter_factory: Option>, +} + +impl TestSource { + fn new(support: bool, batches: Vec) -> Self { + Self { + support, + metrics: ExecutionPlanMetricsSet::new(), + batches, + ..Default::default() + } + } +} + +impl FileSource for TestSource { + fn create_file_opener( + &self, + _object_store: Arc, + _base_config: &FileScanConfig, + _partition: usize, + ) -> Arc { + Arc::new(TestOpener { + batches: self.batches.clone(), + batch_size: self.batch_size, + schema: self.schema.clone(), + projection: self.projection.clone(), + }) + } + + fn as_any(&self) -> &dyn Any { + todo!("should not be called") + } + + fn with_batch_size(&self, batch_size: usize) -> Arc { + Arc::new(TestSource { + batch_size: Some(batch_size), + ..self.clone() + }) + } + + fn with_schema(&self, schema: SchemaRef) -> Arc { + Arc::new(TestSource { + schema: Some(schema), + ..self.clone() + }) + } + + fn with_projection(&self, config: &FileScanConfig) -> Arc { + Arc::new(TestSource { + projection: config.projection.clone(), + ..self.clone() + }) + } + + fn with_statistics(&self, statistics: Statistics) -> Arc { + Arc::new(TestSource { + statistics: Some(statistics), + ..self.clone() + }) + } + + fn metrics(&self) -> &ExecutionPlanMetricsSet { + &self.metrics + } + + fn statistics(&self) -> Result { + Ok(self + .statistics + .as_ref() + .expect("statistics not set") + .clone()) + } + + fn file_type(&self) -> &str { + "test" + } + + fn fmt_extra(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + let support = format!(", pushdown_supported={}", self.support); + + let predicate_string = self + .predicate + .as_ref() + .map(|p| format!(", predicate={p}")) + .unwrap_or_default(); + + write!(f, "{support}{predicate_string}") + } + DisplayFormatType::TreeRender => { + if let Some(predicate) = &self.predicate { + writeln!(f, "pushdown_supported={}", fmt_sql(predicate.as_ref()))?; + writeln!(f, "predicate={}", fmt_sql(predicate.as_ref()))?; + } + Ok(()) + } + } + } + + fn try_pushdown_filters( + &self, + mut filters: Vec>, + config: &ConfigOptions, + ) -> Result>> { + if self.support && config.execution.parquet.pushdown_filters { + if let Some(internal) = self.predicate.as_ref() { + filters.push(Arc::clone(internal)); + } + let new_node = Arc::new(TestSource { + predicate: Some(conjunction(filters.clone())), + ..self.clone() + }); + Ok(FilterPushdownPropagation { + filters: PredicateSupports::all_supported(filters), + updated_node: Some(new_node), + }) + } else { + Ok(FilterPushdownPropagation::unsupported(filters)) + } + } + + fn with_schema_adapter_factory( + &self, + schema_adapter_factory: Arc, + ) -> Result> { + Ok(Arc::new(Self { + schema_adapter_factory: Some(schema_adapter_factory), + ..self.clone() + })) + } + + fn schema_adapter_factory(&self) -> Option> { + self.schema_adapter_factory.clone() + } +} + +#[derive(Debug, Clone)] +pub struct TestScanBuilder { + support: bool, + batches: Vec, + schema: SchemaRef, +} + +impl TestScanBuilder { + pub fn new(schema: SchemaRef) -> Self { + Self { + support: false, + batches: vec![], + schema, + } + } + + pub fn with_support(mut self, support: bool) -> Self { + self.support = support; + self + } + + pub fn build(self) -> Arc { + let source = Arc::new(TestSource::new(self.support, self.batches)); + let base_config = FileScanConfigBuilder::new( + ObjectStoreUrl::parse("test://").unwrap(), + Arc::clone(&self.schema), + source, + ) + .with_file(PartitionedFile::new("test.paqruet", 123)) + .build(); + DataSourceExec::from_data_source(base_config) + } +} + +/// Index into the data that has been returned so far +#[derive(Debug, Default, Clone)] +pub struct BatchIndex { + inner: Arc>, +} + +impl BatchIndex { + /// Return the current index + pub fn value(&self) -> usize { + let inner = self.inner.lock().unwrap(); + *inner + } + + // increment the current index by one + pub fn incr(&self) { + let mut inner = self.inner.lock().unwrap(); + *inner += 1; + } +} + +/// Iterator over batches +#[derive(Debug, Default)] +pub struct TestStream { + /// Vector of record batches + data: Vec, + /// Index into the data that has been returned so far + index: BatchIndex, +} + +impl TestStream { + /// Create an iterator for a vector of record batches. Assumes at + /// least one entry in data (for the schema) + pub fn new(data: Vec) -> Self { + // check that there is at least one entry in data and that all batches have the same schema + assert!(!data.is_empty(), "data must not be empty"); + assert!( + data.iter().all(|batch| batch.schema() == data[0].schema()), + "all batches must have the same schema" + ); + Self { + data, + ..Default::default() + } + } +} + +impl Stream for TestStream { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + let next_batch = self.index.value(); + + Poll::Ready(if next_batch < self.data.len() { + let next_batch = self.index.value(); + self.index.incr(); + Some(Ok(self.data[next_batch].clone())) + } else { + None + }) + } + + fn size_hint(&self) -> (usize, Option) { + (self.data.len(), Some(self.data.len())) + } +} + +/// A harness for testing physical optimizers. +/// +/// You can use this to test the output of a physical optimizer rule using insta snapshots +#[derive(Debug)] +pub struct OptimizationTest { + input: Vec, + output: Result, String>, +} + +impl OptimizationTest { + pub fn new( + input_plan: Arc, + opt: O, + allow_pushdown_filters: bool, + ) -> Self + where + O: PhysicalOptimizerRule, + { + let mut parquet_pushdown_config = ConfigOptions::default(); + parquet_pushdown_config.execution.parquet.pushdown_filters = + allow_pushdown_filters; + + let input = format_execution_plan(&input_plan); + let input_schema = input_plan.schema(); + + let output_result = opt.optimize(input_plan, &parquet_pushdown_config); + let output = output_result + .and_then(|plan| { + if opt.schema_check() && (plan.schema() != input_schema) { + internal_err!( + "Schema mismatch:\n\nBefore:\n{:?}\n\nAfter:\n{:?}", + input_schema, + plan.schema() + ) + } else { + Ok(plan) + } + }) + .map(|plan| format_execution_plan(&plan)) + .map_err(|e| e.to_string()); + + Self { input, output } + } +} + +impl Display for OptimizationTest { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + writeln!(f, "OptimizationTest:")?; + writeln!(f, " input:")?; + for line in &self.input { + writeln!(f, " - {line}")?; + } + writeln!(f, " output:")?; + match &self.output { + Ok(output) => { + writeln!(f, " Ok:")?; + for line in output { + writeln!(f, " - {line}")?; + } + } + Err(err) => { + writeln!(f, " Err: {err}")?; + } + } + Ok(()) + } +} + +pub fn format_execution_plan(plan: &Arc) -> Vec { + format_lines(&displayable(plan.as_ref()).indent(false).to_string()) +} + +fn format_lines(s: &str) -> Vec { + s.trim().split('\n').map(|s| s.to_string()).collect() +} + +#[derive(Debug)] +pub(crate) struct TestNode { + inject_filter: bool, + input: Arc, + predicate: Arc, +} + +impl TestNode { + pub fn new( + inject_filter: bool, + input: Arc, + predicate: Arc, + ) -> Self { + Self { + inject_filter, + input, + predicate, + } + } +} + +impl DisplayAs for TestNode { + fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { + write!( + f, + "TestInsertExec {{ inject_filter: {} }}", + self.inject_filter + ) + } +} + +impl ExecutionPlan for TestNode { + fn name(&self) -> &str { + "TestInsertExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + self.input.properties() + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.input] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + assert!(children.len() == 1); + Ok(Arc::new(TestNode::new( + self.inject_filter, + children[0].clone(), + self.predicate.clone(), + ))) + } + + fn execute( + &self, + _partition: usize, + _context: Arc, + ) -> Result { + unimplemented!("TestInsertExec is a stub for testing.") + } + + fn gather_filters_for_pushdown( + &self, + parent_filters: Vec>, + _config: &ConfigOptions, + ) -> Result { + Ok(FilterDescription::new_with_child_count(1) + .all_parent_filters_supported(parent_filters) + .with_self_filter(Arc::clone(&self.predicate))) + } + + fn handle_child_pushdown_result( + &self, + child_pushdown_result: ChildPushdownResult, + _config: &ConfigOptions, + ) -> Result>> { + if self.inject_filter { + // Add a FilterExec if our own filter was not handled by the child + + // We have 1 child + assert_eq!(child_pushdown_result.self_filters.len(), 1); + let self_pushdown_result = child_pushdown_result.self_filters[0].clone(); + // And pushed down 1 filter + assert_eq!(self_pushdown_result.len(), 1); + let self_pushdown_result = self_pushdown_result.into_inner(); + + match &self_pushdown_result[0] { + PredicateSupport::Unsupported(filter) => { + // We have a filter to push down + let new_child = + FilterExec::try_new(Arc::clone(filter), Arc::clone(&self.input))?; + let new_self = + TestNode::new(false, Arc::new(new_child), self.predicate.clone()); + let mut res = + FilterPushdownPropagation::transparent(child_pushdown_result); + res.updated_node = Some(Arc::new(new_self) as Arc); + Ok(res) + } + PredicateSupport::Supported(_) => { + let res = + FilterPushdownPropagation::transparent(child_pushdown_result); + Ok(res) + } + } + } else { + let res = FilterPushdownPropagation::transparent(child_pushdown_result); + Ok(res) + } + } +} diff --git a/datafusion/core/tests/physical_optimizer/join_selection.rs b/datafusion/core/tests/physical_optimizer/join_selection.rs index d3b6ec700bee..d8c0c142f7fb 100644 --- a/datafusion/core/tests/physical_optimizer/join_selection.rs +++ b/datafusion/core/tests/physical_optimizer/join_selection.rs @@ -251,11 +251,19 @@ async fn test_join_with_swap() { .expect("The type of the plan should not be changed"); assert_eq!( - swapped_join.left().statistics().unwrap().total_byte_size, + swapped_join + .left() + .partition_statistics(None) + .unwrap() + .total_byte_size, Precision::Inexact(8192) ); assert_eq!( - swapped_join.right().statistics().unwrap().total_byte_size, + swapped_join + .right() + .partition_statistics(None) + .unwrap() + .total_byte_size, Precision::Inexact(2097152) ); } @@ -291,11 +299,19 @@ async fn test_left_join_no_swap() { .expect("The type of the plan should not be changed"); assert_eq!( - swapped_join.left().statistics().unwrap().total_byte_size, + swapped_join + .left() + .partition_statistics(None) + .unwrap() + .total_byte_size, Precision::Inexact(8192) ); assert_eq!( - swapped_join.right().statistics().unwrap().total_byte_size, + swapped_join + .right() + .partition_statistics(None) + .unwrap() + .total_byte_size, Precision::Inexact(2097152) ); } @@ -336,11 +352,19 @@ async fn test_join_with_swap_semi() { assert_eq!(swapped_join.schema().fields().len(), 1); assert_eq!( - swapped_join.left().statistics().unwrap().total_byte_size, + swapped_join + .left() + .partition_statistics(None) + .unwrap() + .total_byte_size, Precision::Inexact(8192) ); assert_eq!( - swapped_join.right().statistics().unwrap().total_byte_size, + swapped_join + .right() + .partition_statistics(None) + .unwrap() + .total_byte_size, Precision::Inexact(2097152) ); assert_eq!(original_schema, swapped_join.schema()); @@ -455,11 +479,19 @@ async fn test_join_no_swap() { .expect("The type of the plan should not be changed"); assert_eq!( - swapped_join.left().statistics().unwrap().total_byte_size, + swapped_join + .left() + .partition_statistics(None) + .unwrap() + .total_byte_size, Precision::Inexact(8192) ); assert_eq!( - swapped_join.right().statistics().unwrap().total_byte_size, + swapped_join + .right() + .partition_statistics(None) + .unwrap() + .total_byte_size, Precision::Inexact(2097152) ); } @@ -524,11 +556,19 @@ async fn test_nl_join_with_swap(join_type: JoinType) { ); assert_eq!( - swapped_join.left().statistics().unwrap().total_byte_size, + swapped_join + .left() + .partition_statistics(None) + .unwrap() + .total_byte_size, Precision::Inexact(8192) ); assert_eq!( - swapped_join.right().statistics().unwrap().total_byte_size, + swapped_join + .right() + .partition_statistics(None) + .unwrap() + .total_byte_size, Precision::Inexact(2097152) ); } @@ -589,11 +629,19 @@ async fn test_nl_join_with_swap_no_proj(join_type: JoinType) { ); assert_eq!( - swapped_join.left().statistics().unwrap().total_byte_size, + swapped_join + .left() + .partition_statistics(None) + .unwrap() + .total_byte_size, Precision::Inexact(8192) ); assert_eq!( - swapped_join.right().statistics().unwrap().total_byte_size, + swapped_join + .right() + .partition_statistics(None) + .unwrap() + .total_byte_size, Precision::Inexact(2097152) ); } @@ -1067,6 +1115,14 @@ impl ExecutionPlan for StatisticsExec { fn statistics(&self) -> Result { Ok(self.stats.clone()) } + + fn partition_statistics(&self, partition: Option) -> Result { + Ok(if partition.is_some() { + Statistics::new_unknown(&self.schema) + } else { + self.stats.clone() + }) + } } #[test] diff --git a/datafusion/core/tests/physical_optimizer/mod.rs b/datafusion/core/tests/physical_optimizer/mod.rs index 7d5d07715eeb..98e7b87ad215 100644 --- a/datafusion/core/tests/physical_optimizer/mod.rs +++ b/datafusion/core/tests/physical_optimizer/mod.rs @@ -21,9 +21,11 @@ mod aggregate_statistics; mod combine_partial_final_agg; mod enforce_distribution; mod enforce_sorting; +mod filter_pushdown; mod join_selection; mod limit_pushdown; mod limited_distinct_aggregation; +mod partition_statistics; mod projection_pushdown; mod replace_with_order_preserving_variants; mod sanity_checker; diff --git a/datafusion/core/tests/physical_optimizer/partition_statistics.rs b/datafusion/core/tests/physical_optimizer/partition_statistics.rs new file mode 100644 index 000000000000..62f04f2fe740 --- /dev/null +++ b/datafusion/core/tests/physical_optimizer/partition_statistics.rs @@ -0,0 +1,744 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#[cfg(test)] +mod test { + use arrow::array::{Int32Array, RecordBatch}; + use arrow_schema::{DataType, Field, Schema, SortOptions}; + use datafusion::datasource::listing::ListingTable; + use datafusion::prelude::SessionContext; + use datafusion_catalog::TableProvider; + use datafusion_common::stats::Precision; + use datafusion_common::Result; + use datafusion_common::{ColumnStatistics, ScalarValue, Statistics}; + use datafusion_execution::config::SessionConfig; + use datafusion_execution::TaskContext; + use datafusion_expr_common::operator::Operator; + use datafusion_functions_aggregate::count::count_udaf; + use datafusion_physical_expr::aggregate::AggregateExprBuilder; + use datafusion_physical_expr::expressions::{binary, col, lit, Column}; + use datafusion_physical_expr_common::physical_expr::PhysicalExpr; + use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; + use datafusion_physical_plan::aggregates::{ + AggregateExec, AggregateMode, PhysicalGroupBy, + }; + use datafusion_physical_plan::coalesce_batches::CoalesceBatchesExec; + use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; + use datafusion_physical_plan::empty::EmptyExec; + use datafusion_physical_plan::filter::FilterExec; + use datafusion_physical_plan::joins::CrossJoinExec; + use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; + use datafusion_physical_plan::projection::ProjectionExec; + use datafusion_physical_plan::sorts::sort::SortExec; + use datafusion_physical_plan::union::UnionExec; + use datafusion_physical_plan::{ + execute_stream_partitioned, get_plan_string, ExecutionPlan, + ExecutionPlanProperties, + }; + use futures::TryStreamExt; + use std::sync::Arc; + + /// Creates a test table with statistics from the test data directory. + /// + /// This function: + /// - Creates an external table from './tests/data/test_statistics_per_partition' + /// - If we set the `target_partition` to 2, the data contains 2 partitions, each with 2 rows + /// - Each partition has an "id" column (INT) with the following values: + /// - First partition: [3, 4] + /// - Second partition: [1, 2] + /// - Each row is 110 bytes in size + /// + /// @param target_partition Optional parameter to set the target partitions + /// @return ExecutionPlan representing the scan of the table with statistics + async fn create_scan_exec_with_statistics( + create_table_sql: Option<&str>, + target_partition: Option, + ) -> Arc { + let mut session_config = SessionConfig::new().with_collect_statistics(true); + if let Some(partition) = target_partition { + session_config = session_config.with_target_partitions(partition); + } + let ctx = SessionContext::new_with_config(session_config); + // Create table with partition + let create_table_sql = create_table_sql.unwrap_or( + "CREATE EXTERNAL TABLE t1 (id INT NOT NULL, date DATE) \ + STORED AS PARQUET LOCATION './tests/data/test_statistics_per_partition'\ + PARTITIONED BY (date) \ + WITH ORDER (id ASC);", + ); + // Get table name from `create_table_sql` + let table_name = create_table_sql + .split_whitespace() + .nth(3) + .unwrap_or("t1") + .to_string(); + ctx.sql(create_table_sql) + .await + .unwrap() + .collect() + .await + .unwrap(); + let table = ctx.table_provider(table_name.as_str()).await.unwrap(); + let listing_table = table + .as_any() + .downcast_ref::() + .unwrap() + .clone(); + listing_table + .scan(&ctx.state(), None, &[], None) + .await + .unwrap() + } + + /// Helper function to create expected statistics for a partition with Int32 column + fn create_partition_statistics( + num_rows: usize, + total_byte_size: usize, + min_value: i32, + max_value: i32, + include_date_column: bool, + ) -> Statistics { + let mut column_stats = vec![ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Int32(Some(max_value))), + min_value: Precision::Exact(ScalarValue::Int32(Some(min_value))), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + }]; + + if include_date_column { + column_stats.push(ColumnStatistics { + null_count: Precision::Absent, + max_value: Precision::Absent, + min_value: Precision::Absent, + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + }); + } + + Statistics { + num_rows: Precision::Exact(num_rows), + total_byte_size: Precision::Exact(total_byte_size), + column_statistics: column_stats, + } + } + + #[derive(PartialEq, Eq, Debug)] + enum ExpectedStatistics { + Empty, // row_count == 0 + NonEmpty(i32, i32, usize), // (min_id, max_id, row_count) + } + + /// Helper function to validate that statistics from statistics_by_partition match the actual data + async fn validate_statistics_with_data( + plan: Arc, + expected_stats: Vec, + id_column_index: usize, + ) -> Result<()> { + let ctx = TaskContext::default(); + let partitions = execute_stream_partitioned(plan, Arc::new(ctx))?; + + let mut actual_stats = Vec::new(); + for partition_stream in partitions.into_iter() { + let result: Vec = partition_stream.try_collect().await?; + + let mut min_id = i32::MAX; + let mut max_id = i32::MIN; + let mut row_count = 0; + + for batch in result { + if batch.num_columns() > id_column_index { + let id_array = batch + .column(id_column_index) + .as_any() + .downcast_ref::() + .unwrap(); + for i in 0..batch.num_rows() { + let id_value = id_array.value(i); + min_id = min_id.min(id_value); + max_id = max_id.max(id_value); + row_count += 1; + } + } + } + + if row_count == 0 { + actual_stats.push(ExpectedStatistics::Empty); + } else { + actual_stats + .push(ExpectedStatistics::NonEmpty(min_id, max_id, row_count)); + } + } + + // Compare actual data with expected statistics + assert_eq!( + actual_stats.len(), + expected_stats.len(), + "Number of partitions with data doesn't match expected" + ); + for i in 0..actual_stats.len() { + assert_eq!( + actual_stats[i], expected_stats[i], + "Partition {i} data doesn't match statistics" + ); + } + + Ok(()) + } + + #[tokio::test] + async fn test_statistics_by_partition_of_data_source() -> Result<()> { + let scan = create_scan_exec_with_statistics(None, Some(2)).await; + let statistics = (0..scan.output_partitioning().partition_count()) + .map(|idx| scan.partition_statistics(Some(idx))) + .collect::>>()?; + let expected_statistic_partition_1 = + create_partition_statistics(2, 110, 3, 4, true); + let expected_statistic_partition_2 = + create_partition_statistics(2, 110, 1, 2, true); + // Check the statistics of each partition + assert_eq!(statistics.len(), 2); + assert_eq!(statistics[0], expected_statistic_partition_1); + assert_eq!(statistics[1], expected_statistic_partition_2); + + // Check the statistics_by_partition with real results + let expected_stats = vec![ + ExpectedStatistics::NonEmpty(3, 4, 2), // (min_id, max_id, row_count) for first partition + ExpectedStatistics::NonEmpty(1, 2, 2), // (min_id, max_id, row_count) for second partition + ]; + validate_statistics_with_data(scan, expected_stats, 0).await?; + + Ok(()) + } + + #[tokio::test] + async fn test_statistics_by_partition_of_projection() -> Result<()> { + let scan = create_scan_exec_with_statistics(None, Some(2)).await; + // Add projection execution plan + let exprs: Vec<(Arc, String)> = + vec![(Arc::new(Column::new("id", 0)), "id".to_string())]; + let projection: Arc = + Arc::new(ProjectionExec::try_new(exprs, scan)?); + let statistics = (0..projection.output_partitioning().partition_count()) + .map(|idx| projection.partition_statistics(Some(idx))) + .collect::>>()?; + let expected_statistic_partition_1 = + create_partition_statistics(2, 8, 3, 4, false); + let expected_statistic_partition_2 = + create_partition_statistics(2, 8, 1, 2, false); + // Check the statistics of each partition + assert_eq!(statistics.len(), 2); + assert_eq!(statistics[0], expected_statistic_partition_1); + assert_eq!(statistics[1], expected_statistic_partition_2); + + // Check the statistics_by_partition with real results + let expected_stats = vec![ + ExpectedStatistics::NonEmpty(3, 4, 2), + ExpectedStatistics::NonEmpty(1, 2, 2), + ]; + validate_statistics_with_data(projection, expected_stats, 0).await?; + Ok(()) + } + + #[tokio::test] + async fn test_statistics_by_partition_of_sort() -> Result<()> { + let scan_1 = create_scan_exec_with_statistics(None, Some(1)).await; + // Add sort execution plan + let sort = SortExec::new( + LexOrdering::new(vec![PhysicalSortExpr { + expr: Arc::new(Column::new("id", 0)), + options: SortOptions { + descending: false, + nulls_first: false, + }, + }]), + scan_1, + ); + let sort_exec: Arc = Arc::new(sort.clone()); + let statistics = (0..sort_exec.output_partitioning().partition_count()) + .map(|idx| sort_exec.partition_statistics(Some(idx))) + .collect::>>()?; + let expected_statistic_partition = + create_partition_statistics(4, 220, 1, 4, true); + assert_eq!(statistics.len(), 1); + assert_eq!(statistics[0], expected_statistic_partition); + // Check the statistics_by_partition with real results + let expected_stats = vec![ExpectedStatistics::NonEmpty(1, 4, 4)]; + validate_statistics_with_data(sort_exec.clone(), expected_stats, 0).await?; + + // Sort with preserve_partitioning + let scan_2 = create_scan_exec_with_statistics(None, Some(2)).await; + // Add sort execution plan + let sort_exec: Arc = Arc::new( + SortExec::new( + LexOrdering::new(vec![PhysicalSortExpr { + expr: Arc::new(Column::new("id", 0)), + options: SortOptions { + descending: false, + nulls_first: false, + }, + }]), + scan_2, + ) + .with_preserve_partitioning(true), + ); + let expected_statistic_partition_1 = + create_partition_statistics(2, 110, 3, 4, true); + let expected_statistic_partition_2 = + create_partition_statistics(2, 110, 1, 2, true); + let statistics = (0..sort_exec.output_partitioning().partition_count()) + .map(|idx| sort_exec.partition_statistics(Some(idx))) + .collect::>>()?; + assert_eq!(statistics.len(), 2); + assert_eq!(statistics[0], expected_statistic_partition_1); + assert_eq!(statistics[1], expected_statistic_partition_2); + + // Check the statistics_by_partition with real results + let expected_stats = vec![ + ExpectedStatistics::NonEmpty(3, 4, 2), + ExpectedStatistics::NonEmpty(1, 2, 2), + ]; + validate_statistics_with_data(sort_exec, expected_stats, 0).await?; + Ok(()) + } + + #[tokio::test] + async fn test_statistics_by_partition_of_filter() -> Result<()> { + let scan = create_scan_exec_with_statistics(None, Some(2)).await; + let schema = Schema::new(vec![Field::new("id", DataType::Int32, false)]); + let predicate = binary( + Arc::new(Column::new("id", 0)), + Operator::Lt, + lit(1i32), + &schema, + )?; + let filter: Arc = + Arc::new(FilterExec::try_new(predicate, scan)?); + let full_statistics = filter.partition_statistics(None)?; + let expected_full_statistic = Statistics { + num_rows: Precision::Inexact(0), + total_byte_size: Precision::Inexact(0), + column_statistics: vec![ + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Null), + min_value: Precision::Exact(ScalarValue::Null), + sum_value: Precision::Exact(ScalarValue::Null), + distinct_count: Precision::Exact(0), + }, + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Null), + min_value: Precision::Exact(ScalarValue::Null), + sum_value: Precision::Exact(ScalarValue::Null), + distinct_count: Precision::Exact(0), + }, + ], + }; + assert_eq!(full_statistics, expected_full_statistic); + + let statistics = (0..filter.output_partitioning().partition_count()) + .map(|idx| filter.partition_statistics(Some(idx))) + .collect::>>()?; + assert_eq!(statistics.len(), 2); + assert_eq!(statistics[0], expected_full_statistic); + assert_eq!(statistics[1], expected_full_statistic); + Ok(()) + } + + #[tokio::test] + async fn test_statistic_by_partition_of_union() -> Result<()> { + let scan = create_scan_exec_with_statistics(None, Some(2)).await; + let union_exec: Arc = + Arc::new(UnionExec::new(vec![scan.clone(), scan])); + let statistics = (0..union_exec.output_partitioning().partition_count()) + .map(|idx| union_exec.partition_statistics(Some(idx))) + .collect::>>()?; + // Check that we have 4 partitions (2 from each scan) + assert_eq!(statistics.len(), 4); + let expected_statistic_partition_1 = + create_partition_statistics(2, 110, 3, 4, true); + let expected_statistic_partition_2 = + create_partition_statistics(2, 110, 1, 2, true); + // Verify first partition (from first scan) + assert_eq!(statistics[0], expected_statistic_partition_1); + // Verify second partition (from first scan) + assert_eq!(statistics[1], expected_statistic_partition_2); + // Verify third partition (from second scan - same as first partition) + assert_eq!(statistics[2], expected_statistic_partition_1); + // Verify fourth partition (from second scan - same as second partition) + assert_eq!(statistics[3], expected_statistic_partition_2); + + // Check the statistics_by_partition with real results + let expected_stats = vec![ + ExpectedStatistics::NonEmpty(3, 4, 2), + ExpectedStatistics::NonEmpty(1, 2, 2), + ExpectedStatistics::NonEmpty(3, 4, 2), + ExpectedStatistics::NonEmpty(1, 2, 2), + ]; + validate_statistics_with_data(union_exec, expected_stats, 0).await?; + Ok(()) + } + + #[tokio::test] + async fn test_statistic_by_partition_of_cross_join() -> Result<()> { + let left_scan = create_scan_exec_with_statistics(None, Some(1)).await; + let right_create_table_sql = "CREATE EXTERNAL TABLE t2 (id INT NOT NULL) \ + STORED AS PARQUET LOCATION './tests/data/test_statistics_per_partition'\ + WITH ORDER (id ASC);"; + let right_scan = + create_scan_exec_with_statistics(Some(right_create_table_sql), Some(2)).await; + let cross_join: Arc = + Arc::new(CrossJoinExec::new(left_scan, right_scan)); + let statistics = (0..cross_join.output_partitioning().partition_count()) + .map(|idx| cross_join.partition_statistics(Some(idx))) + .collect::>>()?; + // Check that we have 2 partitions + assert_eq!(statistics.len(), 2); + let mut expected_statistic_partition_1 = + create_partition_statistics(8, 48400, 1, 4, true); + expected_statistic_partition_1 + .column_statistics + .push(ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Int32(Some(4))), + min_value: Precision::Exact(ScalarValue::Int32(Some(3))), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + }); + let mut expected_statistic_partition_2 = + create_partition_statistics(8, 48400, 1, 4, true); + expected_statistic_partition_2 + .column_statistics + .push(ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Int32(Some(2))), + min_value: Precision::Exact(ScalarValue::Int32(Some(1))), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + }); + assert_eq!(statistics[0], expected_statistic_partition_1); + assert_eq!(statistics[1], expected_statistic_partition_2); + + // Check the statistics_by_partition with real results + let expected_stats = vec![ + ExpectedStatistics::NonEmpty(1, 4, 8), + ExpectedStatistics::NonEmpty(1, 4, 8), + ]; + validate_statistics_with_data(cross_join, expected_stats, 0).await?; + Ok(()) + } + + #[tokio::test] + async fn test_statistic_by_partition_of_coalesce_batches() -> Result<()> { + let scan = create_scan_exec_with_statistics(None, Some(2)).await; + dbg!(scan.partition_statistics(Some(0))?); + let coalesce_batches: Arc = + Arc::new(CoalesceBatchesExec::new(scan, 2)); + let expected_statistic_partition_1 = + create_partition_statistics(2, 110, 3, 4, true); + let expected_statistic_partition_2 = + create_partition_statistics(2, 110, 1, 2, true); + let statistics = (0..coalesce_batches.output_partitioning().partition_count()) + .map(|idx| coalesce_batches.partition_statistics(Some(idx))) + .collect::>>()?; + assert_eq!(statistics.len(), 2); + assert_eq!(statistics[0], expected_statistic_partition_1); + assert_eq!(statistics[1], expected_statistic_partition_2); + + // Check the statistics_by_partition with real results + let expected_stats = vec![ + ExpectedStatistics::NonEmpty(3, 4, 2), + ExpectedStatistics::NonEmpty(1, 2, 2), + ]; + validate_statistics_with_data(coalesce_batches, expected_stats, 0).await?; + Ok(()) + } + + #[tokio::test] + async fn test_statistic_by_partition_of_coalesce_partitions() -> Result<()> { + let scan = create_scan_exec_with_statistics(None, Some(2)).await; + let coalesce_partitions: Arc = + Arc::new(CoalescePartitionsExec::new(scan)); + let expected_statistic_partition = + create_partition_statistics(4, 220, 1, 4, true); + let statistics = (0..coalesce_partitions.output_partitioning().partition_count()) + .map(|idx| coalesce_partitions.partition_statistics(Some(idx))) + .collect::>>()?; + assert_eq!(statistics.len(), 1); + assert_eq!(statistics[0], expected_statistic_partition); + + // Check the statistics_by_partition with real results + let expected_stats = vec![ExpectedStatistics::NonEmpty(1, 4, 4)]; + validate_statistics_with_data(coalesce_partitions, expected_stats, 0).await?; + Ok(()) + } + + #[tokio::test] + async fn test_statistic_by_partition_of_local_limit() -> Result<()> { + let scan = create_scan_exec_with_statistics(None, Some(2)).await; + let local_limit: Arc = + Arc::new(LocalLimitExec::new(scan.clone(), 1)); + let statistics = (0..local_limit.output_partitioning().partition_count()) + .map(|idx| local_limit.partition_statistics(Some(idx))) + .collect::>>()?; + assert_eq!(statistics.len(), 2); + let schema = scan.schema(); + let mut expected_statistic_partition = Statistics::new_unknown(&schema); + expected_statistic_partition.num_rows = Precision::Exact(1); + assert_eq!(statistics[0], expected_statistic_partition); + assert_eq!(statistics[1], expected_statistic_partition); + Ok(()) + } + + #[tokio::test] + async fn test_statistic_by_partition_of_global_limit_partitions() -> Result<()> { + let scan = create_scan_exec_with_statistics(None, Some(2)).await; + // Skip 2 rows + let global_limit: Arc = + Arc::new(GlobalLimitExec::new(scan.clone(), 0, Some(2))); + let statistics = (0..global_limit.output_partitioning().partition_count()) + .map(|idx| global_limit.partition_statistics(Some(idx))) + .collect::>>()?; + assert_eq!(statistics.len(), 1); + let expected_statistic_partition = + create_partition_statistics(2, 110, 3, 4, true); + assert_eq!(statistics[0], expected_statistic_partition); + Ok(()) + } + + #[tokio::test] + async fn test_statistic_by_partition_of_agg() -> Result<()> { + let scan = create_scan_exec_with_statistics(None, Some(2)).await; + + let scan_schema = scan.schema(); + + // select id, 1+id, count(*) from t group by id, 1+id + let group_by = PhysicalGroupBy::new_single(vec![ + (col("id", &scan_schema)?, "id".to_string()), + ( + binary( + lit(1), + Operator::Plus, + col("id", &scan_schema)?, + &scan_schema, + )?, + "expr".to_string(), + ), + ]); + + let aggr_expr = vec![AggregateExprBuilder::new(count_udaf(), vec![lit(1)]) + .schema(Arc::clone(&scan_schema)) + .alias(String::from("COUNT(c)")) + .build() + .map(Arc::new)?]; + + let aggregate_exec_partial = Arc::new(AggregateExec::try_new( + AggregateMode::Partial, + group_by.clone(), + aggr_expr.clone(), + vec![None], + Arc::clone(&scan), + scan_schema.clone(), + )?) as _; + + let mut plan_string = get_plan_string(&aggregate_exec_partial); + let _ = plan_string.swap_remove(1); + let expected_plan = vec![ + "AggregateExec: mode=Partial, gby=[id@0 as id, 1 + id@0 as expr], aggr=[COUNT(c)]", + //" DataSourceExec: file_groups={2 groups: [[.../datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-01/j5fUeSDQo22oPyPU.parquet, .../datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-02/j5fUeSDQo22oPyPU.parquet], [.../datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-03/j5fUeSDQo22oPyPU.parquet, .../datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-04/j5fUeSDQo22oPyPU.parquet]]}, projection=[id, date], file_type=parquet + ]; + assert_eq!(plan_string, expected_plan); + + let p0_statistics = aggregate_exec_partial.partition_statistics(Some(0))?; + + let expected_p0_statistics = Statistics { + num_rows: Precision::Inexact(2), + total_byte_size: Precision::Absent, + column_statistics: vec![ + ColumnStatistics { + null_count: Precision::Absent, + max_value: Precision::Exact(ScalarValue::Int32(Some(4))), + min_value: Precision::Exact(ScalarValue::Int32(Some(3))), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + }, + ColumnStatistics::new_unknown(), + ColumnStatistics::new_unknown(), + ], + }; + + assert_eq!(&p0_statistics, &expected_p0_statistics); + + let expected_p1_statistics = Statistics { + num_rows: Precision::Inexact(2), + total_byte_size: Precision::Absent, + column_statistics: vec![ + ColumnStatistics { + null_count: Precision::Absent, + max_value: Precision::Exact(ScalarValue::Int32(Some(2))), + min_value: Precision::Exact(ScalarValue::Int32(Some(1))), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + }, + ColumnStatistics::new_unknown(), + ColumnStatistics::new_unknown(), + ], + }; + + let p1_statistics = aggregate_exec_partial.partition_statistics(Some(1))?; + assert_eq!(&p1_statistics, &expected_p1_statistics); + + validate_statistics_with_data( + aggregate_exec_partial.clone(), + vec![ + ExpectedStatistics::NonEmpty(3, 4, 2), + ExpectedStatistics::NonEmpty(1, 2, 2), + ], + 0, + ) + .await?; + + let agg_final = Arc::new(AggregateExec::try_new( + AggregateMode::FinalPartitioned, + group_by.clone(), + aggr_expr.clone(), + vec![None], + aggregate_exec_partial.clone(), + aggregate_exec_partial.schema(), + )?); + + let p0_statistics = agg_final.partition_statistics(Some(0))?; + assert_eq!(&p0_statistics, &expected_p0_statistics); + + let p1_statistics = agg_final.partition_statistics(Some(1))?; + assert_eq!(&p1_statistics, &expected_p1_statistics); + + validate_statistics_with_data( + agg_final.clone(), + vec![ + ExpectedStatistics::NonEmpty(3, 4, 2), + ExpectedStatistics::NonEmpty(1, 2, 2), + ], + 0, + ) + .await?; + + // select id, 1+id, count(*) from empty_table group by id, 1+id + let empty_table = + Arc::new(EmptyExec::new(scan_schema.clone()).with_partitions(2)); + + let agg_partial = Arc::new(AggregateExec::try_new( + AggregateMode::Partial, + group_by.clone(), + aggr_expr.clone(), + vec![None], + empty_table.clone(), + scan_schema.clone(), + )?) as _; + + let agg_plan = get_plan_string(&agg_partial).remove(0); + assert_eq!("AggregateExec: mode=Partial, gby=[id@0 as id, 1 + id@0 as expr], aggr=[COUNT(c)]",agg_plan); + + let empty_stat = Statistics { + num_rows: Precision::Exact(0), + total_byte_size: Precision::Absent, + column_statistics: vec![ + ColumnStatistics::new_unknown(), + ColumnStatistics::new_unknown(), + ColumnStatistics::new_unknown(), + ], + }; + + assert_eq!(&empty_stat, &agg_partial.partition_statistics(Some(0))?); + assert_eq!(&empty_stat, &agg_partial.partition_statistics(Some(1))?); + validate_statistics_with_data( + agg_partial.clone(), + vec![ExpectedStatistics::Empty, ExpectedStatistics::Empty], + 0, + ) + .await?; + + let agg_partial = Arc::new(AggregateExec::try_new( + AggregateMode::Partial, + group_by.clone(), + aggr_expr.clone(), + vec![None], + empty_table.clone(), + scan_schema.clone(), + )?); + + let agg_final = Arc::new(AggregateExec::try_new( + AggregateMode::FinalPartitioned, + group_by.clone(), + aggr_expr.clone(), + vec![None], + agg_partial.clone(), + agg_partial.schema(), + )?); + + assert_eq!(&empty_stat, &agg_final.partition_statistics(Some(0))?); + assert_eq!(&empty_stat, &agg_final.partition_statistics(Some(1))?); + + validate_statistics_with_data( + agg_final, + vec![ExpectedStatistics::Empty, ExpectedStatistics::Empty], + 0, + ) + .await?; + + // select count(*) from empty_table + let agg_partial = Arc::new(AggregateExec::try_new( + AggregateMode::Partial, + PhysicalGroupBy::default(), + aggr_expr.clone(), + vec![None], + empty_table.clone(), + scan_schema.clone(), + )?); + + let coalesce = Arc::new(CoalescePartitionsExec::new(agg_partial.clone())); + + let agg_final = Arc::new(AggregateExec::try_new( + AggregateMode::Final, + PhysicalGroupBy::default(), + aggr_expr.clone(), + vec![None], + coalesce.clone(), + coalesce.schema(), + )?); + + let expect_stat = Statistics { + num_rows: Precision::Exact(1), + total_byte_size: Precision::Absent, + column_statistics: vec![ColumnStatistics::new_unknown()], + }; + + assert_eq!(&expect_stat, &agg_final.partition_statistics(Some(0))?); + + // Verify that the aggregate final result has exactly one partition with one row + let mut partitions = execute_stream_partitioned( + agg_final.clone(), + Arc::new(TaskContext::default()), + )?; + assert_eq!(1, partitions.len()); + let result: Vec = partitions.remove(0).try_collect().await?; + assert_eq!(1, result[0].num_rows()); + + Ok(()) + } +} diff --git a/datafusion/core/tests/physical_optimizer/projection_pushdown.rs b/datafusion/core/tests/physical_optimizer/projection_pushdown.rs index 911d2c0cee05..7c00d323a8e6 100644 --- a/datafusion/core/tests/physical_optimizer/projection_pushdown.rs +++ b/datafusion/core/tests/physical_optimizer/projection_pushdown.rs @@ -128,7 +128,7 @@ fn test_update_matching_exprs() -> Result<()> { Arc::new(Column::new("b", 1)), )), ], - DataType::Int32, + Field::new("f", DataType::Int32, true).into(), )), Arc::new(CaseExpr::try_new( Some(Arc::new(Column::new("d", 2))), @@ -193,7 +193,7 @@ fn test_update_matching_exprs() -> Result<()> { Arc::new(Column::new("b", 1)), )), ], - DataType::Int32, + Field::new("f", DataType::Int32, true).into(), )), Arc::new(CaseExpr::try_new( Some(Arc::new(Column::new("d", 3))), @@ -261,7 +261,7 @@ fn test_update_projected_exprs() -> Result<()> { Arc::new(Column::new("b", 1)), )), ], - DataType::Int32, + Field::new("f", DataType::Int32, true).into(), )), Arc::new(CaseExpr::try_new( Some(Arc::new(Column::new("d", 2))), @@ -326,7 +326,7 @@ fn test_update_projected_exprs() -> Result<()> { Arc::new(Column::new("b_new", 1)), )), ], - DataType::Int32, + Field::new("f", DataType::Int32, true).into(), )), Arc::new(CaseExpr::try_new( Some(Arc::new(Column::new("d_new", 3))), diff --git a/datafusion/core/tests/physical_optimizer/replace_with_order_preserving_variants.rs b/datafusion/core/tests/physical_optimizer/replace_with_order_preserving_variants.rs index 58eb866c590c..71b9757604ec 100644 --- a/datafusion/core/tests/physical_optimizer/replace_with_order_preserving_variants.rs +++ b/datafusion/core/tests/physical_optimizer/replace_with_order_preserving_variants.rs @@ -18,7 +18,8 @@ use std::sync::Arc; use crate::physical_optimizer::test_utils::{ - check_integrity, sort_preserving_merge_exec, stream_exec_ordered_with_projection, + check_integrity, create_test_schema3, sort_preserving_merge_exec, + stream_exec_ordered_with_projection, }; use datafusion::prelude::SessionContext; @@ -40,13 +41,16 @@ use datafusion_physical_plan::{ }; use datafusion::datasource::source::DataSourceExec; use datafusion_common::tree_node::{TransformedResult, TreeNode}; -use datafusion_common::Result; +use datafusion_common::{assert_contains, Result}; use datafusion_expr::{JoinType, Operator}; use datafusion_physical_expr::expressions::{self, col, Column}; use datafusion_physical_expr::PhysicalSortExpr; -use datafusion_physical_optimizer::enforce_sorting::replace_with_order_preserving_variants::{replace_with_order_preserving_variants, OrderPreservationContext}; +use datafusion_physical_optimizer::enforce_sorting::replace_with_order_preserving_variants::{plan_with_order_breaking_variants, plan_with_order_preserving_variants, replace_with_order_preserving_variants, OrderPreservationContext}; use datafusion_common::config::ConfigOptions; +use crate::physical_optimizer::enforce_sorting::parquet_exec_sorted; +use datafusion_physical_expr_common::sort_expr::LexOrdering; +use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use object_store::memory::InMemory; use object_store::ObjectStore; use rstest::rstest; @@ -1259,3 +1263,77 @@ fn memory_exec_sorted( )) }) } + +#[test] +fn test_plan_with_order_preserving_variants_preserves_fetch() -> Result<()> { + // Create a schema + let schema = create_test_schema3()?; + let parquet_sort_exprs = vec![crate::physical_optimizer::test_utils::sort_expr( + "a", &schema, + )]; + let parquet_exec = parquet_exec_sorted(&schema, parquet_sort_exprs); + let coalesced = + Arc::new(CoalescePartitionsExec::new(parquet_exec.clone()).with_fetch(Some(10))); + + // Test sort's fetch is greater than coalesce fetch, return error because it's not reasonable + let requirements = OrderPreservationContext::new( + coalesced.clone(), + false, + vec![OrderPreservationContext::new( + parquet_exec.clone(), + false, + vec![], + )], + ); + let res = plan_with_order_preserving_variants(requirements, false, true, Some(15)); + assert_contains!(res.unwrap_err().to_string(), "CoalescePartitionsExec fetch [10] should be greater than or equal to SortExec fetch [15]"); + + // Test sort is without fetch, expected to get the fetch value from the coalesced + let requirements = OrderPreservationContext::new( + coalesced.clone(), + false, + vec![OrderPreservationContext::new( + parquet_exec.clone(), + false, + vec![], + )], + ); + let res = plan_with_order_preserving_variants(requirements, false, true, None)?; + assert_eq!(res.plan.fetch(), Some(10),); + + // Test sort's fetch is less than coalesces fetch, expected to get the fetch value from the sort + let requirements = OrderPreservationContext::new( + coalesced, + false, + vec![OrderPreservationContext::new(parquet_exec, false, vec![])], + ); + let res = plan_with_order_preserving_variants(requirements, false, true, Some(5))?; + assert_eq!(res.plan.fetch(), Some(5),); + Ok(()) +} + +#[test] +fn test_plan_with_order_breaking_variants_preserves_fetch() -> Result<()> { + let schema = create_test_schema3()?; + let parquet_sort_exprs = vec![crate::physical_optimizer::test_utils::sort_expr( + "a", &schema, + )]; + let parquet_exec = parquet_exec_sorted(&schema, parquet_sort_exprs.clone()); + let spm = SortPreservingMergeExec::new( + LexOrdering::new(parquet_sort_exprs), + parquet_exec.clone(), + ) + .with_fetch(Some(10)); + let requirements = OrderPreservationContext::new( + Arc::new(spm), + true, + vec![OrderPreservationContext::new( + parquet_exec.clone(), + true, + vec![], + )], + ); + let res = plan_with_order_breaking_variants(requirements)?; + assert_eq!(res.plan.fetch(), Some(10)); + Ok(()) +} diff --git a/datafusion/core/tests/physical_optimizer/test_utils.rs b/datafusion/core/tests/physical_optimizer/test_utils.rs index 4587f99989d3..955486a31030 100644 --- a/datafusion/core/tests/physical_optimizer/test_utils.rs +++ b/datafusion/core/tests/physical_optimizer/test_utils.rs @@ -30,9 +30,10 @@ use datafusion::datasource::memory::MemorySourceConfig; use datafusion::datasource::physical_plan::ParquetSource; use datafusion::datasource::source::DataSourceExec; use datafusion_common::config::ConfigOptions; +use datafusion_common::stats::Precision; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::utils::expr::COUNT_STAR_EXPANSION; -use datafusion_common::{JoinType, Result}; +use datafusion_common::{ColumnStatistics, JoinType, Result, Statistics}; use datafusion_datasource::file_scan_config::FileScanConfigBuilder; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; @@ -96,6 +97,48 @@ pub(crate) fn parquet_exec_with_sort( DataSourceExec::from_data_source(config) } +fn int64_stats() -> ColumnStatistics { + ColumnStatistics { + null_count: Precision::Absent, + sum_value: Precision::Absent, + max_value: Precision::Exact(1_000_000.into()), + min_value: Precision::Exact(0.into()), + distinct_count: Precision::Absent, + } +} + +fn column_stats() -> Vec { + vec![ + int64_stats(), // a + int64_stats(), // b + int64_stats(), // c + ColumnStatistics::default(), + ColumnStatistics::default(), + ] +} + +/// Create parquet datasource exec using schema from [`schema`]. +pub(crate) fn parquet_exec_with_stats(file_size: u64) -> Arc { + let mut statistics = Statistics::new_unknown(&schema()); + statistics.num_rows = Precision::Inexact(10000); + statistics.column_statistics = column_stats(); + + let config = FileScanConfigBuilder::new( + ObjectStoreUrl::parse("test:///").unwrap(), + schema(), + Arc::new(ParquetSource::new(Default::default())), + ) + .with_file(PartitionedFile::new("x".to_string(), file_size)) + .with_statistics(statistics) + .build(); + + assert_eq!( + config.file_source.statistics().unwrap().num_rows, + Precision::Inexact(10000) + ); + DataSourceExec::from_data_source(config) +} + pub fn schema() -> SchemaRef { Arc::new(Schema::new(vec![ Field::new("a", DataType::Int64, true), @@ -557,8 +600,7 @@ pub fn assert_plan_matches_expected( assert_eq!( &expected_lines, &actual_lines, - "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", - expected_lines, actual_lines + "\n\nexpected:\n\n{expected_lines:#?}\nactual:\n\n{actual_lines:#?}\n\n" ); Ok(()) diff --git a/datafusion/core/tests/sql/explain_analyze.rs b/datafusion/core/tests/sql/explain_analyze.rs index e8ef34c2afe7..70e94227cfad 100644 --- a/datafusion/core/tests/sql/explain_analyze.rs +++ b/datafusion/core/tests/sql/explain_analyze.rs @@ -176,9 +176,9 @@ async fn csv_explain_plans() { // Verify schema let expected = vec![ "Explain [plan_type:Utf8, plan:Utf8]", - " Projection: aggregate_test_100.c1 [c1:Utf8]", - " Filter: aggregate_test_100.c2 > Int64(10) [c1:Utf8, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8]", - " TableScan: aggregate_test_100 [c1:Utf8, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8]", + " Projection: aggregate_test_100.c1 [c1:Utf8View]", + " Filter: aggregate_test_100.c2 > Int64(10) [c1:Utf8View, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8View]", + " TableScan: aggregate_test_100 [c1:Utf8View, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8View]", ]; let formatted = plan.display_indent_schema().to_string(); let actual: Vec<&str> = formatted.trim().lines().collect(); @@ -222,11 +222,11 @@ async fn csv_explain_plans() { " {", " graph[label=\"Detailed LogicalPlan\"]", " 7[shape=box label=\"Explain\\nSchema: [plan_type:Utf8, plan:Utf8]\"]", - " 8[shape=box label=\"Projection: aggregate_test_100.c1\\nSchema: [c1:Utf8]\"]", + " 8[shape=box label=\"Projection: aggregate_test_100.c1\\nSchema: [c1:Utf8View]\"]", " 7 -> 8 [arrowhead=none, arrowtail=normal, dir=back]", - " 9[shape=box label=\"Filter: aggregate_test_100.c2 > Int64(10)\\nSchema: [c1:Utf8, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8]\"]", + " 9[shape=box label=\"Filter: aggregate_test_100.c2 > Int64(10)\\nSchema: [c1:Utf8View, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8View]\"]", " 8 -> 9 [arrowhead=none, arrowtail=normal, dir=back]", - " 10[shape=box label=\"TableScan: aggregate_test_100\\nSchema: [c1:Utf8, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8]\"]", + " 10[shape=box label=\"TableScan: aggregate_test_100\\nSchema: [c1:Utf8View, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8View]\"]", " 9 -> 10 [arrowhead=none, arrowtail=normal, dir=back]", " }", "}", @@ -250,9 +250,9 @@ async fn csv_explain_plans() { // Verify schema let expected = vec![ "Explain [plan_type:Utf8, plan:Utf8]", - " Projection: aggregate_test_100.c1 [c1:Utf8]", - " Filter: aggregate_test_100.c2 > Int8(10) [c1:Utf8, c2:Int8]", - " TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[aggregate_test_100.c2 > Int8(10)] [c1:Utf8, c2:Int8]", + " Projection: aggregate_test_100.c1 [c1:Utf8View]", + " Filter: aggregate_test_100.c2 > Int8(10) [c1:Utf8View, c2:Int8]", + " TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[aggregate_test_100.c2 > Int8(10)] [c1:Utf8View, c2:Int8]", ]; let formatted = plan.display_indent_schema().to_string(); let actual: Vec<&str> = formatted.trim().lines().collect(); @@ -296,11 +296,11 @@ async fn csv_explain_plans() { " {", " graph[label=\"Detailed LogicalPlan\"]", " 7[shape=box label=\"Explain\\nSchema: [plan_type:Utf8, plan:Utf8]\"]", - " 8[shape=box label=\"Projection: aggregate_test_100.c1\\nSchema: [c1:Utf8]\"]", + " 8[shape=box label=\"Projection: aggregate_test_100.c1\\nSchema: [c1:Utf8View]\"]", " 7 -> 8 [arrowhead=none, arrowtail=normal, dir=back]", - " 9[shape=box label=\"Filter: aggregate_test_100.c2 > Int8(10)\\nSchema: [c1:Utf8, c2:Int8]\"]", + " 9[shape=box label=\"Filter: aggregate_test_100.c2 > Int8(10)\\nSchema: [c1:Utf8View, c2:Int8]\"]", " 8 -> 9 [arrowhead=none, arrowtail=normal, dir=back]", - " 10[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[aggregate_test_100.c2 > Int8(10)]\\nSchema: [c1:Utf8, c2:Int8]\"]", + " 10[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[aggregate_test_100.c2 > Int8(10)]\\nSchema: [c1:Utf8View, c2:Int8]\"]", " 9 -> 10 [arrowhead=none, arrowtail=normal, dir=back]", " }", "}", @@ -398,9 +398,9 @@ async fn csv_explain_verbose_plans() { // Verify schema let expected = vec![ "Explain [plan_type:Utf8, plan:Utf8]", - " Projection: aggregate_test_100.c1 [c1:Utf8]", - " Filter: aggregate_test_100.c2 > Int64(10) [c1:Utf8, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8]", - " TableScan: aggregate_test_100 [c1:Utf8, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8]", + " Projection: aggregate_test_100.c1 [c1:Utf8View]", + " Filter: aggregate_test_100.c2 > Int64(10) [c1:Utf8View, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8View]", + " TableScan: aggregate_test_100 [c1:Utf8View, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8View]", ]; let formatted = dataframe.logical_plan().display_indent_schema().to_string(); let actual: Vec<&str> = formatted.trim().lines().collect(); @@ -444,11 +444,11 @@ async fn csv_explain_verbose_plans() { " {", " graph[label=\"Detailed LogicalPlan\"]", " 7[shape=box label=\"Explain\\nSchema: [plan_type:Utf8, plan:Utf8]\"]", - " 8[shape=box label=\"Projection: aggregate_test_100.c1\\nSchema: [c1:Utf8]\"]", + " 8[shape=box label=\"Projection: aggregate_test_100.c1\\nSchema: [c1:Utf8View]\"]", " 7 -> 8 [arrowhead=none, arrowtail=normal, dir=back]", - " 9[shape=box label=\"Filter: aggregate_test_100.c2 > Int64(10)\\nSchema: [c1:Utf8, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8]\"]", + " 9[shape=box label=\"Filter: aggregate_test_100.c2 > Int64(10)\\nSchema: [c1:Utf8View, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8View]\"]", " 8 -> 9 [arrowhead=none, arrowtail=normal, dir=back]", - " 10[shape=box label=\"TableScan: aggregate_test_100\\nSchema: [c1:Utf8, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8]\"]", + " 10[shape=box label=\"TableScan: aggregate_test_100\\nSchema: [c1:Utf8View, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8View]\"]", " 9 -> 10 [arrowhead=none, arrowtail=normal, dir=back]", " }", "}", @@ -472,9 +472,9 @@ async fn csv_explain_verbose_plans() { // Verify schema let expected = vec![ "Explain [plan_type:Utf8, plan:Utf8]", - " Projection: aggregate_test_100.c1 [c1:Utf8]", - " Filter: aggregate_test_100.c2 > Int8(10) [c1:Utf8, c2:Int8]", - " TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[aggregate_test_100.c2 > Int8(10)] [c1:Utf8, c2:Int8]", + " Projection: aggregate_test_100.c1 [c1:Utf8View]", + " Filter: aggregate_test_100.c2 > Int8(10) [c1:Utf8View, c2:Int8]", + " TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[aggregate_test_100.c2 > Int8(10)] [c1:Utf8View, c2:Int8]", ]; let formatted = plan.display_indent_schema().to_string(); let actual: Vec<&str> = formatted.trim().lines().collect(); @@ -518,11 +518,11 @@ async fn csv_explain_verbose_plans() { " {", " graph[label=\"Detailed LogicalPlan\"]", " 7[shape=box label=\"Explain\\nSchema: [plan_type:Utf8, plan:Utf8]\"]", - " 8[shape=box label=\"Projection: aggregate_test_100.c1\\nSchema: [c1:Utf8]\"]", + " 8[shape=box label=\"Projection: aggregate_test_100.c1\\nSchema: [c1:Utf8View]\"]", " 7 -> 8 [arrowhead=none, arrowtail=normal, dir=back]", - " 9[shape=box label=\"Filter: aggregate_test_100.c2 > Int8(10)\\nSchema: [c1:Utf8, c2:Int8]\"]", + " 9[shape=box label=\"Filter: aggregate_test_100.c2 > Int8(10)\\nSchema: [c1:Utf8View, c2:Int8]\"]", " 8 -> 9 [arrowhead=none, arrowtail=normal, dir=back]", - " 10[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[aggregate_test_100.c2 > Int8(10)]\\nSchema: [c1:Utf8, c2:Int8]\"]", + " 10[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[aggregate_test_100.c2 > Int8(10)]\\nSchema: [c1:Utf8View, c2:Int8]\"]", " 9 -> 10 [arrowhead=none, arrowtail=normal, dir=back]", " }", "}", @@ -561,7 +561,9 @@ async fn csv_explain_verbose_plans() { async fn explain_analyze_runs_optimizers(#[values("*", "1")] count_expr: &str) { // repro for https://github.com/apache/datafusion/issues/917 // where EXPLAIN ANALYZE was not correctly running optimizer - let ctx = SessionContext::new(); + let ctx = SessionContext::new_with_config( + SessionConfig::new().with_collect_statistics(true), + ); register_alltypes_parquet(&ctx).await; // This happens as an optimization pass where count(*)/count(1) can be @@ -782,7 +784,7 @@ async fn explain_logical_plan_only() { vec!["logical_plan", "Projection: count(Int64(1)) AS count(*)\ \n Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]]\ \n SubqueryAlias: t\ - \n Projection: \ + \n Projection:\ \n Values: (Utf8(\"a\"), Int64(1), Int64(100)), (Utf8(\"a\"), Int64(2), Int64(150))"]]; assert_eq!(expected, actual); } diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs index 579049692e7d..2a5597b9fb7e 100644 --- a/datafusion/core/tests/sql/mod.rs +++ b/datafusion/core/tests/sql/mod.rs @@ -63,6 +63,7 @@ pub mod create_drop; pub mod explain_analyze; pub mod joins; mod path_partition; +mod runtime_config; pub mod select; mod sql_api; diff --git a/datafusion/core/tests/sql/path_partition.rs b/datafusion/core/tests/sql/path_partition.rs index 20326c5fa84a..3572bc70d4bc 100644 --- a/datafusion/core/tests/sql/path_partition.rs +++ b/datafusion/core/tests/sql/path_partition.rs @@ -26,8 +26,6 @@ use std::sync::Arc; use arrow::datatypes::DataType; use datafusion::assert_batches_sorted_eq; use datafusion::datasource::listing::ListingTableUrl; -use datafusion::datasource::physical_plan::ParquetSource; -use datafusion::datasource::source::DataSourceExec; use datafusion::{ datasource::{ file_format::{csv::CsvFormat, parquet::ParquetFormat}, @@ -41,12 +39,9 @@ use datafusion::{ use datafusion_catalog::TableProvider; use datafusion_common::stats::Precision; use datafusion_common::test_util::batches_to_sort_string; -use datafusion_common::ScalarValue; use datafusion_datasource::file_scan_config::FileScanConfig; use datafusion_datasource::metadata::MetadataColumn; use datafusion_execution::config::SessionConfig; -use datafusion_expr::{col, lit, Expr, Operator}; -use datafusion_physical_expr::expressions::{BinaryExpr, Column, Literal}; use async_trait::async_trait; use bytes::Bytes; @@ -59,6 +54,11 @@ use object_store::{ }; use object_store::{Attributes, MultipartUpload, PutMultipartOpts, PutPayload}; use url::Url; +use datafusion_common::ScalarValue; +use datafusion_datasource::source::DataSourceExec; +use datafusion_datasource_parquet::source::ParquetSource; +use datafusion_expr::{col, lit, Expr, Operator}; +use datafusion_physical_expr::expressions::{BinaryExpr, Column, Literal}; #[tokio::test] async fn parquet_partition_pruning_filter() -> Result<()> { @@ -488,7 +488,9 @@ async fn parquet_multiple_nonstring_partitions() -> Result<()> { #[tokio::test] async fn parquet_statistics() -> Result<()> { - let ctx = SessionContext::new(); + let mut config = SessionConfig::new(); + config.options_mut().execution.collect_statistics = true; + let ctx = SessionContext::new_with_config(config); register_partitioned_alltypes_parquet( &ctx, @@ -515,7 +517,7 @@ async fn parquet_statistics() -> Result<()> { let schema = physical_plan.schema(); assert_eq!(schema.fields().len(), 4); - let stat_cols = physical_plan.statistics()?.column_statistics; + let stat_cols = physical_plan.partition_statistics(None)?.column_statistics; assert_eq!(stat_cols.len(), 4); // stats for the first col are read from the parquet file assert_eq!(stat_cols[0].null_count, Precision::Exact(3)); @@ -530,7 +532,7 @@ async fn parquet_statistics() -> Result<()> { let schema = physical_plan.schema(); assert_eq!(schema.fields().len(), 2); - let stat_cols = physical_plan.statistics()?.column_statistics; + let stat_cols = physical_plan.partition_statistics(None)?.column_statistics; assert_eq!(stat_cols.len(), 2); // stats for the first col are read from the parquet file assert_eq!(stat_cols[0].null_count, Precision::Exact(1)); @@ -596,7 +598,7 @@ async fn test_metadata_columns() -> Result<()> { ctx.register_table("t", table).unwrap(); let result = ctx - .sql("SELECT id, size, location, last_modified FROM t WHERE size > 1500 ORDER BY id LIMIT 10") + .sql("SELECT id, size, location, last_modified FROM t WHERE size > 1500 ORDER BY id LIMIT 12") .await? .collect() .await?; @@ -614,9 +616,12 @@ async fn test_metadata_columns() -> Result<()> { "| 2 | 1851 | year=2021/month=09/day=09/file.parquet | 1970-01-01T00:00:00Z |", "| 2 | 1851 | year=2021/month=10/day=09/file.parquet | 1970-01-01T00:00:00Z |", "| 2 | 1851 | year=2021/month=10/day=28/file.parquet | 1970-01-01T00:00:00Z |", + "| 3 | 1851 | year=2021/month=09/day=09/file.parquet | 1970-01-01T00:00:00Z |", + "| 3 | 1851 | year=2021/month=10/day=09/file.parquet | 1970-01-01T00:00:00Z |", "| 3 | 1851 | year=2021/month=10/day=28/file.parquet | 1970-01-01T00:00:00Z |", "+----+------+----------------------------------------+----------------------+", ]; + assert_batches_sorted_eq!(expected, &result); Ok(()) @@ -780,7 +785,8 @@ async fn create_partitioned_alltypes_parquet_table( .map(|x| (x.0.to_owned(), x.1.clone())) .collect::>(), ) - .with_metadata_cols(metadata_cols.to_vec()); + .with_metadata_cols(metadata_cols.to_vec()) + .with_session_config_options(&ctx.copied_config()); let table_path = ListingTableUrl::parse(table_path).unwrap(); let store_path = diff --git a/datafusion/core/tests/sql/runtime_config.rs b/datafusion/core/tests/sql/runtime_config.rs new file mode 100644 index 000000000000..18e07bb61ed9 --- /dev/null +++ b/datafusion/core/tests/sql/runtime_config.rs @@ -0,0 +1,166 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Tests for runtime configuration SQL interface + +use std::sync::Arc; + +use datafusion::execution::context::SessionContext; +use datafusion::execution::context::TaskContext; +use datafusion_physical_plan::common::collect; + +#[tokio::test] +async fn test_memory_limit_with_spill() { + let ctx = SessionContext::new(); + + ctx.sql("SET datafusion.runtime.memory_limit = '1M'") + .await + .unwrap() + .collect() + .await + .unwrap(); + + ctx.sql("SET datafusion.execution.sort_spill_reservation_bytes = 0") + .await + .unwrap() + .collect() + .await + .unwrap(); + + let query = "select * from generate_series(1,10000000) as t1(v1) order by v1;"; + let df = ctx.sql(query).await.unwrap(); + + let plan = df.create_physical_plan().await.unwrap(); + let task_ctx = Arc::new(TaskContext::from(&ctx.state())); + let stream = plan.execute(0, task_ctx).unwrap(); + + let _results = collect(stream).await; + let metrics = plan.metrics().unwrap(); + let spill_count = metrics.spill_count().unwrap(); + assert!(spill_count > 0, "Expected spills but none occurred"); +} + +#[tokio::test] +async fn test_no_spill_with_adequate_memory() { + let ctx = SessionContext::new(); + + ctx.sql("SET datafusion.runtime.memory_limit = '10M'") + .await + .unwrap() + .collect() + .await + .unwrap(); + ctx.sql("SET datafusion.execution.sort_spill_reservation_bytes = 0") + .await + .unwrap() + .collect() + .await + .unwrap(); + + let query = "select * from generate_series(1,100000) as t1(v1) order by v1;"; + let df = ctx.sql(query).await.unwrap(); + + let plan = df.create_physical_plan().await.unwrap(); + let task_ctx = Arc::new(TaskContext::from(&ctx.state())); + let stream = plan.execute(0, task_ctx).unwrap(); + + let _results = collect(stream).await; + let metrics = plan.metrics().unwrap(); + let spill_count = metrics.spill_count().unwrap(); + assert_eq!(spill_count, 0, "Expected no spills but some occurred"); +} + +#[tokio::test] +async fn test_multiple_configs() { + let ctx = SessionContext::new(); + + ctx.sql("SET datafusion.runtime.memory_limit = '100M'") + .await + .unwrap() + .collect() + .await + .unwrap(); + ctx.sql("SET datafusion.execution.batch_size = '2048'") + .await + .unwrap() + .collect() + .await + .unwrap(); + + let query = "select * from generate_series(1,100000) as t1(v1) order by v1;"; + let result = ctx.sql(query).await.unwrap().collect().await; + + assert!(result.is_ok(), "Should not fail due to memory limit"); + + let state = ctx.state(); + let batch_size = state.config().options().execution.batch_size; + assert_eq!(batch_size, 2048); +} + +#[tokio::test] +async fn test_memory_limit_enforcement() { + let ctx = SessionContext::new(); + + ctx.sql("SET datafusion.runtime.memory_limit = '1M'") + .await + .unwrap() + .collect() + .await + .unwrap(); + + let query = "select * from generate_series(1,100000) as t1(v1) order by v1;"; + let result = ctx.sql(query).await.unwrap().collect().await; + + assert!(result.is_err(), "Should fail due to memory limit"); + + ctx.sql("SET datafusion.runtime.memory_limit = '100M'") + .await + .unwrap() + .collect() + .await + .unwrap(); + + let result = ctx.sql(query).await.unwrap().collect().await; + + assert!(result.is_ok(), "Should not fail due to memory limit"); +} + +#[tokio::test] +async fn test_invalid_memory_limit() { + let ctx = SessionContext::new(); + + let result = ctx + .sql("SET datafusion.runtime.memory_limit = '100X'") + .await; + + assert!(result.is_err()); + let error_message = result.unwrap_err().to_string(); + assert!(error_message.contains("Unsupported unit 'X'")); +} + +#[tokio::test] +async fn test_unknown_runtime_config() { + let ctx = SessionContext::new(); + + let result = ctx + .sql("SET datafusion.runtime.unknown_config = 'value'") + .await; + + assert!(result.is_err()); + let error_message = result.unwrap_err().to_string(); + assert!(error_message.contains("Unknown runtime configuration")); +} diff --git a/datafusion/core/tests/tpc-ds/49.sql b/datafusion/core/tests/tpc-ds/49.sql index 090e9746c0d8..219877719f22 100644 --- a/datafusion/core/tests/tpc-ds/49.sql +++ b/datafusion/core/tests/tpc-ds/49.sql @@ -110,7 +110,7 @@ select channel, item, return_ratio, return_rank, currency_rank from where sr.sr_return_amt > 10000 and sts.ss_net_profit > 1 - and sts.ss_net_paid > 0 + and sts.ss_net_paid > 0 and sts.ss_quantity > 0 and ss_sold_date_sk = d_date_sk and d_year = 2000 diff --git a/datafusion/core/tests/tracing/mod.rs b/datafusion/core/tests/tracing/mod.rs index 787dd9f4f3cb..df8a28c021d1 100644 --- a/datafusion/core/tests/tracing/mod.rs +++ b/datafusion/core/tests/tracing/mod.rs @@ -55,9 +55,9 @@ async fn test_tracer_injection() { let untraced_result = SpawnedTask::spawn(run_query()).join().await; if let Err(e) = untraced_result { // Check if the error message contains the expected error. - assert!(e.is_panic(), "Expected a panic, but got: {:?}", e); + assert!(e.is_panic(), "Expected a panic, but got: {e:?}"); assert_contains!(e.to_string(), "Task ID not found in spawn graph"); - info!("Caught expected panic: {}", e); + info!("Caught expected panic: {e}"); } else { panic!("Expected the task to panic, but it completed successfully"); }; @@ -94,7 +94,7 @@ async fn run_query() { ctx.register_object_store(&url, traceable_store.clone()); // Register a listing table from the test data directory. - let table_path = format!("test://{}/", test_data); + let table_path = format!("test://{test_data}/"); ctx.register_listing_table("alltypes", &table_path, listing_options, None, None) .await .expect("Failed to register table"); diff --git a/datafusion/core/tests/user_defined/expr_planner.rs b/datafusion/core/tests/user_defined/expr_planner.rs index 1fc6d14c5b22..07d289cab06c 100644 --- a/datafusion/core/tests/user_defined/expr_planner.rs +++ b/datafusion/core/tests/user_defined/expr_planner.rs @@ -56,7 +56,7 @@ impl ExprPlanner for MyCustomPlanner { } BinaryOperator::Question => { Ok(PlannerResult::Planned(Expr::Alias(Alias::new( - Expr::Literal(ScalarValue::Boolean(Some(true))), + Expr::Literal(ScalarValue::Boolean(Some(true)), None), None::<&str>, format!("{} ? {}", expr.left, expr.right), )))) diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index 5cbb05f290a7..aa5a72c0fb45 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -18,6 +18,8 @@ //! This module contains end to end demonstrations of creating //! user defined aggregate functions +use std::any::Any; +use std::collections::HashMap; use std::hash::{DefaultHasher, Hash, Hasher}; use std::mem::{size_of, size_of_val}; use std::sync::{ @@ -26,10 +28,11 @@ use std::sync::{ }; use arrow::array::{ - types::UInt64Type, AsArray, Int32Array, PrimitiveArray, StringArray, StructArray, + record_batch, types::UInt64Type, Array, AsArray, Int32Array, PrimitiveArray, + StringArray, StructArray, UInt64Array, }; use arrow::datatypes::{Fields, Schema}; - +use arrow_schema::FieldRef; use datafusion::common::test_util::batches_to_string; use datafusion::dataframe::DataFrame; use datafusion::datasource::MemTable; @@ -48,11 +51,12 @@ use datafusion::{ prelude::SessionContext, scalar::ScalarValue, }; -use datafusion_common::assert_contains; +use datafusion_common::{assert_contains, exec_datafusion_err}; use datafusion_common::{cast::as_primitive_array, exec_err}; +use datafusion_expr::expr::WindowFunction; use datafusion_expr::{ - col, create_udaf, function::AccumulatorArgs, AggregateUDFImpl, GroupsAccumulator, - LogicalPlanBuilder, SimpleAggregateUDF, + col, create_udaf, function::AccumulatorArgs, AggregateUDFImpl, Expr, + GroupsAccumulator, LogicalPlanBuilder, SimpleAggregateUDF, WindowFunctionDefinition, }; use datafusion_functions_aggregate::average::AvgAccumulator; @@ -569,7 +573,7 @@ impl TimeSum { // Returns the same type as its input let return_type = timestamp_type.clone(); - let state_fields = vec![Field::new("sum", timestamp_type, true)]; + let state_fields = vec![Field::new("sum", timestamp_type, true).into()]; let volatility = Volatility::Immutable; @@ -669,7 +673,7 @@ impl FirstSelector { let state_fields = state_type .into_iter() .enumerate() - .map(|(i, t)| Field::new(format!("{i}"), t, true)) + .map(|(i, t)| Field::new(format!("{i}"), t, true).into()) .collect::>(); // Possible input signatures @@ -781,7 +785,7 @@ struct TestGroupsAccumulator { } impl AggregateUDFImpl for TestGroupsAccumulator { - fn as_any(&self) -> &dyn std::any::Any { + fn as_any(&self) -> &dyn Any { self } @@ -890,3 +894,264 @@ impl GroupsAccumulator for TestGroupsAccumulator { size_of::() } } + +#[derive(Debug)] +struct MetadataBasedAggregateUdf { + name: String, + signature: Signature, + metadata: HashMap, +} + +impl MetadataBasedAggregateUdf { + fn new(metadata: HashMap) -> Self { + // The name we return must be unique. Otherwise we will not call distinct + // instances of this UDF. This is a small hack for the unit tests to get unique + // names, but you could do something more elegant with the metadata. + let name = format!("metadata_based_udf_{}", metadata.len()); + Self { + name, + signature: Signature::exact(vec![DataType::UInt64], Volatility::Immutable), + metadata, + } + } +} + +impl AggregateUDFImpl for MetadataBasedAggregateUdf { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + unimplemented!("this should never be called since return_field is implemented"); + } + + fn return_field(&self, _arg_fields: &[FieldRef]) -> Result { + Ok(Field::new(self.name(), DataType::UInt64, true) + .with_metadata(self.metadata.clone()) + .into()) + } + + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + let input_expr = acc_args + .exprs + .first() + .ok_or(exec_datafusion_err!("Expected one argument"))?; + let input_field = input_expr.return_field(acc_args.schema)?; + + let double_output = input_field + .metadata() + .get("modify_values") + .map(|v| v == "double_output") + .unwrap_or(false); + + Ok(Box::new(MetadataBasedAccumulator { + double_output, + curr_sum: 0, + })) + } +} + +#[derive(Debug)] +struct MetadataBasedAccumulator { + double_output: bool, + curr_sum: u64, +} + +impl Accumulator for MetadataBasedAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let arr = values[0] + .as_any() + .downcast_ref::() + .ok_or(exec_datafusion_err!("Expected UInt64Array"))?; + + self.curr_sum = arr.iter().fold(self.curr_sum, |a, b| a + b.unwrap_or(0)); + + Ok(()) + } + + fn evaluate(&mut self) -> Result { + let v = match self.double_output { + true => self.curr_sum * 2, + false => self.curr_sum, + }; + + Ok(ScalarValue::from(v)) + } + + fn size(&self) -> usize { + 9 + } + + fn state(&mut self) -> Result> { + Ok(vec![ScalarValue::from(self.curr_sum)]) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.update_batch(states) + } +} + +#[tokio::test] +async fn test_metadata_based_aggregate() -> Result<()> { + let data_array = Arc::new(UInt64Array::from(vec![0, 5, 10, 15, 20])) as ArrayRef; + let schema = Arc::new(Schema::new(vec![ + Field::new("no_metadata", DataType::UInt64, true), + Field::new("with_metadata", DataType::UInt64, true).with_metadata( + [("modify_values".to_string(), "double_output".to_string())] + .into_iter() + .collect(), + ), + ])); + + let batch = RecordBatch::try_new( + schema, + vec![Arc::clone(&data_array), Arc::clone(&data_array)], + )?; + + let ctx = SessionContext::new(); + ctx.register_batch("t", batch)?; + let df = ctx.table("t").await?; + + let no_output_meta_udf = + AggregateUDF::from(MetadataBasedAggregateUdf::new(HashMap::new())); + let with_output_meta_udf = AggregateUDF::from(MetadataBasedAggregateUdf::new( + [("output_metatype".to_string(), "custom_value".to_string())] + .into_iter() + .collect(), + )); + + let df = df.aggregate( + vec![], + vec![ + no_output_meta_udf + .call(vec![col("no_metadata")]) + .alias("meta_no_in_no_out"), + no_output_meta_udf + .call(vec![col("with_metadata")]) + .alias("meta_with_in_no_out"), + with_output_meta_udf + .call(vec![col("no_metadata")]) + .alias("meta_no_in_with_out"), + with_output_meta_udf + .call(vec![col("with_metadata")]) + .alias("meta_with_in_with_out"), + ], + )?; + + let actual = df.collect().await?; + + // To test for output metadata handling, we set the expected values on the result + // To test for input metadata handling, we check the numbers returned + let mut output_meta = HashMap::new(); + let _ = output_meta.insert("output_metatype".to_string(), "custom_value".to_string()); + let expected_schema = Schema::new(vec![ + Field::new("meta_no_in_no_out", DataType::UInt64, true), + Field::new("meta_with_in_no_out", DataType::UInt64, true), + Field::new("meta_no_in_with_out", DataType::UInt64, true) + .with_metadata(output_meta.clone()), + Field::new("meta_with_in_with_out", DataType::UInt64, true) + .with_metadata(output_meta.clone()), + ]); + + let expected = record_batch!( + ("meta_no_in_no_out", UInt64, [50]), + ("meta_with_in_no_out", UInt64, [100]), + ("meta_no_in_with_out", UInt64, [50]), + ("meta_with_in_with_out", UInt64, [100]) + )? + .with_schema(Arc::new(expected_schema))?; + + assert_eq!(expected, actual[0]); + + Ok(()) +} + +#[tokio::test] +async fn test_metadata_based_aggregate_as_window() -> Result<()> { + let data_array = Arc::new(UInt64Array::from(vec![0, 5, 10, 15, 20])) as ArrayRef; + let schema = Arc::new(Schema::new(vec![ + Field::new("no_metadata", DataType::UInt64, true), + Field::new("with_metadata", DataType::UInt64, true).with_metadata( + [("modify_values".to_string(), "double_output".to_string())] + .into_iter() + .collect(), + ), + ])); + + let batch = RecordBatch::try_new( + schema, + vec![Arc::clone(&data_array), Arc::clone(&data_array)], + )?; + + let ctx = SessionContext::new(); + ctx.register_batch("t", batch)?; + let df = ctx.table("t").await?; + + let no_output_meta_udf = Arc::new(AggregateUDF::from( + MetadataBasedAggregateUdf::new(HashMap::new()), + )); + let with_output_meta_udf = + Arc::new(AggregateUDF::from(MetadataBasedAggregateUdf::new( + [("output_metatype".to_string(), "custom_value".to_string())] + .into_iter() + .collect(), + ))); + + let df = df.select(vec![ + Expr::from(WindowFunction::new( + WindowFunctionDefinition::AggregateUDF(Arc::clone(&no_output_meta_udf)), + vec![col("no_metadata")], + )) + .alias("meta_no_in_no_out"), + Expr::from(WindowFunction::new( + WindowFunctionDefinition::AggregateUDF(no_output_meta_udf), + vec![col("with_metadata")], + )) + .alias("meta_with_in_no_out"), + Expr::from(WindowFunction::new( + WindowFunctionDefinition::AggregateUDF(Arc::clone(&with_output_meta_udf)), + vec![col("no_metadata")], + )) + .alias("meta_no_in_with_out"), + Expr::from(WindowFunction::new( + WindowFunctionDefinition::AggregateUDF(with_output_meta_udf), + vec![col("with_metadata")], + )) + .alias("meta_with_in_with_out"), + ])?; + + let actual = df.collect().await?; + + // To test for output metadata handling, we set the expected values on the result + // To test for input metadata handling, we check the numbers returned + let mut output_meta = HashMap::new(); + let _ = output_meta.insert("output_metatype".to_string(), "custom_value".to_string()); + let expected_schema = Schema::new(vec![ + Field::new("meta_no_in_no_out", DataType::UInt64, true), + Field::new("meta_with_in_no_out", DataType::UInt64, true), + Field::new("meta_no_in_with_out", DataType::UInt64, true) + .with_metadata(output_meta.clone()), + Field::new("meta_with_in_with_out", DataType::UInt64, true) + .with_metadata(output_meta.clone()), + ]); + + let expected = record_batch!( + ("meta_no_in_no_out", UInt64, [50, 50, 50, 50, 50]), + ("meta_with_in_no_out", UInt64, [100, 100, 100, 100, 100]), + ("meta_no_in_with_out", UInt64, [50, 50, 50, 50, 50]), + ("meta_with_in_with_out", UInt64, [100, 100, 100, 100, 100]) + )? + .with_schema(Arc::new(expected_schema))?; + + assert_eq!(expected, actual[0]); + + Ok(()) +} diff --git a/datafusion/core/tests/user_defined/user_defined_plan.rs b/datafusion/core/tests/user_defined/user_defined_plan.rs index e46940e63154..4d3916c1760e 100644 --- a/datafusion/core/tests/user_defined/user_defined_plan.rs +++ b/datafusion/core/tests/user_defined/user_defined_plan.rs @@ -63,15 +63,14 @@ use std::hash::Hash; use std::task::{Context, Poll}; use std::{any::Any, collections::BTreeMap, fmt, sync::Arc}; +use arrow::array::{Array, ArrayRef, StringViewArray}; use arrow::{ - array::{Int64Array, StringArray}, - datatypes::SchemaRef, - record_batch::RecordBatch, + array::Int64Array, datatypes::SchemaRef, record_batch::RecordBatch, util::pretty::pretty_format_batches, }; use datafusion::execution::session_state::SessionStateBuilder; use datafusion::{ - common::cast::{as_int64_array, as_string_array}, + common::cast::as_int64_array, common::{arrow_datafusion_err, internal_err, DFSchemaRef}, error::{DataFusionError, Result}, execution::{ @@ -100,6 +99,7 @@ use datafusion_optimizer::AnalyzerRule; use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; use async_trait::async_trait; +use datafusion_common::cast::as_string_view_array; use futures::{Stream, StreamExt}; /// Execute the specified sql and return the resulting record batches @@ -796,22 +796,26 @@ fn accumulate_batch( k: &usize, ) -> BTreeMap { let num_rows = input_batch.num_rows(); + // Assuming the input columns are - // column[0]: customer_id / UTF8 + // column[0]: customer_id UTF8View // column[1]: revenue: Int64 - let customer_id = - as_string_array(input_batch.column(0)).expect("Column 0 is not customer_id"); + let customer_id_column = input_batch.column(0); let revenue = as_int64_array(input_batch.column(1)).unwrap(); for row in 0..num_rows { - add_row( - &mut top_values, - customer_id.value(row), - revenue.value(row), - k, - ); + let customer_id = match customer_id_column.data_type() { + arrow::datatypes::DataType::Utf8View => { + let array = as_string_view_array(customer_id_column).unwrap(); + array.value(row) + } + _ => panic!("Unsupported customer_id type"), + }; + + add_row(&mut top_values, customer_id, revenue.value(row), k); } + top_values } @@ -843,11 +847,19 @@ impl Stream for TopKReader { self.state.iter().rev().unzip(); let customer: Vec<&str> = customer.iter().map(|&s| &**s).collect(); + + let customer_array: ArrayRef = match schema.field(0).data_type() { + arrow::datatypes::DataType::Utf8View => { + Arc::new(StringViewArray::from(customer)) + } + other => panic!("Unsupported customer_id output type: {other:?}"), + }; + Poll::Ready(Some( RecordBatch::try_new( schema, vec![ - Arc::new(StringArray::from(customer)), + Arc::new(customer_array), Arc::new(Int64Array::from(revenue)), ], ) @@ -900,11 +912,12 @@ impl MyAnalyzerRule { .map(|e| { e.transform(|e| { Ok(match e { - Expr::Literal(ScalarValue::Int64(i)) => { + Expr::Literal(ScalarValue::Int64(i), _) => { // transform to UInt64 - Transformed::yes(Expr::Literal(ScalarValue::UInt64( - i.map(|i| i as u64), - ))) + Transformed::yes(Expr::Literal( + ScalarValue::UInt64(i.map(|i| i as u64)), + None, + )) } _ => Transformed::no(e), }) diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index 264bd6b66a60..3e8fafc7a636 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -16,16 +16,19 @@ // under the License. use std::any::Any; +use std::collections::HashMap; use std::hash::{DefaultHasher, Hash, Hasher}; use std::sync::Arc; -use arrow::array::as_string_array; +use arrow::array::{as_string_array, create_array, record_batch, Int8Array, UInt64Array}; use arrow::array::{ builder::BooleanBuilder, cast::AsArray, Array, ArrayRef, Float32Array, Float64Array, Int32Array, RecordBatch, StringArray, }; use arrow::compute::kernels::numeric::add; use arrow::datatypes::{DataType, Field, Schema}; +use arrow_schema::extension::{Bool8, CanonicalExtensionType, ExtensionType}; +use arrow_schema::{ArrowError, FieldRef}; use datafusion::common::test_util::batches_to_string; use datafusion::execution::context::{FunctionFactory, RegisterFunction, SessionState}; use datafusion::prelude::*; @@ -35,13 +38,13 @@ use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::utils::take_function_args; use datafusion_common::{ assert_batches_eq, assert_batches_sorted_eq, assert_contains, exec_err, not_impl_err, - plan_err, DFSchema, DataFusionError, HashMap, Result, ScalarValue, + plan_err, DFSchema, DataFusionError, Result, ScalarValue, }; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::{ - Accumulator, ColumnarValue, CreateFunction, CreateFunctionBody, LogicalPlanBuilder, - OperateFunctionArg, ReturnInfo, ReturnTypeArgs, ScalarFunctionArgs, ScalarUDF, - ScalarUDFImpl, Signature, Volatility, + lit_with_metadata, Accumulator, ColumnarValue, CreateFunction, CreateFunctionBody, + LogicalPlanBuilder, OperateFunctionArg, ReturnFieldArgs, ScalarFunctionArgs, + ScalarUDF, ScalarUDFImpl, Signature, Volatility, }; use datafusion_functions_nested::range::range_udf; use parking_lot::Mutex; @@ -57,7 +60,7 @@ async fn csv_query_custom_udf_with_cast() -> Result<()> { let ctx = create_udf_context(); register_aggregate_csv(&ctx).await?; let sql = "SELECT avg(custom_sqrt(c11)) FROM aggregate_test_100"; - let actual = plan_and_collect(&ctx, sql).await.unwrap(); + let actual = plan_and_collect(&ctx, sql).await?; insta::assert_snapshot!(batches_to_string(&actual), @r###" +------------------------------------------+ @@ -76,7 +79,7 @@ async fn csv_query_avg_sqrt() -> Result<()> { register_aggregate_csv(&ctx).await?; // Note it is a different column (c12) than above (c11) let sql = "SELECT avg(custom_sqrt(c12)) FROM aggregate_test_100"; - let actual = plan_and_collect(&ctx, sql).await.unwrap(); + let actual = plan_and_collect(&ctx, sql).await?; insta::assert_snapshot!(batches_to_string(&actual), @r###" +------------------------------------------+ @@ -389,7 +392,7 @@ async fn udaf_as_window_func() -> Result<()> { WindowAggr: windowExpr=[[my_acc(my_table.b) PARTITION BY [my_table.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] TableScan: my_table"#; - let dataframe = context.sql(sql).await.unwrap(); + let dataframe = context.sql(sql).await?; assert_eq!(format!("{}", dataframe.logical_plan()), expected); Ok(()) } @@ -399,7 +402,7 @@ async fn case_sensitive_identifiers_user_defined_functions() -> Result<()> { let ctx = SessionContext::new(); let arr = Int32Array::from(vec![1]); let batch = RecordBatch::try_from_iter(vec![("i", Arc::new(arr) as _)])?; - ctx.register_batch("t", batch).unwrap(); + ctx.register_batch("t", batch)?; let myfunc = Arc::new(|args: &[ColumnarValue]| { let ColumnarValue::Array(array) = &args[0] else { @@ -443,7 +446,7 @@ async fn test_user_defined_functions_with_alias() -> Result<()> { let ctx = SessionContext::new(); let arr = Int32Array::from(vec![1]); let batch = RecordBatch::try_from_iter(vec![("i", Arc::new(arr) as _)])?; - ctx.register_batch("t", batch).unwrap(); + ctx.register_batch("t", batch)?; let myfunc = Arc::new(|args: &[ColumnarValue]| { let ColumnarValue::Array(array) = &args[0] else { @@ -803,7 +806,7 @@ impl ScalarUDFImpl for TakeUDF { &self.signature } fn return_type(&self, _args: &[DataType]) -> Result { - not_impl_err!("Not called because the return_type_from_args is implemented") + not_impl_err!("Not called because the return_field_from_args is implemented") } /// This function returns the type of the first or second argument based on @@ -811,9 +814,9 @@ impl ScalarUDFImpl for TakeUDF { /// /// 1. If the third argument is '0', return the type of the first argument /// 2. If the third argument is '1', return the type of the second argument - fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { - if args.arg_types.len() != 3 { - return plan_err!("Expected 3 arguments, got {}.", args.arg_types.len()); + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + if args.arg_fields.len() != 3 { + return plan_err!("Expected 3 arguments, got {}.", args.arg_fields.len()); } let take_idx = if let Some(take_idx) = args.scalar_arguments.get(2) { @@ -838,9 +841,12 @@ impl ScalarUDFImpl for TakeUDF { ); }; - Ok(ReturnInfo::new_nullable( - args.arg_types[take_idx].to_owned(), - )) + Ok(Field::new( + self.name(), + args.arg_fields[take_idx].data_type().to_owned(), + true, + ) + .into()) } // The actual implementation @@ -1004,8 +1010,7 @@ impl ScalarFunctionWrapper { if let Some(value) = placeholder.strip_prefix('$') { Ok(value.parse().map(|v: usize| v - 1).map_err(|e| { DataFusionError::Execution(format!( - "Placeholder `{}` parsing error: {}!", - placeholder, e + "Placeholder `{placeholder}` parsing error: {e}!" )) })?) } else { @@ -1160,7 +1165,7 @@ async fn create_scalar_function_from_sql_statement_postgres_syntax() -> Result<( match ctx.sql(sql).await { Ok(_) => {} Err(e) => { - panic!("Error creating function: {}", e); + panic!("Error creating function: {e}"); } } @@ -1367,3 +1372,403 @@ async fn register_alltypes_parquet(ctx: &SessionContext) -> Result<()> { async fn plan_and_collect(ctx: &SessionContext, sql: &str) -> Result> { ctx.sql(sql).await?.collect().await } + +#[derive(Debug)] +struct MetadataBasedUdf { + name: String, + signature: Signature, + metadata: HashMap, +} + +impl MetadataBasedUdf { + fn new(metadata: HashMap) -> Self { + // The name we return must be unique. Otherwise we will not call distinct + // instances of this UDF. This is a small hack for the unit tests to get unique + // names, but you could do something more elegant with the metadata. + let name = format!("metadata_based_udf_{}", metadata.len()); + Self { + name, + signature: Signature::exact(vec![DataType::UInt64], Volatility::Immutable), + metadata, + } + } +} + +impl ScalarUDFImpl for MetadataBasedUdf { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _args: &[DataType]) -> Result { + unimplemented!( + "this should never be called since return_field_from_args is implemented" + ); + } + + fn return_field_from_args(&self, _args: ReturnFieldArgs) -> Result { + Ok(Field::new(self.name(), DataType::UInt64, true) + .with_metadata(self.metadata.clone()) + .into()) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + assert_eq!(args.arg_fields.len(), 1); + let should_double = args.arg_fields[0] + .metadata() + .get("modify_values") + .map(|v| v == "double_output") + .unwrap_or(false); + let mulitplier = if should_double { 2 } else { 1 }; + + match &args.args[0] { + ColumnarValue::Array(array) => { + let array_values: Vec<_> = array + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .map(|v| v.map(|x| x * mulitplier)) + .collect(); + let array_ref = Arc::new(UInt64Array::from(array_values)) as ArrayRef; + Ok(ColumnarValue::Array(array_ref)) + } + ColumnarValue::Scalar(value) => { + let ScalarValue::UInt64(value) = value else { + return exec_err!("incorrect data type"); + }; + + Ok(ColumnarValue::Scalar(ScalarValue::UInt64( + value.map(|v| v * mulitplier), + ))) + } + } + } + + fn equals(&self, other: &dyn ScalarUDFImpl) -> bool { + self.name == other.name() + } +} + +#[tokio::test] +async fn test_metadata_based_udf() -> Result<()> { + let data_array = Arc::new(UInt64Array::from(vec![0, 5, 10, 15, 20])) as ArrayRef; + let schema = Arc::new(Schema::new(vec![ + Field::new("no_metadata", DataType::UInt64, true), + Field::new("with_metadata", DataType::UInt64, true).with_metadata( + [("modify_values".to_string(), "double_output".to_string())] + .into_iter() + .collect(), + ), + ])); + let batch = RecordBatch::try_new( + schema, + vec![Arc::clone(&data_array), Arc::clone(&data_array)], + )?; + + let ctx = SessionContext::new(); + ctx.register_batch("t", batch)?; + let t = ctx.table("t").await?; + let no_output_meta_udf = ScalarUDF::from(MetadataBasedUdf::new(HashMap::new())); + let with_output_meta_udf = ScalarUDF::from(MetadataBasedUdf::new( + [("output_metatype".to_string(), "custom_value".to_string())] + .into_iter() + .collect(), + )); + + let plan = LogicalPlanBuilder::from(t.into_optimized_plan()?) + .project(vec![ + no_output_meta_udf + .call(vec![col("no_metadata")]) + .alias("meta_no_in_no_out"), + no_output_meta_udf + .call(vec![col("with_metadata")]) + .alias("meta_with_in_no_out"), + with_output_meta_udf + .call(vec![col("no_metadata")]) + .alias("meta_no_in_with_out"), + with_output_meta_udf + .call(vec![col("with_metadata")]) + .alias("meta_with_in_with_out"), + ])? + .build()?; + + let actual = DataFrame::new(ctx.state(), plan).collect().await?; + + // To test for output metadata handling, we set the expected values on the result + // To test for input metadata handling, we check the numbers returned + let mut output_meta = HashMap::new(); + let _ = output_meta.insert("output_metatype".to_string(), "custom_value".to_string()); + let expected_schema = Schema::new(vec![ + Field::new("meta_no_in_no_out", DataType::UInt64, true), + Field::new("meta_with_in_no_out", DataType::UInt64, true), + Field::new("meta_no_in_with_out", DataType::UInt64, true) + .with_metadata(output_meta.clone()), + Field::new("meta_with_in_with_out", DataType::UInt64, true) + .with_metadata(output_meta.clone()), + ]); + + let expected = record_batch!( + ("meta_no_in_no_out", UInt64, [0, 5, 10, 15, 20]), + ("meta_with_in_no_out", UInt64, [0, 10, 20, 30, 40]), + ("meta_no_in_with_out", UInt64, [0, 5, 10, 15, 20]), + ("meta_with_in_with_out", UInt64, [0, 10, 20, 30, 40]) + )? + .with_schema(Arc::new(expected_schema))?; + + assert_eq!(expected, actual[0]); + + ctx.deregister_table("t")?; + Ok(()) +} + +#[tokio::test] +async fn test_metadata_based_udf_with_literal() -> Result<()> { + let ctx = SessionContext::new(); + let input_metadata: HashMap = + [("modify_values".to_string(), "double_output".to_string())] + .into_iter() + .collect(); + let df = ctx.sql("select 0;").await?.select(vec![ + lit(5u64).alias_with_metadata("lit_with_doubling", Some(input_metadata.clone())), + lit(5u64).alias("lit_no_doubling"), + lit_with_metadata(5u64, Some(input_metadata)) + .alias("lit_with_double_no_alias_metadata"), + ])?; + + let output_metadata: HashMap = + [("output_metatype".to_string(), "custom_value".to_string())] + .into_iter() + .collect(); + let custom_udf = ScalarUDF::from(MetadataBasedUdf::new(output_metadata.clone())); + + let plan = LogicalPlanBuilder::from(df.into_optimized_plan()?) + .project(vec![ + custom_udf + .call(vec![col("lit_with_doubling")]) + .alias("doubled_output"), + custom_udf + .call(vec![col("lit_no_doubling")]) + .alias("not_doubled_output"), + custom_udf + .call(vec![col("lit_with_double_no_alias_metadata")]) + .alias("double_without_alias_metadata"), + ])? + .build()?; + + let actual = DataFrame::new(ctx.state(), plan).collect().await?; + + let schema = Arc::new(Schema::new(vec![ + Field::new("doubled_output", DataType::UInt64, false) + .with_metadata(output_metadata.clone()), + Field::new("not_doubled_output", DataType::UInt64, false) + .with_metadata(output_metadata.clone()), + Field::new("double_without_alias_metadata", DataType::UInt64, false) + .with_metadata(output_metadata.clone()), + ])); + + let expected = RecordBatch::try_new( + schema, + vec![ + create_array!(UInt64, [10]), + create_array!(UInt64, [5]), + create_array!(UInt64, [10]), + ], + )?; + + assert_eq!(expected, actual[0]); + + Ok(()) +} + +/// This UDF is to test extension handling, both on the input and output +/// sides. For the input, we will handle the data differently if there is +/// the canonical extension type Bool8. For the output we will add a +/// user defined extension type. +#[derive(Debug)] +struct ExtensionBasedUdf { + name: String, + signature: Signature, +} + +impl Default for ExtensionBasedUdf { + fn default() -> Self { + Self { + name: "canonical_extension_udf".to_string(), + signature: Signature::exact(vec![DataType::Int8], Volatility::Immutable), + } + } +} +impl ScalarUDFImpl for ExtensionBasedUdf { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _args: &[DataType]) -> Result { + Ok(DataType::Utf8) + } + + fn return_field_from_args(&self, _args: ReturnFieldArgs) -> Result { + Ok(Field::new("canonical_extension_udf", DataType::Utf8, true) + .with_extension_type(MyUserExtentionType {}) + .into()) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + assert_eq!(args.arg_fields.len(), 1); + let input_field = args.arg_fields[0].as_ref(); + + let output_as_bool = matches!( + CanonicalExtensionType::try_from(input_field), + Ok(CanonicalExtensionType::Bool8(_)) + ); + + // If we have the extension type set, we are outputting a boolean value. + // Otherwise we output a string representation of the numeric value. + fn print_value(v: Option, as_bool: bool) -> Option { + v.map(|x| match as_bool { + true => format!("{}", x != 0), + false => format!("{x}"), + }) + } + + match &args.args[0] { + ColumnarValue::Array(array) => { + let array_values: Vec<_> = array + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .map(|v| print_value(v, output_as_bool)) + .collect(); + let array_ref = Arc::new(StringArray::from(array_values)) as ArrayRef; + Ok(ColumnarValue::Array(array_ref)) + } + ColumnarValue::Scalar(value) => { + let ScalarValue::Int8(value) = value else { + return exec_err!("incorrect data type"); + }; + + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(print_value( + *value, + output_as_bool, + )))) + } + } + } + + fn equals(&self, other: &dyn ScalarUDFImpl) -> bool { + self.name == other.name() + } +} + +struct MyUserExtentionType {} + +impl ExtensionType for MyUserExtentionType { + const NAME: &'static str = "my_user_extention_type"; + type Metadata = (); + + fn metadata(&self) -> &Self::Metadata { + &() + } + + fn serialize_metadata(&self) -> Option { + None + } + + fn deserialize_metadata( + _metadata: Option<&str>, + ) -> std::result::Result { + Ok(()) + } + + fn supports_data_type( + &self, + data_type: &DataType, + ) -> std::result::Result<(), ArrowError> { + if let DataType::Utf8 = data_type { + Ok(()) + } else { + Err(ArrowError::InvalidArgumentError( + "only utf8 supported".to_string(), + )) + } + } + + fn try_new( + _data_type: &DataType, + _metadata: Self::Metadata, + ) -> std::result::Result { + Ok(Self {}) + } +} + +#[tokio::test] +async fn test_extension_based_udf() -> Result<()> { + let data_array = Arc::new(Int8Array::from(vec![0, 0, 10, 20])) as ArrayRef; + let schema = Arc::new(Schema::new(vec![ + Field::new("no_extension", DataType::Int8, true), + Field::new("with_extension", DataType::Int8, true).with_extension_type(Bool8), + ])); + let batch = RecordBatch::try_new( + schema, + vec![Arc::clone(&data_array), Arc::clone(&data_array)], + )?; + + let ctx = SessionContext::new(); + ctx.register_batch("t", batch)?; + let t = ctx.table("t").await?; + let extension_based_udf = ScalarUDF::from(ExtensionBasedUdf::default()); + + let plan = LogicalPlanBuilder::from(t.into_optimized_plan()?) + .project(vec![ + extension_based_udf + .call(vec![col("no_extension")]) + .alias("without_bool8_extension"), + extension_based_udf + .call(vec![col("with_extension")]) + .alias("with_bool8_extension"), + ])? + .build()?; + + let actual = DataFrame::new(ctx.state(), plan).collect().await?; + + // To test for output extension handling, we set the expected values on the result + // To test for input extensions handling, we check the strings returned + let expected_schema = Schema::new(vec![ + Field::new("without_bool8_extension", DataType::Utf8, true) + .with_extension_type(MyUserExtentionType {}), + Field::new("with_bool8_extension", DataType::Utf8, true) + .with_extension_type(MyUserExtentionType {}), + ]); + + let expected = record_batch!( + ("without_bool8_extension", Utf8, ["0", "0", "10", "20"]), + ( + "with_bool8_extension", + Utf8, + ["false", "false", "true", "true"] + ) + )? + .with_schema(Arc::new(expected_schema))?; + + assert_eq!(expected, actual[0]); + + ctx.deregister_table("t")?; + Ok(()) +} diff --git a/datafusion/core/tests/user_defined/user_defined_table_functions.rs b/datafusion/core/tests/user_defined/user_defined_table_functions.rs index e4aff0b00705..2c6611f382ce 100644 --- a/datafusion/core/tests/user_defined/user_defined_table_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_table_functions.rs @@ -205,7 +205,7 @@ impl TableFunctionImpl for SimpleCsvTableFunc { let mut filepath = String::new(); for expr in exprs { match expr { - Expr::Literal(ScalarValue::Utf8(Some(ref path))) => { + Expr::Literal(ScalarValue::Utf8(Some(ref path)), _) => { filepath.clone_from(path); } expr => new_exprs.push(expr.clone()), diff --git a/datafusion/core/tests/user_defined/user_defined_window_functions.rs b/datafusion/core/tests/user_defined/user_defined_window_functions.rs index 7c56507acd45..bcd2c3945e39 100644 --- a/datafusion/core/tests/user_defined/user_defined_window_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_window_functions.rs @@ -18,11 +18,16 @@ //! This module contains end to end tests of creating //! user defined window functions -use arrow::array::{ArrayRef, AsArray, Int64Array, RecordBatch, StringArray}; +use arrow::array::{ + record_batch, Array, ArrayRef, AsArray, Int64Array, RecordBatch, StringArray, + UInt64Array, +}; use arrow::datatypes::{DataType, Field, Schema}; +use arrow_schema::FieldRef; use datafusion::common::test_util::batches_to_string; use datafusion::common::{Result, ScalarValue}; use datafusion::prelude::SessionContext; +use datafusion_common::exec_datafusion_err; use datafusion_expr::{ PartitionEvaluator, Signature, TypeSignature, Volatility, WindowUDF, WindowUDFImpl, }; @@ -34,6 +39,7 @@ use datafusion_physical_expr::{ expressions::{col, lit}, PhysicalExpr, }; +use std::collections::HashMap; use std::{ any::Any, ops::Range, @@ -559,8 +565,8 @@ impl OddCounter { &self.aliases } - fn field(&self, field_args: WindowUDFFieldArgs) -> Result { - Ok(Field::new(field_args.name(), DataType::Int64, true)) + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + Ok(Field::new(field_args.name(), DataType::Int64, true).into()) } } @@ -678,7 +684,7 @@ impl WindowUDFImpl for VariadicWindowUDF { unimplemented!("unnecessary for testing"); } - fn field(&self, _: WindowUDFFieldArgs) -> Result { + fn field(&self, _: WindowUDFFieldArgs) -> Result { unimplemented!("unnecessary for testing"); } } @@ -723,11 +729,11 @@ fn test_default_expressions() -> Result<()> { ]; for input_exprs in &test_cases { - let input_types = input_exprs + let input_fields = input_exprs .iter() - .map(|expr: &Arc| expr.data_type(&schema).unwrap()) + .map(|expr: &Arc| expr.return_field(&schema).unwrap()) .collect::>(); - let expr_args = ExpressionArgs::new(input_exprs, &input_types); + let expr_args = ExpressionArgs::new(input_exprs, &input_fields); let ret_exprs = udwf.expressions(expr_args); @@ -735,9 +741,7 @@ fn test_default_expressions() -> Result<()> { assert_eq!( input_exprs.len(), ret_exprs.len(), - "\nInput expressions: {:?}\nReturned expressions: {:?}", - input_exprs, - ret_exprs + "\nInput expressions: {input_exprs:?}\nReturned expressions: {ret_exprs:?}" ); // Compares each returned expression with original input expressions @@ -753,3 +757,149 @@ fn test_default_expressions() -> Result<()> { } Ok(()) } + +#[derive(Debug)] +struct MetadataBasedWindowUdf { + name: String, + signature: Signature, + metadata: HashMap, +} + +impl MetadataBasedWindowUdf { + fn new(metadata: HashMap) -> Self { + // The name we return must be unique. Otherwise we will not call distinct + // instances of this UDF. This is a small hack for the unit tests to get unique + // names, but you could do something more elegant with the metadata. + let name = format!("metadata_based_udf_{}", metadata.len()); + Self { + name, + signature: Signature::exact(vec![DataType::UInt64], Volatility::Immutable), + metadata, + } + } +} + +impl WindowUDFImpl for MetadataBasedWindowUdf { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn partition_evaluator( + &self, + partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result> { + let input_field = partition_evaluator_args + .input_fields() + .first() + .ok_or(exec_datafusion_err!("Expected one argument"))?; + + let double_output = input_field + .metadata() + .get("modify_values") + .map(|v| v == "double_output") + .unwrap_or(false); + + Ok(Box::new(MetadataBasedPartitionEvaluator { double_output })) + } + + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + Ok(Field::new(field_args.name(), DataType::UInt64, true) + .with_metadata(self.metadata.clone()) + .into()) + } +} + +#[derive(Debug)] +struct MetadataBasedPartitionEvaluator { + double_output: bool, +} + +impl PartitionEvaluator for MetadataBasedPartitionEvaluator { + fn evaluate_all(&mut self, values: &[ArrayRef], num_rows: usize) -> Result { + let values = values[0].as_any().downcast_ref::().unwrap(); + let sum = values.iter().fold(0_u64, |acc, v| acc + v.unwrap_or(0)); + + let result = if self.double_output { sum * 2 } else { sum }; + + Ok(Arc::new(UInt64Array::from_value(result, num_rows))) + } +} + +#[tokio::test] +async fn test_metadata_based_window_fn() -> Result<()> { + let data_array = Arc::new(UInt64Array::from(vec![0, 5, 10, 15, 20])) as ArrayRef; + let schema = Arc::new(Schema::new(vec![ + Field::new("no_metadata", DataType::UInt64, true), + Field::new("with_metadata", DataType::UInt64, true).with_metadata( + [("modify_values".to_string(), "double_output".to_string())] + .into_iter() + .collect(), + ), + ])); + + let batch = RecordBatch::try_new( + schema, + vec![Arc::clone(&data_array), Arc::clone(&data_array)], + )?; + + let ctx = SessionContext::new(); + ctx.register_batch("t", batch)?; + let df = ctx.table("t").await?; + + let no_output_meta_udf = WindowUDF::from(MetadataBasedWindowUdf::new(HashMap::new())); + let with_output_meta_udf = WindowUDF::from(MetadataBasedWindowUdf::new( + [("output_metatype".to_string(), "custom_value".to_string())] + .into_iter() + .collect(), + )); + + let df = df.select(vec![ + no_output_meta_udf + .call(vec![datafusion_expr::col("no_metadata")]) + .alias("meta_no_in_no_out"), + no_output_meta_udf + .call(vec![datafusion_expr::col("with_metadata")]) + .alias("meta_with_in_no_out"), + with_output_meta_udf + .call(vec![datafusion_expr::col("no_metadata")]) + .alias("meta_no_in_with_out"), + with_output_meta_udf + .call(vec![datafusion_expr::col("with_metadata")]) + .alias("meta_with_in_with_out"), + ])?; + + let actual = df.collect().await?; + + // To test for output metadata handling, we set the expected values on the result + // To test for input metadata handling, we check the numbers returned + let mut output_meta = HashMap::new(); + let _ = output_meta.insert("output_metatype".to_string(), "custom_value".to_string()); + let expected_schema = Schema::new(vec![ + Field::new("meta_no_in_no_out", DataType::UInt64, true), + Field::new("meta_with_in_no_out", DataType::UInt64, true), + Field::new("meta_no_in_with_out", DataType::UInt64, true) + .with_metadata(output_meta.clone()), + Field::new("meta_with_in_with_out", DataType::UInt64, true) + .with_metadata(output_meta.clone()), + ]); + + let expected = record_batch!( + ("meta_no_in_no_out", UInt64, [50, 50, 50, 50, 50]), + ("meta_with_in_no_out", UInt64, [100, 100, 100, 100, 100]), + ("meta_no_in_with_out", UInt64, [50, 50, 50, 50, 50]), + ("meta_with_in_with_out", UInt64, [100, 100, 100, 100, 100]) + )? + .with_schema(Arc::new(expected_schema))?; + + assert_eq!(expected, actual[0]); + + Ok(()) +} diff --git a/datafusion/datasource-avro/src/avro_to_arrow/arrow_array_reader.rs b/datafusion/datasource-avro/src/avro_to_arrow/arrow_array_reader.rs index 9a1b54b872ad..36553b36bc6c 100644 --- a/datafusion/datasource-avro/src/avro_to_arrow/arrow_array_reader.rs +++ b/datafusion/datasource-avro/src/avro_to_arrow/arrow_array_reader.rs @@ -21,7 +21,7 @@ use apache_avro::schema::RecordSchema; use apache_avro::{ schema::{Schema as AvroSchema, SchemaKind}, types::Value, - AvroResult, Error as AvroError, Reader as AvroReader, + Error as AvroError, Reader as AvroReader, }; use arrow::array::{ make_array, Array, ArrayBuilder, ArrayData, ArrayDataBuilder, ArrayRef, @@ -33,7 +33,7 @@ use arrow::buffer::{Buffer, MutableBuffer}; use arrow::datatypes::{ ArrowDictionaryKeyType, ArrowNumericType, ArrowPrimitiveType, DataType, Date32Type, Date64Type, Field, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, - Int8Type, Schema, Time32MillisecondType, Time32SecondType, Time64MicrosecondType, + Int8Type, Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType, TimeUnit, TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, @@ -56,23 +56,17 @@ type RecordSlice<'a> = &'a [&'a Vec<(String, Value)>]; pub struct AvroArrowArrayReader<'a, R: Read> { reader: AvroReader<'a, R>, schema: SchemaRef, - projection: Option>, schema_lookup: BTreeMap, } impl AvroArrowArrayReader<'_, R> { - pub fn try_new( - reader: R, - schema: SchemaRef, - projection: Option>, - ) -> Result { + pub fn try_new(reader: R, schema: SchemaRef) -> Result { let reader = AvroReader::new(reader)?; let writer_schema = reader.writer_schema().clone(); let schema_lookup = Self::schema_lookup(writer_schema)?; Ok(Self { reader, schema, - projection, schema_lookup, }) } @@ -123,7 +117,7 @@ impl AvroArrowArrayReader<'_, R> { AvroSchema::Record(RecordSchema { fields, lookup, .. }) => { lookup.iter().for_each(|(field_name, pos)| { schema_lookup - .insert(format!("{}.{}", parent_field_name, field_name), *pos); + .insert(format!("{parent_field_name}.{field_name}"), *pos); }); for field in fields { @@ -137,7 +131,7 @@ impl AvroArrowArrayReader<'_, R> { } } AvroSchema::Array(schema) => { - let sub_parent_field_name = format!("{}.element", parent_field_name); + let sub_parent_field_name = format!("{parent_field_name}.element"); Self::child_schema_lookup( &sub_parent_field_name, &schema.items, @@ -175,20 +169,9 @@ impl AvroArrowArrayReader<'_, R> { }; let rows = rows.iter().collect::>>(); - let projection = self.projection.clone().unwrap_or_default(); - let arrays = - self.build_struct_array(&rows, "", self.schema.fields(), &projection); - let projected_fields = if projection.is_empty() { - self.schema.fields().clone() - } else { - projection - .iter() - .filter_map(|name| self.schema.column_with_name(name)) - .map(|(_, field)| field.clone()) - .collect() - }; - let projected_schema = Arc::new(Schema::new(projected_fields)); - Some(arrays.and_then(|arr| RecordBatch::try_new(projected_schema, arr))) + let arrays = self.build_struct_array(&rows, "", self.schema.fields()); + + Some(arrays.and_then(|arr| RecordBatch::try_new(Arc::clone(&self.schema), arr))) } fn build_boolean_array(&self, rows: RecordSlice, col_name: &str) -> ArrayRef { @@ -615,7 +598,7 @@ impl AvroArrowArrayReader<'_, R> { let sub_parent_field_name = format!("{}.{}", parent_field_name, list_field.name()); let arrays = - self.build_struct_array(&rows, &sub_parent_field_name, fields, &[])?; + self.build_struct_array(&rows, &sub_parent_field_name, fields)?; let data_type = DataType::Struct(fields.clone()); ArrayDataBuilder::new(data_type) .len(rows.len()) @@ -645,20 +628,14 @@ impl AvroArrowArrayReader<'_, R> { /// The function does not construct the StructArray as some callers would want the child arrays. /// /// *Note*: The function is recursive, and will read nested structs. - /// - /// If `projection` is not empty, then all values are returned. The first level of projection - /// occurs at the `RecordBatch` level. No further projection currently occurs, but would be - /// useful if plucking values from a struct, e.g. getting `a.b.c.e` from `a.b.c.{d, e}`. fn build_struct_array( &self, rows: RecordSlice, parent_field_name: &str, struct_fields: &Fields, - projection: &[String], ) -> ArrowResult> { let arrays: ArrowResult> = struct_fields .iter() - .filter(|field| projection.is_empty() || projection.contains(field.name())) .map(|field| { let field_path = if parent_field_name.is_empty() { field.name().to_string() @@ -840,12 +817,8 @@ impl AvroArrowArrayReader<'_, R> { } }) .collect::>>(); - let arrays = self.build_struct_array( - &struct_rows, - &field_path, - fields, - &[], - )?; + let arrays = + self.build_struct_array(&struct_rows, &field_path, fields)?; // construct a struct array's data in order to set null buffer let data_type = DataType::Struct(fields.clone()); let data = ArrayDataBuilder::new(data_type) @@ -965,40 +938,31 @@ fn resolve_string(v: &Value) -> ArrowResult> { .map_err(|e| SchemaError(format!("expected resolvable string : {e:?}"))) } -fn resolve_u8(v: &Value) -> AvroResult { - let int = match v { - Value::Int(n) => Ok(Value::Int(*n)), - Value::Long(n) => Ok(Value::Int(*n as i32)), - other => Err(AvroError::GetU8(other.into())), - }?; - if let Value::Int(n) = int { - if n >= 0 && n <= From::from(u8::MAX) { - return Ok(n as u8); - } - } +fn resolve_u8(v: &Value) -> Option { + let v = match v { + Value::Union(_, inner) => inner.as_ref(), + _ => v, + }; - Err(AvroError::GetU8(int.into())) + match v { + Value::Int(n) => u8::try_from(*n).ok(), + Value::Long(n) => u8::try_from(*n).ok(), + _ => None, + } } fn resolve_bytes(v: &Value) -> Option> { - let v = if let Value::Union(_, b) = v { b } else { v }; + let v = match v { + Value::Union(_, inner) => inner.as_ref(), + _ => v, + }; + match v { - Value::Bytes(_) => Ok(v.clone()), - Value::String(s) => Ok(Value::Bytes(s.clone().into_bytes())), - Value::Array(items) => Ok(Value::Bytes( - items - .iter() - .map(resolve_u8) - .collect::, _>>() - .ok()?, - )), - other => Err(AvroError::GetBytes(other.into())), - } - .ok() - .and_then(|v| match v { - Value::Bytes(s) => Some(s), + Value::Bytes(bytes) => Some(bytes.clone()), + Value::String(s) => Some(s.as_bytes().to_vec()), + Value::Array(items) => items.iter().map(resolve_u8).collect::>>(), _ => None, - }) + } } fn resolve_fixed(v: &Value, size: usize) -> Option> { diff --git a/datafusion/datasource-avro/src/avro_to_arrow/reader.rs b/datafusion/datasource-avro/src/avro_to_arrow/reader.rs index bc7b50a9cdc3..7f5900605a06 100644 --- a/datafusion/datasource-avro/src/avro_to_arrow/reader.rs +++ b/datafusion/datasource-avro/src/avro_to_arrow/reader.rs @@ -16,7 +16,7 @@ // under the License. use super::arrow_array_reader::AvroArrowArrayReader; -use arrow::datatypes::SchemaRef; +use arrow::datatypes::{Fields, SchemaRef}; use arrow::error::Result as ArrowResult; use arrow::record_batch::RecordBatch; use datafusion_common::Result; @@ -133,19 +133,35 @@ impl Reader<'_, R> { /// /// If reading a `File`, you can customise the Reader, such as to enable schema /// inference, use `ReaderBuilder`. + /// + /// If projection is provided, it uses a schema with only the fields in the projection, respecting their order. + /// Only the first level of projection is handled. No further projection currently occurs, but would be + /// useful if plucking values from a struct, e.g. getting `a.b.c.e` from `a.b.c.{d, e}`. pub fn try_new( reader: R, schema: SchemaRef, batch_size: usize, projection: Option>, ) -> Result { + let projected_schema = projection.as_ref().filter(|p| !p.is_empty()).map_or_else( + || Arc::clone(&schema), + |proj| { + Arc::new(arrow::datatypes::Schema::new( + proj.iter() + .filter_map(|name| { + schema.column_with_name(name).map(|(_, f)| f.clone()) + }) + .collect::(), + )) + }, + ); + Ok(Self { array_reader: AvroArrowArrayReader::try_new( reader, - Arc::clone(&schema), - projection, + Arc::clone(&projected_schema), )?, - schema, + schema: projected_schema, batch_size, }) } @@ -179,10 +195,13 @@ mod tests { use arrow::datatypes::{DataType, Field}; use std::fs::File; - fn build_reader(name: &str) -> Reader { + fn build_reader(name: &str, projection: Option>) -> Reader { let testdata = datafusion_common::test_util::arrow_test_data(); let filename = format!("{testdata}/avro/{name}"); - let builder = ReaderBuilder::new().read_schema().with_batch_size(64); + let mut builder = ReaderBuilder::new().read_schema().with_batch_size(64); + if let Some(projection) = projection { + builder = builder.with_projection(projection); + } builder.build(File::open(filename).unwrap()).unwrap() } @@ -195,7 +214,7 @@ mod tests { #[test] fn test_avro_basic() { - let mut reader = build_reader("alltypes_dictionary.avro"); + let mut reader = build_reader("alltypes_dictionary.avro", None); let batch = reader.next().unwrap().unwrap(); assert_eq!(11, batch.num_columns()); @@ -281,4 +300,58 @@ mod tests { assert_eq!(1230768000000000, col.value(0)); assert_eq!(1230768060000000, col.value(1)); } + + #[test] + fn test_avro_with_projection() { + // Test projection to filter and reorder columns + let projection = Some(vec![ + "string_col".to_string(), + "double_col".to_string(), + "bool_col".to_string(), + ]); + let mut reader = build_reader("alltypes_dictionary.avro", projection); + let batch = reader.next().unwrap().unwrap(); + + // Only 3 columns should be present (not all 11) + assert_eq!(3, batch.num_columns()); + assert_eq!(2, batch.num_rows()); + + let schema = reader.schema(); + let batch_schema = batch.schema(); + assert_eq!(schema, batch_schema); + + // Verify columns are in the order specified in projection + // First column should be string_col (was at index 9 in original) + assert_eq!("string_col", schema.field(0).name()); + assert_eq!(&DataType::Binary, schema.field(0).data_type()); + let col = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!("0".as_bytes(), col.value(0)); + assert_eq!("1".as_bytes(), col.value(1)); + + // Second column should be double_col (was at index 7 in original) + assert_eq!("double_col", schema.field(1).name()); + assert_eq!(&DataType::Float64, schema.field(1).data_type()); + let col = batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(0.0, col.value(0)); + assert_eq!(10.1, col.value(1)); + + // Third column should be bool_col (was at index 1 in original) + assert_eq!("bool_col", schema.field(2).name()); + assert_eq!(&DataType::Boolean, schema.field(2).data_type()); + let col = batch + .column(2) + .as_any() + .downcast_ref::() + .unwrap(); + assert!(col.value(0)); + assert!(!col.value(1)); + } } diff --git a/datafusion/datasource-avro/src/avro_to_arrow/schema.rs b/datafusion/datasource-avro/src/avro_to_arrow/schema.rs index 276056c24c01..f53d38e51d1f 100644 --- a/datafusion/datasource-avro/src/avro_to_arrow/schema.rs +++ b/datafusion/datasource-avro/src/avro_to_arrow/schema.rs @@ -22,7 +22,7 @@ use apache_avro::types::Value; use apache_avro::Schema as AvroSchema; use arrow::datatypes::{DataType, IntervalUnit, Schema, TimeUnit, UnionMode}; use arrow::datatypes::{Field, UnionFields}; -use datafusion_common::error::{DataFusionError, Result}; +use datafusion_common::error::Result; use std::collections::HashMap; use std::sync::Arc; @@ -107,9 +107,7 @@ fn schema_to_field_with_props( .data_type() .clone() } else { - return Err(DataFusionError::AvroError( - apache_avro::Error::GetUnionDuplicate, - )); + return Err(apache_avro::Error::GetUnionDuplicate.into()); } } else { let fields = sub_schemas diff --git a/datafusion/datasource-avro/src/file_format.rs b/datafusion/datasource-avro/src/file_format.rs index 4b50fee1d326..700b7058b477 100644 --- a/datafusion/datasource-avro/src/file_format.rs +++ b/datafusion/datasource-avro/src/file_format.rs @@ -37,12 +37,12 @@ use datafusion_datasource::file_compression_type::FileCompressionType; use datafusion_datasource::file_format::{FileFormat, FileFormatFactory}; use datafusion_datasource::file_scan_config::{FileScanConfig, FileScanConfigBuilder}; use datafusion_datasource::source::DataSourceExec; -use datafusion_physical_expr::PhysicalExpr; use datafusion_physical_plan::ExecutionPlan; use datafusion_session::Session; use async_trait::async_trait; use object_store::{GetResultPayload, ObjectMeta, ObjectStore}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; #[derive(Default)] /// Factory struct used to create [`AvroFormat`] diff --git a/datafusion/datasource-avro/src/source.rs b/datafusion/datasource-avro/src/source.rs index ce3722e7b11e..3254f48bab39 100644 --- a/datafusion/datasource-avro/src/source.rs +++ b/datafusion/datasource-avro/src/source.rs @@ -18,142 +18,22 @@ //! Execution plan for reading line-delimited Avro files use std::any::Any; -use std::fmt::Formatter; use std::sync::Arc; use crate::avro_to_arrow::Reader as AvroReader; -use datafusion_common::error::Result; - use arrow::datatypes::SchemaRef; -use datafusion_common::{Constraints, Statistics}; +use datafusion_common::error::Result; +use datafusion_common::Statistics; use datafusion_datasource::file::FileSource; use datafusion_datasource::file_scan_config::FileScanConfig; use datafusion_datasource::file_stream::FileOpener; -use datafusion_datasource::source::DataSourceExec; -use datafusion_execution::{SendableRecordBatchStream, TaskContext}; -use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; +use datafusion_datasource::schema_adapter::SchemaAdapterFactory; use datafusion_physical_expr_common::sort_expr::LexOrdering; -use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; -use datafusion_physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet}; -use datafusion_physical_plan::{ - DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, -}; +use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet; use object_store::ObjectStore; -/// Execution plan for scanning Avro data source -#[derive(Debug, Clone)] -#[deprecated(since = "46.0.0", note = "use DataSourceExec instead")] -pub struct AvroExec { - inner: DataSourceExec, - base_config: FileScanConfig, -} - -#[allow(unused, deprecated)] -impl AvroExec { - /// Create a new Avro reader execution plan provided base configurations - pub fn new(base_config: FileScanConfig) -> Self { - let ( - projected_schema, - projected_constraints, - projected_statistics, - projected_output_ordering, - ) = base_config.project(); - let cache = Self::compute_properties( - Arc::clone(&projected_schema), - &projected_output_ordering, - projected_constraints, - &base_config, - ); - let base_config = base_config.with_source(Arc::new(AvroSource::default())); - Self { - inner: DataSourceExec::new(Arc::new(base_config.clone())), - base_config, - } - } - - /// Ref to the base configs - pub fn base_config(&self) -> &FileScanConfig { - &self.base_config - } - - /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. - fn compute_properties( - schema: SchemaRef, - orderings: &[LexOrdering], - constraints: Constraints, - file_scan_config: &FileScanConfig, - ) -> PlanProperties { - // Equivalence Properties - let eq_properties = EquivalenceProperties::new_with_orderings(schema, orderings) - .with_constraints(constraints); - let n_partitions = file_scan_config.file_groups.len(); - - PlanProperties::new( - eq_properties, - Partitioning::UnknownPartitioning(n_partitions), // Output Partitioning - EmissionType::Incremental, - Boundedness::Bounded, - ) - } -} - -#[allow(unused, deprecated)] -impl DisplayAs for AvroExec { - fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { - self.inner.fmt_as(t, f) - } -} - -#[allow(unused, deprecated)] -impl ExecutionPlan for AvroExec { - fn name(&self) -> &'static str { - "AvroExec" - } - - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { - self.inner.properties() - } - fn children(&self) -> Vec<&Arc> { - Vec::new() - } - fn with_new_children( - self: Arc, - _: Vec>, - ) -> Result> { - Ok(self) - } - - fn execute( - &self, - partition: usize, - context: Arc, - ) -> Result { - self.inner.execute(partition, context) - } - - fn statistics(&self) -> Result { - self.inner.statistics() - } - - fn metrics(&self) -> Option { - self.inner.metrics() - } - - fn fetch(&self) -> Option { - self.inner.fetch() - } - - fn with_fetch(&self, limit: Option) -> Option> { - self.inner.with_fetch(limit) - } -} - /// AvroSource holds the extra configuration that is necessary for opening avro files #[derive(Clone, Default)] pub struct AvroSource { @@ -162,6 +42,7 @@ pub struct AvroSource { projection: Option>, metrics: ExecutionPlanMetricsSet, projected_statistics: Option, + schema_adapter_factory: Option>, } impl AvroSource { @@ -244,6 +125,20 @@ impl FileSource for AvroSource { ) -> Result> { Ok(None) } + + fn with_schema_adapter_factory( + &self, + schema_adapter_factory: Arc, + ) -> Result> { + Ok(Arc::new(Self { + schema_adapter_factory: Some(schema_adapter_factory), + ..self.clone() + })) + } + + fn schema_adapter_factory(&self) -> Option> { + self.schema_adapter_factory.clone() + } } mod private { diff --git a/datafusion/datasource-csv/src/file_format.rs b/datafusion/datasource-csv/src/file_format.rs index 76f3c50a70a7..1abcadc1d414 100644 --- a/datafusion/datasource-csv/src/file_format.rs +++ b/datafusion/datasource-csv/src/file_format.rs @@ -50,7 +50,6 @@ use datafusion_datasource::write::orchestration::spawn_writer_tasks_and_join; use datafusion_datasource::write::BatchSerializer; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_expr::dml::InsertOp; -use datafusion_physical_expr::PhysicalExpr; use datafusion_physical_expr_common::sort_expr::LexRequirement; use datafusion_physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan}; use datafusion_session::Session; @@ -62,6 +61,7 @@ use futures::stream::BoxStream; use futures::{pin_mut, Stream, StreamExt, TryStreamExt}; use object_store::{delimited::newline_delimited_stream, ObjectMeta, ObjectStore}; use regex::Regex; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; #[derive(Default)] /// Factory used to create [`CsvFormat`] @@ -414,11 +414,11 @@ impl FileFormat for CsvFormat { let has_header = self .options .has_header - .unwrap_or(state.config_options().catalog.has_header); + .unwrap_or_else(|| state.config_options().catalog.has_header); let newlines_in_values = self .options .newlines_in_values - .unwrap_or(state.config_options().catalog.newlines_in_values); + .unwrap_or_else(|| state.config_options().catalog.newlines_in_values); let conf_builder = FileScanConfigBuilder::from(conf) .with_file_compression_type(self.options.compression.into()) @@ -454,11 +454,11 @@ impl FileFormat for CsvFormat { let has_header = self .options() .has_header - .unwrap_or(state.config_options().catalog.has_header); + .unwrap_or_else(|| state.config_options().catalog.has_header); let newlines_in_values = self .options() .newlines_in_values - .unwrap_or(state.config_options().catalog.newlines_in_values); + .unwrap_or_else(|| state.config_options().catalog.newlines_in_values); let options = self .options() @@ -504,7 +504,7 @@ impl CsvFormat { && self .options .has_header - .unwrap_or(state.config_options().catalog.has_header), + .unwrap_or_else(|| state.config_options().catalog.has_header), ) .with_delimiter(self.options.delimiter) .with_quote(self.options.quote); diff --git a/datafusion/datasource-csv/src/source.rs b/datafusion/datasource-csv/src/source.rs index f5d45cd3fc88..3af1f2b345ba 100644 --- a/datafusion/datasource-csv/src/source.rs +++ b/datafusion/datasource-csv/src/source.rs @@ -17,6 +17,7 @@ //! Execution plan for reading CSV files +use datafusion_datasource::schema_adapter::SchemaAdapterFactory; use std::any::Any; use std::fmt; use std::io::{Read, Seek, SeekFrom}; @@ -28,379 +29,27 @@ use datafusion_datasource::file_compression_type::FileCompressionType; use datafusion_datasource::file_meta::FileMeta; use datafusion_datasource::file_stream::{FileOpenFuture, FileOpener}; use datafusion_datasource::{ - calculate_range, FileRange, ListingTableUrl, RangeCalculation, + as_file_source, calculate_range, FileRange, ListingTableUrl, RangeCalculation, }; use arrow::csv; use arrow::datatypes::SchemaRef; -use datafusion_common::config::ConfigOptions; -use datafusion_common::{Constraints, DataFusionError, Result, Statistics}; +use datafusion_common::{DataFusionError, Result, Statistics}; use datafusion_common_runtime::JoinSet; use datafusion_datasource::file::FileSource; use datafusion_datasource::file_scan_config::FileScanConfig; -use datafusion_datasource::source::DataSourceExec; -use datafusion_execution::{SendableRecordBatchStream, TaskContext}; -use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; -use datafusion_physical_expr_common::sort_expr::LexOrdering; -use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; -use datafusion_physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet}; -use datafusion_physical_plan::projection::ProjectionExec; +use datafusion_execution::TaskContext; +use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet; use datafusion_physical_plan::{ - DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties, + DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, }; use crate::file_format::CsvDecoder; -use datafusion_datasource::file_groups::FileGroup; use futures::{StreamExt, TryStreamExt}; use object_store::buffered::BufWriter; use object_store::{GetOptions, GetResultPayload, ObjectStore}; use tokio::io::AsyncWriteExt; -/// Old Csv source, deprecated with DataSourceExec implementation and CsvSource -/// -/// See examples on `CsvSource` -#[derive(Debug, Clone)] -#[deprecated(since = "46.0.0", note = "use DataSourceExec instead")] -pub struct CsvExec { - base_config: FileScanConfig, - inner: DataSourceExec, -} - -/// Builder for [`CsvExec`]. -/// -/// See example on [`CsvExec`]. -#[derive(Debug, Clone)] -#[deprecated(since = "46.0.0", note = "use FileScanConfig instead")] -pub struct CsvExecBuilder { - file_scan_config: FileScanConfig, - file_compression_type: FileCompressionType, - // TODO: it seems like these format options could be reused across all the various CSV config - has_header: bool, - delimiter: u8, - quote: u8, - terminator: Option, - escape: Option, - comment: Option, - newlines_in_values: bool, -} - -#[allow(unused, deprecated)] -impl CsvExecBuilder { - /// Create a new builder to read the provided file scan configuration. - pub fn new(file_scan_config: FileScanConfig) -> Self { - Self { - file_scan_config, - // TODO: these defaults are duplicated from `CsvOptions` - should they be computed? - has_header: false, - delimiter: b',', - quote: b'"', - terminator: None, - escape: None, - comment: None, - newlines_in_values: false, - file_compression_type: FileCompressionType::UNCOMPRESSED, - } - } - - /// Set whether the first row defines the column names. - /// - /// The default value is `false`. - pub fn with_has_header(mut self, has_header: bool) -> Self { - self.has_header = has_header; - self - } - - /// Set the column delimeter. - /// - /// The default is `,`. - pub fn with_delimeter(mut self, delimiter: u8) -> Self { - self.delimiter = delimiter; - self - } - - /// Set the quote character. - /// - /// The default is `"`. - pub fn with_quote(mut self, quote: u8) -> Self { - self.quote = quote; - self - } - - /// Set the line terminator. If not set, the default is CRLF. - /// - /// The default is None. - pub fn with_terminator(mut self, terminator: Option) -> Self { - self.terminator = terminator; - self - } - - /// Set the escape character. - /// - /// The default is `None` (i.e. quotes cannot be escaped). - pub fn with_escape(mut self, escape: Option) -> Self { - self.escape = escape; - self - } - - /// Set the comment character. - /// - /// The default is `None` (i.e. comments are not supported). - pub fn with_comment(mut self, comment: Option) -> Self { - self.comment = comment; - self - } - - /// Set whether newlines in (quoted) values are supported. - /// - /// Parsing newlines in quoted values may be affected by execution behaviour such as - /// parallel file scanning. Setting this to `true` ensures that newlines in values are - /// parsed successfully, which may reduce performance. - /// - /// The default value is `false`. - pub fn with_newlines_in_values(mut self, newlines_in_values: bool) -> Self { - self.newlines_in_values = newlines_in_values; - self - } - - /// Set the file compression type. - /// - /// The default is [`FileCompressionType::UNCOMPRESSED`]. - pub fn with_file_compression_type( - mut self, - file_compression_type: FileCompressionType, - ) -> Self { - self.file_compression_type = file_compression_type; - self - } - - /// Build a [`CsvExec`]. - #[must_use] - pub fn build(self) -> CsvExec { - let Self { - file_scan_config: base_config, - file_compression_type, - has_header, - delimiter, - quote, - terminator, - escape, - comment, - newlines_in_values, - } = self; - - let ( - projected_schema, - projected_constraints, - projected_statistics, - projected_output_ordering, - ) = base_config.project(); - let cache = CsvExec::compute_properties( - projected_schema, - &projected_output_ordering, - projected_constraints, - &base_config, - ); - let csv = CsvSource::new(has_header, delimiter, quote) - .with_comment(comment) - .with_escape(escape) - .with_terminator(terminator); - let base_config = base_config - .with_newlines_in_values(newlines_in_values) - .with_file_compression_type(file_compression_type) - .with_source(Arc::new(csv)); - - CsvExec { - inner: DataSourceExec::new(Arc::new(base_config.clone())), - base_config, - } - } -} - -#[allow(unused, deprecated)] -impl CsvExec { - /// Create a new CSV reader execution plan provided base and specific configurations - #[allow(clippy::too_many_arguments)] - pub fn new( - base_config: FileScanConfig, - has_header: bool, - delimiter: u8, - quote: u8, - terminator: Option, - escape: Option, - comment: Option, - newlines_in_values: bool, - file_compression_type: FileCompressionType, - ) -> Self { - CsvExecBuilder::new(base_config) - .with_has_header(has_header) - .with_delimeter(delimiter) - .with_quote(quote) - .with_terminator(terminator) - .with_escape(escape) - .with_comment(comment) - .with_newlines_in_values(newlines_in_values) - .with_file_compression_type(file_compression_type) - .build() - } - - /// Return a [`CsvExecBuilder`]. - /// - /// See example on [`CsvExec`] and [`CsvExecBuilder`] for specifying CSV table options. - pub fn builder(file_scan_config: FileScanConfig) -> CsvExecBuilder { - CsvExecBuilder::new(file_scan_config) - } - - /// Ref to the base configs - pub fn base_config(&self) -> &FileScanConfig { - &self.base_config - } - - fn file_scan_config(&self) -> FileScanConfig { - self.inner - .data_source() - .as_any() - .downcast_ref::() - .unwrap() - .clone() - } - - fn csv_source(&self) -> CsvSource { - let source = self.file_scan_config(); - source - .file_source() - .as_any() - .downcast_ref::() - .unwrap() - .clone() - } - - /// true if the first line of each file is a header - pub fn has_header(&self) -> bool { - self.csv_source().has_header() - } - - /// Specifies whether newlines in (quoted) values are supported. - /// - /// Parsing newlines in quoted values may be affected by execution behaviour such as - /// parallel file scanning. Setting this to `true` ensures that newlines in values are - /// parsed successfully, which may reduce performance. - /// - /// The default behaviour depends on the `datafusion.catalog.newlines_in_values` setting. - pub fn newlines_in_values(&self) -> bool { - let source = self.file_scan_config(); - source.newlines_in_values() - } - - fn output_partitioning_helper(file_scan_config: &FileScanConfig) -> Partitioning { - Partitioning::UnknownPartitioning(file_scan_config.file_groups.len()) - } - - /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. - fn compute_properties( - schema: SchemaRef, - orderings: &[LexOrdering], - constraints: Constraints, - file_scan_config: &FileScanConfig, - ) -> PlanProperties { - // Equivalence Properties - let eq_properties = EquivalenceProperties::new_with_orderings(schema, orderings) - .with_constraints(constraints); - - PlanProperties::new( - eq_properties, - Self::output_partitioning_helper(file_scan_config), // Output Partitioning - EmissionType::Incremental, - Boundedness::Bounded, - ) - } - - fn with_file_groups(mut self, file_groups: Vec) -> Self { - self.base_config.file_groups = file_groups.clone(); - let mut file_source = self.file_scan_config(); - file_source = file_source.with_file_groups(file_groups); - self.inner = self.inner.with_data_source(Arc::new(file_source)); - self - } -} - -#[allow(unused, deprecated)] -impl DisplayAs for CsvExec { - fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { - self.inner.fmt_as(t, f) - } -} - -#[allow(unused, deprecated)] -impl ExecutionPlan for CsvExec { - fn name(&self) -> &'static str { - "CsvExec" - } - - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { - self.inner.properties() - } - - fn children(&self) -> Vec<&Arc> { - // this is a leaf node and has no children - vec![] - } - - fn with_new_children( - self: Arc, - _: Vec>, - ) -> Result> { - Ok(self) - } - - /// Redistribute files across partitions according to their size - /// See comments on `FileGroupPartitioner` for more detail. - /// - /// Return `None` if can't get repartitioned (empty, compressed file, or `newlines_in_values` set). - fn repartitioned( - &self, - target_partitions: usize, - config: &ConfigOptions, - ) -> Result>> { - self.inner.repartitioned(target_partitions, config) - } - - fn execute( - &self, - partition: usize, - context: Arc, - ) -> Result { - self.inner.execute(partition, context) - } - - fn statistics(&self) -> Result { - self.inner.statistics() - } - - fn metrics(&self) -> Option { - self.inner.metrics() - } - - fn fetch(&self) -> Option { - self.inner.fetch() - } - - fn with_fetch(&self, limit: Option) -> Option> { - self.inner.with_fetch(limit) - } - - fn try_swapping_with_projection( - &self, - projection: &ProjectionExec, - ) -> Result>> { - self.inner.try_swapping_with_projection(projection) - } -} - /// A Config for [`CsvOpener`] /// /// # Example: create a `DataSourceExec` for CSV @@ -443,6 +92,7 @@ pub struct CsvSource { comment: Option, metrics: ExecutionPlanMetricsSet, projected_statistics: Option, + schema_adapter_factory: Option>, } impl CsvSource { @@ -564,6 +214,12 @@ impl CsvOpener { } } +impl From for Arc { + fn from(source: CsvSource) -> Self { + as_file_source(source) + } +} + impl FileSource for CsvSource { fn create_file_opener( &self, @@ -626,6 +282,20 @@ impl FileSource for CsvSource { DisplayFormatType::TreeRender => Ok(()), } } + + fn with_schema_adapter_factory( + &self, + schema_adapter_factory: Arc, + ) -> Result> { + Ok(Arc::new(Self { + schema_adapter_factory: Some(schema_adapter_factory), + ..self.clone() + })) + } + + fn schema_adapter_factory(&self) -> Option> { + self.schema_adapter_factory.clone() + } } impl FileOpener for CsvOpener { @@ -743,6 +413,11 @@ pub async fn plan_to_csv( let parsed = ListingTableUrl::parse(path)?; let object_store_url = parsed.object_store(); let store = task_ctx.runtime_env().object_store(&object_store_url)?; + let writer_buffer_size = task_ctx + .session_config() + .options() + .execution + .objectstore_writer_buffer_size; let mut join_set = JoinSet::new(); for i in 0..plan.output_partitioning().partition_count() { let storeref = Arc::clone(&store); @@ -752,7 +427,8 @@ pub async fn plan_to_csv( let mut stream = plan.execute(i, Arc::clone(&task_ctx))?; join_set.spawn(async move { - let mut buf_writer = BufWriter::new(storeref, file.clone()); + let mut buf_writer = + BufWriter::with_capacity(storeref, file.clone(), writer_buffer_size); let mut buffer = Vec::with_capacity(1024); //only write headers on first iteration let mut write_headers = true; diff --git a/datafusion/datasource-json/src/file_format.rs b/datafusion/datasource-json/src/file_format.rs index 8d0515804fc7..2b5f65c87a15 100644 --- a/datafusion/datasource-json/src/file_format.rs +++ b/datafusion/datasource-json/src/file_format.rs @@ -52,7 +52,6 @@ use datafusion_datasource::write::orchestration::spawn_writer_tasks_and_join; use datafusion_datasource::write::BatchSerializer; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_expr::dml::InsertOp; -use datafusion_physical_expr::PhysicalExpr; use datafusion_physical_expr_common::sort_expr::LexRequirement; use datafusion_physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan}; use datafusion_session::Session; @@ -61,6 +60,7 @@ use async_trait::async_trait; use bytes::{Buf, Bytes}; use datafusion_datasource::source::DataSourceExec; use object_store::{GetResultPayload, ObjectMeta, ObjectStore}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; #[derive(Default)] /// Factory struct used to create [JsonFormat] diff --git a/datafusion/datasource-json/src/source.rs b/datafusion/datasource-json/src/source.rs index ee96d050966d..af37e1033ef1 100644 --- a/datafusion/datasource-json/src/source.rs +++ b/datafusion/datasource-json/src/source.rs @@ -30,198 +30,25 @@ use datafusion_datasource::decoder::{deserialize_stream, DecoderDeserializer}; use datafusion_datasource::file_compression_type::FileCompressionType; use datafusion_datasource::file_meta::FileMeta; use datafusion_datasource::file_stream::{FileOpenFuture, FileOpener}; -use datafusion_datasource::{calculate_range, ListingTableUrl, RangeCalculation}; +use datafusion_datasource::schema_adapter::SchemaAdapterFactory; +use datafusion_datasource::{ + as_file_source, calculate_range, ListingTableUrl, RangeCalculation, +}; use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanProperties}; use arrow::json::ReaderBuilder; use arrow::{datatypes::SchemaRef, json}; -use datafusion_common::{Constraints, Statistics}; +use datafusion_common::Statistics; use datafusion_datasource::file::FileSource; use datafusion_datasource::file_scan_config::FileScanConfig; -use datafusion_datasource::source::DataSourceExec; -use datafusion_execution::{SendableRecordBatchStream, TaskContext}; -use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; -use datafusion_physical_expr_common::sort_expr::LexOrdering; -use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; -use datafusion_physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet}; -use datafusion_physical_plan::{DisplayAs, DisplayFormatType, PlanProperties}; - -use datafusion_datasource::file_groups::FileGroup; +use datafusion_execution::TaskContext; +use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet; + use futures::{StreamExt, TryStreamExt}; use object_store::buffered::BufWriter; use object_store::{GetOptions, GetResultPayload, ObjectStore}; use tokio::io::AsyncWriteExt; -/// Execution plan for scanning NdJson data source -#[derive(Debug, Clone)] -#[deprecated(since = "46.0.0", note = "use DataSourceExec instead")] -pub struct NdJsonExec { - inner: DataSourceExec, - base_config: FileScanConfig, - file_compression_type: FileCompressionType, -} - -#[allow(unused, deprecated)] -impl NdJsonExec { - /// Create a new JSON reader execution plan provided base configurations - pub fn new( - base_config: FileScanConfig, - file_compression_type: FileCompressionType, - ) -> Self { - let ( - projected_schema, - projected_constraints, - projected_statistics, - projected_output_ordering, - ) = base_config.project(); - let cache = Self::compute_properties( - projected_schema, - &projected_output_ordering, - projected_constraints, - &base_config, - ); - - let json = JsonSource::default(); - let base_config = base_config - .with_file_compression_type(file_compression_type) - .with_source(Arc::new(json)); - - Self { - inner: DataSourceExec::new(Arc::new(base_config.clone())), - file_compression_type: base_config.file_compression_type, - base_config, - } - } - - /// Ref to the base configs - pub fn base_config(&self) -> &FileScanConfig { - &self.base_config - } - - /// Ref to file compression type - pub fn file_compression_type(&self) -> &FileCompressionType { - &self.file_compression_type - } - - fn file_scan_config(&self) -> FileScanConfig { - self.inner - .data_source() - .as_any() - .downcast_ref::() - .unwrap() - .clone() - } - - fn json_source(&self) -> JsonSource { - let source = self.file_scan_config(); - source - .file_source() - .as_any() - .downcast_ref::() - .unwrap() - .clone() - } - - fn output_partitioning_helper(file_scan_config: &FileScanConfig) -> Partitioning { - Partitioning::UnknownPartitioning(file_scan_config.file_groups.len()) - } - - /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. - fn compute_properties( - schema: SchemaRef, - orderings: &[LexOrdering], - constraints: Constraints, - file_scan_config: &FileScanConfig, - ) -> PlanProperties { - // Equivalence Properties - let eq_properties = EquivalenceProperties::new_with_orderings(schema, orderings) - .with_constraints(constraints); - - PlanProperties::new( - eq_properties, - Self::output_partitioning_helper(file_scan_config), // Output Partitioning - EmissionType::Incremental, - Boundedness::Bounded, - ) - } - - fn with_file_groups(mut self, file_groups: Vec) -> Self { - self.base_config.file_groups = file_groups.clone(); - let mut file_source = self.file_scan_config(); - file_source = file_source.with_file_groups(file_groups); - self.inner = self.inner.with_data_source(Arc::new(file_source)); - self - } -} - -#[allow(unused, deprecated)] -impl DisplayAs for NdJsonExec { - fn fmt_as( - &self, - t: DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { - self.inner.fmt_as(t, f) - } -} - -#[allow(unused, deprecated)] -impl ExecutionPlan for NdJsonExec { - fn name(&self) -> &'static str { - "NdJsonExec" - } - - fn as_any(&self) -> &dyn Any { - self - } - fn properties(&self) -> &PlanProperties { - self.inner.properties() - } - - fn children(&self) -> Vec<&Arc> { - Vec::new() - } - - fn with_new_children( - self: Arc, - _: Vec>, - ) -> Result> { - Ok(self) - } - - fn repartitioned( - &self, - target_partitions: usize, - config: &datafusion_common::config::ConfigOptions, - ) -> Result>> { - self.inner.repartitioned(target_partitions, config) - } - - fn execute( - &self, - partition: usize, - context: Arc, - ) -> Result { - self.inner.execute(partition, context) - } - - fn statistics(&self) -> Result { - self.inner.statistics() - } - - fn metrics(&self) -> Option { - self.inner.metrics() - } - - fn fetch(&self) -> Option { - self.inner.fetch() - } - - fn with_fetch(&self, limit: Option) -> Option> { - self.inner.with_fetch(limit) - } -} - /// A [`FileOpener`] that opens a JSON file and yields a [`FileOpenFuture`] pub struct JsonOpener { batch_size: usize, @@ -253,6 +80,7 @@ pub struct JsonSource { batch_size: Option, metrics: ExecutionPlanMetricsSet, projected_statistics: Option, + schema_adapter_factory: Option>, } impl JsonSource { @@ -262,6 +90,12 @@ impl JsonSource { } } +impl From for Arc { + fn from(source: JsonSource) -> Self { + as_file_source(source) + } +} + impl FileSource for JsonSource { fn create_file_opener( &self, @@ -316,6 +150,20 @@ impl FileSource for JsonSource { fn file_type(&self) -> &str { "json" } + + fn with_schema_adapter_factory( + &self, + schema_adapter_factory: Arc, + ) -> Result> { + Ok(Arc::new(Self { + schema_adapter_factory: Some(schema_adapter_factory), + ..self.clone() + })) + } + + fn schema_adapter_factory(&self) -> Option> { + self.schema_adapter_factory.clone() + } } impl FileOpener for JsonOpener { @@ -399,6 +247,11 @@ pub async fn plan_to_json( let parsed = ListingTableUrl::parse(path)?; let object_store_url = parsed.object_store(); let store = task_ctx.runtime_env().object_store(&object_store_url)?; + let writer_buffer_size = task_ctx + .session_config() + .options() + .execution + .objectstore_writer_buffer_size; let mut join_set = JoinSet::new(); for i in 0..plan.output_partitioning().partition_count() { let storeref = Arc::clone(&store); @@ -408,7 +261,8 @@ pub async fn plan_to_json( let mut stream = plan.execute(i, Arc::clone(&task_ctx))?; join_set.spawn(async move { - let mut buf_writer = BufWriter::new(storeref, file.clone()); + let mut buf_writer = + BufWriter::with_capacity(storeref, file.clone(), writer_buffer_size); let mut buffer = Vec::with_capacity(1024); while let Some(batch) = stream.next().await.transpose()? { diff --git a/datafusion/datasource-parquet/src/file_format.rs b/datafusion/datasource-parquet/src/file_format.rs index 2ef4f236f278..c3f322319789 100644 --- a/datafusion/datasource-parquet/src/file_format.rs +++ b/datafusion/datasource-parquet/src/file_format.rs @@ -18,20 +18,22 @@ //! [`ParquetFormat`]: Parquet [`FileFormat`] abstractions use std::any::Any; +use std::cell::RefCell; use std::fmt; use std::fmt::Debug; use std::ops::Range; +use std::rc::Rc; use std::sync::Arc; use arrow::array::RecordBatch; use arrow::datatypes::{Fields, Schema, SchemaRef, TimeUnit}; use datafusion_datasource::file_compression_type::FileCompressionType; use datafusion_datasource::file_sink_config::{FileSink, FileSinkConfig}; -use datafusion_datasource::write::{create_writer, get_writer_schema, SharedBuffer}; - -use datafusion_datasource::file_format::{ - FileFormat, FileFormatFactory, FilePushdownSupport, +use datafusion_datasource::write::{ + get_writer_schema, ObjectWriterBuilder, SharedBuffer, }; + +use datafusion_datasource::file_format::{FileFormat, FileFormatFactory}; use datafusion_datasource::write::demux::DemuxedStreamReceiver; use arrow::compute::sum; @@ -41,7 +43,7 @@ use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::stats::Precision; use datafusion_common::{ internal_datafusion_err, internal_err, not_impl_err, ColumnStatistics, - DataFusionError, GetExt, Result, DEFAULT_PARQUET_EXTENSION, + DataFusionError, GetExt, HashSet, Result, DEFAULT_PARQUET_EXTENSION, }; use datafusion_common::{HashMap, Statistics}; use datafusion_common_runtime::{JoinSet, SpawnedTask}; @@ -52,16 +54,13 @@ use datafusion_datasource::sink::{DataSink, DataSinkExec}; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryPool, MemoryReservation}; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_expr::dml::InsertOp; -use datafusion_expr::Expr; use datafusion_functions_aggregate::min_max::{MaxAccumulator, MinAccumulator}; -use datafusion_physical_expr::PhysicalExpr; use datafusion_physical_expr_common::sort_expr::LexRequirement; use datafusion_physical_plan::Accumulator; use datafusion_physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan}; use datafusion_session::Session; -use crate::can_expr_be_pushed_down_with_schemas; -use crate::source::ParquetSource; +use crate::source::{parse_coerce_int96_string, ParquetSource}; use async_trait::async_trait; use bytes::Bytes; use datafusion_datasource::source::DataSourceExec; @@ -87,6 +86,7 @@ use parquet::format::FileMetaData; use parquet::schema::types::SchemaDescriptor; use tokio::io::{AsyncWrite, AsyncWriteExt}; use tokio::sync::mpsc::{self, Receiver, Sender}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; /// Initial writing buffer size. Note this is just a size hint for efficiency. It /// will grow beyond the set value if needed. @@ -304,9 +304,10 @@ async fn fetch_schema_with_location( store: &dyn ObjectStore, file: &ObjectMeta, metadata_size_hint: Option, + coerce_int96: Option, ) -> Result<(Path, Schema)> { let loc_path = file.location.clone(); - let schema = fetch_schema(store, file, metadata_size_hint).await?; + let schema = fetch_schema(store, file, metadata_size_hint, coerce_int96).await?; Ok((loc_path, schema)) } @@ -337,12 +338,17 @@ impl FileFormat for ParquetFormat { store: &Arc, objects: &[ObjectMeta], ) -> Result { + let coerce_int96 = match self.coerce_int96() { + Some(time_unit) => Some(parse_coerce_int96_string(time_unit.as_str())?), + None => None, + }; let mut schemas: Vec<_> = futures::stream::iter(objects) .map(|object| { fetch_schema_with_location( store.as_ref(), object, self.metadata_size_hint(), + coerce_int96, ) }) .boxed() // Workaround https://github.com/rust-lang/rust/issues/64552 @@ -430,9 +436,11 @@ impl FileFormat for ParquetFormat { if let Some(metadata_size_hint) = metadata_size_hint { source = source.with_metadata_size_hint(metadata_size_hint) } + // Apply schema adapter factory before building the new config + let file_source = source.apply_schema_adapter(&conf)?; let conf = FileScanConfigBuilder::from(conf) - .with_source(Arc::new(source)) + .with_source(file_source) .build(); Ok(DataSourceExec::from_data_source(conf)) } @@ -453,27 +461,6 @@ impl FileFormat for ParquetFormat { Ok(Arc::new(DataSinkExec::new(input, sink, order_requirements)) as _) } - fn supports_filters_pushdown( - &self, - file_schema: &Schema, - table_schema: &Schema, - filters: &[&Expr], - ) -> Result { - if !self.options().global.pushdown_filters { - return Ok(FilePushdownSupport::NoSupport); - } - - let all_supported = filters.iter().all(|filter| { - can_expr_be_pushed_down_with_schemas(filter, file_schema, table_schema) - }); - - Ok(if all_supported { - FilePushdownSupport::Supported - } else { - FilePushdownSupport::NotSupportedForFilter - }) - } - fn file_source(&self) -> Arc { Arc::new(ParquetSource::default()) } @@ -588,38 +575,186 @@ pub fn coerce_int96_to_resolution( file_schema: &Schema, time_unit: &TimeUnit, ) -> Option { - let mut transform = false; - let parquet_fields: HashMap<_, _> = parquet_schema + // Traverse the parquet_schema columns looking for int96 physical types. If encountered, insert + // the field's full path into a set. + let int96_fields: HashSet<_> = parquet_schema .columns() .iter() - .map(|f| { - let dt = f.physical_type(); - if dt.eq(&Type::INT96) { - transform = true; - } - (f.name(), dt) - }) + .filter(|f| f.physical_type() == Type::INT96) + .map(|f| f.path().string()) .collect(); - if !transform { + if int96_fields.is_empty() { + // The schema doesn't contain any int96 fields, so skip the remaining logic. return None; } - let transformed_fields: Vec> = file_schema - .fields - .iter() - .map(|field| match parquet_fields.get(field.name().as_str()) { - Some(Type::INT96) => { - field_with_new_type(field, DataType::Timestamp(*time_unit, None)) + // Do a DFS into the schema using a stack, looking for timestamp(nanos) fields that originated + // as int96 to coerce to the provided time_unit. + + type NestedFields = Rc>>; + type StackContext<'a> = ( + Vec<&'a str>, // The Parquet column path (e.g., "c0.list.element.c1") for the current field. + &'a FieldRef, // The current field to be processed. + NestedFields, // The parent's fields that this field will be (possibly) type-coerced and + // inserted into. All fields have a parent, so this is not an Option type. + Option, // Nested types need to create their own vector of fields for their + // children. For primitive types this will remain None. For nested + // types it is None the first time they are processed. Then, we + // instantiate a vector for its children, push the field back onto the + // stack to be processed again, and DFS into its children. The next + // time we process the field, we know we have DFS'd into the children + // because this field is Some. + ); + + // This is our top-level fields from which we will construct our schema. We pass this into our + // initial stack context as the parent fields, and the DFS populates it. + let fields = Rc::new(RefCell::new(Vec::with_capacity(file_schema.fields.len()))); + + // TODO: It might be possible to only DFS into nested fields that we know contain an int96 if we + // use some sort of LPM data structure to check if we're currently DFS'ing nested types that are + // in a column path that contains an int96. That can be a future optimization for large schemas. + let transformed_schema = { + // Populate the stack with our top-level fields. + let mut stack: Vec = file_schema + .fields() + .iter() + .rev() + .map(|f| (vec![f.name().as_str()], f, Rc::clone(&fields), None)) + .collect(); + + // Pop fields to DFS into until we have exhausted the stack. + while let Some((parquet_path, current_field, parent_fields, child_fields)) = + stack.pop() + { + match (current_field.data_type(), child_fields) { + (DataType::Struct(unprocessed_children), None) => { + // This is the first time popping off this struct. We don't yet know the + // correct types of its children (i.e., if they need coercing) so we create + // a vector for child_fields, push the struct node back onto the stack to be + // processed again (see below) after processing all its children. + let child_fields = Rc::new(RefCell::new(Vec::with_capacity( + unprocessed_children.len(), + ))); + // Note that here we push the struct back onto the stack with its + // parent_fields in the same position, now with Some(child_fields). + stack.push(( + parquet_path.clone(), + current_field, + parent_fields, + Some(Rc::clone(&child_fields)), + )); + // Push all the children in reverse to maintain original schema order due to + // stack processing. + for child in unprocessed_children.into_iter().rev() { + let mut child_path = parquet_path.clone(); + // Build up a normalized path that we'll use as a key into the original + // int96_fields set above to test if this originated as int96. + child_path.push("."); + child_path.push(child.name()); + // Note that here we push the field onto the stack using the struct's + // new child_fields vector as the field's parent_fields. + stack.push((child_path, child, Rc::clone(&child_fields), None)); + } + } + (DataType::Struct(unprocessed_children), Some(processed_children)) => { + // This is the second time popping off this struct. The child_fields vector + // now contains each field that has been DFS'd into, and we can construct + // the resulting struct with correct child types. + let processed_children = processed_children.borrow(); + assert_eq!(processed_children.len(), unprocessed_children.len()); + let processed_struct = Field::new_struct( + current_field.name(), + processed_children.as_slice(), + current_field.is_nullable(), + ); + parent_fields.borrow_mut().push(Arc::new(processed_struct)); + } + (DataType::List(unprocessed_child), None) => { + // This is the first time popping off this list. See struct docs above. + let child_fields = Rc::new(RefCell::new(Vec::with_capacity(1))); + stack.push(( + parquet_path.clone(), + current_field, + parent_fields, + Some(Rc::clone(&child_fields)), + )); + let mut child_path = parquet_path.clone(); + // Spark uses a definition for arrays/lists that results in a group + // named "list" that is not maintained when parsing to Arrow. We just push + // this name into the path. + child_path.push(".list."); + child_path.push(unprocessed_child.name()); + stack.push(( + child_path.clone(), + unprocessed_child, + Rc::clone(&child_fields), + None, + )); + } + (DataType::List(_), Some(processed_children)) => { + // This is the second time popping off this list. See struct docs above. + let processed_children = processed_children.borrow(); + assert_eq!(processed_children.len(), 1); + let processed_list = Field::new_list( + current_field.name(), + Arc::clone(&processed_children[0]), + current_field.is_nullable(), + ); + parent_fields.borrow_mut().push(Arc::new(processed_list)); + } + (DataType::Map(unprocessed_child, _), None) => { + // This is the first time popping off this map. See struct docs above. + let child_fields = Rc::new(RefCell::new(Vec::with_capacity(1))); + stack.push(( + parquet_path.clone(), + current_field, + parent_fields, + Some(Rc::clone(&child_fields)), + )); + let mut child_path = parquet_path.clone(); + child_path.push("."); + child_path.push(unprocessed_child.name()); + stack.push(( + child_path.clone(), + unprocessed_child, + Rc::clone(&child_fields), + None, + )); + } + (DataType::Map(_, sorted), Some(processed_children)) => { + // This is the second time popping off this map. See struct docs above. + let processed_children = processed_children.borrow(); + assert_eq!(processed_children.len(), 1); + let processed_map = Field::new( + current_field.name(), + DataType::Map(Arc::clone(&processed_children[0]), *sorted), + current_field.is_nullable(), + ); + parent_fields.borrow_mut().push(Arc::new(processed_map)); + } + (DataType::Timestamp(TimeUnit::Nanosecond, None), None) + if int96_fields.contains(parquet_path.concat().as_str()) => + // We found a timestamp(nanos) and it originated as int96. Coerce it to the correct + // time_unit. + { + parent_fields.borrow_mut().push(field_with_new_type( + current_field, + DataType::Timestamp(*time_unit, None), + )); + } + // Other types can be cloned as they are. + _ => parent_fields.borrow_mut().push(Arc::clone(current_field)), } - _ => Arc::clone(field), - }) - .collect(); + } + assert_eq!(fields.borrow().len(), file_schema.fields.len()); + Schema::new_with_metadata( + fields.borrow_mut().clone(), + file_schema.metadata.clone(), + ) + }; - Some(Schema::new_with_metadata( - transformed_fields, - file_schema.metadata.clone(), - )) + Some(transformed_schema) } /// Coerces the file schema if the table schema uses a view type. @@ -825,6 +960,7 @@ async fn fetch_schema( store: &dyn ObjectStore, file: &ObjectMeta, metadata_size_hint: Option, + coerce_int96: Option, ) -> Result { let metadata = fetch_parquet_metadata(store, file, metadata_size_hint).await?; let file_metadata = metadata.file_metadata(); @@ -832,6 +968,11 @@ async fn fetch_schema( file_metadata.schema_descr(), file_metadata.key_value_metadata(), )?; + let schema = coerce_int96 + .and_then(|time_unit| { + coerce_int96_to_resolution(file_metadata.schema_descr(), &schema, &time_unit) + }) + .unwrap_or(schema); Ok(schema) } @@ -938,7 +1079,7 @@ pub fn statistics_from_parquet_meta_calc( .ok(); } Err(e) => { - debug!("Failed to create statistics converter: {}", e); + debug!("Failed to create statistics converter: {e}"); null_counts_array[idx] = Precision::Exact(num_rows); } } @@ -1090,9 +1231,18 @@ impl ParquetSink { &self, location: &Path, object_store: Arc, + context: &Arc, parquet_props: WriterProperties, ) -> Result> { - let buf_writer = BufWriter::new(object_store, location.clone()); + let buf_writer = BufWriter::with_capacity( + object_store, + location.clone(), + context + .session_config() + .options() + .execution + .objectstore_writer_buffer_size, + ); let options = ArrowWriterOptions::new() .with_properties(parquet_props) .with_skip_arrow_metadata(self.parquet_options.global.skip_arrow_metadata); @@ -1148,12 +1298,12 @@ impl FileSink for ParquetSink { .create_async_arrow_writer( &path, Arc::clone(&object_store), + context, parquet_props.clone(), ) .await?; - let mut reservation = - MemoryConsumer::new(format!("ParquetSink[{}]", path)) - .register(context.memory_pool()); + let mut reservation = MemoryConsumer::new(format!("ParquetSink[{path}]")) + .register(context.memory_pool()); file_write_tasks.spawn(async move { while let Some(batch) = rx.recv().await { writer.write(&batch).await?; @@ -1166,14 +1316,21 @@ impl FileSink for ParquetSink { Ok((path, file_metadata)) }); } else { - let writer = create_writer( + let writer = ObjectWriterBuilder::new( // Parquet files as a whole are never compressed, since they // manage compressed blocks themselves. FileCompressionType::UNCOMPRESSED, &path, Arc::clone(&object_store), ) - .await?; + .with_buffer_size(Some( + context + .session_config() + .options() + .execution + .objectstore_writer_buffer_size, + )) + .build()?; let schema = get_writer_schema(&self.config); let props = parquet_props.clone(); let parallel_options_clone = parallel_options.clone(); @@ -1585,3 +1742,220 @@ fn create_max_min_accs( .collect(); (max_values, min_values) } + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use super::*; + + use arrow::datatypes::DataType; + use parquet::schema::parser::parse_message_type; + + #[test] + fn coerce_int96_to_resolution_with_mixed_timestamps() { + // Unclear if Spark (or other writer) could generate a file with mixed timestamps like this, + // but we want to test the scenario just in case since it's at least a valid schema as far + // as the Parquet spec is concerned. + let spark_schema = " + message spark_schema { + optional int96 c0; + optional int64 c1 (TIMESTAMP(NANOS,true)); + optional int64 c2 (TIMESTAMP(NANOS,false)); + optional int64 c3 (TIMESTAMP(MILLIS,true)); + optional int64 c4 (TIMESTAMP(MILLIS,false)); + optional int64 c5 (TIMESTAMP(MICROS,true)); + optional int64 c6 (TIMESTAMP(MICROS,false)); + } + "; + + let schema = parse_message_type(spark_schema).expect("should parse schema"); + let descr = SchemaDescriptor::new(Arc::new(schema)); + + let arrow_schema = parquet_to_arrow_schema(&descr, None).unwrap(); + + let result = + coerce_int96_to_resolution(&descr, &arrow_schema, &TimeUnit::Microsecond) + .unwrap(); + + // Only the first field (c0) should be converted to a microsecond timestamp because it's the + // only timestamp that originated from an INT96. + let expected_schema = Schema::new(vec![ + Field::new("c0", DataType::Timestamp(TimeUnit::Microsecond, None), true), + Field::new( + "c1", + DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())), + true, + ), + Field::new("c2", DataType::Timestamp(TimeUnit::Nanosecond, None), true), + Field::new( + "c3", + DataType::Timestamp(TimeUnit::Millisecond, Some("UTC".into())), + true, + ), + Field::new("c4", DataType::Timestamp(TimeUnit::Millisecond, None), true), + Field::new( + "c5", + DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into())), + true, + ), + Field::new("c6", DataType::Timestamp(TimeUnit::Microsecond, None), true), + ]); + + assert_eq!(result, expected_schema); + } + + #[test] + fn coerce_int96_to_resolution_with_nested_types() { + // This schema is derived from Comet's CometFuzzTestSuite ParquetGenerator only using int96 + // primitive types with generateStruct, generateArray, and generateMap set to true, with one + // additional field added to c4's struct to make sure all fields in a struct get modified. + // https://github.com/apache/datafusion-comet/blob/main/spark/src/main/scala/org/apache/comet/testing/ParquetGenerator.scala + let spark_schema = " + message spark_schema { + optional int96 c0; + optional group c1 { + optional int96 c0; + } + optional group c2 { + optional group c0 (LIST) { + repeated group list { + optional int96 element; + } + } + } + optional group c3 (LIST) { + repeated group list { + optional int96 element; + } + } + optional group c4 (LIST) { + repeated group list { + optional group element { + optional int96 c0; + optional int96 c1; + } + } + } + optional group c5 (MAP) { + repeated group key_value { + required int96 key; + optional int96 value; + } + } + optional group c6 (LIST) { + repeated group list { + optional group element (MAP) { + repeated group key_value { + required int96 key; + optional int96 value; + } + } + } + } + } + "; + + let schema = parse_message_type(spark_schema).expect("should parse schema"); + let descr = SchemaDescriptor::new(Arc::new(schema)); + + let arrow_schema = parquet_to_arrow_schema(&descr, None).unwrap(); + + let result = + coerce_int96_to_resolution(&descr, &arrow_schema, &TimeUnit::Microsecond) + .unwrap(); + + let expected_schema = Schema::new(vec![ + Field::new("c0", DataType::Timestamp(TimeUnit::Microsecond, None), true), + Field::new_struct( + "c1", + vec![Field::new( + "c0", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + )], + true, + ), + Field::new_struct( + "c2", + vec![Field::new_list( + "c0", + Field::new( + "element", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + ), + true, + )], + true, + ), + Field::new_list( + "c3", + Field::new( + "element", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + ), + true, + ), + Field::new_list( + "c4", + Field::new_struct( + "element", + vec![ + Field::new( + "c0", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + ), + Field::new( + "c1", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + ), + ], + true, + ), + true, + ), + Field::new_map( + "c5", + "key_value", + Field::new( + "key", + DataType::Timestamp(TimeUnit::Microsecond, None), + false, + ), + Field::new( + "value", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + ), + false, + true, + ), + Field::new_list( + "c6", + Field::new_map( + "element", + "key_value", + Field::new( + "key", + DataType::Timestamp(TimeUnit::Microsecond, None), + false, + ), + Field::new( + "value", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + ), + false, + true, + ), + true, + ), + ]); + + assert_eq!(result, expected_schema); + } +} diff --git a/datafusion/datasource-parquet/src/mod.rs b/datafusion/datasource-parquet/src/mod.rs index 516b13792189..0b4e86240383 100644 --- a/datafusion/datasource-parquet/src/mod.rs +++ b/datafusion/datasource-parquet/src/mod.rs @@ -19,8 +19,6 @@ // https://github.com/apache/datafusion/issues/11143 #![cfg_attr(not(test), deny(clippy::clone_on_ref_ptr))] -//! [`ParquetExec`] FileSource for reading Parquet files - pub mod access_plan; pub mod file_format; mod metrics; @@ -32,28 +30,7 @@ mod row_group_filter; pub mod source; mod writer; -use std::any::Any; -use std::fmt::Formatter; -use std::sync::Arc; - pub use access_plan::{ParquetAccessPlan, RowGroupAccess}; -use arrow::datatypes::SchemaRef; -use datafusion_common::config::{ConfigOptions, TableParquetOptions}; -use datafusion_common::Result; -use datafusion_common::{Constraints, Statistics}; -use datafusion_datasource::file_scan_config::FileScanConfig; -use datafusion_datasource::schema_adapter::SchemaAdapterFactory; -use datafusion_datasource::source::DataSourceExec; -use datafusion_execution::{SendableRecordBatchStream, TaskContext}; -use datafusion_physical_expr::{ - EquivalenceProperties, LexOrdering, Partitioning, PhysicalExpr, -}; -use datafusion_physical_optimizer::pruning::PruningPredicate; -use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; -use datafusion_physical_plan::metrics::MetricsSet; -use datafusion_physical_plan::{ - DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, -}; pub use file_format::*; pub use metrics::ParquetFileMetrics; pub use page_filter::PagePruningAccessPlanFilter; @@ -61,491 +38,4 @@ pub use reader::{DefaultParquetFileReaderFactory, ParquetFileReaderFactory}; pub use row_filter::build_row_filter; pub use row_filter::can_expr_be_pushed_down_with_schemas; pub use row_group_filter::RowGroupAccessPlanFilter; -use source::ParquetSource; pub use writer::plan_to_parquet; - -use datafusion_datasource::file_groups::FileGroup; -use log::debug; - -#[derive(Debug, Clone)] -#[deprecated(since = "46.0.0", note = "use DataSourceExec instead")] -/// Deprecated Execution plan replaced with DataSourceExec -pub struct ParquetExec { - inner: DataSourceExec, - base_config: FileScanConfig, - table_parquet_options: TableParquetOptions, - /// Optional predicate for row filtering during parquet scan - predicate: Option>, - /// Optional predicate for pruning row groups (derived from `predicate`) - pruning_predicate: Option>, - /// Optional user defined parquet file reader factory - parquet_file_reader_factory: Option>, - /// Optional user defined schema adapter - schema_adapter_factory: Option>, -} - -#[allow(unused, deprecated)] -impl From for ParquetExecBuilder { - fn from(exec: ParquetExec) -> Self { - exec.into_builder() - } -} - -/// [`ParquetExecBuilder`], deprecated builder for [`ParquetExec`]. -/// -/// ParquetExec is replaced with `DataSourceExec` and it includes `ParquetSource` -/// -/// See example on [`ParquetSource`]. -#[deprecated( - since = "46.0.0", - note = "use DataSourceExec with ParquetSource instead" -)] -#[allow(unused, deprecated)] -pub struct ParquetExecBuilder { - file_scan_config: FileScanConfig, - predicate: Option>, - metadata_size_hint: Option, - table_parquet_options: TableParquetOptions, - parquet_file_reader_factory: Option>, - schema_adapter_factory: Option>, -} - -#[allow(unused, deprecated)] -impl ParquetExecBuilder { - /// Create a new builder to read the provided file scan configuration - pub fn new(file_scan_config: FileScanConfig) -> Self { - Self::new_with_options(file_scan_config, TableParquetOptions::default()) - } - - /// Create a new builder to read the data specified in the file scan - /// configuration with the provided `TableParquetOptions`. - pub fn new_with_options( - file_scan_config: FileScanConfig, - table_parquet_options: TableParquetOptions, - ) -> Self { - Self { - file_scan_config, - predicate: None, - metadata_size_hint: None, - table_parquet_options, - parquet_file_reader_factory: None, - schema_adapter_factory: None, - } - } - - /// Update the list of files groups to read - pub fn with_file_groups(mut self, file_groups: Vec) -> Self { - self.file_scan_config.file_groups = file_groups; - self - } - - /// Set the filter predicate when reading. - /// - /// See the "Predicate Pushdown" section of the [`ParquetExec`] documentation - /// for more details. - pub fn with_predicate(mut self, predicate: Arc) -> Self { - self.predicate = Some(predicate); - self - } - - /// Set the metadata size hint - /// - /// This value determines how many bytes at the end of the file the default - /// [`ParquetFileReaderFactory`] will request in the initial IO. If this is - /// too small, the ParquetExec will need to make additional IO requests to - /// read the footer. - pub fn with_metadata_size_hint(mut self, metadata_size_hint: usize) -> Self { - self.metadata_size_hint = Some(metadata_size_hint); - self - } - - /// Set the options for controlling how the ParquetExec reads parquet files. - /// - /// See also [`Self::new_with_options`] - pub fn with_table_parquet_options( - mut self, - table_parquet_options: TableParquetOptions, - ) -> Self { - self.table_parquet_options = table_parquet_options; - self - } - - /// Set optional user defined parquet file reader factory. - /// - /// You can use [`ParquetFileReaderFactory`] to more precisely control how - /// data is read from parquet files (e.g. skip re-reading metadata, coalesce - /// I/O operations, etc). - /// - /// The default reader factory reads directly from an [`ObjectStore`] - /// instance using individual I/O operations for the footer and each page. - /// - /// If a custom `ParquetFileReaderFactory` is provided, then data access - /// operations will be routed to this factory instead of [`ObjectStore`]. - /// - /// [`ObjectStore`]: object_store::ObjectStore - pub fn with_parquet_file_reader_factory( - mut self, - parquet_file_reader_factory: Arc, - ) -> Self { - self.parquet_file_reader_factory = Some(parquet_file_reader_factory); - self - } - - /// Set optional schema adapter factory. - /// - /// [`SchemaAdapterFactory`] allows user to specify how fields from the - /// parquet file get mapped to that of the table schema. The default schema - /// adapter uses arrow's cast library to map the parquet fields to the table - /// schema. - pub fn with_schema_adapter_factory( - mut self, - schema_adapter_factory: Arc, - ) -> Self { - self.schema_adapter_factory = Some(schema_adapter_factory); - self - } - - /// Convenience: build an `Arc`d `ParquetExec` from this builder - pub fn build_arc(self) -> Arc { - Arc::new(self.build()) - } - - /// Build a [`ParquetExec`] - #[must_use] - pub fn build(self) -> ParquetExec { - let Self { - file_scan_config, - predicate, - metadata_size_hint, - table_parquet_options, - parquet_file_reader_factory, - schema_adapter_factory, - } = self; - let mut parquet = ParquetSource::new(table_parquet_options); - if let Some(predicate) = predicate.clone() { - parquet = parquet - .with_predicate(Arc::clone(&file_scan_config.file_schema), predicate); - } - if let Some(metadata_size_hint) = metadata_size_hint { - parquet = parquet.with_metadata_size_hint(metadata_size_hint) - } - if let Some(parquet_reader_factory) = parquet_file_reader_factory { - parquet = parquet.with_parquet_file_reader_factory(parquet_reader_factory) - } - if let Some(schema_factory) = schema_adapter_factory { - parquet = parquet.with_schema_adapter_factory(schema_factory); - } - - let base_config = file_scan_config.with_source(Arc::new(parquet.clone())); - debug!("Creating ParquetExec, files: {:?}, projection {:?}, predicate: {:?}, limit: {:?}", - base_config.file_groups, base_config.projection, predicate, base_config.limit); - - ParquetExec { - inner: DataSourceExec::new(Arc::new(base_config.clone())), - base_config, - predicate, - pruning_predicate: parquet.pruning_predicate, - schema_adapter_factory: parquet.schema_adapter_factory, - parquet_file_reader_factory: parquet.parquet_file_reader_factory, - table_parquet_options: parquet.table_parquet_options, - } - } -} - -#[allow(unused, deprecated)] -impl ParquetExec { - /// Create a new Parquet reader execution plan provided file list and schema. - pub fn new( - base_config: FileScanConfig, - predicate: Option>, - metadata_size_hint: Option, - table_parquet_options: TableParquetOptions, - ) -> Self { - let mut builder = - ParquetExecBuilder::new_with_options(base_config, table_parquet_options); - if let Some(predicate) = predicate { - builder = builder.with_predicate(predicate); - } - if let Some(metadata_size_hint) = metadata_size_hint { - builder = builder.with_metadata_size_hint(metadata_size_hint); - } - builder.build() - } - /// Return a [`ParquetExecBuilder`]. - /// - /// See example on [`ParquetExec`] and [`ParquetExecBuilder`] for specifying - /// parquet table options. - pub fn builder(file_scan_config: FileScanConfig) -> ParquetExecBuilder { - ParquetExecBuilder::new(file_scan_config) - } - - /// Convert this `ParquetExec` into a builder for modification - pub fn into_builder(self) -> ParquetExecBuilder { - // list out fields so it is clear what is being dropped - // (note the fields which are dropped are re-created as part of calling - // `build` on the builder) - let file_scan_config = self.file_scan_config(); - let parquet = self.parquet_source(); - - ParquetExecBuilder { - file_scan_config, - predicate: parquet.predicate, - metadata_size_hint: parquet.metadata_size_hint, - table_parquet_options: parquet.table_parquet_options, - parquet_file_reader_factory: parquet.parquet_file_reader_factory, - schema_adapter_factory: parquet.schema_adapter_factory, - } - } - fn file_scan_config(&self) -> FileScanConfig { - self.inner - .data_source() - .as_any() - .downcast_ref::() - .unwrap() - .clone() - } - - fn parquet_source(&self) -> ParquetSource { - self.file_scan_config() - .file_source() - .as_any() - .downcast_ref::() - .unwrap() - .clone() - } - - /// [`FileScanConfig`] that controls this scan (such as which files to read) - pub fn base_config(&self) -> &FileScanConfig { - &self.base_config - } - /// Options passed to the parquet reader for this scan - pub fn table_parquet_options(&self) -> &TableParquetOptions { - &self.table_parquet_options - } - /// Optional predicate. - pub fn predicate(&self) -> Option<&Arc> { - self.predicate.as_ref() - } - /// Optional reference to this parquet scan's pruning predicate - pub fn pruning_predicate(&self) -> Option<&Arc> { - self.pruning_predicate.as_ref() - } - /// return the optional file reader factory - pub fn parquet_file_reader_factory( - &self, - ) -> Option<&Arc> { - self.parquet_file_reader_factory.as_ref() - } - /// Optional user defined parquet file reader factory. - pub fn with_parquet_file_reader_factory( - mut self, - parquet_file_reader_factory: Arc, - ) -> Self { - let mut parquet = self.parquet_source(); - parquet.parquet_file_reader_factory = - Some(Arc::clone(&parquet_file_reader_factory)); - let file_source = self.file_scan_config(); - self.inner = self - .inner - .with_data_source(Arc::new(file_source.with_source(Arc::new(parquet)))); - self.parquet_file_reader_factory = Some(parquet_file_reader_factory); - self - } - /// return the optional schema adapter factory - pub fn schema_adapter_factory(&self) -> Option<&Arc> { - self.schema_adapter_factory.as_ref() - } - /// Set optional schema adapter factory. - /// - /// [`SchemaAdapterFactory`] allows user to specify how fields from the - /// parquet file get mapped to that of the table schema. The default schema - /// adapter uses arrow's cast library to map the parquet fields to the table - /// schema. - pub fn with_schema_adapter_factory( - mut self, - schema_adapter_factory: Arc, - ) -> Self { - let mut parquet = self.parquet_source(); - parquet.schema_adapter_factory = Some(Arc::clone(&schema_adapter_factory)); - let file_source = self.file_scan_config(); - self.inner = self - .inner - .with_data_source(Arc::new(file_source.with_source(Arc::new(parquet)))); - self.schema_adapter_factory = Some(schema_adapter_factory); - self - } - /// If true, the predicate will be used during the parquet scan. - /// Defaults to false - /// - /// [`Expr`]: datafusion_expr::Expr - pub fn with_pushdown_filters(mut self, pushdown_filters: bool) -> Self { - let mut parquet = self.parquet_source(); - parquet.table_parquet_options.global.pushdown_filters = pushdown_filters; - let file_source = self.file_scan_config(); - self.inner = self - .inner - .with_data_source(Arc::new(file_source.with_source(Arc::new(parquet)))); - self.table_parquet_options.global.pushdown_filters = pushdown_filters; - self - } - - /// Return the value described in [`Self::with_pushdown_filters`] - fn pushdown_filters(&self) -> bool { - self.parquet_source() - .table_parquet_options - .global - .pushdown_filters - } - /// If true, the `RowFilter` made by `pushdown_filters` may try to - /// minimize the cost of filter evaluation by reordering the - /// predicate [`Expr`]s. If false, the predicates are applied in - /// the same order as specified in the query. Defaults to false. - /// - /// [`Expr`]: datafusion_expr::Expr - pub fn with_reorder_filters(mut self, reorder_filters: bool) -> Self { - let mut parquet = self.parquet_source(); - parquet.table_parquet_options.global.reorder_filters = reorder_filters; - let file_source = self.file_scan_config(); - self.inner = self - .inner - .with_data_source(Arc::new(file_source.with_source(Arc::new(parquet)))); - self.table_parquet_options.global.reorder_filters = reorder_filters; - self - } - /// Return the value described in [`Self::with_reorder_filters`] - fn reorder_filters(&self) -> bool { - self.parquet_source() - .table_parquet_options - .global - .reorder_filters - } - /// If enabled, the reader will read the page index - /// This is used to optimize filter pushdown - /// via `RowSelector` and `RowFilter` by - /// eliminating unnecessary IO and decoding - fn bloom_filter_on_read(&self) -> bool { - self.parquet_source() - .table_parquet_options - .global - .bloom_filter_on_read - } - /// Return the value described in [`ParquetSource::with_enable_page_index`] - fn enable_page_index(&self) -> bool { - self.parquet_source() - .table_parquet_options - .global - .enable_page_index - } - - fn output_partitioning_helper(file_config: &FileScanConfig) -> Partitioning { - Partitioning::UnknownPartitioning(file_config.file_groups.len()) - } - - /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. - fn compute_properties( - schema: SchemaRef, - orderings: &[LexOrdering], - constraints: Constraints, - file_config: &FileScanConfig, - ) -> PlanProperties { - PlanProperties::new( - EquivalenceProperties::new_with_orderings(schema, orderings) - .with_constraints(constraints), - Self::output_partitioning_helper(file_config), // Output Partitioning - EmissionType::Incremental, - Boundedness::Bounded, - ) - } - - /// Updates the file groups to read and recalculates the output partitioning - /// - /// Note this function does not update statistics or other properties - /// that depend on the file groups. - fn with_file_groups_and_update_partitioning( - mut self, - file_groups: Vec, - ) -> Self { - let mut config = self.file_scan_config(); - config.file_groups = file_groups; - self.inner = self.inner.with_data_source(Arc::new(config)); - self - } -} - -#[allow(unused, deprecated)] -impl DisplayAs for ParquetExec { - fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { - self.inner.fmt_as(t, f) - } -} - -#[allow(unused, deprecated)] -impl ExecutionPlan for ParquetExec { - fn name(&self) -> &'static str { - "ParquetExec" - } - - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { - self.inner.properties() - } - - fn children(&self) -> Vec<&Arc> { - // this is a leaf node and has no children - vec![] - } - - fn with_new_children( - self: Arc, - _: Vec>, - ) -> Result> { - Ok(self) - } - - /// Redistribute files across partitions according to their size - /// See comments on `FileGroupPartitioner` for more detail. - fn repartitioned( - &self, - target_partitions: usize, - config: &ConfigOptions, - ) -> Result>> { - self.inner.repartitioned(target_partitions, config) - } - - fn execute( - &self, - partition_index: usize, - ctx: Arc, - ) -> Result { - self.inner.execute(partition_index, ctx) - } - fn metrics(&self) -> Option { - self.inner.metrics() - } - fn statistics(&self) -> Result { - self.inner.statistics() - } - fn fetch(&self) -> Option { - self.inner.fetch() - } - - fn with_fetch(&self, limit: Option) -> Option> { - self.inner.with_fetch(limit) - } -} - -fn should_enable_page_index( - enable_page_index: bool, - page_pruning_predicate: &Option>, -) -> bool { - enable_page_index - && page_pruning_predicate.is_some() - && page_pruning_predicate - .as_ref() - .map(|p| p.filter_number() > 0) - .unwrap_or(false) -} diff --git a/datafusion/datasource-parquet/src/opener.rs b/datafusion/datasource-parquet/src/opener.rs index 4517ed885a20..9e14425074f7 100644 --- a/datafusion/datasource-parquet/src/opener.rs +++ b/datafusion/datasource-parquet/src/opener.rs @@ -23,8 +23,7 @@ use crate::page_filter::PagePruningAccessPlanFilter; use crate::row_group_filter::RowGroupAccessPlanFilter; use crate::{ apply_file_schema_type_coercions, coerce_int96_to_resolution, row_filter, - should_enable_page_index, ParquetAccessPlan, ParquetFileMetrics, - ParquetFileReaderFactory, + ParquetAccessPlan, ParquetFileMetrics, ParquetFileReaderFactory, }; use datafusion_datasource::file_meta::FileMeta; use datafusion_datasource::file_stream::{FileOpenFuture, FileOpener}; @@ -182,7 +181,7 @@ impl FileOpener for ParquetOpener { // Build predicates for this specific file let (pruning_predicate, page_pruning_predicate) = build_pruning_predicates( - &predicate, + predicate.as_ref(), &logical_file_schema, &predicate_creation_errors, ); @@ -234,8 +233,7 @@ impl FileOpener for ParquetOpener { Ok(None) => {} Err(e) => { debug!( - "Ignoring error building row filter for '{:?}': {}", - predicate, e + "Ignoring error building row filter for '{predicate:?}': {e}" ); } }; @@ -394,8 +392,8 @@ pub(crate) fn build_page_pruning_predicate( )) } -fn build_pruning_predicates( - predicate: &Option>, +pub(crate) fn build_pruning_predicates( + predicate: Option<&Arc>, file_schema: &SchemaRef, predicate_creation_errors: &Count, ) -> ( @@ -444,3 +442,15 @@ async fn load_page_index( Ok(reader_metadata) } } + +fn should_enable_page_index( + enable_page_index: bool, + page_pruning_predicate: &Option>, +) -> bool { + enable_page_index + && page_pruning_predicate.is_some() + && page_pruning_predicate + .as_ref() + .map(|p| p.filter_number() > 0) + .unwrap_or(false) +} diff --git a/datafusion/datasource-parquet/src/page_filter.rs b/datafusion/datasource-parquet/src/page_filter.rs index 148527998ab5..84f5c4c2d6d5 100644 --- a/datafusion/datasource-parquet/src/page_filter.rs +++ b/datafusion/datasource-parquet/src/page_filter.rs @@ -28,9 +28,10 @@ use arrow::{ array::ArrayRef, datatypes::{Schema, SchemaRef}, }; +use datafusion_common::pruning::PruningStatistics; use datafusion_common::ScalarValue; use datafusion_physical_expr::{split_conjunction, PhysicalExpr}; -use datafusion_physical_optimizer::pruning::{PruningPredicate, PruningStatistics}; +use datafusion_physical_optimizer::pruning::PruningPredicate; use log::{debug, trace}; use parquet::arrow::arrow_reader::statistics::StatisticsConverter; @@ -333,7 +334,7 @@ fn prune_pages_in_one_row_group( assert_eq!(page_row_counts.len(), values.len()); let mut sum_row = *page_row_counts.first().unwrap(); let mut selected = *values.first().unwrap(); - trace!("Pruned to {:?} using {:?}", values, pruning_stats); + trace!("Pruned to {values:?} using {pruning_stats:?}"); for (i, &f) in values.iter().enumerate().skip(1) { if f == selected { sum_row += *page_row_counts.get(i).unwrap(); diff --git a/datafusion/datasource-parquet/src/row_filter.rs b/datafusion/datasource-parquet/src/row_filter.rs index 366ad058ecc6..db455fed6160 100644 --- a/datafusion/datasource-parquet/src/row_filter.rs +++ b/datafusion/datasource-parquet/src/row_filter.rs @@ -299,6 +299,7 @@ struct PushdownChecker<'schema> { non_primitive_columns: bool, /// Does the expression reference any columns that are in the table /// schema but not in the file schema? + /// This includes partition columns and projected columns. projected_columns: bool, // Indices into the table schema of the columns required to evaluate the expression required_columns: BTreeSet, @@ -366,44 +367,19 @@ fn pushdown_columns( .then_some(checker.required_columns.into_iter().collect())) } -/// creates a PushdownChecker for a single use to check a given column with the given schemes. Used -/// to check preemptively if a column name would prevent pushdowning. -/// effectively does the inverse of [`pushdown_columns`] does, but with a single given column -/// (instead of traversing the entire tree to determine this) -fn would_column_prevent_pushdown(column_name: &str, table_schema: &Schema) -> bool { - let mut checker = PushdownChecker::new(table_schema); - - // the return of this is only used for [`PushdownChecker::f_down()`], so we can safely ignore - // it here. I'm just verifying we know the return type of this so nobody accidentally changes - // the return type of this fn and it gets implicitly ignored here. - let _: Option = checker.check_single_column(column_name); - - // and then return a value based on the state of the checker - checker.prevents_pushdown() -} - /// Recurses through expr as a tree, finds all `column`s, and checks if any of them would prevent /// this expression from being predicate pushed down. If any of them would, this returns false. /// Otherwise, true. +/// Note that the schema passed in here is *not* the physical file schema (as it is not available at that point in time); +/// it is the schema of the table that this expression is being evaluated against minus any projected columns and partition columns. pub fn can_expr_be_pushed_down_with_schemas( - expr: &datafusion_expr::Expr, - _file_schema: &Schema, - table_schema: &Schema, + expr: &Arc, + file_schema: &Schema, ) -> bool { - let mut can_be_pushed = true; - expr.apply(|expr| match expr { - datafusion_expr::Expr::Column(column) => { - can_be_pushed &= !would_column_prevent_pushdown(column.name(), table_schema); - Ok(if can_be_pushed { - TreeNodeRecursion::Jump - } else { - TreeNodeRecursion::Stop - }) - } - _ => Ok(TreeNodeRecursion::Continue), - }) - .unwrap(); // we never return an Err, so we can safely unwrap this - can_be_pushed + match pushdown_columns(expr, file_schema) { + Ok(Some(_)) => true, + Ok(None) | Err(_) => false, + } } /// Calculate the total compressed size of all `Column`'s required for @@ -516,7 +492,7 @@ mod test { use super::*; use datafusion_common::ScalarValue; - use arrow::datatypes::{Field, Fields, TimeUnit::Nanosecond}; + use arrow::datatypes::{Field, TimeUnit::Nanosecond}; use datafusion_datasource::schema_adapter::DefaultSchemaAdapterFactory; use datafusion_expr::{col, Expr}; use datafusion_physical_expr::planner::logical2physical; @@ -581,6 +557,7 @@ mod test { // Test all should fail let expr = col("timestamp_col").lt(Expr::Literal( ScalarValue::TimestampNanosecond(Some(1), Some(Arc::from("UTC"))), + None, )); let expr = logical2physical(&expr, &table_schema); let schema_adapter_factory = Arc::new(DefaultSchemaAdapterFactory); @@ -621,6 +598,7 @@ mod test { // Test all should pass let expr = col("timestamp_col").gt(Expr::Literal( ScalarValue::TimestampNanosecond(Some(0), Some(Arc::from("UTC"))), + None, )); let expr = logical2physical(&expr, &table_schema); let schema_adapter_factory = Arc::new(DefaultSchemaAdapterFactory); @@ -649,73 +627,45 @@ mod test { #[test] fn nested_data_structures_prevent_pushdown() { - let table_schema = get_basic_table_schema(); - - let file_schema = Schema::new(vec![Field::new( - "list_col", - DataType::Struct(Fields::empty()), - true, - )]); + let table_schema = Arc::new(get_lists_table_schema()); - let expr = col("list_col").is_not_null(); + let expr = col("utf8_list").is_not_null(); + let expr = logical2physical(&expr, &table_schema); + check_expression_can_evaluate_against_schema(&expr, &table_schema); - assert!(!can_expr_be_pushed_down_with_schemas( - &expr, - &file_schema, - &table_schema - )); + assert!(!can_expr_be_pushed_down_with_schemas(&expr, &table_schema)); } #[test] fn projected_columns_prevent_pushdown() { let table_schema = get_basic_table_schema(); - let file_schema = - Schema::new(vec![Field::new("existing_col", DataType::Int64, true)]); + let expr = + Arc::new(Column::new("nonexistent_column", 0)) as Arc; - let expr = col("nonexistent_column").is_null(); - - assert!(!can_expr_be_pushed_down_with_schemas( - &expr, - &file_schema, - &table_schema - )); + assert!(!can_expr_be_pushed_down_with_schemas(&expr, &table_schema)); } #[test] fn basic_expr_doesnt_prevent_pushdown() { let table_schema = get_basic_table_schema(); - let file_schema = - Schema::new(vec![Field::new("string_col", DataType::Utf8, true)]); - let expr = col("string_col").is_null(); + let expr = logical2physical(&expr, &table_schema); - assert!(can_expr_be_pushed_down_with_schemas( - &expr, - &file_schema, - &table_schema - )); + assert!(can_expr_be_pushed_down_with_schemas(&expr, &table_schema)); } #[test] fn complex_expr_doesnt_prevent_pushdown() { let table_schema = get_basic_table_schema(); - let file_schema = Schema::new(vec![ - Field::new("string_col", DataType::Utf8, true), - Field::new("bigint_col", DataType::Int64, true), - ]); - let expr = col("string_col") .is_not_null() - .or(col("bigint_col").gt(Expr::Literal(ScalarValue::Int64(Some(5))))); + .or(col("bigint_col").gt(Expr::Literal(ScalarValue::Int64(Some(5)), None))); + let expr = logical2physical(&expr, &table_schema); - assert!(can_expr_be_pushed_down_with_schemas( - &expr, - &file_schema, - &table_schema - )); + assert!(can_expr_be_pushed_down_with_schemas(&expr, &table_schema)); } fn get_basic_table_schema() -> Schema { @@ -730,4 +680,27 @@ mod test { parquet_to_arrow_schema(metadata.file_metadata().schema_descr(), None) .expect("parsing schema") } + + fn get_lists_table_schema() -> Schema { + let testdata = datafusion_common::test_util::parquet_test_data(); + let file = std::fs::File::open(format!("{testdata}/list_columns.parquet")) + .expect("opening file"); + + let reader = SerializedFileReader::new(file).expect("creating reader"); + + let metadata = reader.metadata(); + + parquet_to_arrow_schema(metadata.file_metadata().schema_descr(), None) + .expect("parsing schema") + } + + /// Sanity check that the given expression could be evaluated against the given schema without any errors. + /// This will fail if the expression references columns that are not in the schema or if the types of the columns are incompatible, etc. + fn check_expression_can_evaluate_against_schema( + expr: &Arc, + table_schema: &Arc, + ) -> bool { + let batch = RecordBatch::new_empty(Arc::clone(table_schema)); + expr.evaluate(&batch).is_ok() + } } diff --git a/datafusion/datasource-parquet/src/row_group_filter.rs b/datafusion/datasource-parquet/src/row_group_filter.rs index 13418cdeee22..f9fb9214429d 100644 --- a/datafusion/datasource-parquet/src/row_group_filter.rs +++ b/datafusion/datasource-parquet/src/row_group_filter.rs @@ -21,9 +21,10 @@ use std::sync::Arc; use super::{ParquetAccessPlan, ParquetFileMetrics}; use arrow::array::{ArrayRef, BooleanArray}; use arrow::datatypes::Schema; +use datafusion_common::pruning::PruningStatistics; use datafusion_common::{Column, Result, ScalarValue}; use datafusion_datasource::FileRange; -use datafusion_physical_optimizer::pruning::{PruningPredicate, PruningStatistics}; +use datafusion_physical_optimizer::pruning::PruningPredicate; use parquet::arrow::arrow_reader::statistics::StatisticsConverter; use parquet::arrow::parquet_column; use parquet::basic::Type; @@ -1241,12 +1242,16 @@ mod tests { .run( lit("1").eq(lit("1")).and( col(r#""String""#) - .eq(Expr::Literal(ScalarValue::Utf8View(Some(String::from( - "Hello_Not_Exists", - ))))) - .or(col(r#""String""#).eq(Expr::Literal(ScalarValue::Utf8View( - Some(String::from("Hello_Not_Exists2")), - )))), + .eq(Expr::Literal( + ScalarValue::Utf8View(Some(String::from("Hello_Not_Exists"))), + None, + )) + .or(col(r#""String""#).eq(Expr::Literal( + ScalarValue::Utf8View(Some(String::from( + "Hello_Not_Exists2", + ))), + None, + ))), ), ) .await @@ -1265,7 +1270,7 @@ mod tests { let expr = col(r#""String""#).in_list( (1..25) - .map(|i| lit(format!("Hello_Not_Exists{}", i))) + .map(|i| lit(format!("Hello_Not_Exists{i}"))) .collect::>(), false, ); @@ -1326,15 +1331,18 @@ mod tests { // generate pruning predicate `(String = "Hello") OR (String = "the quick") OR (String = "are you")` .run( col(r#""String""#) - .eq(Expr::Literal(ScalarValue::Utf8View(Some(String::from( - "Hello", - ))))) - .or(col(r#""String""#).eq(Expr::Literal(ScalarValue::Utf8View( - Some(String::from("the quick")), - )))) - .or(col(r#""String""#).eq(Expr::Literal(ScalarValue::Utf8View( - Some(String::from("are you")), - )))), + .eq(Expr::Literal( + ScalarValue::Utf8View(Some(String::from("Hello"))), + None, + )) + .or(col(r#""String""#).eq(Expr::Literal( + ScalarValue::Utf8View(Some(String::from("the quick"))), + None, + ))) + .or(col(r#""String""#).eq(Expr::Literal( + ScalarValue::Utf8View(Some(String::from("are you"))), + None, + ))), ) .await } diff --git a/datafusion/datasource-parquet/src/source.rs b/datafusion/datasource-parquet/src/source.rs index 9034afa89d8c..c3658280ecb4 100644 --- a/datafusion/datasource-parquet/src/source.rs +++ b/datafusion/datasource-parquet/src/source.rs @@ -21,12 +21,13 @@ use std::fmt::Debug; use std::fmt::Formatter; use std::sync::Arc; -use crate::opener::build_page_pruning_predicate; -use crate::opener::build_pruning_predicate; +use crate::opener::{build_page_pruning_predicate, build_pruning_predicate, build_pruning_predicates}; use crate::opener::ParquetOpener; -use crate::page_filter::PagePruningAccessPlanFilter; -use crate::DefaultParquetFileReaderFactory; +use crate::row_filter::can_expr_be_pushed_down_with_schemas; +use crate::{DefaultParquetFileReaderFactory, PagePruningAccessPlanFilter}; use crate::ParquetFileReaderFactory; +use datafusion_common::config::ConfigOptions; +use datafusion_datasource::as_file_source; use datafusion_datasource::file_stream::FileOpener; use datafusion_datasource::schema_adapter::{ DefaultSchemaAdapterFactory, SchemaAdapterFactory, @@ -37,14 +38,19 @@ use datafusion_common::config::TableParquetOptions; use datafusion_common::{DataFusionError, Statistics}; use datafusion_datasource::file::FileSource; use datafusion_datasource::file_scan_config::FileScanConfig; +use datafusion_physical_expr::conjunction; use datafusion_physical_expr_common::physical_expr::fmt_sql; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; -use datafusion_physical_optimizer::pruning::PruningPredicate; -use datafusion_physical_plan::metrics::{ExecutionPlanMetricsSet, MetricBuilder}; +use datafusion_physical_plan::filter_pushdown::FilterPushdownPropagation; +use datafusion_physical_plan::filter_pushdown::PredicateSupport; +use datafusion_physical_plan::filter_pushdown::PredicateSupports; +use datafusion_physical_plan::metrics::{Count, MetricBuilder}; +use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet; use datafusion_physical_plan::DisplayFormatType; use itertools::Itertools; use object_store::ObjectStore; +use datafusion_physical_optimizer::pruning::PruningPredicate; /// Execution plan for reading one or more Parquet files. /// @@ -92,7 +98,7 @@ use object_store::ObjectStore; /// # let predicate = lit(true); /// let source = Arc::new( /// ParquetSource::default() -/// .with_predicate(Arc::clone(&file_schema), predicate) +/// .with_predicate(predicate) /// ); /// // Create a DataSourceExec for reading `file1.parquet` with a file size of 100MB /// let config = FileScanConfigBuilder::new(object_store_url, file_schema, source) @@ -259,6 +265,10 @@ pub struct ParquetSource { pub(crate) table_parquet_options: TableParquetOptions, /// Optional metrics pub(crate) metrics: ExecutionPlanMetricsSet, + /// The schema of the file. + /// In particular, this is the schema of the table without partition columns, + /// *not* the physical schema of the file. + pub(crate) file_schema: Option, /// Optional predicate for row filtering during parquet scan pub(crate) predicate: Option>, /// Optional predicate for pruning row groups (derived from `predicate`) @@ -325,7 +335,6 @@ impl ParquetSource { conf } - /// Options passed to the parquet reader for this scan pub fn table_parquet_options(&self) -> &TableParquetOptions { &self.table_parquet_options @@ -353,29 +362,8 @@ impl ParquetSource { self } - /// return the optional schema adapter factory - pub fn schema_adapter_factory(&self) -> Option<&Arc> { - self.schema_adapter_factory.as_ref() - } - - /// Set optional schema adapter factory. - /// - /// [`SchemaAdapterFactory`] allows user to specify how fields from the - /// parquet file get mapped to that of the table schema. The default schema - /// adapter uses arrow's cast library to map the parquet fields to the table - /// schema. - pub fn with_schema_adapter_factory( - mut self, - schema_adapter_factory: Arc, - ) -> Self { - self.schema_adapter_factory = Some(schema_adapter_factory); - self - } - /// If true, the predicate will be used during the parquet scan. - /// Defaults to false - /// - /// [`Expr`]: datafusion_expr::Expr + /// Defaults to false. pub fn with_pushdown_filters(mut self, pushdown_filters: bool) -> Self { self.table_parquet_options.global.pushdown_filters = pushdown_filters; self @@ -436,10 +424,34 @@ impl ParquetSource { fn bloom_filter_on_read(&self) -> bool { self.table_parquet_options.global.bloom_filter_on_read } + + /// Applies schema adapter factory from the FileScanConfig if present. + /// + /// # Arguments + /// * `conf` - FileScanConfig that may contain a schema adapter factory + /// # Returns + /// The converted FileSource with schema adapter factory applied if provided + pub fn apply_schema_adapter( + self, + conf: &FileScanConfig, + ) -> datafusion_common::Result> { + let file_source: Arc = self.into(); + + // If the FileScanConfig.file_source() has a schema adapter factory, apply it + if let Some(factory) = conf.file_source().schema_adapter_factory() { + file_source.with_schema_adapter_factory( + Arc::::clone(&factory), + ) + } else { + Ok(file_source) + } + } } /// Parses datafusion.common.config.ParquetOptions.coerce_int96 String to a arrow_schema.datatype.TimeUnit -fn parse_coerce_int96_string(str_setting: &str) -> datafusion_common::Result { +pub(crate) fn parse_coerce_int96_string( + str_setting: &str, +) -> datafusion_common::Result { let str_setting_lower: &str = &str_setting.to_lowercase(); match str_setting_lower { @@ -454,6 +466,13 @@ fn parse_coerce_int96_string(str_setting: &str) -> datafusion_common::Result for Arc { + fn from(source: ParquetSource) -> Self { + as_file_source(source) + } +} + impl FileSource for ParquetSource { fn create_file_opener( &self, @@ -513,8 +532,11 @@ impl FileSource for ParquetSource { Arc::new(conf) } - fn with_schema(&self, _schema: SchemaRef) -> Arc { - Arc::new(Self { ..self.clone() }) + fn with_schema(&self, schema: SchemaRef) -> Arc { + Arc::new(Self { + file_schema: Some(schema), + ..self.clone() + }) } fn with_statistics(&self, statistics: Statistics) -> Arc { @@ -559,25 +581,41 @@ impl FileSource for ParquetSource { .predicate() .map(|p| format!(", predicate={p}")) .unwrap_or_default(); - let pruning_predicate_string = self - .pruning_predicate - .as_ref() - .map(|pre| { - let mut guarantees = pre + + write!(f, "{predicate_string}")?; + + // Try to build a the pruning predicates. + // These are only generated here because it's useful to have *some* + // idea of what pushdown is happening when viewing plans. + // However it is important to note that these predicates are *not* + // necessarily the predicates that are actually evaluated: + // the actual predicates are built in reference to the physical schema of + // each file, which we do not have at this point and hence cannot use. + // Instead we use the logical schema of the file (the table schema without partition columns). + if let (Some(file_schema), Some(predicate)) = + (&self.file_schema, &self.predicate) + { + let predicate_creation_errors = Count::new(); + if let (Some(pruning_predicate), _) = build_pruning_predicates( + Some(predicate), + file_schema, + &predicate_creation_errors, + ) { + let mut guarantees = pruning_predicate .literal_guarantees() .iter() - .map(|item| format!("{}", item)) + .map(|item| format!("{item}")) .collect_vec(); guarantees.sort(); - format!( + writeln!( + f, ", pruning_predicate={}, required_guarantees=[{}]", - pre.predicate_expr(), + pruning_predicate.predicate_expr(), guarantees.join(", ") - ) - }) - .unwrap_or_default(); - - write!(f, "{}{}", predicate_string, pruning_predicate_string) + )?; + } + }; + Ok(()) } DisplayFormatType::TreeRender => { if let Some(predicate) = self.predicate() { @@ -587,4 +625,83 @@ impl FileSource for ParquetSource { } } } + + fn try_pushdown_filters( + &self, + filters: Vec>, + config: &ConfigOptions, + ) -> datafusion_common::Result>> { + let Some(file_schema) = self.file_schema.clone() else { + return Ok(FilterPushdownPropagation::unsupported(filters)); + }; + // Determine if based on configs we should push filters down. + // If either the table / scan itself or the config has pushdown enabled, + // we will push down the filters. + // If both are disabled, we will not push down the filters. + // By default they are both disabled. + // Regardless of pushdown, we will update the predicate to include the filters + // because even if scan pushdown is disabled we can still use the filters for stats pruning. + let config_pushdown_enabled = config.execution.parquet.pushdown_filters; + let table_pushdown_enabled = self.pushdown_filters(); + let pushdown_filters = table_pushdown_enabled || config_pushdown_enabled; + + let mut source = self.clone(); + let mut allowed_filters = vec![]; + let mut remaining_filters = vec![]; + for filter in &filters { + if can_expr_be_pushed_down_with_schemas(filter, &file_schema) { + // This filter can be pushed down + allowed_filters.push(Arc::clone(filter)); + } else { + // This filter cannot be pushed down + remaining_filters.push(Arc::clone(filter)); + } + } + if allowed_filters.is_empty() { + // No filters can be pushed down, so we can just return the remaining filters + // and avoid replacing the source in the physical plan. + return Ok(FilterPushdownPropagation::unsupported(filters)); + } + let predicate = match source.predicate { + Some(predicate) => conjunction( + std::iter::once(predicate).chain(allowed_filters.iter().cloned()), + ), + None => conjunction(allowed_filters.iter().cloned()), + }; + source.predicate = Some(predicate); + source = source.with_pushdown_filters(pushdown_filters); + let source = Arc::new(source); + let filters = PredicateSupports::new( + allowed_filters + .into_iter() + .map(|f| { + if pushdown_filters { + PredicateSupport::Supported(f) + } else { + PredicateSupport::Unsupported(f) + } + }) + .chain( + remaining_filters + .into_iter() + .map(PredicateSupport::Unsupported), + ) + .collect(), + ); + Ok(FilterPushdownPropagation::with_filters(filters).with_updated_node(source)) + } + + fn with_schema_adapter_factory( + &self, + schema_adapter_factory: Arc, + ) -> datafusion_common::Result> { + Ok(Arc::new(Self { + schema_adapter_factory: Some(schema_adapter_factory), + ..self.clone() + })) + } + + fn schema_adapter_factory(&self) -> Option> { + self.schema_adapter_factory.clone() + } } diff --git a/datafusion/datasource-parquet/src/writer.rs b/datafusion/datasource-parquet/src/writer.rs index 64eb37c81f5d..d37b6e26a753 100644 --- a/datafusion/datasource-parquet/src/writer.rs +++ b/datafusion/datasource-parquet/src/writer.rs @@ -46,7 +46,15 @@ pub async fn plan_to_parquet( let propclone = writer_properties.clone(); let storeref = Arc::clone(&store); - let buf_writer = BufWriter::new(storeref, file.clone()); + let buf_writer = BufWriter::with_capacity( + storeref, + file.clone(), + task_ctx + .session_config() + .options() + .execution + .objectstore_writer_buffer_size, + ); let mut stream = plan.execute(i, Arc::clone(&task_ctx))?; join_set.spawn(async move { let mut writer = diff --git a/datafusion/datasource-parquet/tests/apply_schema_adapter_tests.rs b/datafusion/datasource-parquet/tests/apply_schema_adapter_tests.rs new file mode 100644 index 000000000000..955cd224e6a4 --- /dev/null +++ b/datafusion/datasource-parquet/tests/apply_schema_adapter_tests.rs @@ -0,0 +1,206 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod parquet_adapter_tests { + use arrow::{ + datatypes::{DataType, Field, Schema, SchemaRef}, + record_batch::RecordBatch, + }; + use datafusion_common::{ColumnStatistics, DataFusionError, Result}; + use datafusion_datasource::{ + file::FileSource, + file_scan_config::FileScanConfigBuilder, + schema_adapter::{SchemaAdapter, SchemaAdapterFactory, SchemaMapper}, + }; + use datafusion_datasource_parquet::source::ParquetSource; + use datafusion_execution::object_store::ObjectStoreUrl; + use std::{fmt::Debug, sync::Arc}; + + /// A test schema adapter factory that adds prefix to column names + #[derive(Debug)] + struct PrefixAdapterFactory { + prefix: String, + } + + impl SchemaAdapterFactory for PrefixAdapterFactory { + fn create( + &self, + projected_table_schema: SchemaRef, + _table_schema: SchemaRef, + ) -> Box { + Box::new(PrefixAdapter { + input_schema: projected_table_schema, + prefix: self.prefix.clone(), + }) + } + } + + /// A test schema adapter that adds prefix to column names + #[derive(Debug)] + struct PrefixAdapter { + input_schema: SchemaRef, + prefix: String, + } + + impl SchemaAdapter for PrefixAdapter { + fn map_column_index(&self, index: usize, file_schema: &Schema) -> Option { + let field = self.input_schema.field(index); + file_schema.fields.find(field.name()).map(|(i, _)| i) + } + + fn map_schema( + &self, + file_schema: &Schema, + ) -> Result<(Arc, Vec)> { + let mut projection = Vec::with_capacity(file_schema.fields().len()); + for (file_idx, file_field) in file_schema.fields().iter().enumerate() { + if self.input_schema.fields().find(file_field.name()).is_some() { + projection.push(file_idx); + } + } + + // Create a schema mapper that adds a prefix to column names + #[derive(Debug)] + struct PrefixSchemaMapping { + // Keep only the prefix field which is actually used in the implementation + prefix: String, + } + + impl SchemaMapper for PrefixSchemaMapping { + fn map_batch(&self, batch: RecordBatch) -> Result { + // Create a new schema with prefixed field names + let prefixed_fields: Vec = batch + .schema() + .fields() + .iter() + .map(|field| { + Field::new( + format!("{}{}", self.prefix, field.name()), + field.data_type().clone(), + field.is_nullable(), + ) + }) + .collect(); + let prefixed_schema = Arc::new(Schema::new(prefixed_fields)); + + // Create a new batch with the prefixed schema but the same data + let options = arrow::record_batch::RecordBatchOptions::default(); + RecordBatch::try_new_with_options( + prefixed_schema, + batch.columns().to_vec(), + &options, + ) + .map_err(|e| DataFusionError::ArrowError(e, None)) + } + + fn map_column_statistics( + &self, + stats: &[ColumnStatistics], + ) -> Result> { + // For testing, just return the input statistics + Ok(stats.to_vec()) + } + } + + Ok(( + Arc::new(PrefixSchemaMapping { + prefix: self.prefix.clone(), + }), + projection, + )) + } + } + + #[test] + fn test_apply_schema_adapter_with_factory() { + // Create a schema + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, true), + ])); + + // Create a parquet source + let source = ParquetSource::default(); + + // Create a file scan config with source that has a schema adapter factory + let factory = Arc::new(PrefixAdapterFactory { + prefix: "test_".to_string(), + }); + + let file_source = source.clone().with_schema_adapter_factory(factory).unwrap(); + + let config = FileScanConfigBuilder::new( + ObjectStoreUrl::local_filesystem(), + schema.clone(), + file_source, + ) + .build(); + + // Apply schema adapter to a new source + let result_source = source.apply_schema_adapter(&config).unwrap(); + + // Verify the adapter was applied + assert!(result_source.schema_adapter_factory().is_some()); + + // Create adapter and test it produces expected schema + let adapter_factory = result_source.schema_adapter_factory().unwrap(); + let adapter = adapter_factory.create(schema.clone(), schema.clone()); + + // Create a dummy batch to test the schema mapping + let dummy_batch = RecordBatch::new_empty(schema.clone()); + + // Get the file schema (which is the same as the table schema in this test) + let (mapper, _) = adapter.map_schema(&schema).unwrap(); + + // Apply the mapping to get the output schema + let mapped_batch = mapper.map_batch(dummy_batch).unwrap(); + let output_schema = mapped_batch.schema(); + + // Check the column names have the prefix + assert_eq!(output_schema.field(0).name(), "test_id"); + assert_eq!(output_schema.field(1).name(), "test_name"); + } + + #[test] + fn test_apply_schema_adapter_without_factory() { + // Create a schema + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, true), + ])); + + // Create a parquet source + let source = ParquetSource::default(); + + // Convert to Arc + let file_source: Arc = Arc::new(source.clone()); + + // Create a file scan config without a schema adapter factory + let config = FileScanConfigBuilder::new( + ObjectStoreUrl::local_filesystem(), + schema.clone(), + file_source, + ) + .build(); + + // Apply schema adapter function - should pass through the source unchanged + let result_source = source.apply_schema_adapter(&config).unwrap(); + + // Verify no adapter was applied + assert!(result_source.schema_adapter_factory().is_none()); + } +} diff --git a/datafusion/datasource/Cargo.toml b/datafusion/datasource/Cargo.toml index 1088efc268c9..c936e4c1004c 100644 --- a/datafusion/datasource/Cargo.toml +++ b/datafusion/datasource/Cargo.toml @@ -66,7 +66,7 @@ parquet = { workspace = true, optional = true } rand = { workspace = true } tempfile = { workspace = true, optional = true } tokio = { workspace = true } -tokio-util = { version = "0.7.14", features = ["io"], optional = true } +tokio-util = { version = "0.7.15", features = ["io"], optional = true } url = { workspace = true } xz2 = { version = "0.1", optional = true, features = ["static"] } zstd = { version = "0.13", optional = true, default-features = false } diff --git a/datafusion/datasource/benches/split_groups_by_statistics.rs b/datafusion/datasource/benches/split_groups_by_statistics.rs index f7c5e1b44ae0..3876b0b1217b 100644 --- a/datafusion/datasource/benches/split_groups_by_statistics.rs +++ b/datafusion/datasource/benches/split_groups_by_statistics.rs @@ -55,7 +55,7 @@ pub fn compare_split_groups_by_statistics_algorithms(c: &mut Criterion) { group.bench_with_input( BenchmarkId::new( "original", - format!("files={},overlap={:.1}", num_files, overlap), + format!("files={num_files},overlap={overlap:.1}"), ), &( file_groups.clone(), @@ -77,8 +77,8 @@ pub fn compare_split_groups_by_statistics_algorithms(c: &mut Criterion) { for &tp in &target_partitions { group.bench_with_input( BenchmarkId::new( - format!("v2_partitions={}", tp), - format!("files={},overlap={:.1}", num_files, overlap), + format!("v2_partitions={tp}"), + format!("files={num_files},overlap={overlap:.1}"), ), &( file_groups.clone(), diff --git a/datafusion/datasource/src/file.rs b/datafusion/datasource/src/file.rs index 0066f39801a1..c5f21ebf1a0f 100644 --- a/datafusion/datasource/src/file.rs +++ b/datafusion/datasource/src/file.rs @@ -25,17 +25,32 @@ use std::sync::Arc; use crate::file_groups::FileGroupPartitioner; use crate::file_scan_config::FileScanConfig; use crate::file_stream::FileOpener; +use crate::schema_adapter::SchemaAdapterFactory; use arrow::datatypes::SchemaRef; -use datafusion_common::Statistics; -use datafusion_physical_expr::LexOrdering; +use datafusion_common::config::ConfigOptions; +use datafusion_common::{not_impl_err, Result, Statistics}; +use datafusion_physical_expr::{LexOrdering, PhysicalExpr}; +use datafusion_physical_plan::filter_pushdown::FilterPushdownPropagation; use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet; use datafusion_physical_plan::DisplayFormatType; use object_store::ObjectStore; -/// Common file format behaviors needs to implement. +/// Helper function to convert any type implementing FileSource to Arc<dyn FileSource> +pub fn as_file_source(source: T) -> Arc { + Arc::new(source) +} + +/// file format specific behaviors for elements in [`DataSource`] /// -/// See implementation examples such as `ParquetSource`, `CsvSource` +/// See more details on specific implementations: +/// * [`ArrowSource`](https://docs.rs/datafusion/latest/datafusion/datasource/physical_plan/struct.ArrowSource.html) +/// * [`AvroSource`](https://docs.rs/datafusion/latest/datafusion/datasource/physical_plan/struct.AvroSource.html) +/// * [`CsvSource`](https://docs.rs/datafusion/latest/datafusion/datasource/physical_plan/struct.CsvSource.html) +/// * [`JsonSource`](https://docs.rs/datafusion/latest/datafusion/datasource/physical_plan/struct.JsonSource.html) +/// * [`ParquetSource`](https://docs.rs/datafusion/latest/datafusion/datasource/physical_plan/struct.ParquetSource.html) +/// +/// [`DataSource`]: crate::source::DataSource pub trait FileSource: Send + Sync { /// Creates a `dyn FileOpener` based on given parameters fn create_file_opener( @@ -57,7 +72,7 @@ pub trait FileSource: Send + Sync { /// Return execution plan metrics fn metrics(&self) -> &ExecutionPlanMetricsSet; /// Return projected statistics - fn statistics(&self) -> datafusion_common::Result; + fn statistics(&self) -> Result; /// String representation of file source such as "csv", "json", "parquet" fn file_type(&self) -> &str; /// Format FileType specific information @@ -65,17 +80,19 @@ pub trait FileSource: Send + Sync { Ok(()) } - /// If supported by the [`FileSource`], redistribute files across partitions according to their size. - /// Allows custom file formats to implement their own repartitioning logic. + /// If supported by the [`FileSource`], redistribute files across partitions + /// according to their size. Allows custom file formats to implement their + /// own repartitioning logic. /// - /// Provides a default repartitioning behavior, see comments on [`FileGroupPartitioner`] for more detail. + /// The default implementation uses [`FileGroupPartitioner`]. See that + /// struct for more details. fn repartitioned( &self, target_partitions: usize, repartition_file_min_size: usize, output_ordering: Option, config: &FileScanConfig, - ) -> datafusion_common::Result> { + ) -> Result> { if config.file_compression_type.is_compressed() || config.new_lines_in_values { return Ok(None); } @@ -93,4 +110,42 @@ pub trait FileSource: Send + Sync { } Ok(None) } + + /// Try to push down filters into this FileSource. + /// See [`ExecutionPlan::handle_child_pushdown_result`] for more details. + /// + /// [`ExecutionPlan::handle_child_pushdown_result`]: datafusion_physical_plan::ExecutionPlan::handle_child_pushdown_result + fn try_pushdown_filters( + &self, + filters: Vec>, + _config: &ConfigOptions, + ) -> Result>> { + Ok(FilterPushdownPropagation::unsupported(filters)) + } + + /// Set optional schema adapter factory. + /// + /// [`SchemaAdapterFactory`] allows user to specify how fields from the + /// file get mapped to that of the table schema. If you implement this + /// method, you should also implement [`schema_adapter_factory`]. + /// + /// The default implementation returns a not implemented error. + /// + /// [`schema_adapter_factory`]: Self::schema_adapter_factory + fn with_schema_adapter_factory( + &self, + _factory: Arc, + ) -> Result> { + not_impl_err!( + "FileSource {} does not support schema adapter factory", + self.file_type() + ) + } + + /// Returns the current schema adapter factory if set + /// + /// Default implementation returns `None`. + fn schema_adapter_factory(&self) -> Option> { + None + } } diff --git a/datafusion/datasource/src/file_format.rs b/datafusion/datasource/src/file_format.rs index 0e0b7b12e16a..c5b0846c5992 100644 --- a/datafusion/datasource/src/file_format.rs +++ b/datafusion/datasource/src/file_format.rs @@ -28,16 +28,16 @@ use crate::file_compression_type::FileCompressionType; use crate::file_scan_config::FileScanConfig; use crate::file_sink_config::FileSinkConfig; -use arrow::datatypes::{Schema, SchemaRef}; +use arrow::datatypes::SchemaRef; use datafusion_common::file_options::file_type::FileType; use datafusion_common::{internal_err, not_impl_err, GetExt, Result, Statistics}; -use datafusion_expr::Expr; -use datafusion_physical_expr::{LexRequirement, PhysicalExpr}; +use datafusion_physical_expr::LexRequirement; use datafusion_physical_plan::ExecutionPlan; use datafusion_session::Session; use async_trait::async_trait; use object_store::{ObjectMeta, ObjectStore}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; /// Default max records to scan to infer the schema pub const DEFAULT_SCHEMA_INFER_MAX_RECORD: usize = 1000; @@ -109,37 +109,10 @@ pub trait FileFormat: Send + Sync + fmt::Debug { not_impl_err!("Writer not implemented for this format") } - /// Check if the specified file format has support for pushing down the provided filters within - /// the given schemas. Added initially to support the Parquet file format's ability to do this. - fn supports_filters_pushdown( - &self, - _file_schema: &Schema, - _table_schema: &Schema, - _filters: &[&Expr], - ) -> Result { - Ok(FilePushdownSupport::NoSupport) - } - /// Return the related FileSource such as `CsvSource`, `JsonSource`, etc. fn file_source(&self) -> Arc; } -/// An enum to distinguish between different states when determining if certain filters can be -/// pushed down to file scanning -#[derive(Debug, PartialEq)] -pub enum FilePushdownSupport { - /// The file format/system being asked does not support any sort of pushdown. This should be - /// used even if the file format theoretically supports some sort of pushdown, but it's not - /// enabled or implemented yet. - NoSupport, - /// The file format/system being asked *does* support pushdown, but it can't make it work for - /// the provided filter/expression - NotSupportedForFilter, - /// The file format/system being asked *does* support pushdown and *can* make it work for the - /// provided filter/expression - Supported, -} - /// Factory for creating [`FileFormat`] instances based on session and command level options /// /// Users can provide their own `FileFormatFactory` to support arbitrary file formats diff --git a/datafusion/datasource/src/file_groups.rs b/datafusion/datasource/src/file_groups.rs index 15c86427ed00..8bfadbef775c 100644 --- a/datafusion/datasource/src/file_groups.rs +++ b/datafusion/datasource/src/file_groups.rs @@ -420,9 +420,19 @@ impl FileGroup { self.files.push(file); } - /// Get the statistics for this group - pub fn statistics(&self) -> Option<&Statistics> { - self.statistics.as_deref() + /// Get the specific file statistics for the given index + /// If the index is None, return the `FileGroup` statistics + pub fn file_statistics(&self, index: Option) -> Option<&Statistics> { + if let Some(index) = index { + self.files.get(index).and_then(|f| f.statistics.as_deref()) + } else { + self.statistics.as_deref() + } + } + + /// Get the mutable reference to the statistics for this group + pub fn statistics_mut(&mut self) -> Option<&mut Statistics> { + self.statistics.as_mut().map(Arc::make_mut) } /// Partition the list of files into `n` groups @@ -953,8 +963,8 @@ mod test { (Some(_), None) => panic!("Expected Some, got None"), (None, Some(_)) => panic!("Expected None, got Some"), (Some(expected), Some(actual)) => { - let expected_string = format!("{:#?}", expected); - let actual_string = format!("{:#?}", actual); + let expected_string = format!("{expected:#?}"); + let actual_string = format!("{actual:#?}"); assert_eq!(expected_string, actual_string); } } diff --git a/datafusion/datasource/src/file_scan_config.rs b/datafusion/datasource/src/file_scan_config.rs index eecdbdb59a62..ffcb6301280e 100644 --- a/datafusion/datasource/src/file_scan_config.rs +++ b/datafusion/datasource/src/file_scan_config.rs @@ -20,9 +20,21 @@ use std::{ any::Any, borrow::Cow, collections::HashMap, fmt::Debug, fmt::Formatter, - fmt::Result as FmtResult, marker::PhantomData, mem::size_of, sync::Arc, vec, + fmt::Result as FmtResult, marker::PhantomData, sync::Arc, }; +use crate::file_groups::FileGroup; +#[allow(unused_imports)] +use crate::schema_adapter::SchemaAdapterFactory; +use crate::{ + display::FileGroupsDisplay, + file::FileSource, + file_compression_type::FileCompressionType, + file_stream::FileStream, + source::{DataSource, DataSourceExec}, + statistics::MinMaxStatistics, + PartitionedFile, +}; use arrow::{ array::{ ArrayData, ArrayRef, BufferBuilder, DictionaryArray, RecordBatch, @@ -31,15 +43,19 @@ use arrow::{ buffer::Buffer, datatypes::{ArrowNativeType, DataType, Field, Schema, SchemaRef, UInt16Type}, }; -use datafusion_common::{exec_err, ColumnStatistics, Constraints, Result, Statistics}; +use datafusion_common::{ + config::ConfigOptions, exec_err, ColumnStatistics, Constraints, Result, Statistics, +}; use datafusion_common::{DataFusionError, ScalarValue}; use datafusion_execution::{ object_store::ObjectStoreUrl, SendableRecordBatchStream, TaskContext, }; +use datafusion_physical_expr::PhysicalExpr; use datafusion_physical_expr::{ expressions::Column, EquivalenceProperties, LexOrdering, Partitioning, PhysicalSortExpr, }; +use datafusion_physical_plan::filter_pushdown::FilterPushdownPropagation; use datafusion_physical_plan::{ display::{display_orderings, ProjectSchemaDisplay}, metrics::ExecutionPlanMetricsSet, @@ -48,18 +64,7 @@ use datafusion_physical_plan::{ }; use log::{debug, warn}; use object_store::ObjectMeta; - -use crate::file_groups::FileGroup; -use crate::{ - display::FileGroupsDisplay, - file::FileSource, - file_compression_type::FileCompressionType, - file_stream::FileStream, - metadata::MetadataColumn, - source::{DataSource, DataSourceExec}, - statistics::MinMaxStatistics, - PartitionedFile, -}; +use crate::metadata::MetadataColumn; /// The base configurations for a [`DataSourceExec`], the a physical plan for /// any given file format. @@ -73,6 +78,7 @@ use crate::{ /// # use arrow::datatypes::{Field, Fields, DataType, Schema, SchemaRef}; /// # use object_store::ObjectStore; /// # use datafusion_common::Statistics; +/// # use datafusion_common::Result; /// # use datafusion_datasource::file::FileSource; /// # use datafusion_datasource::file_groups::FileGroup; /// # use datafusion_datasource::PartitionedFile; @@ -82,6 +88,7 @@ use crate::{ /// # use datafusion_execution::object_store::ObjectStoreUrl; /// # use datafusion_physical_plan::ExecutionPlan; /// # use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet; +/// # use datafusion_datasource::schema_adapter::SchemaAdapterFactory; /// # let file_schema = Arc::new(Schema::new(vec![ /// # Field::new("c1", DataType::Int32, false), /// # Field::new("c2", DataType::Int32, false), @@ -89,22 +96,26 @@ use crate::{ /// # Field::new("c4", DataType::Int32, false), /// # ])); /// # // Note: crate mock ParquetSource, as ParquetSource is not in the datasource crate +/// #[derive(Clone)] /// # struct ParquetSource { -/// # projected_statistics: Option +/// # projected_statistics: Option, +/// # schema_adapter_factory: Option> /// # }; /// # impl FileSource for ParquetSource { /// # fn create_file_opener(&self, _: Arc, _: &FileScanConfig, _: usize) -> Arc { unimplemented!() } /// # fn as_any(&self) -> &dyn Any { self } /// # fn with_batch_size(&self, _: usize) -> Arc { unimplemented!() } -/// # fn with_schema(&self, _: SchemaRef) -> Arc { unimplemented!() } +/// # fn with_schema(&self, _: SchemaRef) -> Arc { Arc::new(self.clone()) as Arc } /// # fn with_projection(&self, _: &FileScanConfig) -> Arc { unimplemented!() } -/// # fn with_statistics(&self, statistics: Statistics) -> Arc { Arc::new(Self {projected_statistics: Some(statistics)} ) } +/// # fn with_statistics(&self, statistics: Statistics) -> Arc { Arc::new(Self {projected_statistics: Some(statistics), schema_adapter_factory: self.schema_adapter_factory.clone()} ) } /// # fn metrics(&self) -> &ExecutionPlanMetricsSet { unimplemented!() } -/// # fn statistics(&self) -> datafusion_common::Result { Ok(self.projected_statistics.clone().expect("projected_statistics should be set")) } +/// # fn statistics(&self) -> Result { Ok(self.projected_statistics.clone().expect("projected_statistics should be set")) } /// # fn file_type(&self) -> &str { "parquet" } +/// # fn with_schema_adapter_factory(&self, factory: Arc) -> Result> { Ok(Arc::new(Self {projected_statistics: self.projected_statistics.clone(), schema_adapter_factory: Some(factory)} )) } +/// # fn schema_adapter_factory(&self) -> Option> { self.schema_adapter_factory.clone() } /// # } /// # impl ParquetSource { -/// # fn new() -> Self { Self {projected_statistics: None} } +/// # fn new() -> Self { Self {projected_statistics: None, schema_adapter_factory: None} } /// # } /// // create FileScan config for reading parquet files from file:// /// let object_store_url = ObjectStoreUrl::local_filesystem(); @@ -232,9 +243,15 @@ pub struct FileScanConfig { pub struct FileScanConfigBuilder { object_store_url: ObjectStoreUrl, /// Table schema before any projections or partition columns are applied. - /// This schema is used to read the files, but is **not** necessarily the schema of the physical files. - /// Rather this is the schema that the physical file schema will be mapped onto, and the schema that the + /// + /// This schema is used to read the files, but is **not** necessarily the + /// schema of the physical files. Rather this is the schema that the + /// physical file schema will be mapped onto, and the schema that the /// [`DataSourceExec`] will return. + /// + /// This is usually the same as the table schema as specified by the `TableProvider` minus any partition columns. + /// + /// This probably would be better named `table_schema` file_schema: SchemaRef, file_source: Arc, @@ -415,7 +432,9 @@ impl FileScanConfigBuilder { let statistics = statistics.unwrap_or_else(|| Statistics::new_unknown(&file_schema)); - let file_source = file_source.with_statistics(statistics.clone()); + let file_source = file_source + .with_statistics(statistics.clone()) + .with_schema(Arc::clone(&file_schema)); let file_compression_type = file_compression_type.unwrap_or(FileCompressionType::UNCOMPRESSED); let new_lines_in_values = new_lines_in_values.unwrap_or(false); @@ -473,7 +492,6 @@ impl DataSource for FileScanConfig { let source = self .file_source .with_batch_size(batch_size) - .with_schema(Arc::clone(&self.file_schema)) .with_projection(self); let opener = source.create_file_opener(object_store, self, partition); @@ -489,7 +507,8 @@ impl DataSource for FileScanConfig { fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> FmtResult { match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { - let (schema, _, _, orderings) = self.project(); + let schema = self.projected_schema(); + let orderings = get_projected_output_ordering(self, &schema); write!(f, "file_groups=")?; FileGroupsDisplay(&self.file_groups).fmt_as(t, f)?; @@ -591,7 +610,7 @@ impl DataSource for FileScanConfig { &file_scan .projection .clone() - .unwrap_or((0..self.file_schema.fields().len()).collect()), + .unwrap_or_else(|| (0..self.file_schema.fields().len()).collect()), ); DataSourceExec::from_data_source( FileScanConfigBuilder::from(file_scan) @@ -602,6 +621,32 @@ impl DataSource for FileScanConfig { ) as _ })) } + + fn try_pushdown_filters( + &self, + filters: Vec>, + config: &ConfigOptions, + ) -> Result>> { + let result = self.file_source.try_pushdown_filters(filters, config)?; + match result.updated_node { + Some(new_file_source) => { + let file_scan_config = FileScanConfigBuilder::from(self.clone()) + .with_source(new_file_source) + .build(); + Ok(FilterPushdownPropagation { + filters: result.filters, + updated_node: Some(Arc::new(file_scan_config) as _), + }) + } + None => { + // If the file source does not support filter pushdown, return the original config + Ok(FilterPushdownPropagation { + filters: result.filters, + updated_node: None, + }) + } + } + } } impl FileScanConfig { @@ -622,7 +667,9 @@ impl FileScanConfig { file_source: Arc, ) -> Self { let statistics = Statistics::new_unknown(&file_schema); - let file_source = file_source.with_statistics(statistics.clone()); + let file_source = file_source + .with_statistics(statistics.clone()) + .with_schema(Arc::clone(&file_schema)); Self { object_store_url, file_schema, @@ -666,8 +713,8 @@ impl FileScanConfig { match &self.projection { Some(proj) => proj.clone(), None => (0..self.file_schema.fields().len() - + self.table_partition_cols.len() - + self.metadata_cols.len()) + + self.metadata_cols.len() + + self.table_partition_cols.len()) .collect(), } } @@ -786,6 +833,7 @@ impl FileScanConfig { } /// Set the metadata columns of the files + #[deprecated(since = "47.0.0", note = "use FileScanConfigBuilder instead")] pub fn with_metadata_cols(mut self, metadata_cols: Vec) -> Self { self.metadata_cols = metadata_cols; self @@ -954,8 +1002,8 @@ impl FileScanConfig { .filter(|(_, group)| { group.is_empty() || min - > statistics - .max(*group.last().expect("groups should not be empty")) + > statistics + .max(*group.last().expect("groups should not be empty")) }) .min_by_key(|(_, group)| group.len()) { @@ -1016,9 +1064,9 @@ impl FileScanConfig { None, flattened_files.iter().copied(), ) - .map_err(|e| { - e.context("construct min/max statistics for split_groups_by_statistics") - })?; + .map_err(|e| { + e.context("construct min/max statistics for split_groups_by_statistics") + })?; let indices_sorted_by_min = statistics.min_values_sorted(); let mut file_groups_indices: Vec> = vec![]; @@ -1087,7 +1135,8 @@ impl Debug for FileScanConfig { impl DisplayAs for FileScanConfig { fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> FmtResult { - let (schema, _, _, orderings) = self.project(); + let schema = self.projected_schema(); + let orderings = get_projected_output_ordering(self, &schema); write!(f, "file_groups=")?; FileGroupsDisplay(&self.file_groups).fmt_as(t, f)?; @@ -1110,13 +1159,13 @@ impl DisplayAs for FileScanConfig { } } -/// A helper that projects extended (i.e. partition/metadata) columns into the file record batches. +/// A helper that projects partition columns into the file record batches. /// /// One interesting trick is the usage of a cache for the key buffers of the partition column /// dictionaries. Indeed, the partition columns are constant, so the dictionaries that represent them /// have all their keys equal to 0. This enables us to re-use the same "all-zero" buffer across batches, /// which makes the space consumption of the partition columns O(batch_size) instead of O(record_count). -pub struct ExtendedColumnProjector { +pub struct PartitionColumnProjector { /// An Arrow buffer initialized to zeros that represents the key array of all partition /// columns (partition columns are materialized by dictionary arrays with only one /// value in the dictionary, thus all the keys are equal to zero). @@ -1133,9 +1182,9 @@ pub struct ExtendedColumnProjector { metadata_map: HashMap, } -impl ExtendedColumnProjector { - // Create a projector to insert the partitioning/metadata columns into batches read from files - // - `projected_schema`: the target schema with file, partitioning and metadata columns +impl PartitionColumnProjector { + // Create a projector to insert the partitioning columns into batches read from files + // - `projected_schema`: the target schema with both file and partitioning columns // - `table_partition_cols`: all the partitioning column names // - `metadata_cols`: all the metadata column names pub fn new( @@ -1178,7 +1227,7 @@ impl ExtendedColumnProjector { } } - // Transform the batch read from the file by inserting both partitioning and metadata columns + // Transform the batch read from the file by inserting the partitioning columns // to the right positions as deduced from `projected_schema` // - `file_batch`: batch read from the file, with internal projection applied // - `partition_values`: the list of partition values, one for each partition column @@ -1189,12 +1238,9 @@ impl ExtendedColumnProjector { partition_values: &[ScalarValue], metadata: &ObjectMeta, ) -> Result { - // Calculate expected number of columns from the file (excluding partition and metadata columns) - let expected_cols = self.projected_schema.fields().len() - - self.projected_partition_indexes.len() - - self.projected_metadata_indexes.len(); + let expected_cols = + self.projected_schema.fields().len() - self.projected_partition_indexes.len() - self.projected_metadata_indexes.len(); - // Verify the file batch has the expected number of columns if file_batch.columns().len() != expected_cols { return exec_err!( "Unexpected batch schema from file, expected {} cols but got {}", @@ -1203,21 +1249,18 @@ impl ExtendedColumnProjector { ); } - // Start with the columns from the file batch let mut cols = file_batch.columns().to_vec(); - - // Insert partition columns for &(pidx, sidx) in &self.projected_partition_indexes { - // Get the partition value from the provided values - let p_value = partition_values.get(pidx).ok_or_else(|| { - DataFusionError::Execution( - "Invalid partitioning found on disk".to_string(), - ) - })?; + let p_value = + partition_values + .get(pidx) + .ok_or(DataFusionError::Execution( + "Invalid partitioning found on disk".to_string(), + ))?; let mut partition_value = Cow::Borrowed(p_value); - // Check if user forgot to dict-encode the partition value and apply auto-fix if needed + // check if user forgot to dict-encode the partition value let field = self.projected_schema.field(sidx); let expected_data_type = field.data_type(); let actual_data_type = partition_value.data_type(); @@ -1231,7 +1274,6 @@ impl ExtendedColumnProjector { } } - // Create array and insert at the correct schema position cols.insert( sidx, create_output_array( @@ -1239,7 +1281,7 @@ impl ExtendedColumnProjector { partition_value.as_ref(), file_batch.num_rows(), )?, - ); + ) } // Insert metadata columns @@ -1260,13 +1302,12 @@ impl ExtendedColumnProjector { cols.insert(sidx, scalar_value.to_array_of_size(file_batch.num_rows())?); } - // Create a new record batch with all columns in the correct order RecordBatch::try_new_with_options( Arc::clone(&self.projected_schema), cols, &RecordBatchOptions::new().with_row_count(Some(file_batch.num_rows())), ) - .map_err(Into::into) + .map_err(Into::into) } } @@ -1565,15 +1606,11 @@ mod tests { generate_test_files, test_util::MockSource, tests::aggr_test_schema, verify_sort_integrity, }; - use object_store::{path::Path, ObjectMeta}; use super::*; use arrow::{ - array::{ - Int32Array, RecordBatch, StringArray, TimestampMicrosecondArray, UInt64Array, - }, + array::{Int32Array, RecordBatch}, compute::SortOptions, - datatypes::TimeUnit, }; use datafusion_common::stats::Precision; @@ -1581,6 +1618,9 @@ mod tests { use datafusion_expr::{execution_props::ExecutionProps, SortExpr}; use datafusion_physical_expr::create_physical_expr; use std::collections::HashMap; + use arrow::array::{StringArray, TimestampMicrosecondArray, UInt64Array}; + use arrow::datatypes::TimeUnit; + use object_store::path::Path; fn create_physical_sort_expr( e: &SortExpr, @@ -1606,6 +1646,7 @@ mod tests { schema.fields().iter().map(|f| f.name().clone()).collect() } + fn test_object_meta() -> ObjectMeta { ObjectMeta { location: Path::from("test"), @@ -1669,7 +1710,7 @@ mod tests { ); // verify the proj_schema includes the last column and exactly the same the field it is defined - let (proj_schema, _, _, _) = conf.project(); + let proj_schema = conf.projected_schema(); assert_eq!(proj_schema.fields().len(), file_schema.fields().len() + 1); assert_eq!( *proj_schema.field(file_schema.fields().len()), @@ -1775,15 +1816,15 @@ mod tests { assert_eq!(source_statistics, statistics); assert_eq!(source_statistics.column_statistics.len(), 3); - let (proj_schema, ..) = conf.project(); + let proj_schema = conf.projected_schema(); // created a projector for that projected schema - let mut proj = ExtendedColumnProjector::new( + let mut proj = PartitionColumnProjector::new( proj_schema, &partition_cols .iter() .map(|x| x.0.clone()) .collect::>(), - &[], + &[] ); // project first batch @@ -1796,7 +1837,7 @@ mod tests { wrap_partition_value_in_dict(ScalarValue::from("10")), wrap_partition_value_in_dict(ScalarValue::from("26")), ], - &test_object_meta(), + &test_object_meta() ) .expect("Projection of partition columns into record batch failed"); let expected = [ @@ -1825,7 +1866,7 @@ mod tests { wrap_partition_value_in_dict(ScalarValue::from("10")), wrap_partition_value_in_dict(ScalarValue::from("27")), ], - &test_object_meta(), + &test_object_meta() ) .expect("Projection of partition columns into record batch failed"); let expected = [ @@ -1856,7 +1897,7 @@ mod tests { wrap_partition_value_in_dict(ScalarValue::from("10")), wrap_partition_value_in_dict(ScalarValue::from("28")), ], - &test_object_meta(), + &test_object_meta() ) .expect("Projection of partition columns into record batch failed"); let expected = [ @@ -1885,7 +1926,7 @@ mod tests { ScalarValue::from("10"), ScalarValue::from("26"), ], - &test_object_meta(), + &test_object_meta() ) .expect("Projection of partition columns into record batch failed"); let expected = [ @@ -1921,7 +1962,7 @@ mod tests { Statistics::new_unknown(&schema), to_partition_cols(partition_cols), ) - .projected_file_schema(); + .projected_file_schema(); // Assert partition column filtered out in projected file schema let expected_columns = vec!["c1", "c4", "c6"]; @@ -1954,7 +1995,7 @@ mod tests { Statistics::new_unknown(&schema), to_partition_cols(partition_cols), ) - .projected_file_schema(); + .projected_file_schema(); // Assert projected file schema is equal to file schema assert_eq!(projection.fields(), schema.fields()); @@ -2246,10 +2287,10 @@ mod tests { file_schema, Arc::new(MockSource::default()), ) - .with_projection(projection) - .with_statistics(statistics) - .with_table_partition_cols(table_partition_cols) - .build() + .with_projection(projection) + .with_statistics(statistics) + .with_table_partition_cols(table_partition_cols) + .build() } /// Convert partition columns from Vec to Vec @@ -2280,7 +2321,7 @@ mod tests { Arc::new(Int32Array::from(c.1.clone())), ], ) - .unwrap() + .unwrap() } /// Create a test ObjectMeta with given path, size and a fixed timestamp @@ -2309,9 +2350,9 @@ mod tests { file_schema.clone(), Arc::new(MockSource::default()), ) - .with_projection(projection) - .with_metadata_cols(metadata_cols) - .build() + .with_projection(projection) + .with_metadata_cols(metadata_cols) + .build() } #[test] @@ -2380,7 +2421,7 @@ mod tests { Arc::clone(&file_schema), Arc::clone(&file_source), ) - .build(); + .build(); // Verify default values assert_eq!(config.object_store_url, object_store_url); @@ -2511,9 +2552,9 @@ mod tests { file_schema.clone(), Arc::new(MockSource::default()), ) - .with_table_partition_cols(partition_cols) - .with_metadata_cols(metadata_cols) - .build(); + .with_table_partition_cols(partition_cols) + .with_metadata_cols(metadata_cols) + .build(); // Get projected schema let schema = conf.projected_schema(); @@ -2558,7 +2599,7 @@ mod tests { let projected_schema = Arc::new(Schema::new(projected_fields)); // Create projector - let mut projector = ExtendedColumnProjector::new( + let mut projector = PartitionColumnProjector::new( Arc::clone(&projected_schema), &[], // No partition columns &metadata_cols, @@ -2574,7 +2615,7 @@ mod tests { ])), vec![Arc::new(a_values), Arc::new(b_values)], ) - .unwrap(); + .unwrap(); // Apply projection let result = projector.project(file_batch, &[], &object_meta).unwrap(); @@ -2660,7 +2701,7 @@ mod tests { let projected_schema = Arc::new(Schema::new(projected_fields)); // Create projector - let mut projector = ExtendedColumnProjector::new( + let mut projector = PartitionColumnProjector::new( Arc::clone(&projected_schema), &partition_cols, &metadata_cols, @@ -2672,7 +2713,7 @@ mod tests { Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])), vec![Arc::new(a_values)], ) - .unwrap(); + .unwrap(); // Apply projection let result = projector @@ -2724,13 +2765,13 @@ mod tests { Arc::clone(&schema), Arc::clone(&file_source), ) - .with_projection(Some(vec![0, 2])) - .with_limit(Some(10)) - .with_table_partition_cols(partition_cols.clone()) - .with_file(file.clone()) - .with_constraints(Constraints::default()) - .with_newlines_in_values(true) - .build(); + .with_projection(Some(vec![0, 2])) + .with_limit(Some(10)) + .with_table_partition_cols(partition_cols.clone()) + .with_file(file.clone()) + .with_constraints(Constraints::default()) + .with_newlines_in_values(true) + .build(); // Create a new builder from the config let new_builder = FileScanConfigBuilder::from(original_config); @@ -2768,7 +2809,7 @@ mod tests { // Setup sort expression let exec_props = ExecutionProps::new(); let df_schema = DFSchema::try_from_qualified_schema("test", schema.as_ref())?; - let sort_expr = vec![col("value").sort(true, false)]; + let sort_expr = [col("value").sort(true, false)]; let physical_sort_exprs: Vec<_> = sort_expr .iter() @@ -2887,10 +2928,7 @@ mod tests { avg_files_per_partition ); - println!( - "Distribution - min files: {}, max files: {}", - min_size, max_size - ); + println!("Distribution - min files: {min_size}, max files: {max_size}"); } } @@ -2902,7 +2940,7 @@ mod tests { &sort_ordering, 0, ) - .unwrap_err(); + .unwrap_err(); assert!( err.to_string() diff --git a/datafusion/datasource/src/file_stream.rs b/datafusion/datasource/src/file_stream.rs index 307ee66e9b4e..acc8edd3c0d0 100644 --- a/datafusion/datasource/src/file_stream.rs +++ b/datafusion/datasource/src/file_stream.rs @@ -28,7 +28,7 @@ use std::sync::Arc; use std::task::{Context, Poll}; use crate::file_meta::FileMeta; -use crate::file_scan_config::{ExtendedColumnProjector, FileScanConfig}; +use crate::file_scan_config::{PartitionColumnProjector, FileScanConfig}; use crate::PartitionedFile; use arrow::datatypes::SchemaRef; use datafusion_common::error::Result; @@ -60,7 +60,7 @@ pub struct FileStream { /// which can be resolved to a stream of `RecordBatch`. file_opener: Arc, /// The extended (partitioning + metadata) column projector - col_projector: ExtendedColumnProjector, + pc_projector: PartitionColumnProjector, /// The stream state state: FileStreamState, /// File stream specific metrics @@ -79,15 +79,15 @@ impl FileStream { file_opener: Arc, metrics: &ExecutionPlanMetricsSet, ) -> Result { - let (projected_schema, ..) = config.project(); - let col_projector = ExtendedColumnProjector::new( + let projected_schema = config.projected_schema(); + let pc_projector = PartitionColumnProjector::new( Arc::clone(&projected_schema), &config .table_partition_cols .iter() .map(|x| x.name().clone()) .collect::>(), - &config.metadata_cols, + &config.metadata_cols ); let file_group = config.file_groups[partition].clone(); @@ -97,7 +97,7 @@ impl FileStream { projected_schema, remain: config.limit, file_opener, - col_projector, + pc_projector, state: FileStreamState::Idle, file_stream_metrics: FileStreamMetrics::new(metrics, partition), baseline_metrics: BaselineMetrics::new(metrics, partition), @@ -237,7 +237,7 @@ impl FileStream { self.file_stream_metrics.time_scanning_until_data.stop(); self.file_stream_metrics.time_scanning_total.stop(); let result = self - .col_projector + .pc_projector .project(batch, partition_values, object_meta) .map_err(|e| ArrowError::ExternalError(e.into())) .map(|batch| match &mut self.remain { @@ -1009,7 +1009,7 @@ mod tests { #[tokio::test] async fn test_extended_column_projector() -> Result<()> { - use crate::file_scan_config::ExtendedColumnProjector; + use crate::file_scan_config::PartitionColumnProjector; use crate::metadata::MetadataColumn; use arrow::array::{StringArray, UInt64Array}; use arrow::datatypes::{DataType, Field, Schema}; @@ -1056,7 +1056,7 @@ mod tests { let partition_values = vec![ScalarValue::Utf8(Some("2023".to_string()))]; // Create the projector - let mut projector = ExtendedColumnProjector::new( + let mut projector = PartitionColumnProjector::new( schema_with_partition.clone(), &["year".to_string()], &[], @@ -1091,7 +1091,7 @@ mod tests { ])); // Create the projector with metadata columns - let mut projector = ExtendedColumnProjector::new( + let mut projector = PartitionColumnProjector::new( schema_with_metadata.clone(), &[], &[MetadataColumn::Location(None), MetadataColumn::Size], @@ -1129,7 +1129,7 @@ mod tests { ])); // Create the projector - let mut projector = ExtendedColumnProjector::new( + let mut projector = PartitionColumnProjector::new( schema_combined.clone(), &["year".to_string()], &[MetadataColumn::Location(None), MetadataColumn::Size], @@ -1169,7 +1169,7 @@ mod tests { ])); // Create the projector - let mut projector = ExtendedColumnProjector::new( + let mut projector = PartitionColumnProjector::new( schema_mixed.clone(), &["year".to_string()], &[MetadataColumn::Location(None), MetadataColumn::Size], @@ -1261,7 +1261,7 @@ mod tests { )]; // Create the projector - let mut projector = ExtendedColumnProjector::new( + let mut projector = PartitionColumnProjector::new( schema_with_dict.clone(), &["year".to_string()], &[], @@ -1291,7 +1291,7 @@ mod tests { // Test 6: Auto-fix for non-dictionary partition values // Create a projector expecting dictionary-encoded values - let mut projector = ExtendedColumnProjector::new( + let mut projector = PartitionColumnProjector::new( schema_with_dict.clone(), &["year".to_string()], &[], diff --git a/datafusion/datasource/src/memory.rs b/datafusion/datasource/src/memory.rs index 6d0e16ef4b91..54cea71843ee 100644 --- a/datafusion/datasource/src/memory.rs +++ b/datafusion/datasource/src/memory.rs @@ -15,9 +15,9 @@ // specific language governing permissions and limitations // under the License. -//! Execution plan for reading in-memory batches of data - use std::any::Any; +use std::cmp::Ordering; +use std::collections::BinaryHeap; use std::fmt; use std::fmt::Debug; use std::sync::Arc; @@ -25,335 +25,33 @@ use std::sync::Arc; use crate::sink::DataSink; use crate::source::{DataSource, DataSourceExec}; use async_trait::async_trait; -use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; use datafusion_physical_plan::memory::MemoryStream; use datafusion_physical_plan::projection::{ all_alias_free_columns, new_projections_for_columns, ProjectionExec, }; use datafusion_physical_plan::{ common, ColumnarValue, DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, - PhysicalExpr, PlanProperties, SendableRecordBatchStream, Statistics, + PhysicalExpr, SendableRecordBatchStream, Statistics, }; use arrow::array::{RecordBatch, RecordBatchOptions}; use arrow::datatypes::{Schema, SchemaRef}; -use datafusion_common::{ - internal_err, plan_err, project_schema, Constraints, Result, ScalarValue, -}; +use datafusion_common::{internal_err, plan_err, project_schema, Result, ScalarValue}; use datafusion_execution::TaskContext; use datafusion_physical_expr::equivalence::ProjectionMapping; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::utils::collect_columns; use datafusion_physical_expr::{EquivalenceProperties, LexOrdering}; use futures::StreamExt; +use itertools::Itertools; use tokio::sync::RwLock; -/// Execution plan for reading in-memory batches of data -#[derive(Clone)] -#[deprecated( - since = "46.0.0", - note = "use MemorySourceConfig and DataSourceExec instead" -)] -pub struct MemoryExec { - inner: DataSourceExec, - /// The partitions to query - partitions: Vec>, - /// Optional projection - projection: Option>, - // Sort information: one or more equivalent orderings - sort_information: Vec, - /// if partition sizes should be displayed - show_sizes: bool, -} - -#[allow(unused, deprecated)] -impl Debug for MemoryExec { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - self.inner.fmt_as(DisplayFormatType::Default, f) - } -} - -#[allow(unused, deprecated)] -impl DisplayAs for MemoryExec { - fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { - self.inner.fmt_as(t, f) - } -} - -#[allow(unused, deprecated)] -impl ExecutionPlan for MemoryExec { - fn name(&self) -> &'static str { - "MemoryExec" - } - - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { - self.inner.properties() - } - - fn children(&self) -> Vec<&Arc> { - // This is a leaf node and has no children - vec![] - } - - fn with_new_children( - self: Arc, - children: Vec>, - ) -> Result> { - // MemoryExec has no children - if children.is_empty() { - Ok(self) - } else { - internal_err!("Children cannot be replaced in {self:?}") - } - } - - fn execute( - &self, - partition: usize, - context: Arc, - ) -> Result { - self.inner.execute(partition, context) - } - - /// We recompute the statistics dynamically from the arrow metadata as it is pretty cheap to do so - fn statistics(&self) -> Result { - self.inner.statistics() - } - - fn try_swapping_with_projection( - &self, - projection: &ProjectionExec, - ) -> Result>> { - self.inner.try_swapping_with_projection(projection) - } -} - -#[allow(unused, deprecated)] -impl MemoryExec { - /// Create a new execution plan for reading in-memory record batches - /// The provided `schema` should not have the projection applied. - pub fn try_new( - partitions: &[Vec], - schema: SchemaRef, - projection: Option>, - ) -> Result { - let source = MemorySourceConfig::try_new(partitions, schema, projection.clone())?; - let data_source = DataSourceExec::new(Arc::new(source)); - Ok(Self { - inner: data_source, - partitions: partitions.to_vec(), - projection, - sort_information: vec![], - show_sizes: true, - }) - } - - /// Create a new execution plan from a list of constant values (`ValuesExec`) - pub fn try_new_as_values( - schema: SchemaRef, - data: Vec>>, - ) -> Result { - if data.is_empty() { - return plan_err!("Values list cannot be empty"); - } - - let n_row = data.len(); - let n_col = schema.fields().len(); - - // We have this single row batch as a placeholder to satisfy evaluation argument - // and generate a single output row - let placeholder_schema = Arc::new(Schema::empty()); - let placeholder_batch = RecordBatch::try_new_with_options( - Arc::clone(&placeholder_schema), - vec![], - &RecordBatchOptions::new().with_row_count(Some(1)), - )?; - - // Evaluate each column - let arrays = (0..n_col) - .map(|j| { - (0..n_row) - .map(|i| { - let expr = &data[i][j]; - let result = expr.evaluate(&placeholder_batch)?; - - match result { - ColumnarValue::Scalar(scalar) => Ok(scalar), - ColumnarValue::Array(array) if array.len() == 1 => { - ScalarValue::try_from_array(&array, 0) - } - ColumnarValue::Array(_) => { - plan_err!("Cannot have array values in a values list") - } - } - }) - .collect::>>() - .and_then(ScalarValue::iter_to_array) - }) - .collect::>>()?; - - let batch = RecordBatch::try_new_with_options( - Arc::clone(&schema), - arrays, - &RecordBatchOptions::new().with_row_count(Some(n_row)), - )?; - - let partitions = vec![batch]; - Self::try_new_from_batches(Arc::clone(&schema), partitions) - } - - /// Create a new plan using the provided schema and batches. - /// - /// Errors if any of the batches don't match the provided schema, or if no - /// batches are provided. - pub fn try_new_from_batches( - schema: SchemaRef, - batches: Vec, - ) -> Result { - if batches.is_empty() { - return plan_err!("Values list cannot be empty"); - } - - for batch in &batches { - let batch_schema = batch.schema(); - if batch_schema != schema { - return plan_err!( - "Batch has invalid schema. Expected: {}, got: {}", - schema, - batch_schema - ); - } - } - - let partitions = vec![batches]; - let source = MemorySourceConfig { - partitions: partitions.clone(), - schema: Arc::clone(&schema), - projected_schema: Arc::clone(&schema), - projection: None, - sort_information: vec![], - show_sizes: true, - fetch: None, - }; - let data_source = DataSourceExec::new(Arc::new(source)); - Ok(Self { - inner: data_source, - partitions, - projection: None, - sort_information: vec![], - show_sizes: true, - }) - } - - fn memory_source_config(&self) -> MemorySourceConfig { - self.inner - .data_source() - .as_any() - .downcast_ref::() - .unwrap() - .clone() - } - - pub fn with_constraints(mut self, constraints: Constraints) -> Self { - self.inner = self.inner.with_constraints(constraints); - self - } - - /// Set `show_sizes` to determine whether to display partition sizes - pub fn with_show_sizes(mut self, show_sizes: bool) -> Self { - let mut memory_source = self.memory_source_config(); - memory_source.show_sizes = show_sizes; - self.show_sizes = show_sizes; - self.inner = DataSourceExec::new(Arc::new(memory_source)); - self - } - - /// Ref to constraints - pub fn constraints(&self) -> &Constraints { - self.properties().equivalence_properties().constraints() - } - - /// Ref to partitions - pub fn partitions(&self) -> &[Vec] { - &self.partitions - } - - /// Ref to projection - pub fn projection(&self) -> &Option> { - &self.projection - } - - /// Show sizes - pub fn show_sizes(&self) -> bool { - self.show_sizes - } - - /// Ref to sort information - pub fn sort_information(&self) -> &[LexOrdering] { - &self.sort_information - } - - /// A memory table can be ordered by multiple expressions simultaneously. - /// [`EquivalenceProperties`] keeps track of expressions that describe the - /// global ordering of the schema. These columns are not necessarily same; e.g. - /// ```text - /// ┌-------┐ - /// | a | b | - /// |---|---| - /// | 1 | 9 | - /// | 2 | 8 | - /// | 3 | 7 | - /// | 5 | 5 | - /// └---┴---┘ - /// ``` - /// where both `a ASC` and `b DESC` can describe the table ordering. With - /// [`EquivalenceProperties`], we can keep track of these equivalences - /// and treat `a ASC` and `b DESC` as the same ordering requirement. - /// - /// Note that if there is an internal projection, that projection will be - /// also applied to the given `sort_information`. - pub fn try_with_sort_information( - mut self, - sort_information: Vec, - ) -> Result { - self.sort_information = sort_information.clone(); - let mut memory_source = self.memory_source_config(); - memory_source = memory_source.try_with_sort_information(sort_information)?; - self.inner = DataSourceExec::new(Arc::new(memory_source)); - Ok(self) - } - - /// Arc clone of ref to original schema - pub fn original_schema(&self) -> SchemaRef { - Arc::clone(&self.inner.schema()) - } - - /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. - fn compute_properties( - schema: SchemaRef, - orderings: &[LexOrdering], - constraints: Constraints, - partitions: &[Vec], - ) -> PlanProperties { - PlanProperties::new( - EquivalenceProperties::new_with_orderings(schema, orderings) - .with_constraints(constraints), - Partitioning::UnknownPartitioning(partitions.len()), - EmissionType::Incremental, - Boundedness::Bounded, - ) - } -} - /// Data source configuration for reading in-memory batches of data #[derive(Clone, Debug)] pub struct MemorySourceConfig { - /// The partitions to query + /// The partitions to query. + /// + /// Each partition is a `Vec`. partitions: Vec>, /// Schema representing the data before projection schema: SchemaRef, @@ -399,9 +97,7 @@ impl DataSource for MemorySourceConfig { let output_ordering = self .sort_information .first() - .map(|output_ordering| { - format!(", output_ordering={}", output_ordering) - }) + .map(|output_ordering| format!(", output_ordering={output_ordering}")) .unwrap_or_default(); let eq_properties = self.eq_properties(); @@ -409,12 +105,12 @@ impl DataSource for MemorySourceConfig { let constraints = if constraints.is_empty() { String::new() } else { - format!(", {}", constraints) + format!(", {constraints}") }; let limit = self .fetch - .map_or(String::new(), |limit| format!(", fetch={}", limit)); + .map_or(String::new(), |limit| format!(", fetch={limit}")); if self.show_sizes { write!( f, @@ -445,6 +141,39 @@ impl DataSource for MemorySourceConfig { } } + /// If possible, redistribute batches across partitions according to their size. + /// + /// Returns `Ok(None)` if unable to repartition. Preserve output ordering if exists. + /// Refer to [`DataSource::repartitioned`] for further details. + fn repartitioned( + &self, + target_partitions: usize, + _repartition_file_min_size: usize, + output_ordering: Option, + ) -> Result>> { + if self.partitions.is_empty() || self.partitions.len() >= target_partitions + // if have no partitions, or already have more partitions than desired, do not repartition + { + return Ok(None); + } + + let maybe_repartitioned = if let Some(output_ordering) = output_ordering { + self.repartition_preserving_order(target_partitions, output_ordering)? + } else { + self.repartition_evenly_by_size(target_partitions)? + }; + + if let Some(repartitioned) = maybe_repartitioned { + Ok(Some(Arc::new(Self::try_new( + &repartitioned, + self.original_schema(), + self.projection.clone(), + )?))) + } else { + Ok(None) + } + } + fn output_partitioning(&self) -> Partitioning { Partitioning::UnknownPartitioning(self.partitions.len()) } @@ -723,6 +452,226 @@ impl MemorySourceConfig { pub fn original_schema(&self) -> SchemaRef { Arc::clone(&self.schema) } + + /// Repartition while preserving order. + /// + /// Returns `Ok(None)` if cannot fulfill the requested repartitioning, such + /// as having too few batches to fulfill the `target_partitions` or if unable + /// to preserve output ordering. + fn repartition_preserving_order( + &self, + target_partitions: usize, + output_ordering: LexOrdering, + ) -> Result>>> { + if !self.eq_properties().ordering_satisfy(&output_ordering) { + Ok(None) + } else { + let total_num_batches = + self.partitions.iter().map(|b| b.len()).sum::(); + if total_num_batches < target_partitions { + // no way to create the desired repartitioning + return Ok(None); + } + + let cnt_to_repartition = target_partitions - self.partitions.len(); + + // Label the current partitions and their order. + // Such that when we later split up the partitions into smaller sizes, we are maintaining the order. + let to_repartition = self + .partitions + .iter() + .enumerate() + .map(|(idx, batches)| RePartition { + idx: idx + (cnt_to_repartition * idx), // make space in ordering for split partitions + row_count: batches.iter().map(|batch| batch.num_rows()).sum(), + batches: batches.clone(), + }) + .collect_vec(); + + // Put all of the partitions into a heap ordered by `RePartition::partial_cmp`, which sizes + // by count of rows. + let mut max_heap = BinaryHeap::with_capacity(target_partitions); + for rep in to_repartition { + max_heap.push(rep); + } + + // Split the largest partitions into smaller partitions. Maintaining the output + // order of the partitions & newly created partitions. + let mut cannot_split_further = Vec::with_capacity(target_partitions); + for _ in 0..cnt_to_repartition { + // triggers loop for the cnt_to_repartition. So if need another 4 partitions, it attempts to split 4 times. + loop { + // Take the largest item off the heap, and attempt to split. + let Some(to_split) = max_heap.pop() else { + // Nothing left to attempt repartition. Break inner loop. + break; + }; + + // Split the partition. The new partitions will be ordered with idx and idx+1. + let mut new_partitions = to_split.split(); + if new_partitions.len() > 1 { + for new_partition in new_partitions { + max_heap.push(new_partition); + } + // Successful repartition. Break inner loop, and return to outer `cnt_to_repartition` loop. + break; + } else { + cannot_split_further.push(new_partitions.remove(0)); + } + } + } + let mut partitions = max_heap.drain().collect_vec(); + partitions.extend(cannot_split_further); + + // Finally, sort all partitions by the output ordering. + // This was the original ordering of the batches within the partition. We are maintaining this ordering. + partitions.sort_by_key(|p| p.idx); + let partitions = partitions.into_iter().map(|rep| rep.batches).collect_vec(); + + Ok(Some(partitions)) + } + } + + /// Repartition into evenly sized chunks (as much as possible without batch splitting), + /// disregarding any ordering. + /// + /// Current implementation uses a first-fit-decreasing bin packing, modified to enable + /// us to still return the desired count of `target_partitions`. + /// + /// Returns `Ok(None)` if cannot fulfill the requested repartitioning, such + /// as having too few batches to fulfill the `target_partitions`. + fn repartition_evenly_by_size( + &self, + target_partitions: usize, + ) -> Result>>> { + // determine if we have enough total batches to fulfill request + let mut flatten_batches = + self.partitions.clone().into_iter().flatten().collect_vec(); + if flatten_batches.len() < target_partitions { + return Ok(None); + } + + // Take all flattened batches (all in 1 partititon/vec) and divide evenly into the desired number of `target_partitions`. + let total_num_rows = flatten_batches.iter().map(|b| b.num_rows()).sum::(); + // sort by size, so we pack multiple smaller batches into the same partition + flatten_batches.sort_by_key(|b| std::cmp::Reverse(b.num_rows())); + + // Divide. + let mut partitions = + vec![Vec::with_capacity(flatten_batches.len()); target_partitions]; + let mut target_partition_size = total_num_rows.div_ceil(target_partitions); + let mut total_rows_seen = 0; + let mut curr_bin_row_count = 0; + let mut idx = 0; + for batch in flatten_batches { + let row_cnt = batch.num_rows(); + idx = std::cmp::min(idx, target_partitions - 1); + + partitions[idx].push(batch); + curr_bin_row_count += row_cnt; + total_rows_seen += row_cnt; + + if curr_bin_row_count >= target_partition_size { + idx += 1; + curr_bin_row_count = 0; + + // update target_partition_size, to handle very lopsided batch distributions + // while still returning the count of `target_partitions` + if total_rows_seen < total_num_rows { + target_partition_size = (total_num_rows - total_rows_seen) + .div_ceil(target_partitions - idx); + } + } + } + + Ok(Some(partitions)) + } +} + +/// For use in repartitioning, track the total size and original partition index. +/// +/// Do not implement clone, in order to avoid unnecessary copying during repartitioning. +struct RePartition { + /// Original output ordering for the partition. + idx: usize, + /// Total size of the partition, for use in heap ordering + /// (a.k.a. splitting up the largest partitions). + row_count: usize, + /// A partition containing record batches. + batches: Vec, +} + +impl RePartition { + /// Split [`RePartition`] into 2 pieces, consuming self. + /// + /// Returns only 1 partition if cannot be split further. + fn split(self) -> Vec { + if self.batches.len() == 1 { + return vec![self]; + } + + let new_0 = RePartition { + idx: self.idx, // output ordering + row_count: 0, + batches: vec![], + }; + let new_1 = RePartition { + idx: self.idx + 1, // output ordering +1 + row_count: 0, + batches: vec![], + }; + let split_pt = self.row_count / 2; + + let [new_0, new_1] = self.batches.into_iter().fold( + [new_0, new_1], + |[mut new0, mut new1], batch| { + if new0.row_count < split_pt { + new0.add_batch(batch); + } else { + new1.add_batch(batch); + } + [new0, new1] + }, + ); + vec![new_0, new_1] + } + + fn add_batch(&mut self, batch: RecordBatch) { + self.row_count += batch.num_rows(); + self.batches.push(batch); + } +} + +impl PartialOrd for RePartition { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.row_count.cmp(&other.row_count)) + } +} + +impl Ord for RePartition { + fn cmp(&self, other: &Self) -> Ordering { + self.row_count.cmp(&other.row_count) + } +} + +impl PartialEq for RePartition { + fn eq(&self, other: &Self) -> bool { + self.row_count.eq(&other.row_count) + } +} + +impl Eq for RePartition {} + +impl fmt::Display for RePartition { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "{}rows-in-{}batches@{}", + self.row_count, + self.batches.len(), + self.idx + ) + } } /// Type alias for partition data @@ -868,10 +817,14 @@ mod memory_source_tests { #[cfg(test)] mod tests { + use crate::test_util::col; use crate::tests::{aggr_test_schema, make_partition}; use super::*; + use arrow::array::{ArrayRef, Int32Array, Int64Array, StringArray}; + use arrow::compute::SortOptions; + use datafusion_physical_expr::PhysicalSortExpr; use datafusion_physical_plan::expressions::lit; use arrow::datatypes::{DataType, Field}; @@ -976,7 +929,7 @@ mod tests { )?; assert_eq!( - values.statistics()?, + values.partition_statistics(None)?, Statistics { num_rows: Precision::Exact(rows), total_byte_size: Precision::Exact(8), // not important @@ -992,4 +945,462 @@ mod tests { Ok(()) } + + fn batch(row_size: usize) -> RecordBatch { + let a: ArrayRef = Arc::new(Int32Array::from(vec![1; row_size])); + let b: ArrayRef = Arc::new(StringArray::from_iter(vec![Some("foo"); row_size])); + let c: ArrayRef = Arc::new(Int64Array::from_iter(vec![1; row_size])); + RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap() + } + + fn schema() -> SchemaRef { + batch(1).schema() + } + + fn memorysrcconfig_no_partitions( + sort_information: Vec, + ) -> Result { + let partitions = vec![]; + MemorySourceConfig::try_new(&partitions, schema(), None)? + .try_with_sort_information(sort_information) + } + + fn memorysrcconfig_1_partition_1_batch( + sort_information: Vec, + ) -> Result { + let partitions = vec![vec![batch(100)]]; + MemorySourceConfig::try_new(&partitions, schema(), None)? + .try_with_sort_information(sort_information) + } + + fn memorysrcconfig_3_partitions_1_batch_each( + sort_information: Vec, + ) -> Result { + let partitions = vec![vec![batch(100)], vec![batch(100)], vec![batch(100)]]; + MemorySourceConfig::try_new(&partitions, schema(), None)? + .try_with_sort_information(sort_information) + } + + fn memorysrcconfig_3_partitions_with_2_batches_each( + sort_information: Vec, + ) -> Result { + let partitions = vec![ + vec![batch(100), batch(100)], + vec![batch(100), batch(100)], + vec![batch(100), batch(100)], + ]; + MemorySourceConfig::try_new(&partitions, schema(), None)? + .try_with_sort_information(sort_information) + } + + /// Batches of different sizes, with batches ordered by size (100_000, 10_000, 100, 1) + /// in the Memtable partition (a.k.a. vector of batches). + fn memorysrcconfig_1_partition_with_different_sized_batches( + sort_information: Vec, + ) -> Result { + let partitions = vec![vec![batch(100_000), batch(10_000), batch(100), batch(1)]]; + MemorySourceConfig::try_new(&partitions, schema(), None)? + .try_with_sort_information(sort_information) + } + + /// Same as [`memorysrcconfig_1_partition_with_different_sized_batches`], + /// but the batches are ordered differently (not by size) + /// in the Memtable partition (a.k.a. vector of batches). + fn memorysrcconfig_1_partition_with_ordering_not_matching_size( + sort_information: Vec, + ) -> Result { + let partitions = vec![vec![batch(100_000), batch(1), batch(100), batch(10_000)]]; + MemorySourceConfig::try_new(&partitions, schema(), None)? + .try_with_sort_information(sort_information) + } + + fn memorysrcconfig_2_partition_with_different_sized_batches( + sort_information: Vec, + ) -> Result { + let partitions = vec![ + vec![batch(100_000), batch(10_000), batch(1_000)], + vec![batch(2_000), batch(20)], + ]; + MemorySourceConfig::try_new(&partitions, schema(), None)? + .try_with_sort_information(sort_information) + } + + fn memorysrcconfig_2_partition_with_extreme_sized_batches( + sort_information: Vec, + ) -> Result { + let partitions = vec![ + vec![ + batch(100_000), + batch(1), + batch(1), + batch(1), + batch(1), + batch(0), + ], + vec![batch(1), batch(1), batch(1), batch(1), batch(0), batch(100)], + ]; + MemorySourceConfig::try_new(&partitions, schema(), None)? + .try_with_sort_information(sort_information) + } + + /// Assert that we get the expected count of partitions after repartitioning. + /// + /// If None, then we expected the [`DataSource::repartitioned`] to return None. + fn assert_partitioning( + partitioned_datasrc: Option>, + partition_cnt: Option, + ) { + let should_exist = if let Some(partition_cnt) = partition_cnt { + format!("new datasource should exist and have {partition_cnt:?} partitions") + } else { + "new datasource should not exist".into() + }; + + let actual = partitioned_datasrc + .map(|datasrc| datasrc.output_partitioning().partition_count()); + assert_eq!( + actual, + partition_cnt, + "partitioned datasrc does not match expected, we expected {should_exist}, instead found {actual:?}" + ); + } + + fn run_all_test_scenarios( + output_ordering: Option, + sort_information_on_config: Vec, + ) -> Result<()> { + let not_used = usize::MAX; + + // src has no partitions + let mem_src_config = + memorysrcconfig_no_partitions(sort_information_on_config.clone())?; + let partitioned_datasrc = + mem_src_config.repartitioned(1, not_used, output_ordering.clone())?; + assert_partitioning(partitioned_datasrc, None); + + // src has partitions == target partitions (=1) + let target_partitions = 1; + let mem_src_config = + memorysrcconfig_1_partition_1_batch(sort_information_on_config.clone())?; + let partitioned_datasrc = mem_src_config.repartitioned( + target_partitions, + not_used, + output_ordering.clone(), + )?; + assert_partitioning(partitioned_datasrc, None); + + // src has partitions == target partitions (=3) + let target_partitions = 3; + let mem_src_config = memorysrcconfig_3_partitions_1_batch_each( + sort_information_on_config.clone(), + )?; + let partitioned_datasrc = mem_src_config.repartitioned( + target_partitions, + not_used, + output_ordering.clone(), + )?; + assert_partitioning(partitioned_datasrc, None); + + // src has partitions > target partitions, but we don't merge them + let target_partitions = 2; + let mem_src_config = memorysrcconfig_3_partitions_1_batch_each( + sort_information_on_config.clone(), + )?; + let partitioned_datasrc = mem_src_config.repartitioned( + target_partitions, + not_used, + output_ordering.clone(), + )?; + assert_partitioning(partitioned_datasrc, None); + + // src has partitions < target partitions, but not enough batches (per partition) to split into more partitions + let target_partitions = 4; + let mem_src_config = memorysrcconfig_3_partitions_1_batch_each( + sort_information_on_config.clone(), + )?; + let partitioned_datasrc = mem_src_config.repartitioned( + target_partitions, + not_used, + output_ordering.clone(), + )?; + assert_partitioning(partitioned_datasrc, None); + + // src has partitions < target partitions, and can split to sufficient amount + // has 6 batches across 3 partitions. Will need to split 2 of it's partitions. + let target_partitions = 5; + let mem_src_config = memorysrcconfig_3_partitions_with_2_batches_each( + sort_information_on_config.clone(), + )?; + let partitioned_datasrc = mem_src_config.repartitioned( + target_partitions, + not_used, + output_ordering.clone(), + )?; + assert_partitioning(partitioned_datasrc, Some(5)); + + // src has partitions < target partitions, and can split to sufficient amount + // has 6 batches across 3 partitions. Will need to split all of it's partitions. + let target_partitions = 6; + let mem_src_config = memorysrcconfig_3_partitions_with_2_batches_each( + sort_information_on_config.clone(), + )?; + let partitioned_datasrc = mem_src_config.repartitioned( + target_partitions, + not_used, + output_ordering.clone(), + )?; + assert_partitioning(partitioned_datasrc, Some(6)); + + // src has partitions < target partitions, but not enough total batches to fulfill the split (desired target_partitions) + let target_partitions = 3 * 2 + 1; + let mem_src_config = memorysrcconfig_3_partitions_with_2_batches_each( + sort_information_on_config.clone(), + )?; + let partitioned_datasrc = mem_src_config.repartitioned( + target_partitions, + not_used, + output_ordering.clone(), + )?; + assert_partitioning(partitioned_datasrc, None); + + // src has 1 partition with many batches of lopsided sizes + // make sure it handles the split properly + let target_partitions = 2; + let mem_src_config = memorysrcconfig_1_partition_with_different_sized_batches( + sort_information_on_config, + )?; + let partitioned_datasrc = mem_src_config.clone().repartitioned( + target_partitions, + not_used, + output_ordering, + )?; + assert_partitioning(partitioned_datasrc.clone(), Some(2)); + // Starting = batch(100_000), batch(10_000), batch(100), batch(1). + // It should have split as p1=batch(100_000), p2=[batch(10_000), batch(100), batch(1)] + let partitioned_datasrc = partitioned_datasrc.unwrap(); + let Some(mem_src_config) = partitioned_datasrc + .as_any() + .downcast_ref::() + else { + unreachable!() + }; + let repartitioned_raw_batches = mem_src_config.partitions.clone(); + assert_eq!(repartitioned_raw_batches.len(), 2); + let [ref p1, ref p2] = repartitioned_raw_batches[..] else { + unreachable!() + }; + // p1=batch(100_000) + assert_eq!(p1.len(), 1); + assert_eq!(p1[0].num_rows(), 100_000); + // p2=[batch(10_000), batch(100), batch(1)] + assert_eq!(p2.len(), 3); + assert_eq!(p2[0].num_rows(), 10_000); + assert_eq!(p2[1].num_rows(), 100); + assert_eq!(p2[2].num_rows(), 1); + + Ok(()) + } + + #[test] + fn test_repartition_no_sort_information_no_output_ordering() -> Result<()> { + let no_sort = vec![]; + let no_output_ordering = None; + + // Test: common set of functionality + run_all_test_scenarios(no_output_ordering.clone(), no_sort.clone())?; + + // Test: how no-sort-order divides differently. + // * does not preserve separate partitions (with own internal ordering) on even split, + // * nor does it preserve ordering (re-orders batch(2_000) vs batch(1_000)). + let target_partitions = 3; + let mem_src_config = + memorysrcconfig_2_partition_with_different_sized_batches(no_sort)?; + let partitioned_datasrc = mem_src_config.clone().repartitioned( + target_partitions, + usize::MAX, + no_output_ordering, + )?; + assert_partitioning(partitioned_datasrc.clone(), Some(3)); + // Starting = batch(100_000), batch(10_000), batch(1_000), batch(2_000), batch(20) + // It should have split as p1=batch(100_000), p2=batch(10_000), p3=rest(mixed across original partitions) + let repartitioned_raw_batches = mem_src_config + .repartition_evenly_by_size(target_partitions)? + .unwrap(); + assert_eq!(repartitioned_raw_batches.len(), 3); + let [ref p1, ref p2, ref p3] = repartitioned_raw_batches[..] else { + unreachable!() + }; + // p1=batch(100_000) + assert_eq!(p1.len(), 1); + assert_eq!(p1[0].num_rows(), 100_000); + // p2=batch(10_000) + assert_eq!(p2.len(), 1); + assert_eq!(p2[0].num_rows(), 10_000); + // p3= batch(2_000), batch(1_000), batch(20) + assert_eq!(p3.len(), 3); + assert_eq!(p3[0].num_rows(), 2_000); + assert_eq!(p3[1].num_rows(), 1_000); + assert_eq!(p3[2].num_rows(), 20); + + Ok(()) + } + + #[test] + fn test_repartition_no_sort_information_no_output_ordering_lopsized_batches( + ) -> Result<()> { + let no_sort = vec![]; + let no_output_ordering = None; + + // Test: case has two input partitions: + // b(100_000), b(1), b(1), b(1), b(1), b(0) + // b(1), b(1), b(1), b(1), b(0), b(100) + // + // We want an output with target_partitions=5, which means the ideal division is: + // b(100_000) + // b(100) + // b(1), b(1), b(1) + // b(1), b(1), b(1) + // b(1), b(1), b(0) + let target_partitions = 5; + let mem_src_config = + memorysrcconfig_2_partition_with_extreme_sized_batches(no_sort)?; + let partitioned_datasrc = mem_src_config.clone().repartitioned( + target_partitions, + usize::MAX, + no_output_ordering, + )?; + assert_partitioning(partitioned_datasrc.clone(), Some(5)); + // Starting partition 1 = batch(100_000), batch(1), batch(1), batch(1), batch(1), batch(0) + // Starting partition 1 = batch(1), batch(1), batch(1), batch(1), batch(0), batch(100) + // It should have split as p1=batch(100_000), p2=batch(100), p3=[batch(1),batch(1)], p4=[batch(1),batch(1)], p5=[batch(1),batch(1),batch(0),batch(0)] + let repartitioned_raw_batches = mem_src_config + .repartition_evenly_by_size(target_partitions)? + .unwrap(); + assert_eq!(repartitioned_raw_batches.len(), 5); + let [ref p1, ref p2, ref p3, ref p4, ref p5] = repartitioned_raw_batches[..] + else { + unreachable!() + }; + // p1=batch(100_000) + assert_eq!(p1.len(), 1); + assert_eq!(p1[0].num_rows(), 100_000); + // p2=batch(100) + assert_eq!(p2.len(), 1); + assert_eq!(p2[0].num_rows(), 100); + // p3=[batch(1),batch(1),batch(1)] + assert_eq!(p3.len(), 3); + assert_eq!(p3[0].num_rows(), 1); + assert_eq!(p3[1].num_rows(), 1); + assert_eq!(p3[2].num_rows(), 1); + // p4=[batch(1),batch(1),batch(1)] + assert_eq!(p4.len(), 3); + assert_eq!(p4[0].num_rows(), 1); + assert_eq!(p4[1].num_rows(), 1); + assert_eq!(p4[2].num_rows(), 1); + // p5=[batch(1),batch(1),batch(0),batch(0)] + assert_eq!(p5.len(), 4); + assert_eq!(p5[0].num_rows(), 1); + assert_eq!(p5[1].num_rows(), 1); + assert_eq!(p5[2].num_rows(), 0); + assert_eq!(p5[3].num_rows(), 0); + + Ok(()) + } + + #[test] + fn test_repartition_with_sort_information() -> Result<()> { + let schema = schema(); + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { + expr: col("c", &schema).unwrap(), + options: SortOptions::default(), + }]); + let has_sort = vec![sort_key.clone()]; + let output_ordering = Some(sort_key); + + // Test: common set of functionality + run_all_test_scenarios(output_ordering.clone(), has_sort.clone())?; + + // Test: DOES preserve separate partitions (with own internal ordering) + let target_partitions = 3; + let mem_src_config = + memorysrcconfig_2_partition_with_different_sized_batches(has_sort)?; + let partitioned_datasrc = mem_src_config.clone().repartitioned( + target_partitions, + usize::MAX, + output_ordering.clone(), + )?; + assert_partitioning(partitioned_datasrc.clone(), Some(3)); + // Starting = batch(100_000), batch(10_000), batch(1_000), batch(2_000), batch(20) + // It should have split as p1=batch(100_000), p2=[batch(10_000),batch(1_000)], p3= + let Some(output_ord) = output_ordering else { + unreachable!() + }; + let repartitioned_raw_batches = mem_src_config + .repartition_preserving_order(target_partitions, output_ord)? + .unwrap(); + assert_eq!(repartitioned_raw_batches.len(), 3); + let [ref p1, ref p2, ref p3] = repartitioned_raw_batches[..] else { + unreachable!() + }; + // p1=batch(100_000) + assert_eq!(p1.len(), 1); + assert_eq!(p1[0].num_rows(), 100_000); + // p2=[batch(10_000),batch(1_000)] + assert_eq!(p2.len(), 2); + assert_eq!(p2[0].num_rows(), 10_000); + assert_eq!(p2[1].num_rows(), 1_000); + // p3=batch(2_000), batch(20) + assert_eq!(p3.len(), 2); + assert_eq!(p3[0].num_rows(), 2_000); + assert_eq!(p3[1].num_rows(), 20); + + Ok(()) + } + + #[test] + fn test_repartition_with_batch_ordering_not_matching_sizing() -> Result<()> { + let schema = schema(); + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { + expr: col("c", &schema).unwrap(), + options: SortOptions::default(), + }]); + let has_sort = vec![sort_key.clone()]; + let output_ordering = Some(sort_key); + + // src has 1 partition with many batches of lopsided sizes + // note that the input vector of batches are not ordered by decreasing size + let target_partitions = 2; + let mem_src_config = + memorysrcconfig_1_partition_with_ordering_not_matching_size(has_sort)?; + let partitioned_datasrc = mem_src_config.clone().repartitioned( + target_partitions, + usize::MAX, + output_ordering, + )?; + assert_partitioning(partitioned_datasrc.clone(), Some(2)); + // Starting = batch(100_000), batch(1), batch(100), batch(10_000). + // It should have split as p1=batch(100_000), p2=[batch(1), batch(100), batch(10_000)] + let partitioned_datasrc = partitioned_datasrc.unwrap(); + let Some(mem_src_config) = partitioned_datasrc + .as_any() + .downcast_ref::() + else { + unreachable!() + }; + let repartitioned_raw_batches = mem_src_config.partitions.clone(); + assert_eq!(repartitioned_raw_batches.len(), 2); + let [ref p1, ref p2] = repartitioned_raw_batches[..] else { + unreachable!() + }; + // p1=batch(100_000) + assert_eq!(p1.len(), 1); + assert_eq!(p1[0].num_rows(), 100_000); + // p2=[batch(1), batch(100), batch(10_000)] -- **this is preserving the partition order** + assert_eq!(p2.len(), 3); + assert_eq!(p2[0].num_rows(), 1); + assert_eq!(p2[1].num_rows(), 100); + assert_eq!(p2[2].num_rows(), 10_000); + + Ok(()) + } } diff --git a/datafusion/datasource/src/mod.rs b/datafusion/datasource/src/mod.rs index a9915156c879..d7042f45e14f 100644 --- a/datafusion/datasource/src/mod.rs +++ b/datafusion/datasource/src/mod.rs @@ -49,6 +49,7 @@ pub mod test_util; pub mod url; pub mod write; +pub use self::file::as_file_source; pub use self::url::ListingTableUrl; use crate::file_groups::FileGroup; use chrono::TimeZone; @@ -380,7 +381,7 @@ pub fn generate_test_files(num_files: usize, overlap_factor: f64) -> Vec datafusion_common::Result; + + /// Adapts file-level column `Statistics` to match the `table_schema` + fn map_column_statistics( + &self, + file_col_statistics: &[ColumnStatistics], + ) -> datafusion_common::Result>; } /// Default [`SchemaAdapterFactory`] for mapping schemas. @@ -219,6 +225,25 @@ pub(crate) struct DefaultSchemaAdapter { projected_table_schema: SchemaRef, } +/// Checks if a file field can be cast to a table field +/// +/// Returns Ok(true) if casting is possible, or an error explaining why casting is not possible +pub(crate) fn can_cast_field( + file_field: &Field, + table_field: &Field, +) -> datafusion_common::Result { + if can_cast_types(file_field.data_type(), table_field.data_type()) { + Ok(true) + } else { + plan_err!( + "Cannot cast file schema field {} of type {:?} to table schema field of type {:?}", + file_field.name(), + file_field.data_type(), + table_field.data_type() + ) + } +} + impl SchemaAdapter for DefaultSchemaAdapter { /// Map a column index in the table schema to a column index in a particular /// file schema @@ -242,40 +267,53 @@ impl SchemaAdapter for DefaultSchemaAdapter { &self, file_schema: &Schema, ) -> datafusion_common::Result<(Arc, Vec)> { - let mut projection = Vec::with_capacity(file_schema.fields().len()); - let mut field_mappings = vec![None; self.projected_table_schema.fields().len()]; - - for (file_idx, file_field) in file_schema.fields.iter().enumerate() { - if let Some((table_idx, table_field)) = - self.projected_table_schema.fields().find(file_field.name()) - { - match can_cast_types(file_field.data_type(), table_field.data_type()) { - true => { - field_mappings[table_idx] = Some(projection.len()); - projection.push(file_idx); - } - false => { - return plan_err!( - "Cannot cast file schema field {} of type {:?} to table schema field of type {:?}", - file_field.name(), - file_field.data_type(), - table_field.data_type() - ) - } - } - } - } + let (field_mappings, projection) = create_field_mapping( + file_schema, + &self.projected_table_schema, + can_cast_field, + )?; Ok(( - Arc::new(SchemaMapping { - projected_table_schema: Arc::clone(&self.projected_table_schema), + Arc::new(SchemaMapping::new( + Arc::clone(&self.projected_table_schema), field_mappings, - }), + )), projection, )) } } +/// Helper function that creates field mappings between file schema and table schema +/// +/// Maps columns from the file schema to their corresponding positions in the table schema, +/// applying type compatibility checking via the provided predicate function. +/// +/// Returns field mappings (for column reordering) and a projection (for field selection). +pub(crate) fn create_field_mapping( + file_schema: &Schema, + projected_table_schema: &SchemaRef, + can_map_field: F, +) -> datafusion_common::Result<(Vec>, Vec)> +where + F: Fn(&Field, &Field) -> datafusion_common::Result, +{ + let mut projection = Vec::with_capacity(file_schema.fields().len()); + let mut field_mappings = vec![None; projected_table_schema.fields().len()]; + + for (file_idx, file_field) in file_schema.fields.iter().enumerate() { + if let Some((table_idx, table_field)) = + projected_table_schema.fields().find(file_field.name()) + { + if can_map_field(file_field, table_field)? { + field_mappings[table_idx] = Some(projection.len()); + projection.push(file_idx); + } + } + } + + Ok((field_mappings, projection)) +} + /// The SchemaMapping struct holds a mapping from the file schema to the table /// schema and any necessary type conversions. /// @@ -298,6 +336,21 @@ pub struct SchemaMapping { field_mappings: Vec>, } +impl SchemaMapping { + /// Creates a new SchemaMapping instance + /// + /// Initializes the field mappings needed to transform file data to the projected table schema + pub fn new( + projected_table_schema: SchemaRef, + field_mappings: Vec>, + ) -> Self { + Self { + projected_table_schema, + field_mappings, + } + } +} + impl SchemaMapper for SchemaMapping { /// Adapts a `RecordBatch` to match the `projected_table_schema` using the stored mapping and /// conversions. @@ -334,4 +387,317 @@ impl SchemaMapper for SchemaMapping { let record_batch = RecordBatch::try_new_with_options(schema, cols, &options)?; Ok(record_batch) } + + /// Adapts file-level column `Statistics` to match the `table_schema` + fn map_column_statistics( + &self, + file_col_statistics: &[ColumnStatistics], + ) -> datafusion_common::Result> { + let mut table_col_statistics = vec![]; + + // Map the statistics for each field in the file schema to the corresponding field in the + // table schema, if a field is not present in the file schema, we need to fill it with `ColumnStatistics::new_unknown` + for (_, file_col_idx) in self + .projected_table_schema + .fields() + .iter() + .zip(&self.field_mappings) + { + if let Some(file_col_idx) = file_col_idx { + table_col_statistics.push( + file_col_statistics + .get(*file_col_idx) + .cloned() + .unwrap_or_default(), + ); + } else { + table_col_statistics.push(ColumnStatistics::new_unknown()); + } + } + + Ok(table_col_statistics) + } +} + +#[cfg(test)] +mod tests { + use arrow::datatypes::{DataType, Field}; + use datafusion_common::{stats::Precision, Statistics}; + + use super::*; + + #[test] + fn test_schema_mapping_map_statistics_basic() { + // Create table schema (a, b, c) + let table_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Utf8, true), + Field::new("c", DataType::Float64, true), + ])); + + // Create file schema (b, a) - different order, missing c + let file_schema = Schema::new(vec![ + Field::new("b", DataType::Utf8, true), + Field::new("a", DataType::Int32, true), + ]); + + // Create SchemaAdapter + let adapter = DefaultSchemaAdapter { + projected_table_schema: Arc::clone(&table_schema), + }; + + // Get mapper and projection + let (mapper, projection) = adapter.map_schema(&file_schema).unwrap(); + + // Should project columns 0,1 from file + assert_eq!(projection, vec![0, 1]); + + // Create file statistics + let mut file_stats = Statistics::default(); + + // Statistics for column b (index 0 in file) + let b_stats = ColumnStatistics { + null_count: Precision::Exact(5), + ..Default::default() + }; + + // Statistics for column a (index 1 in file) + let a_stats = ColumnStatistics { + null_count: Precision::Exact(10), + ..Default::default() + }; + + file_stats.column_statistics = vec![b_stats, a_stats]; + + // Map statistics + let table_col_stats = mapper + .map_column_statistics(&file_stats.column_statistics) + .unwrap(); + + // Verify stats + assert_eq!(table_col_stats.len(), 3); + assert_eq!(table_col_stats[0].null_count, Precision::Exact(10)); // a from file idx 1 + assert_eq!(table_col_stats[1].null_count, Precision::Exact(5)); // b from file idx 0 + assert_eq!(table_col_stats[2].null_count, Precision::Absent); // c (unknown) + } + + #[test] + fn test_schema_mapping_map_statistics_empty() { + // Create schemas + let table_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Utf8, true), + ])); + let file_schema = Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Utf8, true), + ]); + + let adapter = DefaultSchemaAdapter { + projected_table_schema: Arc::clone(&table_schema), + }; + let (mapper, _) = adapter.map_schema(&file_schema).unwrap(); + + // Empty file statistics + let file_stats = Statistics::default(); + let table_col_stats = mapper + .map_column_statistics(&file_stats.column_statistics) + .unwrap(); + + // All stats should be unknown + assert_eq!(table_col_stats.len(), 2); + assert_eq!(table_col_stats[0], ColumnStatistics::new_unknown(),); + assert_eq!(table_col_stats[1], ColumnStatistics::new_unknown(),); + } + + #[test] + fn test_can_cast_field() { + // Same type should work + let from_field = Field::new("col", DataType::Int32, true); + let to_field = Field::new("col", DataType::Int32, true); + assert!(can_cast_field(&from_field, &to_field).unwrap()); + + // Casting Int32 to Float64 is allowed + let from_field = Field::new("col", DataType::Int32, true); + let to_field = Field::new("col", DataType::Float64, true); + assert!(can_cast_field(&from_field, &to_field).unwrap()); + + // Casting Float64 to Utf8 should work (converts to string) + let from_field = Field::new("col", DataType::Float64, true); + let to_field = Field::new("col", DataType::Utf8, true); + assert!(can_cast_field(&from_field, &to_field).unwrap()); + + // Binary to Utf8 is not supported - this is an example of a cast that should fail + // Note: We use Binary instead of Utf8->Int32 because Arrow actually supports that cast + let from_field = Field::new("col", DataType::Binary, true); + let to_field = Field::new("col", DataType::Decimal128(10, 2), true); + let result = can_cast_field(&from_field, &to_field); + assert!(result.is_err()); + let error_msg = result.unwrap_err().to_string(); + assert!(error_msg.contains("Cannot cast file schema field col")); + } + + #[test] + fn test_create_field_mapping() { + // Define the table schema + let table_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Utf8, true), + Field::new("c", DataType::Float64, true), + ])); + + // Define file schema: different order, missing column c, and b has different type + let file_schema = Schema::new(vec![ + Field::new("b", DataType::Float64, true), // Different type but castable to Utf8 + Field::new("a", DataType::Int32, true), // Same type + Field::new("d", DataType::Boolean, true), // Not in table schema + ]); + + // Custom can_map_field function that allows all mappings for testing + let allow_all = |_: &Field, _: &Field| Ok(true); + + // Test field mapping + let (field_mappings, projection) = + create_field_mapping(&file_schema, &table_schema, allow_all).unwrap(); + + // Expected: + // - field_mappings[0] (a) maps to projection[1] + // - field_mappings[1] (b) maps to projection[0] + // - field_mappings[2] (c) is None (not in file) + assert_eq!(field_mappings, vec![Some(1), Some(0), None]); + assert_eq!(projection, vec![0, 1]); // Projecting file columns b, a + + // Test with a failing mapper + let fails_all = |_: &Field, _: &Field| Ok(false); + let (field_mappings, projection) = + create_field_mapping(&file_schema, &table_schema, fails_all).unwrap(); + + // Should have no mappings or projections if all cast checks fail + assert_eq!(field_mappings, vec![None, None, None]); + assert_eq!(projection, Vec::::new()); + + // Test with error-producing mapper + let error_mapper = |_: &Field, _: &Field| plan_err!("Test error"); + let result = create_field_mapping(&file_schema, &table_schema, error_mapper); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("Test error")); + } + + #[test] + fn test_schema_mapping_new() { + // Define the projected table schema + let projected_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Utf8, true), + ])); + + // Define field mappings from table to file + let field_mappings = vec![Some(1), Some(0)]; + + // Create SchemaMapping manually + let mapping = + SchemaMapping::new(Arc::clone(&projected_schema), field_mappings.clone()); + + // Check that fields were set correctly + assert_eq!(*mapping.projected_table_schema, *projected_schema); + assert_eq!(mapping.field_mappings, field_mappings); + + // Test with a batch to ensure it works properly + let batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![ + Field::new("b_file", DataType::Utf8, true), + Field::new("a_file", DataType::Int32, true), + ])), + vec![ + Arc::new(arrow::array::StringArray::from(vec!["hello", "world"])), + Arc::new(arrow::array::Int32Array::from(vec![1, 2])), + ], + ) + .unwrap(); + + // Test that map_batch works with our manually created mapping + let mapped_batch = mapping.map_batch(batch).unwrap(); + + // Verify the mapped batch has the correct schema and data + assert_eq!(*mapped_batch.schema(), *projected_schema); + assert_eq!(mapped_batch.num_columns(), 2); + assert_eq!(mapped_batch.column(0).len(), 2); // a column + assert_eq!(mapped_batch.column(1).len(), 2); // b column + } + + #[test] + fn test_map_schema_error_path() { + // Define the table schema + let table_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Utf8, true), + Field::new("c", DataType::Decimal128(10, 2), true), // Use Decimal which has stricter cast rules + ])); + + // Define file schema with incompatible type for column c + let file_schema = Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Float64, true), // Different but castable + Field::new("c", DataType::Binary, true), // Not castable to Decimal128 + ]); + + // Create DefaultSchemaAdapter + let adapter = DefaultSchemaAdapter { + projected_table_schema: Arc::clone(&table_schema), + }; + + // map_schema should error due to incompatible types + let result = adapter.map_schema(&file_schema); + assert!(result.is_err()); + let error_msg = result.unwrap_err().to_string(); + assert!(error_msg.contains("Cannot cast file schema field c")); + } + + #[test] + fn test_map_schema_happy_path() { + // Define the table schema + let table_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Utf8, true), + Field::new("c", DataType::Decimal128(10, 2), true), + ])); + + // Create DefaultSchemaAdapter + let adapter = DefaultSchemaAdapter { + projected_table_schema: Arc::clone(&table_schema), + }; + + // Define compatible file schema (missing column c) + let compatible_file_schema = Schema::new(vec![ + Field::new("a", DataType::Int64, true), // Can be cast to Int32 + Field::new("b", DataType::Float64, true), // Can be cast to Utf8 + ]); + + // Test successful schema mapping + let (mapper, projection) = adapter.map_schema(&compatible_file_schema).unwrap(); + + // Verify field_mappings and projection created correctly + assert_eq!(projection, vec![0, 1]); // Projecting a and b + + // Verify the SchemaMapping works with actual data + let file_batch = RecordBatch::try_new( + Arc::new(compatible_file_schema.clone()), + vec![ + Arc::new(arrow::array::Int64Array::from(vec![100, 200])), + Arc::new(arrow::array::Float64Array::from(vec![1.5, 2.5])), + ], + ) + .unwrap(); + + let mapped_batch = mapper.map_batch(file_batch).unwrap(); + + // Verify correct schema mapping + assert_eq!(*mapped_batch.schema(), *table_schema); + assert_eq!(mapped_batch.num_columns(), 3); // a, b, c + + // Column c should be null since it wasn't in the file schema + let c_array = mapped_batch.column(2); + assert_eq!(c_array.len(), 2); + assert_eq!(c_array.null_count(), 2); + } } diff --git a/datafusion/datasource/src/source.rs b/datafusion/datasource/src/source.rs index 6c9122ce1ac1..30ecc38709f4 100644 --- a/datafusion/datasource/src/source.rs +++ b/datafusion/datasource/src/source.rs @@ -31,44 +31,119 @@ use datafusion_physical_plan::{ use crate::file_scan_config::FileScanConfig; use datafusion_common::config::ConfigOptions; -use datafusion_common::{Constraints, Statistics}; +use datafusion_common::{Constraints, Result, Statistics}; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; -use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; +use datafusion_physical_expr::{EquivalenceProperties, Partitioning, PhysicalExpr}; use datafusion_physical_expr_common::sort_expr::LexOrdering; +use datafusion_physical_plan::filter_pushdown::{ + ChildPushdownResult, FilterPushdownPropagation, +}; -/// Common behaviors in Data Sources for both from Files and Memory. +/// A source of data, typically a list of files or memory +/// +/// This trait provides common behaviors for abstract sources of data. It has +/// two common implementations: +/// +/// 1. [`FileScanConfig`]: lists of files +/// 2. [`MemorySourceConfig`]: in memory list of `RecordBatch` +/// +/// File format specific behaviors are defined by [`FileSource`] /// /// # See Also -/// * [`DataSourceExec`] for physical plan implementation -/// * [`FileSource`] for file format implementations (Parquet, Json, etc) +/// * [`FileSource`] for file format specific implementations (Parquet, Json, etc) +/// * [`DataSourceExec`]: The [`ExecutionPlan`] that reads from a `DataSource` /// /// # Notes +/// /// Requires `Debug` to assist debugging /// +/// [`FileScanConfig`]: https://docs.rs/datafusion/latest/datafusion/datasource/physical_plan/struct.FileScanConfig.html +/// [`MemorySourceConfig`]: https://docs.rs/datafusion/latest/datafusion/datasource/memory/struct.MemorySourceConfig.html /// [`FileSource`]: crate::file::FileSource +/// [`FileFormat``]: https://docs.rs/datafusion/latest/datafusion/datasource/file_format/index.html +/// [`TableProvider`]: https://docs.rs/datafusion/latest/datafusion/catalog/trait.TableProvider.html +/// +/// The following diagram shows how DataSource, FileSource, and DataSourceExec are related +/// ```text +/// ┌─────────────────────┐ -----► execute path +/// │ │ ┄┄┄┄┄► init path +/// │ DataSourceExec │ +/// │ │ +/// └───────▲─────────────┘ +/// ┊ │ +/// ┊ │ +/// ┌──────────▼──────────┐ ┌──────────-──────────┐ +/// │ │ | | +/// │ DataSource(trait) │ | TableProvider(trait)| +/// │ │ | | +/// └───────▲─────────────┘ └─────────────────────┘ +/// ┊ │ ┊ +/// ┌───────────────┿──┴────────────────┐ ┊ +/// | ┌┄┄┄┄┄┄┄┄┄┄┄┘ | ┊ +/// | ┊ | ┊ +/// ┌──────────▼──────────┐ ┌──────────▼──────────┐ ┊ +/// │ │ │ │ ┌──────────▼──────────┐ +/// │ FileScanConfig │ │ MemorySourceConfig │ | | +/// │ │ │ │ | FileFormat(trait) | +/// └──────────────▲──────┘ └─────────────────────┘ | | +/// │ ┊ └─────────────────────┘ +/// │ ┊ ┊ +/// │ ┊ ┊ +/// ┌──────────▼──────────┐ ┌──────────▼──────────┐ +/// │ │ │ ArrowSource │ +/// │ FileSource(trait) ◄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄│ ... │ +/// │ │ │ ParquetSource │ +/// └─────────────────────┘ └─────────────────────┘ +/// │ +/// │ +/// │ +/// │ +/// ┌──────────▼──────────┐ +/// │ ArrowSource │ +/// │ ... │ +/// │ ParquetSource │ +/// └─────────────────────┘ +/// | +/// FileOpener (called by FileStream) +/// │ +/// ┌──────────▼──────────┐ +/// │ │ +/// │ RecordBatch │ +/// │ │ +/// └─────────────────────┘ +/// ``` pub trait DataSource: Send + Sync + Debug { fn open( &self, partition: usize, context: Arc, - ) -> datafusion_common::Result; + ) -> Result; fn as_any(&self) -> &dyn Any; /// Format this source for display in explain plans fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> fmt::Result; - /// Return a copy of this DataSource with a new partitioning scheme + /// Return a copy of this DataSource with a new partitioning scheme. + /// + /// Returns `Ok(None)` (the default) if the partitioning cannot be changed. + /// Refer to [`ExecutionPlan::repartitioned`] for details on when None should be returned. + /// + /// Repartitioning should not change the output ordering, if this ordering exists. + /// Refer to [`MemorySourceConfig::repartition_preserving_order`](crate::memory::MemorySourceConfig) + /// and the FileSource's + /// [`FileGroupPartitioner::repartition_file_groups`](crate::file_groups::FileGroupPartitioner::repartition_file_groups) + /// for examples. fn repartitioned( &self, _target_partitions: usize, _repartition_file_min_size: usize, _output_ordering: Option, - ) -> datafusion_common::Result>> { + ) -> Result>> { Ok(None) } fn output_partitioning(&self) -> Partitioning; fn eq_properties(&self) -> EquivalenceProperties; - fn statistics(&self) -> datafusion_common::Result; + fn statistics(&self) -> Result; /// Return a copy of this DataSource with a new fetch limit fn with_fetch(&self, _limit: Option) -> Option>; fn fetch(&self) -> Option; @@ -78,17 +153,30 @@ pub trait DataSource: Send + Sync + Debug { fn try_swapping_with_projection( &self, _projection: &ProjectionExec, - ) -> datafusion_common::Result>>; + ) -> Result>>; + /// Try to push down filters into this DataSource. + /// See [`ExecutionPlan::handle_child_pushdown_result`] for more details. + /// + /// [`ExecutionPlan::handle_child_pushdown_result`]: datafusion_physical_plan::ExecutionPlan::handle_child_pushdown_result + fn try_pushdown_filters( + &self, + filters: Vec>, + _config: &ConfigOptions, + ) -> Result>> { + Ok(FilterPushdownPropagation::unsupported(filters)) + } } -/// [`ExecutionPlan`] handles different file formats like JSON, CSV, AVRO, ARROW, PARQUET +/// [`ExecutionPlan`] that reads one or more files /// -/// `DataSourceExec` implements common functionality such as applying projections, -/// and caching plan properties. +/// `DataSourceExec` implements common functionality such as applying +/// projections, and caching plan properties. /// -/// The [`DataSource`] trait describes where to find the data for this data -/// source (for example what files or what in memory partitions). Format -/// specifics are implemented with the [`FileSource`] trait. +/// The [`DataSource`] describes where to find the data for this data source +/// (for example in files or what in memory partitions). +/// +/// For file based [`DataSource`]s, format specific behavior is implemented in +/// the [`FileSource`] trait. /// /// [`FileSource`]: crate::file::FileSource #[derive(Clone, Debug)] @@ -131,15 +219,19 @@ impl ExecutionPlan for DataSourceExec { fn with_new_children( self: Arc, _: Vec>, - ) -> datafusion_common::Result> { + ) -> Result> { Ok(self) } + /// Implementation of [`ExecutionPlan::repartitioned`] which relies upon the inner [`DataSource::repartitioned`]. + /// + /// If the data source does not support changing its partitioning, returns `Ok(None)` (the default). Refer + /// to [`ExecutionPlan::repartitioned`] for more details. fn repartitioned( &self, target_partitions: usize, config: &ConfigOptions, - ) -> datafusion_common::Result>> { + ) -> Result>> { let data_source = self.data_source.repartitioned( target_partitions, config.optimizer.repartition_file_min_size, @@ -163,7 +255,7 @@ impl ExecutionPlan for DataSourceExec { &self, partition: usize, context: Arc, - ) -> datafusion_common::Result { + ) -> Result { self.data_source.open(partition, context) } @@ -171,10 +263,28 @@ impl ExecutionPlan for DataSourceExec { Some(self.data_source.metrics().clone_inner()) } - fn statistics(&self) -> datafusion_common::Result { + fn statistics(&self) -> Result { self.data_source.statistics() } + fn partition_statistics(&self, partition: Option) -> Result { + if let Some(partition) = partition { + let mut statistics = Statistics::new_unknown(&self.schema()); + if let Some(file_config) = + self.data_source.as_any().downcast_ref::() + { + if let Some(file_group) = file_config.file_groups.get(partition) { + if let Some(stat) = file_group.file_statistics(None) { + statistics = stat.clone(); + } + } + } + Ok(statistics) + } else { + Ok(self.data_source.statistics()?) + } + } + fn with_fetch(&self, limit: Option) -> Option> { let data_source = self.data_source.with_fetch(limit)?; let cache = self.cache.clone(); @@ -189,9 +299,37 @@ impl ExecutionPlan for DataSourceExec { fn try_swapping_with_projection( &self, projection: &ProjectionExec, - ) -> datafusion_common::Result>> { + ) -> Result>> { self.data_source.try_swapping_with_projection(projection) } + + fn handle_child_pushdown_result( + &self, + child_pushdown_result: ChildPushdownResult, + config: &ConfigOptions, + ) -> Result>> { + // Push any remaining filters into our data source + let res = self.data_source.try_pushdown_filters( + child_pushdown_result.parent_filters.collect_all(), + config, + )?; + match res.updated_node { + Some(data_source) => { + let mut new_node = self.clone(); + new_node.data_source = data_source; + new_node.cache = + Self::compute_properties(Arc::clone(&new_node.data_source)); + Ok(FilterPushdownPropagation { + filters: res.filters, + updated_node: Some(Arc::new(new_node)), + }) + } + None => Ok(FilterPushdownPropagation { + filters: res.filters, + updated_node: None, + }), + } + } } impl DataSourceExec { @@ -254,3 +392,13 @@ impl DataSourceExec { }) } } + +/// Create a new `DataSourceExec` from a `DataSource` +impl From for DataSourceExec +where + S: DataSource + 'static, +{ + fn from(source: S) -> Self { + Self::new(Arc::new(source)) + } +} diff --git a/datafusion/datasource/src/statistics.rs b/datafusion/datasource/src/statistics.rs index 8a04d77b273d..b42d3bb361b7 100644 --- a/datafusion/datasource/src/statistics.rs +++ b/datafusion/datasource/src/statistics.rs @@ -137,7 +137,9 @@ impl MinMaxStatistics { // Reverse the projection to get the index of the column in the full statistics // The file statistics contains _every_ column , but the sort column's index() // refers to the index in projected_schema - let i = projection.map(|p| p[c.index()]).unwrap_or(c.index()); + let i = projection + .map(|p| p[c.index()]) + .unwrap_or_else(|| c.index()); let (min, max) = get_min_max(i).map_err(|e| { e.context(format!("get min/max for column: '{}'", c.name())) @@ -476,7 +478,7 @@ pub fn compute_all_files_statistics( // Then summary statistics across all file groups let file_groups_statistics = file_groups_with_stats .iter() - .filter_map(|file_group| file_group.statistics()); + .filter_map(|file_group| file_group.file_statistics(None)); let mut statistics = Statistics::try_merge_iter(file_groups_statistics, &table_schema)?; diff --git a/datafusion/datasource/src/test_util.rs b/datafusion/datasource/src/test_util.rs index 9a9b98d5041b..e4a5114aa073 100644 --- a/datafusion/datasource/src/test_util.rs +++ b/datafusion/datasource/src/test_util.rs @@ -17,12 +17,14 @@ use crate::{ file::FileSource, file_scan_config::FileScanConfig, file_stream::FileOpener, + schema_adapter::SchemaAdapterFactory, }; use std::sync::Arc; -use arrow::datatypes::SchemaRef; +use arrow::datatypes::{Schema, SchemaRef}; use datafusion_common::{Result, Statistics}; +use datafusion_physical_expr::{expressions::Column, PhysicalExpr}; use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet; use object_store::ObjectStore; @@ -31,6 +33,7 @@ use object_store::ObjectStore; pub(crate) struct MockSource { metrics: ExecutionPlanMetricsSet, projected_statistics: Option, + schema_adapter_factory: Option>, } impl FileSource for MockSource { @@ -80,4 +83,23 @@ impl FileSource for MockSource { fn file_type(&self) -> &str { "mock" } + + fn with_schema_adapter_factory( + &self, + schema_adapter_factory: Arc, + ) -> Result> { + Ok(Arc::new(Self { + schema_adapter_factory: Some(schema_adapter_factory), + ..self.clone() + })) + } + + fn schema_adapter_factory(&self) -> Option> { + self.schema_adapter_factory.clone() + } +} + +/// Create a column expression +pub(crate) fn col(name: &str, schema: &Schema) -> Result> { + Ok(Arc::new(Column::new_with_schema(name, schema)?)) } diff --git a/datafusion/datasource/src/write/demux.rs b/datafusion/datasource/src/write/demux.rs index 49c3a64d24aa..75fb557b63d2 100644 --- a/datafusion/datasource/src/write/demux.rs +++ b/datafusion/datasource/src/write/demux.rs @@ -45,7 +45,7 @@ use datafusion_execution::TaskContext; use chrono::NaiveDate; use futures::StreamExt; use object_store::path::Path; -use rand::distributions::DistString; +use rand::distr::SampleString; use tokio::sync::mpsc::{self, Receiver, Sender, UnboundedReceiver, UnboundedSender}; type RecordBatchReceiver = Receiver; @@ -151,8 +151,7 @@ async fn row_count_demuxer( let max_buffered_batches = exec_options.max_buffered_batches_per_output_file; let minimum_parallel_files = exec_options.minimum_parallel_output_files; let mut part_idx = 0; - let write_id = - rand::distributions::Alphanumeric.sample_string(&mut rand::thread_rng(), 16); + let write_id = rand::distr::Alphanumeric.sample_string(&mut rand::rng(), 16); let mut open_file_streams = Vec::with_capacity(minimum_parallel_files); @@ -225,7 +224,7 @@ fn generate_file_path( if !single_file_output { base_output_path .prefix() - .child(format!("{}_{}.{}", write_id, part_idx, file_extension)) + .child(format!("{write_id}_{part_idx}.{file_extension}")) } else { base_output_path.prefix().to_owned() } @@ -267,8 +266,7 @@ async fn hive_style_partitions_demuxer( file_extension: String, keep_partition_by_columns: bool, ) -> Result<()> { - let write_id = - rand::distributions::Alphanumeric.sample_string(&mut rand::thread_rng(), 16); + let write_id = rand::distr::Alphanumeric.sample_string(&mut rand::rng(), 16); let exec_options = &context.session_config().options().execution; let max_buffered_recordbatches = exec_options.max_buffered_batches_per_output_file; @@ -513,7 +511,7 @@ fn compute_take_arrays( for vals in all_partition_values.iter() { part_key.push(vals[i].clone().into()); } - let builder = take_map.entry(part_key).or_insert(UInt64Builder::new()); + let builder = take_map.entry(part_key).or_insert_with(UInt64Builder::new); builder.append_value(i as u64); } take_map @@ -556,5 +554,5 @@ fn compute_hive_style_file_path( file_path = file_path.child(format!("{}={}", partition_by[j].0, part_key[j])); } - file_path.child(format!("{}.{}", write_id, file_extension)) + file_path.child(format!("{write_id}.{file_extension}")) } diff --git a/datafusion/datasource/src/write/mod.rs b/datafusion/datasource/src/write/mod.rs index f581126095a7..3694568682a5 100644 --- a/datafusion/datasource/src/write/mod.rs +++ b/datafusion/datasource/src/write/mod.rs @@ -77,15 +77,18 @@ pub trait BatchSerializer: Sync + Send { /// Returns an [`AsyncWrite`] which writes to the given object store location /// with the specified compression. +/// +/// The writer will have a default buffer size as chosen by [`BufWriter::new`]. +/// /// We drop the `AbortableWrite` struct and the writer will not try to cleanup on failure. /// Users can configure automatic cleanup with their cloud provider. +#[deprecated(since = "48.0.0", note = "Use ObjectWriterBuilder::new(...) instead")] pub async fn create_writer( file_compression_type: FileCompressionType, location: &Path, object_store: Arc, ) -> Result> { - let buf_writer = BufWriter::new(object_store, location.clone()); - file_compression_type.convert_async_writer(buf_writer) + ObjectWriterBuilder::new(file_compression_type, location, object_store).build() } /// Converts table schema to writer schema, which may differ in the case @@ -109,3 +112,108 @@ pub fn get_writer_schema(config: &FileSinkConfig) -> Arc { Arc::clone(config.output_schema()) } } + +/// A builder for an [`AsyncWrite`] that writes to an object store location. +/// +/// This can be used to specify file compression on the writer. The writer +/// will have a default buffer size unless altered. The specific default size +/// is chosen by [`BufWriter::new`]. +/// +/// We drop the `AbortableWrite` struct and the writer will not try to cleanup on failure. +/// Users can configure automatic cleanup with their cloud provider. +#[derive(Debug)] +pub struct ObjectWriterBuilder { + /// Compression type for object writer. + file_compression_type: FileCompressionType, + /// Output path + location: Path, + /// The related store that handles the given path + object_store: Arc, + /// The size of the buffer for the object writer. + buffer_size: Option, +} + +impl ObjectWriterBuilder { + /// Create a new [`ObjectWriterBuilder`] for the specified path and compression type. + pub fn new( + file_compression_type: FileCompressionType, + location: &Path, + object_store: Arc, + ) -> Self { + Self { + file_compression_type, + location: location.clone(), + object_store, + buffer_size: None, + } + } + + /// Set buffer size in bytes for object writer. + /// + /// # Example + /// ``` + /// # use datafusion_datasource::file_compression_type::FileCompressionType; + /// # use datafusion_datasource::write::ObjectWriterBuilder; + /// # use object_store::memory::InMemory; + /// # use object_store::path::Path; + /// # use std::sync::Arc; + /// # let compression_type = FileCompressionType::UNCOMPRESSED; + /// # let location = Path::from("/foo/bar"); + /// # let object_store = Arc::new(InMemory::new()); + /// let mut builder = ObjectWriterBuilder::new(compression_type, &location, object_store); + /// builder.set_buffer_size(Some(20 * 1024 * 1024)); //20 MiB + /// assert_eq!(builder.get_buffer_size(), Some(20 * 1024 * 1024), "Internal error: Builder buffer size doesn't match"); + /// ``` + pub fn set_buffer_size(&mut self, buffer_size: Option) { + self.buffer_size = buffer_size; + } + + /// Set buffer size in bytes for object writer, returning the builder. + /// + /// # Example + /// ``` + /// # use datafusion_datasource::file_compression_type::FileCompressionType; + /// # use datafusion_datasource::write::ObjectWriterBuilder; + /// # use object_store::memory::InMemory; + /// # use object_store::path::Path; + /// # use std::sync::Arc; + /// # let compression_type = FileCompressionType::UNCOMPRESSED; + /// # let location = Path::from("/foo/bar"); + /// # let object_store = Arc::new(InMemory::new()); + /// let builder = ObjectWriterBuilder::new(compression_type, &location, object_store) + /// .with_buffer_size(Some(20 * 1024 * 1024)); //20 MiB + /// assert_eq!(builder.get_buffer_size(), Some(20 * 1024 * 1024), "Internal error: Builder buffer size doesn't match"); + /// ``` + pub fn with_buffer_size(mut self, buffer_size: Option) -> Self { + self.buffer_size = buffer_size; + self + } + + /// Currently specified buffer size in bytes. + pub fn get_buffer_size(&self) -> Option { + self.buffer_size + } + + /// Return a writer object that writes to the object store location. + /// + /// If a buffer size has not been set, the default buffer buffer size will + /// be used. + /// + /// # Errors + /// If there is an error applying the compression type. + pub fn build(self) -> Result> { + let Self { + file_compression_type, + location, + object_store, + buffer_size, + } = self; + + let buf_writer = match buffer_size { + Some(size) => BufWriter::with_capacity(object_store, location, size), + None => BufWriter::new(object_store, location), + }; + + file_compression_type.convert_async_writer(buf_writer) + } +} diff --git a/datafusion/datasource/src/write/orchestration.rs b/datafusion/datasource/src/write/orchestration.rs index 0ac1d26c6cc1..a09509ac5862 100644 --- a/datafusion/datasource/src/write/orchestration.rs +++ b/datafusion/datasource/src/write/orchestration.rs @@ -22,7 +22,7 @@ use std::sync::Arc; use super::demux::DemuxedStreamReceiver; -use super::{create_writer, BatchSerializer}; +use super::{BatchSerializer, ObjectWriterBuilder}; use crate::file_compression_type::FileCompressionType; use datafusion_common::error::Result; @@ -257,7 +257,15 @@ pub async fn spawn_writer_tasks_and_join( }); while let Some((location, rb_stream)) = file_stream_rx.recv().await { let writer = - create_writer(compression, &location, Arc::clone(&object_store)).await?; + ObjectWriterBuilder::new(compression, &location, Arc::clone(&object_store)) + .with_buffer_size(Some( + context + .session_config() + .options() + .execution + .objectstore_writer_buffer_size, + )) + .build()?; if tx_file_bundle .send((rb_stream, Arc::clone(&serializer), writer)) diff --git a/datafusion/doc/src/lib.rs b/datafusion/doc/src/lib.rs index 68ed1e2352ca..f9b916c2b3ab 100644 --- a/datafusion/doc/src/lib.rs +++ b/datafusion/doc/src/lib.rs @@ -93,7 +93,7 @@ impl Documentation { self.doc_section.label, self.doc_section .description - .map(|s| format!(", description = \"{}\"", s)) + .map(|s| format!(", description = \"{s}\"")) .unwrap_or_default(), ) .as_ref(), @@ -110,7 +110,7 @@ impl Documentation { &self .sql_example .clone() - .map(|s| format!("\n sql_example = r#\"{}\"#,", s)) + .map(|s| format!("\n sql_example = r#\"{s}\"#,")) .unwrap_or_default(), ); @@ -120,7 +120,7 @@ impl Documentation { args.iter().for_each(|(name, value)| { if value.contains(st_arg_token) { if name.starts_with("The ") { - result.push_str(format!("\n standard_argument(\n name = \"{}\"),", name).as_ref()); + result.push_str(format!("\n standard_argument(\n name = \"{name}\"),").as_ref()); } else { result.push_str(format!("\n standard_argument(\n name = \"{}\",\n prefix = \"{}\"\n ),", name, value.replace(st_arg_token, "")).as_ref()); } @@ -132,7 +132,7 @@ impl Documentation { if let Some(args) = self.arguments.clone() { args.iter().for_each(|(name, value)| { if !value.contains(st_arg_token) { - result.push_str(format!("\n argument(\n name = \"{}\",\n description = \"{}\"\n ),", name, value).as_ref()); + result.push_str(format!("\n argument(\n name = \"{name}\",\n description = \"{value}\"\n ),").as_ref()); } }); } @@ -140,7 +140,7 @@ impl Documentation { if let Some(alt_syntax) = self.alternative_syntax.clone() { alt_syntax.iter().for_each(|syntax| { result.push_str( - format!("\n alternative_syntax = \"{}\",", syntax).as_ref(), + format!("\n alternative_syntax = \"{syntax}\",").as_ref(), ); }); } @@ -148,8 +148,7 @@ impl Documentation { // Related UDFs if let Some(related_udf) = self.related_udfs.clone() { related_udf.iter().for_each(|udf| { - result - .push_str(format!("\n related_udf(name = \"{}\"),", udf).as_ref()); + result.push_str(format!("\n related_udf(name = \"{udf}\"),").as_ref()); }); } diff --git a/datafusion/execution/Cargo.toml b/datafusion/execution/Cargo.toml index 20e507e98b68..5988d3a33660 100644 --- a/datafusion/execution/Cargo.toml +++ b/datafusion/execution/Cargo.toml @@ -52,3 +52,4 @@ url = { workspace = true } [dev-dependencies] chrono = { workspace = true } +insta = { workspace = true } diff --git a/datafusion/execution/src/config.rs b/datafusion/execution/src/config.rs index 53646dc5b468..1e00a1ce4725 100644 --- a/datafusion/execution/src/config.rs +++ b/datafusion/execution/src/config.rs @@ -193,9 +193,11 @@ impl SessionConfig { /// /// [`target_partitions`]: datafusion_common::config::ExecutionOptions::target_partitions pub fn with_target_partitions(mut self, n: usize) -> Self { - // partition count must be greater than zero - assert!(n > 0); - self.options.execution.target_partitions = n; + self.options.execution.target_partitions = if n == 0 { + datafusion_common::config::ExecutionOptions::default().target_partitions + } else { + n + }; self } diff --git a/datafusion/execution/src/disk_manager.rs b/datafusion/execution/src/disk_manager.rs index 2b21a6dbf175..1810601fd362 100644 --- a/datafusion/execution/src/disk_manager.rs +++ b/datafusion/execution/src/disk_manager.rs @@ -22,7 +22,7 @@ use datafusion_common::{ }; use log::debug; use parking_lot::Mutex; -use rand::{thread_rng, Rng}; +use rand::{rng, Rng}; use std::path::{Path, PathBuf}; use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; @@ -32,7 +32,95 @@ use crate::memory_pool::human_readable_size; const DEFAULT_MAX_TEMP_DIRECTORY_SIZE: u64 = 100 * 1024 * 1024 * 1024; // 100GB +/// Builder pattern for the [DiskManager] structure +#[derive(Clone, Debug)] +pub struct DiskManagerBuilder { + /// The storage mode of the disk manager + mode: DiskManagerMode, + /// The maximum amount of data (in bytes) stored inside the temporary directories. + /// Default to 100GB + max_temp_directory_size: u64, +} + +impl Default for DiskManagerBuilder { + fn default() -> Self { + Self { + mode: DiskManagerMode::OsTmpDirectory, + max_temp_directory_size: DEFAULT_MAX_TEMP_DIRECTORY_SIZE, + } + } +} + +impl DiskManagerBuilder { + pub fn set_mode(&mut self, mode: DiskManagerMode) { + self.mode = mode; + } + + pub fn with_mode(mut self, mode: DiskManagerMode) -> Self { + self.set_mode(mode); + self + } + + pub fn set_max_temp_directory_size(&mut self, value: u64) { + self.max_temp_directory_size = value; + } + + pub fn with_max_temp_directory_size(mut self, value: u64) -> Self { + self.set_max_temp_directory_size(value); + self + } + + /// Create a DiskManager given the builder + pub fn build(self) -> Result { + match self.mode { + DiskManagerMode::OsTmpDirectory => Ok(DiskManager { + local_dirs: Mutex::new(Some(vec![])), + max_temp_directory_size: self.max_temp_directory_size, + used_disk_space: Arc::new(AtomicU64::new(0)), + }), + DiskManagerMode::Directories(conf_dirs) => { + let local_dirs = create_local_dirs(conf_dirs)?; + debug!( + "Created local dirs {local_dirs:?} as DataFusion working directory" + ); + Ok(DiskManager { + local_dirs: Mutex::new(Some(local_dirs)), + max_temp_directory_size: self.max_temp_directory_size, + used_disk_space: Arc::new(AtomicU64::new(0)), + }) + } + DiskManagerMode::Disabled => Ok(DiskManager { + local_dirs: Mutex::new(None), + max_temp_directory_size: self.max_temp_directory_size, + used_disk_space: Arc::new(AtomicU64::new(0)), + }), + } + } +} + +#[derive(Clone, Debug)] +pub enum DiskManagerMode { + /// Create a new [DiskManager] that creates temporary files within + /// a temporary directory chosen by the OS + OsTmpDirectory, + + /// Create a new [DiskManager] that creates temporary files within + /// the specified directories. One of the directories will be chosen + /// at random for each temporary file created. + Directories(Vec), + + /// Disable disk manager, attempts to create temporary files will error + Disabled, +} + +impl Default for DiskManagerMode { + fn default() -> Self { + Self::OsTmpDirectory + } +} + /// Configuration for temporary disk access +#[deprecated(since = "48.0.0", note = "Use DiskManagerBuilder instead")] #[derive(Debug, Clone)] pub enum DiskManagerConfig { /// Use the provided [DiskManager] instance @@ -50,12 +138,14 @@ pub enum DiskManagerConfig { Disabled, } +#[allow(deprecated)] impl Default for DiskManagerConfig { fn default() -> Self { Self::NewOs } } +#[allow(deprecated)] impl DiskManagerConfig { /// Create temporary files in a temporary directory chosen by the OS pub fn new() -> Self { @@ -91,7 +181,14 @@ pub struct DiskManager { } impl DiskManager { + /// Creates a builder for [DiskManager] + pub fn builder() -> DiskManagerBuilder { + DiskManagerBuilder::default() + } + /// Create a DiskManager given the configuration + #[allow(deprecated)] + #[deprecated(since = "48.0.0", note = "Use DiskManager::builder() instead")] pub fn try_new(config: DiskManagerConfig) -> Result> { match config { DiskManagerConfig::Existing(manager) => Ok(manager), @@ -103,8 +200,7 @@ impl DiskManager { DiskManagerConfig::NewSpecified(conf_dirs) => { let local_dirs = create_local_dirs(conf_dirs)?; debug!( - "Created local dirs {:?} as DataFusion working directory", - local_dirs + "Created local dirs {local_dirs:?} as DataFusion working directory" ); Ok(Arc::new(Self { local_dirs: Mutex::new(Some(local_dirs)), @@ -120,10 +216,10 @@ impl DiskManager { } } - pub fn with_max_temp_directory_size( - mut self, + pub fn set_max_temp_directory_size( + &mut self, max_temp_directory_size: u64, - ) -> Result { + ) -> Result<()> { // If the disk manager is disabled and `max_temp_directory_size` is not 0, // this operation is not meaningful, fail early. if self.local_dirs.lock().is_none() && max_temp_directory_size != 0 { @@ -133,6 +229,26 @@ impl DiskManager { } self.max_temp_directory_size = max_temp_directory_size; + Ok(()) + } + + pub fn set_arc_max_temp_directory_size( + this: &mut Arc, + max_temp_directory_size: u64, + ) -> Result<()> { + if let Some(inner) = Arc::get_mut(this) { + inner.set_max_temp_directory_size(max_temp_directory_size)?; + Ok(()) + } else { + config_err!("DiskManager should be a single instance") + } + } + + pub fn with_max_temp_directory_size( + mut self, + max_temp_directory_size: u64, + ) -> Result { + self.set_max_temp_directory_size(max_temp_directory_size)?; Ok(self) } @@ -175,7 +291,7 @@ impl DiskManager { local_dirs.push(Arc::new(tempdir)); } - let dir_index = thread_rng().gen_range(0..local_dirs.len()); + let dir_index = rng().random_range(0..local_dirs.len()); Ok(RefCountedTempFile { _parent_temp_dir: Arc::clone(&local_dirs[dir_index]), tempfile: Builder::new() @@ -286,8 +402,7 @@ mod tests { #[test] fn lazy_temp_dir_creation() -> Result<()> { // A default configuration should not create temp files until requested - let config = DiskManagerConfig::new(); - let dm = DiskManager::try_new(config)?; + let dm = Arc::new(DiskManagerBuilder::default().build()?); assert_eq!(0, local_dir_snapshot(&dm).len()); @@ -319,11 +434,14 @@ mod tests { let local_dir2 = TempDir::new()?; let local_dir3 = TempDir::new()?; let local_dirs = vec![local_dir1.path(), local_dir2.path(), local_dir3.path()]; - let config = DiskManagerConfig::new_specified( - local_dirs.iter().map(|p| p.into()).collect(), + let dm = Arc::new( + DiskManagerBuilder::default() + .with_mode(DiskManagerMode::Directories( + local_dirs.iter().map(|p| p.into()).collect(), + )) + .build()?, ); - let dm = DiskManager::try_new(config)?; assert!(dm.tmp_files_enabled()); let actual = dm.create_tmp_file("Testing")?; @@ -335,8 +453,12 @@ mod tests { #[test] fn test_disabled_disk_manager() { - let config = DiskManagerConfig::Disabled; - let manager = DiskManager::try_new(config).unwrap(); + let manager = Arc::new( + DiskManagerBuilder::default() + .with_mode(DiskManagerMode::Disabled) + .build() + .unwrap(), + ); assert!(!manager.tmp_files_enabled()); assert_eq!( manager.create_tmp_file("Testing").unwrap_err().strip_backtrace(), @@ -347,11 +469,9 @@ mod tests { #[test] fn test_disk_manager_create_spill_folder() { let dir = TempDir::new().unwrap(); - let config = DiskManagerConfig::new_specified(vec![dir.path().to_owned()]); - - DiskManager::try_new(config) - .unwrap() - .create_tmp_file("Testing") + DiskManagerBuilder::default() + .with_mode(DiskManagerMode::Directories(vec![dir.path().to_path_buf()])) + .build() .unwrap(); } @@ -374,8 +494,7 @@ mod tests { #[test] fn test_temp_file_still_alive_after_disk_manager_dropped() -> Result<()> { // Test for the case using OS arranged temporary directory - let config = DiskManagerConfig::new(); - let dm = DiskManager::try_new(config)?; + let dm = Arc::new(DiskManagerBuilder::default().build()?); let temp_file = dm.create_tmp_file("Testing")?; let temp_file_path = temp_file.path().to_owned(); assert!(temp_file_path.exists()); @@ -391,10 +510,13 @@ mod tests { let local_dir2 = TempDir::new()?; let local_dir3 = TempDir::new()?; let local_dirs = [local_dir1.path(), local_dir2.path(), local_dir3.path()]; - let config = DiskManagerConfig::new_specified( - local_dirs.iter().map(|p| p.into()).collect(), + let dm = Arc::new( + DiskManagerBuilder::default() + .with_mode(DiskManagerMode::Directories( + local_dirs.iter().map(|p| p.into()).collect(), + )) + .build()?, ); - let dm = DiskManager::try_new(config)?; let temp_file = dm.create_tmp_file("Testing")?; let temp_file_path = temp_file.path().to_owned(); assert!(temp_file_path.exists()); diff --git a/datafusion/execution/src/memory_pool/mod.rs b/datafusion/execution/src/memory_pool/mod.rs index 625a779b3eea..19e509d263ea 100644 --- a/datafusion/execution/src/memory_pool/mod.rs +++ b/datafusion/execution/src/memory_pool/mod.rs @@ -141,6 +141,25 @@ pub trait MemoryPool: Send + Sync + std::fmt::Debug { /// Return the total amount of memory reserved fn reserved(&self) -> usize; + + /// Return the memory limit of the pool + /// + /// The default implementation of `MemoryPool::memory_limit` + /// will return `MemoryLimit::Unknown`. + /// If you are using your custom memory pool, but have the requirement to + /// know the memory usage limit of the pool, please implement this method + /// to return it(`Memory::Finite(limit)`). + fn memory_limit(&self) -> MemoryLimit { + MemoryLimit::Unknown + } +} + +/// Memory limit of `MemoryPool` +pub enum MemoryLimit { + Infinite, + /// Bounded memory limit in bytes. + Finite(usize), + Unknown, } /// A memory consumer is a named allocation traced by a particular diff --git a/datafusion/execution/src/memory_pool/pool.rs b/datafusion/execution/src/memory_pool/pool.rs index cd6863939d27..11467f69be1c 100644 --- a/datafusion/execution/src/memory_pool/pool.rs +++ b/datafusion/execution/src/memory_pool/pool.rs @@ -15,7 +15,9 @@ // specific language governing permissions and limitations // under the License. -use crate::memory_pool::{MemoryConsumer, MemoryPool, MemoryReservation}; +use crate::memory_pool::{ + human_readable_size, MemoryConsumer, MemoryLimit, MemoryPool, MemoryReservation, +}; use datafusion_common::HashMap; use datafusion_common::{resources_datafusion_err, DataFusionError, Result}; use log::debug; @@ -48,6 +50,10 @@ impl MemoryPool for UnboundedMemoryPool { fn reserved(&self) -> usize { self.used.load(Ordering::Relaxed) } + + fn memory_limit(&self) -> MemoryLimit { + MemoryLimit::Infinite + } } /// A [`MemoryPool`] that implements a greedy first-come first-serve limit. @@ -100,6 +106,10 @@ impl MemoryPool for GreedyMemoryPool { fn reserved(&self) -> usize { self.used.load(Ordering::Relaxed) } + + fn memory_limit(&self) -> MemoryLimit { + MemoryLimit::Finite(self.pool_size) + } } /// A [`MemoryPool`] that prevents spillable reservations from using more than @@ -233,6 +243,10 @@ impl MemoryPool for FairSpillPool { let state = self.state.lock(); state.spillable + state.unspillable } + + fn memory_limit(&self) -> MemoryLimit { + MemoryLimit::Finite(self.pool_size) + } } /// Constructs a resources error based upon the individual [`MemoryReservation`]. @@ -246,7 +260,8 @@ fn insufficient_capacity_err( additional: usize, available: usize, ) -> DataFusionError { - resources_datafusion_err!("Failed to allocate additional {} bytes for {} with {} bytes already allocated for this reservation - {} bytes remain available for the total pool", additional, reservation.registration.consumer.name, reservation.size, available) + resources_datafusion_err!("Failed to allocate additional {} for {} with {} already allocated for this reservation - {} remain available for the total pool", + human_readable_size(additional), reservation.registration.consumer.name, human_readable_size(reservation.size), human_readable_size(available)) } #[derive(Debug)] @@ -328,10 +343,14 @@ impl TrackConsumersPool { consumers[0..std::cmp::min(top, consumers.len())] .iter() .map(|((id, name, can_spill), size)| { - format!("{name}#{id}(can spill: {can_spill}) consumed {size} bytes") + format!( + " {name}#{id}(can spill: {can_spill}) consumed {}", + human_readable_size(*size) + ) }) .collect::>() - .join(", ") + .join(",\n") + + "." } } @@ -408,20 +427,34 @@ impl MemoryPool for TrackConsumersPool { fn reserved(&self) -> usize { self.inner.reserved() } + + fn memory_limit(&self) -> MemoryLimit { + self.inner.memory_limit() + } } fn provide_top_memory_consumers_to_error_msg( error_msg: String, top_consumers: String, ) -> String { - format!("Additional allocation failed with top memory consumers (across reservations) as: {}. Error: {}", top_consumers, error_msg) + format!("Additional allocation failed with top memory consumers (across reservations) as:\n{top_consumers}\nError: {error_msg}") } #[cfg(test)] mod tests { use super::*; + use insta::{allow_duplicates, assert_snapshot, Settings}; use std::sync::Arc; + fn make_settings() -> Settings { + let mut settings = Settings::clone_current(); + settings.add_filter( + r"([^\s]+)\#\d+\(can spill: (true|false)\)", + "$1#[ID](can spill: $2)", + ); + settings + } + #[test] fn test_fair() { let pool = Arc::new(FairSpillPool::new(100)) as _; @@ -440,10 +473,10 @@ mod tests { assert_eq!(pool.reserved(), 4000); let err = r2.try_grow(1).unwrap_err().strip_backtrace(); - assert_eq!(err, "Resources exhausted: Failed to allocate additional 1 bytes for r2 with 2000 bytes already allocated for this reservation - 0 bytes remain available for the total pool"); + assert_snapshot!(err, @"Resources exhausted: Failed to allocate additional 1.0 B for r2 with 2000.0 B already allocated for this reservation - 0.0 B remain available for the total pool"); let err = r2.try_grow(1).unwrap_err().strip_backtrace(); - assert_eq!(err, "Resources exhausted: Failed to allocate additional 1 bytes for r2 with 2000 bytes already allocated for this reservation - 0 bytes remain available for the total pool"); + assert_snapshot!(err, @"Resources exhausted: Failed to allocate additional 1.0 B for r2 with 2000.0 B already allocated for this reservation - 0.0 B remain available for the total pool"); r1.shrink(1990); r2.shrink(2000); @@ -468,12 +501,12 @@ mod tests { .register(&pool); let err = r3.try_grow(70).unwrap_err().strip_backtrace(); - assert_eq!(err, "Resources exhausted: Failed to allocate additional 70 bytes for r3 with 0 bytes already allocated for this reservation - 40 bytes remain available for the total pool"); + assert_snapshot!(err, @"Resources exhausted: Failed to allocate additional 70.0 B for r3 with 0.0 B already allocated for this reservation - 40.0 B remain available for the total pool"); //Shrinking r2 to zero doesn't allow a3 to allocate more than 45 r2.free(); let err = r3.try_grow(70).unwrap_err().strip_backtrace(); - assert_eq!(err, "Resources exhausted: Failed to allocate additional 70 bytes for r3 with 0 bytes already allocated for this reservation - 40 bytes remain available for the total pool"); + assert_snapshot!(err, @"Resources exhausted: Failed to allocate additional 70.0 B for r3 with 0.0 B already allocated for this reservation - 40.0 B remain available for the total pool"); // But dropping r2 does drop(r2); @@ -486,11 +519,13 @@ mod tests { let mut r4 = MemoryConsumer::new("s4").register(&pool); let err = r4.try_grow(30).unwrap_err().strip_backtrace(); - assert_eq!(err, "Resources exhausted: Failed to allocate additional 30 bytes for s4 with 0 bytes already allocated for this reservation - 20 bytes remain available for the total pool"); + assert_snapshot!(err, @"Resources exhausted: Failed to allocate additional 30.0 B for s4 with 0.0 B already allocated for this reservation - 20.0 B remain available for the total pool"); } #[test] fn test_tracked_consumers_pool() { + let setting = make_settings(); + let _bound = setting.bind_to_scope(); let pool: Arc = Arc::new(TrackConsumersPool::new( GreedyMemoryPool::new(100), NonZeroUsize::new(3).unwrap(), @@ -523,20 +558,22 @@ mod tests { // Test: reports if new reservation causes error // using the previously set sizes for other consumers let mut r5 = MemoryConsumer::new("r5").register(&pool); - let expected = format!("Additional allocation failed with top memory consumers (across reservations) as: r1#{}(can spill: false) consumed 50 bytes, r3#{}(can spill: false) consumed 20 bytes, r2#{}(can spill: false) consumed 15 bytes. Error: Failed to allocate additional 150 bytes for r5 with 0 bytes already allocated for this reservation - 5 bytes remain available for the total pool", r1.consumer().id(), r3.consumer().id(), r2.consumer().id()); let res = r5.try_grow(150); - assert!( - matches!( - &res, - Err(DataFusionError::ResourcesExhausted(ref e)) if e.to_string().contains(&expected) - ), - "should provide list of top memory consumers, instead found {:?}", - res - ); + assert!(res.is_err()); + let error = res.unwrap_err().strip_backtrace(); + assert_snapshot!(error, @r" + Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: + r1#[ID](can spill: false) consumed 50.0 B, + r3#[ID](can spill: false) consumed 20.0 B, + r2#[ID](can spill: false) consumed 15.0 B. + Error: Failed to allocate additional 150.0 B for r5 with 0.0 B already allocated for this reservation - 5.0 B remain available for the total pool + "); } #[test] fn test_tracked_consumers_pool_register() { + let setting = make_settings(); + let _bound = setting.bind_to_scope(); let pool: Arc = Arc::new(TrackConsumersPool::new( GreedyMemoryPool::new(100), NonZeroUsize::new(3).unwrap(), @@ -546,15 +583,14 @@ mod tests { // Test: see error message when no consumers recorded yet let mut r0 = MemoryConsumer::new(same_name).register(&pool); - let expected = format!("Additional allocation failed with top memory consumers (across reservations) as: foo#{}(can spill: false) consumed 0 bytes. Error: Failed to allocate additional 150 bytes for foo with 0 bytes already allocated for this reservation - 100 bytes remain available for the total pool", r0.consumer().id()); let res = r0.try_grow(150); - assert!( - matches!( - &res, - Err(DataFusionError::ResourcesExhausted(ref e)) if e.to_string().contains(&expected) - ), - "should provide proper error when no reservations have been made yet, instead found {:?}", res - ); + assert!(res.is_err()); + let error = res.unwrap_err().strip_backtrace(); + assert_snapshot!(error, @r" + Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: + foo#[ID](can spill: false) consumed 0.0 B. + Error: Failed to allocate additional 150.0 B for foo with 0.0 B already allocated for this reservation - 100.0 B remain available for the total pool + "); // API: multiple registrations using the same hashed consumer, // will be recognized *differently* in the TrackConsumersPool. @@ -564,102 +600,101 @@ mod tests { let mut r1 = new_consumer_same_name.register(&pool); // TODO: the insufficient_capacity_err() message is per reservation, not per consumer. // a followup PR will clarify this message "0 bytes already allocated for this reservation" - let expected = format!("Additional allocation failed with top memory consumers (across reservations) as: foo#{}(can spill: false) consumed 10 bytes, foo#{}(can spill: false) consumed 0 bytes. Error: Failed to allocate additional 150 bytes for foo with 0 bytes already allocated for this reservation - 90 bytes remain available for the total pool", r0.consumer().id(), r1.consumer().id()); let res = r1.try_grow(150); - assert!( - matches!( - &res, - Err(DataFusionError::ResourcesExhausted(ref e)) if e.to_string().contains(&expected) - ), - "should provide proper error for 2 consumers, instead found {:?}", - res - ); + assert!(res.is_err()); + let error = res.unwrap_err().strip_backtrace(); + assert_snapshot!(error, @r" + Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: + foo#[ID](can spill: false) consumed 10.0 B, + foo#[ID](can spill: false) consumed 0.0 B. + Error: Failed to allocate additional 150.0 B for foo with 0.0 B already allocated for this reservation - 90.0 B remain available for the total pool + "); // Test: will accumulate size changes per consumer, not per reservation r1.grow(20); - let expected = format!("Additional allocation failed with top memory consumers (across reservations) as: foo#{}(can spill: false) consumed 20 bytes, foo#{}(can spill: false) consumed 10 bytes. Error: Failed to allocate additional 150 bytes for foo with 20 bytes already allocated for this reservation - 70 bytes remain available for the total pool", r1.consumer().id(), r0.consumer().id()); + let res = r1.try_grow(150); - assert!( - matches!( - &res, - Err(DataFusionError::ResourcesExhausted(ref e)) if e.to_string().contains(&expected) - ), - "should provide proper error for 2 consumers(one foo=20 bytes, another foo=10 bytes, available=70), instead found {:?}", res - ); + assert!(res.is_err()); + let error = res.unwrap_err().strip_backtrace(); + assert_snapshot!(error, @r" + Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: + foo#[ID](can spill: false) consumed 20.0 B, + foo#[ID](can spill: false) consumed 10.0 B. + Error: Failed to allocate additional 150.0 B for foo with 20.0 B already allocated for this reservation - 70.0 B remain available for the total pool + "); // Test: different hashed consumer, (even with the same name), // will be recognized as different in the TrackConsumersPool let consumer_with_same_name_but_different_hash = MemoryConsumer::new(same_name).with_can_spill(true); let mut r2 = consumer_with_same_name_but_different_hash.register(&pool); - let expected = format!("Additional allocation failed with top memory consumers (across reservations) as: foo#{}(can spill: false) consumed 20 bytes, foo#{}(can spill: false) consumed 10 bytes, foo#{}(can spill: true) consumed 0 bytes. Error: Failed to allocate additional 150 bytes for foo with 0 bytes already allocated for this reservation - 70 bytes remain available for the total pool", r1.consumer().id(), r0.consumer().id(), r2.consumer().id()); let res = r2.try_grow(150); - assert!( - matches!( - &res, - Err(DataFusionError::ResourcesExhausted(ref e)) if e.to_string().contains(&expected) - ), - "should provide proper error with 3 separate consumers(1 = 20 bytes, 2 = 10 bytes, 3 = 0 bytes), instead found {:?}", res - ); + assert!(res.is_err()); + let error = res.unwrap_err().strip_backtrace(); + assert_snapshot!(error, @r" + Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: + foo#[ID](can spill: false) consumed 20.0 B, + foo#[ID](can spill: false) consumed 10.0 B, + foo#[ID](can spill: true) consumed 0.0 B. + Error: Failed to allocate additional 150.0 B for foo with 0.0 B already allocated for this reservation - 70.0 B remain available for the total pool + "); } #[test] fn test_tracked_consumers_pool_deregister() { fn test_per_pool_type(pool: Arc) { // Baseline: see the 2 memory consumers + let setting = make_settings(); + let _bound = setting.bind_to_scope(); let mut r0 = MemoryConsumer::new("r0").register(&pool); r0.grow(10); let r1_consumer = MemoryConsumer::new("r1"); let mut r1 = r1_consumer.register(&pool); r1.grow(20); - let expected = format!("Additional allocation failed with top memory consumers (across reservations) as: r1#{}(can spill: false) consumed 20 bytes, r0#{}(can spill: false) consumed 10 bytes. Error: Failed to allocate additional 150 bytes for r0 with 10 bytes already allocated for this reservation - 70 bytes remain available for the total pool", r1.consumer().id(), r0.consumer().id()); let res = r0.try_grow(150); - assert!( - matches!( - &res, - Err(DataFusionError::ResourcesExhausted(ref e)) if e.to_string().contains(&expected) - ), - "should provide proper error with both consumers, instead found {:?}", - res - ); + assert!(res.is_err()); + let error = res.unwrap_err().strip_backtrace(); + allow_duplicates!(assert_snapshot!(error, @r" + Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: + r1#[ID](can spill: false) consumed 20.0 B, + r0#[ID](can spill: false) consumed 10.0 B. + Error: Failed to allocate additional 150.0 B for r0 with 10.0 B already allocated for this reservation - 70.0 B remain available for the total pool + ")); // Test: unregister one // only the remaining one should be listed drop(r1); - let expected_consumers = format!("Additional allocation failed with top memory consumers (across reservations) as: r0#{}(can spill: false) consumed 10 bytes", r0.consumer().id()); let res = r0.try_grow(150); - assert!( - matches!( - &res, - Err(DataFusionError::ResourcesExhausted(ref e)) if e.to_string().contains(&expected_consumers) - ), - "should provide proper error with only 1 consumer left registered, instead found {:?}", res - ); + assert!(res.is_err()); + let error = res.unwrap_err().strip_backtrace(); + allow_duplicates!(assert_snapshot!(error, @r" + Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: + r0#[ID](can spill: false) consumed 10.0 B. + Error: Failed to allocate additional 150.0 B for r0 with 10.0 B already allocated for this reservation - 90.0 B remain available for the total pool + ")); // Test: actual message we see is the `available is 70`. When it should be `available is 90`. // This is because the pool.shrink() does not automatically occur within the inner_pool.deregister(). - let expected_90_available = "Failed to allocate additional 150 bytes for r0 with 10 bytes already allocated for this reservation - 90 bytes remain available for the total pool"; let res = r0.try_grow(150); - assert!( - matches!( - &res, - Err(DataFusionError::ResourcesExhausted(ref e)) if e.to_string().contains(expected_90_available) - ), - "should find that the inner pool will still count all bytes for the deregistered consumer until the reservation is dropped, instead found {:?}", res - ); + assert!(res.is_err()); + let error = res.unwrap_err().strip_backtrace(); + allow_duplicates!(assert_snapshot!(error, @r" + Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: + r0#[ID](can spill: false) consumed 10.0 B. + Error: Failed to allocate additional 150.0 B for r0 with 10.0 B already allocated for this reservation - 90.0 B remain available for the total pool + ")); // Test: the registration needs to free itself (or be dropped), // for the proper error message - let expected_90_available = "Failed to allocate additional 150 bytes for r0 with 10 bytes already allocated for this reservation - 90 bytes remain available for the total pool"; let res = r0.try_grow(150); - assert!( - matches!( - &res, - Err(DataFusionError::ResourcesExhausted(ref e)) if e.to_string().contains(expected_90_available) - ), - "should correctly account the total bytes after reservation is free, instead found {:?}", res - ); + assert!(res.is_err()); + let error = res.unwrap_err().strip_backtrace(); + allow_duplicates!(assert_snapshot!(error, @r" + Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: + r0#[ID](can spill: false) consumed 10.0 B. + Error: Failed to allocate additional 150.0 B for r0 with 10.0 B already allocated for this reservation - 90.0 B remain available for the total pool + ")); } let tracked_spill_pool: Arc = Arc::new(TrackConsumersPool::new( @@ -677,6 +712,8 @@ mod tests { #[test] fn test_tracked_consumers_pool_use_beyond_errors() { + let setting = make_settings(); + let _bound = setting.bind_to_scope(); let upcasted: Arc = Arc::new(TrackConsumersPool::new( GreedyMemoryPool::new(100), @@ -700,12 +737,10 @@ mod tests { .unwrap(); // Test: can get runtime metrics, even without an error thrown - let expected = format!("r3#{}(can spill: false) consumed 45 bytes, r1#{}(can spill: false) consumed 20 bytes", r3.consumer().id(), r1.consumer().id()); let res = downcasted.report_top(2); - assert_eq!( - res, expected, - "should provide list of top memory consumers, instead found {:?}", - res - ); + assert_snapshot!(res, @r" + r3#[ID](can spill: false) consumed 45.0 B, + r1#[ID](can spill: false) consumed 20.0 B. + "); } } diff --git a/datafusion/execution/src/runtime_env.rs b/datafusion/execution/src/runtime_env.rs index 95f14f485792..b086430a4ef7 100644 --- a/datafusion/execution/src/runtime_env.rs +++ b/datafusion/execution/src/runtime_env.rs @@ -18,8 +18,10 @@ //! Execution [`RuntimeEnv`] environment that manages access to object //! store, memory manager, disk manager. +#[allow(deprecated)] +use crate::disk_manager::DiskManagerConfig; use crate::{ - disk_manager::{DiskManager, DiskManagerConfig}, + disk_manager::{DiskManager, DiskManagerBuilder, DiskManagerMode}, memory_pool::{ GreedyMemoryPool, MemoryPool, TrackConsumersPool, UnboundedMemoryPool, }, @@ -27,7 +29,7 @@ use crate::{ }; use crate::cache::cache_manager::{CacheManager, CacheManagerConfig}; -use datafusion_common::Result; +use datafusion_common::{config::ConfigEntry, Result}; use object_store::ObjectStore; use std::path::PathBuf; use std::sync::Arc; @@ -170,8 +172,11 @@ pub type RuntimeConfig = RuntimeEnvBuilder; /// /// See example on [`RuntimeEnv`] pub struct RuntimeEnvBuilder { + #[allow(deprecated)] /// DiskManager to manage temporary disk file usage pub disk_manager: DiskManagerConfig, + /// DiskManager builder to manager temporary disk file usage + pub disk_manager_builder: Option, /// [`MemoryPool`] from which to allocate memory /// /// Defaults to using an [`UnboundedMemoryPool`] if `None` @@ -193,18 +198,27 @@ impl RuntimeEnvBuilder { pub fn new() -> Self { Self { disk_manager: Default::default(), + disk_manager_builder: Default::default(), memory_pool: Default::default(), cache_manager: Default::default(), object_store_registry: Arc::new(DefaultObjectStoreRegistry::default()), } } + #[allow(deprecated)] + #[deprecated(since = "48.0.0", note = "Use with_disk_manager_builder instead")] /// Customize disk manager pub fn with_disk_manager(mut self, disk_manager: DiskManagerConfig) -> Self { self.disk_manager = disk_manager; self } + /// Customize the disk manager builder + pub fn with_disk_manager_builder(mut self, disk_manager: DiskManagerBuilder) -> Self { + self.disk_manager_builder = Some(disk_manager); + self + } + /// Customize memory policy pub fn with_memory_pool(mut self, memory_pool: Arc) -> Self { self.memory_pool = Some(memory_pool); @@ -242,13 +256,17 @@ impl RuntimeEnvBuilder { /// Use the specified path to create any needed temporary files pub fn with_temp_file_path(self, path: impl Into) -> Self { - self.with_disk_manager(DiskManagerConfig::new_specified(vec![path.into()])) + self.with_disk_manager_builder( + DiskManagerBuilder::default() + .with_mode(DiskManagerMode::Directories(vec![path.into()])), + ) } /// Build a RuntimeEnv pub fn build(self) -> Result { let Self { disk_manager, + disk_manager_builder, memory_pool, cache_manager, object_store_registry, @@ -258,7 +276,12 @@ impl RuntimeEnvBuilder { Ok(RuntimeEnv { memory_pool, - disk_manager: DiskManager::try_new(disk_manager)?, + disk_manager: if let Some(builder) = disk_manager_builder { + Arc::new(builder.build()?) + } else { + #[allow(deprecated)] + DiskManager::try_new(disk_manager)? + }, cache_manager: CacheManager::try_new(&cache_manager)?, object_store_registry, }) @@ -268,4 +291,58 @@ impl RuntimeEnvBuilder { pub fn build_arc(self) -> Result> { self.build().map(Arc::new) } + + /// Create a new RuntimeEnvBuilder from an existing RuntimeEnv + pub fn from_runtime_env(runtime_env: &RuntimeEnv) -> Self { + let cache_config = CacheManagerConfig { + table_files_statistics_cache: runtime_env + .cache_manager + .get_file_statistic_cache(), + list_files_cache: runtime_env.cache_manager.get_list_files_cache(), + }; + + Self { + #[allow(deprecated)] + disk_manager: DiskManagerConfig::Existing(Arc::clone( + &runtime_env.disk_manager, + )), + disk_manager_builder: None, + memory_pool: Some(Arc::clone(&runtime_env.memory_pool)), + cache_manager: cache_config, + object_store_registry: Arc::clone(&runtime_env.object_store_registry), + } + } + + /// Returns a list of all available runtime configurations with their current values and descriptions + pub fn entries(&self) -> Vec { + // Memory pool configuration + vec![ConfigEntry { + key: "datafusion.runtime.memory_limit".to_string(), + value: None, // Default is system-dependent + description: "Maximum memory limit for query execution. Supports suffixes K (kilobytes), M (megabytes), and G (gigabytes). Example: '2G' for 2 gigabytes.", + }] + } + + /// Generate documentation that can be included in the user guide + pub fn generate_config_markdown() -> String { + use std::fmt::Write as _; + + let s = Self::default(); + + let mut docs = "| key | default | description |\n".to_string(); + docs += "|-----|---------|-------------|\n"; + let mut entries = s.entries(); + entries.sort_unstable_by(|a, b| a.key.cmp(&b.key)); + + for entry in &entries { + let _ = writeln!( + &mut docs, + "| {} | {} | {} |", + entry.key, + entry.value.as_deref().unwrap_or("NULL"), + entry.description + ); + } + docs + } } diff --git a/datafusion/expr-common/src/accumulator.rs b/datafusion/expr-common/src/accumulator.rs index 3a63c3289481..2829a9416f03 100644 --- a/datafusion/expr-common/src/accumulator.rs +++ b/datafusion/expr-common/src/accumulator.rs @@ -42,7 +42,6 @@ use std::fmt::Debug; /// [`state`] and combine the state from multiple accumulators /// via [`merge_batch`], as part of efficient multi-phase grouping. /// -/// [`GroupsAccumulator`]: crate::GroupsAccumulator /// [`update_batch`]: Self::update_batch /// [`retract_batch`]: Self::retract_batch /// [`state`]: Self::state diff --git a/datafusion/expr-common/src/columnar_value.rs b/datafusion/expr-common/src/columnar_value.rs index cb7cbdbac291..a21ad5bbbcc3 100644 --- a/datafusion/expr-common/src/columnar_value.rs +++ b/datafusion/expr-common/src/columnar_value.rs @@ -237,7 +237,7 @@ impl fmt::Display for ColumnarValue { }; if let Ok(formatted) = formatted { - write!(f, "{}", formatted) + write!(f, "{formatted}") } else { write!(f, "Error formatting columnar value") } diff --git a/datafusion/expr-common/src/groups_accumulator.rs b/datafusion/expr-common/src/groups_accumulator.rs index 5ff1c1d07216..9bcc1edff882 100644 --- a/datafusion/expr-common/src/groups_accumulator.rs +++ b/datafusion/expr-common/src/groups_accumulator.rs @@ -21,7 +21,7 @@ use arrow::array::{ArrayRef, BooleanArray}; use datafusion_common::{not_impl_err, Result}; /// Describes how many rows should be emitted during grouping. -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum EmitTo { /// Emit all groups All, diff --git a/datafusion/expr-common/src/interval_arithmetic.rs b/datafusion/expr-common/src/interval_arithmetic.rs index 6af4322df29e..d656c676bd01 100644 --- a/datafusion/expr-common/src/interval_arithmetic.rs +++ b/datafusion/expr-common/src/interval_arithmetic.rs @@ -606,7 +606,7 @@ impl Interval { upper: ScalarValue::Boolean(Some(upper)), }) } - _ => internal_err!("Incompatible data types for logical conjunction"), + _ => internal_err!("Incompatible data types for logical disjunction"), } } @@ -949,6 +949,18 @@ impl Display for Interval { } } +impl From for Interval { + fn from(value: ScalarValue) -> Self { + Self::new(value.clone(), value) + } +} + +impl From<&ScalarValue> for Interval { + fn from(value: &ScalarValue) -> Self { + Self::new(value.to_owned(), value.to_owned()) + } +} + /// Applies the given binary operator the `lhs` and `rhs` arguments. pub fn apply_operator(op: &Operator, lhs: &Interval, rhs: &Interval) -> Result { match *op { @@ -959,6 +971,7 @@ pub fn apply_operator(op: &Operator, lhs: &Interval, rhs: &Interval) -> Result lhs.lt(rhs), Operator::LtEq => lhs.lt_eq(rhs), Operator::And => lhs.and(rhs), + Operator::Or => lhs.or(rhs), Operator::Plus => lhs.add(rhs), Operator::Minus => lhs.sub(rhs), Operator::Multiply => lhs.mul(rhs), @@ -1683,9 +1696,9 @@ impl Display for NullableInterval { match self { Self::Null { .. } => write!(f, "NullableInterval: {{NULL}}"), Self::MaybeNull { values } => { - write!(f, "NullableInterval: {} U {{NULL}}", values) + write!(f, "NullableInterval: {values} U {{NULL}}") } - Self::NotNull { values } => write!(f, "NullableInterval: {}", values), + Self::NotNull { values } => write!(f, "NullableInterval: {values}"), } } } @@ -2706,8 +2719,8 @@ mod tests { ), ]; for (first, second, expected) in possible_cases { - println!("{}", first); - println!("{}", second); + println!("{first}"); + println!("{second}"); assert_eq!(first.union(second)?, expected) } @@ -3704,14 +3717,14 @@ mod tests { #[test] fn test_interval_display() { let interval = Interval::make(Some(0.25_f32), Some(0.50_f32)).unwrap(); - assert_eq!(format!("{}", interval), "[0.25, 0.5]"); + assert_eq!(format!("{interval}"), "[0.25, 0.5]"); let interval = Interval::try_new( ScalarValue::Float32(Some(f32::NEG_INFINITY)), ScalarValue::Float32(Some(f32::INFINITY)), ) .unwrap(); - assert_eq!(format!("{}", interval), "[NULL, NULL]"); + assert_eq!(format!("{interval}"), "[NULL, NULL]"); } macro_rules! capture_mode_change { diff --git a/datafusion/expr-common/src/signature.rs b/datafusion/expr-common/src/signature.rs index a7c9330201bc..5e1705d8ff61 100644 --- a/datafusion/expr-common/src/signature.rs +++ b/datafusion/expr-common/src/signature.rs @@ -843,6 +843,7 @@ impl Signature { volatility, } } + /// Any one of a list of [TypeSignature]s. pub fn one_of(type_signatures: Vec, volatility: Volatility) -> Self { Signature { @@ -850,7 +851,8 @@ impl Signature { volatility, } } - /// Specialized Signature for ArrayAppend and similar functions + + /// Specialized [Signature] for ArrayAppend and similar functions. pub fn array_and_element(volatility: Volatility) -> Self { Signature { type_signature: TypeSignature::ArraySignature( @@ -865,7 +867,41 @@ impl Signature { volatility, } } - /// Specialized Signature for Array functions with an optional index + + /// Specialized [Signature] for ArrayPrepend and similar functions. + pub fn element_and_array(volatility: Volatility) -> Self { + Signature { + type_signature: TypeSignature::ArraySignature( + ArrayFunctionSignature::Array { + arguments: vec![ + ArrayFunctionArgument::Element, + ArrayFunctionArgument::Array, + ], + array_coercion: Some(ListCoercion::FixedSizedListToList), + }, + ), + volatility, + } + } + + /// Specialized [Signature] for functions that take a fixed number of arrays. + pub fn arrays( + n: usize, + coercion: Option, + volatility: Volatility, + ) -> Self { + Signature { + type_signature: TypeSignature::ArraySignature( + ArrayFunctionSignature::Array { + arguments: vec![ArrayFunctionArgument::Array; n], + array_coercion: coercion, + }, + ), + volatility, + } + } + + /// Specialized [Signature] for Array functions with an optional index. pub fn array_and_element_and_optional_index(volatility: Volatility) -> Self { Signature { type_signature: TypeSignature::OneOf(vec![ @@ -889,7 +925,7 @@ impl Signature { } } - /// Specialized Signature for ArrayElement and similar functions + /// Specialized [Signature] for ArrayElement and similar functions. pub fn array_and_index(volatility: Volatility) -> Self { Signature { type_signature: TypeSignature::ArraySignature( @@ -898,23 +934,16 @@ impl Signature { ArrayFunctionArgument::Array, ArrayFunctionArgument::Index, ], - array_coercion: None, + array_coercion: Some(ListCoercion::FixedSizedListToList), }, ), volatility, } } - /// Specialized Signature for ArrayEmpty and similar functions + + /// Specialized [Signature] for ArrayEmpty and similar functions. pub fn array(volatility: Volatility) -> Self { - Signature { - type_signature: TypeSignature::ArraySignature( - ArrayFunctionSignature::Array { - arguments: vec![ArrayFunctionArgument::Array], - array_coercion: None, - }, - ), - volatility, - } + Signature::arrays(1, Some(ListCoercion::FixedSizedListToList), volatility) } } @@ -940,8 +969,7 @@ mod tests { for case in positive_cases { assert!( case.supports_zero_argument(), - "Expected {:?} to support zero arguments", - case + "Expected {case:?} to support zero arguments" ); } @@ -960,8 +988,7 @@ mod tests { for case in negative_cases { assert!( !case.supports_zero_argument(), - "Expected {:?} not to support zero arguments", - case + "Expected {case:?} not to support zero arguments" ); } } diff --git a/datafusion/expr-common/src/statistics.rs b/datafusion/expr-common/src/statistics.rs index 7e0bc88087ef..14f2f331ef5b 100644 --- a/datafusion/expr-common/src/statistics.rs +++ b/datafusion/expr-common/src/statistics.rs @@ -1559,18 +1559,14 @@ mod tests { assert_eq!( new_generic_from_binary_op(&op, &dist_a, &dist_b)?.range()?, apply_operator(&op, a, b)?, - "Failed for {:?} {op} {:?}", - dist_a, - dist_b + "Failed for {dist_a:?} {op} {dist_b:?}" ); } for op in [Gt, GtEq, Lt, LtEq, Eq, NotEq] { assert_eq!( create_bernoulli_from_comparison(&op, &dist_a, &dist_b)?.range()?, apply_operator(&op, a, b)?, - "Failed for {:?} {op} {:?}", - dist_a, - dist_b + "Failed for {dist_a:?} {op} {dist_b:?}" ); } } diff --git a/datafusion/expr-common/src/type_coercion/aggregates.rs b/datafusion/expr-common/src/type_coercion/aggregates.rs index 13d52959aba6..e9377ce7de5a 100644 --- a/datafusion/expr-common/src/type_coercion/aggregates.rs +++ b/datafusion/expr-common/src/type_coercion/aggregates.rs @@ -17,7 +17,7 @@ use crate::signature::TypeSignature; use arrow::datatypes::{ - DataType, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, + DataType, FieldRef, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, }; @@ -82,48 +82,48 @@ pub static TIMES: &[DataType] = &[ DataType::Time64(TimeUnit::Nanosecond), ]; -/// Validate the length of `input_types` matches the `signature` for `agg_fun`. +/// Validate the length of `input_fields` matches the `signature` for `agg_fun`. /// -/// This method DOES NOT validate the argument types - only that (at least one, +/// This method DOES NOT validate the argument fields - only that (at least one, /// in the case of [`TypeSignature::OneOf`]) signature matches the desired /// number of input types. pub fn check_arg_count( func_name: &str, - input_types: &[DataType], + input_fields: &[FieldRef], signature: &TypeSignature, ) -> Result<()> { match signature { TypeSignature::Uniform(agg_count, _) | TypeSignature::Any(agg_count) => { - if input_types.len() != *agg_count { + if input_fields.len() != *agg_count { return plan_err!( "The function {func_name} expects {:?} arguments, but {:?} were provided", agg_count, - input_types.len() + input_fields.len() ); } } TypeSignature::Exact(types) => { - if types.len() != input_types.len() { + if types.len() != input_fields.len() { return plan_err!( "The function {func_name} expects {:?} arguments, but {:?} were provided", types.len(), - input_types.len() + input_fields.len() ); } } TypeSignature::OneOf(variants) => { let ok = variants .iter() - .any(|v| check_arg_count(func_name, input_types, v).is_ok()); + .any(|v| check_arg_count(func_name, input_fields, v).is_ok()); if !ok { return plan_err!( "The function {func_name} does not accept {:?} function arguments.", - input_types.len() + input_fields.len() ); } } TypeSignature::VariadicAny => { - if input_types.is_empty() { + if input_fields.is_empty() { return plan_err!( "The function {func_name} expects at least one argument" ); @@ -210,6 +210,7 @@ pub fn avg_return_type(func_name: &str, arg_type: &DataType) -> Result let new_scale = DECIMAL256_MAX_SCALE.min(*scale + 4); Ok(DataType::Decimal256(new_precision, new_scale)) } + DataType::Duration(time_unit) => Ok(DataType::Duration(*time_unit)), arg_type if NUMERICS.contains(arg_type) => Ok(DataType::Float64), DataType::Dictionary(_, dict_value_type) => { avg_return_type(func_name, dict_value_type.as_ref()) @@ -231,6 +232,7 @@ pub fn avg_sum_type(arg_type: &DataType) -> Result { let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10); Ok(DataType::Decimal256(new_precision, *scale)) } + DataType::Duration(time_unit) => Ok(DataType::Duration(*time_unit)), arg_type if NUMERICS.contains(arg_type) => Ok(DataType::Float64), DataType::Dictionary(_, dict_value_type) => { avg_sum_type(dict_value_type.as_ref()) @@ -298,6 +300,7 @@ pub fn coerce_avg_type(func_name: &str, arg_types: &[DataType]) -> Result Ok(DataType::Decimal128(*p, *s)), DataType::Decimal256(p, s) => Ok(DataType::Decimal256(*p, *s)), d if d.is_numeric() => Ok(DataType::Float64), + DataType::Duration(time_unit) => Ok(DataType::Duration(*time_unit)), DataType::Dictionary(_, v) => coerced_type(func_name, v.as_ref()), _ => { plan_err!( diff --git a/datafusion/expr-common/src/type_coercion/binary.rs b/datafusion/expr-common/src/type_coercion/binary.rs index fdc61cd665ef..d0fcda973381 100644 --- a/datafusion/expr-common/src/type_coercion/binary.rs +++ b/datafusion/expr-common/src/type_coercion/binary.rs @@ -931,6 +931,7 @@ fn coerce_numeric_type_to_decimal(numeric_type: &DataType) -> Option { Int32 | UInt32 => Some(Decimal128(10, 0)), Int64 | UInt64 => Some(Decimal128(20, 0)), // TODO if we convert the floating-point data to the decimal type, it maybe overflow. + Float16 => Some(Decimal128(6, 3)), Float32 => Some(Decimal128(14, 7)), Float64 => Some(Decimal128(30, 15)), _ => None, @@ -949,6 +950,7 @@ fn coerce_numeric_type_to_decimal256(numeric_type: &DataType) -> Option Some(Decimal256(10, 0)), Int64 | UInt64 => Some(Decimal256(20, 0)), // TODO if we convert the floating-point data to the decimal type, it maybe overflow. + Float16 => Some(Decimal256(6, 3)), Float32 => Some(Decimal256(14, 7)), Float64 => Some(Decimal256(30, 15)), _ => None, @@ -1044,6 +1046,7 @@ fn numerical_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option Some(Float64), (_, Float32) | (Float32, _) => Some(Float32), + (_, Float16) | (Float16, _) => Some(Float16), // The following match arms encode the following logic: Given the two // integral types, we choose the narrowest possible integral type that // accommodates all values of both types. Note that to avoid information @@ -1138,7 +1141,7 @@ fn dictionary_comparison_coercion( /// 2. Data type of the other side should be able to cast to string type fn string_concat_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { use arrow::datatypes::DataType::*; - string_coercion(lhs_type, rhs_type).or(match (lhs_type, rhs_type) { + string_coercion(lhs_type, rhs_type).or_else(|| match (lhs_type, rhs_type) { (Utf8View, from_type) | (from_type, Utf8View) => { string_concat_internal_coercion(from_type, &Utf8View) } @@ -1297,6 +1300,13 @@ fn binary_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option Some(LargeBinary) } (Binary, Utf8) | (Utf8, Binary) => Some(Binary), + + // Cast FixedSizeBinary to Binary + (FixedSizeBinary(_), Binary) | (Binary, FixedSizeBinary(_)) => Some(Binary), + (FixedSizeBinary(_), BinaryView) | (BinaryView, FixedSizeBinary(_)) => { + Some(BinaryView) + } + _ => None, } } @@ -1574,6 +1584,10 @@ mod tests { coerce_numeric_type_to_decimal(&DataType::Int64).unwrap(), DataType::Decimal128(20, 0) ); + assert_eq!( + coerce_numeric_type_to_decimal(&DataType::Float16).unwrap(), + DataType::Decimal128(6, 3) + ); assert_eq!( coerce_numeric_type_to_decimal(&DataType::Float32).unwrap(), DataType::Decimal128(14, 7) @@ -2052,6 +2066,13 @@ mod tests { Operator::Plus, Float32 ); + // (_, Float16) | (Float16, _) => Some(Float16), + test_coercion_binary_rule_multiple!( + Float16, + [Float16, Int64, UInt64, Int32, UInt32, Int16, UInt16, Int8, UInt8], + Operator::Plus, + Float16 + ); // (UInt64, Int64 | Int32 | Int16 | Int8) | (Int64 | Int32 | Int16 | Int8, UInt64) => Some(Decimal128(20, 0)), test_coercion_binary_rule_multiple!( UInt64, @@ -2190,6 +2211,18 @@ mod tests { DataType::Boolean ); // float + test_coercion_binary_rule!( + DataType::Float16, + DataType::Int64, + Operator::Eq, + DataType::Float16 + ); + test_coercion_binary_rule!( + DataType::Float16, + DataType::Float64, + Operator::Eq, + DataType::Float64 + ); test_coercion_binary_rule!( DataType::Float32, DataType::Int64, diff --git a/datafusion/expr/Cargo.toml b/datafusion/expr/Cargo.toml index 37e1ed1936fb..d77c59ff64e1 100644 --- a/datafusion/expr/Cargo.toml +++ b/datafusion/expr/Cargo.toml @@ -58,3 +58,4 @@ sqlparser = { workspace = true } [dev-dependencies] ctor = { workspace = true } env_logger = { workspace = true } +insta = { workspace = true } diff --git a/datafusion/expr/src/conditional_expressions.rs b/datafusion/expr/src/conditional_expressions.rs index 9cb51612d0ca..69525ea52137 100644 --- a/datafusion/expr/src/conditional_expressions.rs +++ b/datafusion/expr/src/conditional_expressions.rs @@ -72,7 +72,7 @@ impl CaseBuilder { let then_types: Vec = then_expr .iter() .map(|e| match e { - Expr::Literal(_) => e.get_type(&DFSchema::empty()), + Expr::Literal(_, _) => e.get_type(&DFSchema::empty()), _ => Ok(DataType::Null), }) .collect::>>()?; diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 35eed504c883..30932c210489 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -17,13 +17,15 @@ //! Logical Expressions: [`Expr`] -use std::collections::HashSet; +use std::cmp::Ordering; +use std::collections::{BTreeMap, HashSet}; use std::fmt::{self, Display, Formatter, Write}; use std::hash::{Hash, Hasher}; use std::mem; use std::sync::Arc; use crate::expr_fn::binary_expr; +use crate::function::WindowFunctionSimplification; use crate::logical_plan::Subquery; use crate::Volatility; use crate::{udaf, ExprSchemable, Operator, Signature, WindowFrame, WindowUDF}; @@ -50,7 +52,7 @@ use sqlparser::ast::{ /// BinaryExpr { /// left: Expr::Column("A"), /// op: Operator::Plus, -/// right: Expr::Literal(ScalarValue::Int32(Some(1))) +/// right: Expr::Literal(ScalarValue::Int32(Some(1)), None) /// } /// ``` /// @@ -112,10 +114,10 @@ use sqlparser::ast::{ /// # use datafusion_expr::{lit, col, Expr}; /// // All literals are strongly typed in DataFusion. To make an `i64` 42: /// let expr = lit(42i64); -/// assert_eq!(expr, Expr::Literal(ScalarValue::Int64(Some(42)))); -/// assert_eq!(expr, Expr::Literal(ScalarValue::Int64(Some(42)))); +/// assert_eq!(expr, Expr::Literal(ScalarValue::Int64(Some(42)), None)); +/// assert_eq!(expr, Expr::Literal(ScalarValue::Int64(Some(42)), None)); /// // To make a (typed) NULL: -/// let expr = Expr::Literal(ScalarValue::Int64(None)); +/// let expr = Expr::Literal(ScalarValue::Int64(None), None); /// // to make an (untyped) NULL (the optimizer will coerce this to the correct type): /// let expr = lit(ScalarValue::Null); /// ``` @@ -149,7 +151,7 @@ use sqlparser::ast::{ /// if let Expr::BinaryExpr(binary_expr) = expr { /// assert_eq!(*binary_expr.left, col("c1")); /// let scalar = ScalarValue::Int32(Some(42)); -/// assert_eq!(*binary_expr.right, Expr::Literal(scalar)); +/// assert_eq!(*binary_expr.right, Expr::Literal(scalar, None)); /// assert_eq!(binary_expr.op, Operator::Eq); /// } /// ``` @@ -193,7 +195,7 @@ use sqlparser::ast::{ /// ``` /// # use datafusion_expr::{lit, col}; /// let expr = col("c1") + lit(42); -/// assert_eq!(format!("{expr:?}"), "BinaryExpr(BinaryExpr { left: Column(Column { relation: None, name: \"c1\" }), op: Plus, right: Literal(Int32(42)) })"); +/// assert_eq!(format!("{expr:?}"), "BinaryExpr(BinaryExpr { left: Column(Column { relation: None, name: \"c1\" }), op: Plus, right: Literal(Int32(42), None) })"); /// ``` /// /// ## Use the `Display` trait (detailed expression) @@ -239,7 +241,7 @@ use sqlparser::ast::{ /// let mut scalars = HashSet::new(); /// // apply recursively visits all nodes in the expression tree /// expr.apply(|e| { -/// if let Expr::Literal(scalar) = e { +/// if let Expr::Literal(scalar, _) = e { /// scalars.insert(scalar); /// } /// // The return value controls whether to continue visiting the tree @@ -274,7 +276,7 @@ use sqlparser::ast::{ /// assert!(rewritten.transformed); /// // to 42 = 5 AND b = 6 /// assert_eq!(rewritten.data, lit(42).eq(lit(5)).and(col("b").eq(lit(6)))); -#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] +#[derive(Clone, PartialEq, PartialOrd, Eq, Debug, Hash)] pub enum Expr { /// An expression with a specific name. Alias(Alias), @@ -282,8 +284,8 @@ pub enum Expr { Column(Column), /// A named reference to a variable in a registry. ScalarVariable(DataType, Vec), - /// A constant value. - Literal(ScalarValue), + /// A constant value along with associated metadata + Literal(ScalarValue, Option>), /// A binary expression such as "age > 21" BinaryExpr(BinaryExpr), /// LIKE expression @@ -312,27 +314,7 @@ pub enum Expr { Negative(Box), /// Whether an expression is between a given range. Between(Between), - /// The CASE expression is similar to a series of nested if/else and there are two forms that - /// can be used. The first form consists of a series of boolean "when" expressions with - /// corresponding "then" expressions, and an optional "else" expression. - /// - /// ```text - /// CASE WHEN condition THEN result - /// [WHEN ...] - /// [ELSE result] - /// END - /// ``` - /// - /// The second form uses a base expression and then a series of "when" clauses that match on a - /// literal value. - /// - /// ```text - /// CASE expression - /// WHEN value THEN result - /// [WHEN ...] - /// [ELSE result] - /// END - /// ``` + /// A CASE expression (see docs on [`Case`]) Case(Case), /// Casts the expression to a given type and will return a runtime error if the expression cannot be cast. /// This expression is guaranteed to have a fixed type. @@ -340,7 +322,7 @@ pub enum Expr { /// Casts the expression to a given type and will return a null value if the expression cannot be cast. /// This expression is guaranteed to have a fixed type. TryCast(TryCast), - /// Represents the call of a scalar function with a set of arguments. + /// Call a scalar function with a set of arguments. ScalarFunction(ScalarFunction), /// Calls an aggregate function with arguments, and optional /// `ORDER BY`, `FILTER`, `DISTINCT` and `NULL TREATMENT`. @@ -349,8 +331,8 @@ pub enum Expr { /// /// [`ExprFunctionExt`]: crate::expr_fn::ExprFunctionExt AggregateFunction(AggregateFunction), - /// Represents the call of a window function with arguments. - WindowFunction(WindowFunction), + /// Call a window function with a set of arguments. + WindowFunction(Box), /// Returns whether the list contains the expr value. InList(InList), /// EXISTS subquery @@ -378,7 +360,7 @@ pub enum Expr { /// A place holder for parameters in a prepared statement /// (e.g. `$foo` or `$1`) Placeholder(Placeholder), - /// A place holder which hold a reference to a qualified field + /// A placeholder which holds a reference to a qualified field /// in the outer query, used for correlated sub queries. OuterReferenceColumn(DataType, Column), /// Unnest expression @@ -387,7 +369,7 @@ pub enum Expr { impl Default for Expr { fn default() -> Self { - Expr::Literal(ScalarValue::Null) + Expr::Literal(ScalarValue::Null, None) } } @@ -398,6 +380,13 @@ impl From for Expr { } } +/// Create an [`Expr`] from a [`WindowFunction`] +impl From for Expr { + fn from(value: WindowFunction) -> Self { + Expr::WindowFunction(Box::new(value)) + } +} + /// Create an [`Expr`] from an optional qualifier and a [`FieldRef`]. This is /// useful for creating [`Expr`] from a [`DFSchema`]. /// @@ -462,13 +451,13 @@ impl Hash for Alias { } impl PartialOrd for Alias { - fn partial_cmp(&self, other: &Self) -> Option { + fn partial_cmp(&self, other: &Self) -> Option { let cmp = self.expr.partial_cmp(&other.expr); - let Some(std::cmp::Ordering::Equal) = cmp else { + let Some(Ordering::Equal) = cmp else { return cmp; }; let cmp = self.relation.partial_cmp(&other.relation); - let Some(std::cmp::Ordering::Equal) = cmp else { + let Some(Ordering::Equal) = cmp else { return cmp; }; self.name.partial_cmp(&other.name) @@ -564,6 +553,28 @@ impl Display for BinaryExpr { } /// CASE expression +/// +/// The CASE expression is similar to a series of nested if/else and there are two forms that +/// can be used. The first form consists of a series of boolean "when" expressions with +/// corresponding "then" expressions, and an optional "else" expression. +/// +/// ```text +/// CASE WHEN condition THEN result +/// [WHEN ...] +/// [ELSE result] +/// END +/// ``` +/// +/// The second form uses a base expression and then a series of "when" clauses that match on a +/// literal value. +/// +/// ```text +/// CASE expression +/// WHEN value THEN result +/// [WHEN ...] +/// [ELSE result] +/// END +/// ``` #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Hash)] pub struct Case { /// Optional base expression that can be compared to literal values in the "when" expressions @@ -644,7 +655,9 @@ impl Between { } } -/// ScalarFunction expression invokes a built-in scalar function +/// Invoke a [`ScalarUDF`] with a set of arguments +/// +/// [`ScalarUDF`]: crate::ScalarUDF #[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] pub struct ScalarFunction { /// The function @@ -661,7 +674,9 @@ impl ScalarFunction { } impl ScalarFunction { - /// Create a new ScalarFunction expression with a user-defined function (UDF) + /// Create a new `ScalarFunction` from a [`ScalarUDF`] + /// + /// [`ScalarUDF`]: crate::ScalarUDF pub fn new_udf(udf: Arc, args: Vec) -> Self { Self { func: udf, args } } @@ -851,19 +866,19 @@ pub enum WindowFunctionDefinition { impl WindowFunctionDefinition { /// Returns the datatype of the window function - pub fn return_type( + pub fn return_field( &self, - input_expr_types: &[DataType], + input_expr_fields: &[FieldRef], _input_expr_nullable: &[bool], display_name: &str, - ) -> Result { + ) -> Result { match self { WindowFunctionDefinition::AggregateUDF(fun) => { - fun.return_type(input_expr_types) + fun.return_field(input_expr_fields) + } + WindowFunctionDefinition::WindowUDF(fun) => { + fun.field(WindowUDFFieldArgs::new(input_expr_fields, display_name)) } - WindowFunctionDefinition::WindowUDF(fun) => fun - .field(WindowUDFFieldArgs::new(input_expr_types, display_name)) - .map(|field| field.data_type().clone()), } } @@ -882,6 +897,16 @@ impl WindowFunctionDefinition { WindowFunctionDefinition::AggregateUDF(fun) => fun.name(), } } + + /// Return the the inner window simplification function, if any + /// + /// See [`WindowFunctionSimplification`] for more information + pub fn simplify(&self) -> Option { + match self { + WindowFunctionDefinition::AggregateUDF(_) => None, + WindowFunctionDefinition::WindowUDF(udwf) => udwf.simplify(), + } + } } impl Display for WindowFunctionDefinition { @@ -953,6 +978,13 @@ impl WindowFunction { }, } } + + /// Return the the inner window simplification function, if any + /// + /// See [`WindowFunctionSimplification`] for more information + pub fn simplify(&self) -> Option { + self.fun.simplify() + } } /// EXISTS expression @@ -1549,8 +1581,16 @@ impl Expr { |expr| { // f_up: unalias on up so we can remove nested aliases like // `(x as foo) as bar` - if let Expr::Alias(Alias { expr, .. }) = expr { - Ok(Transformed::yes(*expr)) + if let Expr::Alias(alias) = expr { + match alias + .metadata + .as_ref() + .map(|h| h.is_empty()) + .unwrap_or(true) + { + true => Ok(Transformed::yes(*alias.expr)), + false => Ok(Transformed::no(Expr::Alias(alias))), + } } else { Ok(Transformed::no(expr)) } @@ -1790,18 +1830,11 @@ impl Expr { pub fn infer_placeholder_types(self, schema: &DFSchema) -> Result<(Expr, bool)> { let mut has_placeholder = false; self.transform(|mut expr| { - let expr_type = expr.get_type(schema).ok(); - #[expect(deprecated)] match &mut expr { + // Default to assuming the arguments are the same type Expr::BinaryExpr(BinaryExpr { left, op: _, right }) => { - let binary_expr_type = find_first_non_null_data_type_expr_placeholder( - [left.as_ref(), right.as_ref()].into_iter(), - schema, - ); - if let Some(binary_expr_type) = binary_expr_type { - rewrite_placeholder_type(left.as_mut(), &binary_expr_type)?; - rewrite_placeholder_type(right.as_mut(), &binary_expr_type)?; - } + rewrite_placeholder(left.as_mut(), right.as_ref(), schema)?; + rewrite_placeholder(right.as_mut(), left.as_ref(), schema)?; } Expr::Between(Between { expr, @@ -1809,139 +1842,26 @@ impl Expr { low, high, }) => { - let between_type = find_first_non_null_data_type_expr_placeholder( - [low.as_ref(), high.as_ref(), expr.as_ref()].into_iter(), - schema, - ); - if let Some(between_type) = between_type { - rewrite_placeholder_type(expr.as_mut(), &between_type)?; - rewrite_placeholder_type(low.as_mut(), &between_type)?; - rewrite_placeholder_type(high.as_mut(), &between_type)?; - } + rewrite_placeholder(low.as_mut(), expr.as_ref(), schema)?; + rewrite_placeholder(high.as_mut(), expr.as_ref(), schema)?; } Expr::InList(InList { expr, list, negated: _, }) => { - let in_list_type = find_first_non_null_data_type_expr_placeholder( - [expr.as_ref()].into_iter().chain(list.iter()), - schema, - ); - if let Some(in_list_type) = in_list_type { - rewrite_placeholder_type(expr.as_mut(), &in_list_type)?; - for item in list.iter_mut() { - rewrite_placeholder_type(item, &in_list_type)?; - } + for item in list.iter_mut() { + rewrite_placeholder(item, expr.as_ref(), schema)?; } } Expr::Like(Like { expr, pattern, .. }) | Expr::SimilarTo(Like { expr, pattern, .. }) => { - let like_type = find_first_non_null_data_type_expr_placeholder( - [expr.as_ref(), pattern.as_ref()].into_iter(), - schema, - ); - if let Some(like_type) = like_type { - rewrite_placeholder_type(expr.as_mut(), &like_type)?; - rewrite_placeholder_type(pattern.as_mut(), &like_type)?; - } - } - Expr::InSubquery(InSubquery { - expr, - subquery, - negated: _, - }) => { - let subquery_schema = subquery.subquery.schema(); - let fields = subquery_schema.fields(); - - // Subqueries used in IN expressions must have exactly 1 field - // i.e. `SELECT * FROM foo WHERE 'some_val' IN (SELECT val FROM bar)` - if let [first_field] = &fields[..] { - rewrite_placeholder_type(expr.as_mut(), first_field.data_type())?; - } - } - Expr::Case(Case { - expr, - when_then_expr, - else_expr, - }) => { - // If `expr` is present, then it must match the types of each WHEN expression. - // If `expr` is not present, then the type of each WHEN expression must evaluate to a boolean. - // The types of the THEN and ELSE expressions must match `expr_type` (which is the final type of the CASE expression). - let when_type = match expr { - Some(expr) => { - let mut when_type = expr.get_type(schema)?; - if when_type == DataType::Null { - for (when_expr, _) in when_then_expr.iter() { - let when_expr_type = when_expr.get_type(schema)?; - if when_expr_type != DataType::Null { - when_type = when_expr_type; - break; - } - } - } - when_type - } - None => DataType::Boolean, - }; - if let Some(expr) = expr { - rewrite_placeholder_type(expr.as_mut(), &when_type)?; - } - for (when_expr, then_expr) in when_then_expr.iter_mut() { - rewrite_placeholder_type(when_expr.as_mut(), &when_type)?; - if let Some(expr_type) = &expr_type { - rewrite_placeholder_type(then_expr.as_mut(), expr_type)?; - } - } - if let Some(else_expr) = else_expr { - if let Some(expr_type) = &expr_type { - rewrite_placeholder_type(else_expr.as_mut(), expr_type)?; - } - } - } - // These expressions constrain any immediate placeholders to Boolean. - Expr::Not(expr) - | Expr::IsTrue(expr) - | Expr::IsFalse(expr) - | Expr::IsNotTrue(expr) - | Expr::IsNotFalse(expr) - | Expr::IsNotUnknown(expr) => { - rewrite_placeholder_type(expr.as_mut(), &DataType::Boolean)? - } - // Note that the inner cast expression can technically be any data type - // that is coercible to the data type of the cast expression. - // However, returning `data_type` is preferable to returning DataType::Null - // for placeholder inference. - Expr::Cast(Cast { expr, data_type }) - | Expr::TryCast(TryCast { expr, data_type }) => { - rewrite_placeholder_type(expr.as_mut(), data_type)?; - } - // Negative expressions can technically be any numeric data type, but if we have - // an immediate placeholder, let's infer it as Int64. - Expr::Negative(expr) => { - rewrite_placeholder_type(expr.as_mut(), &DataType::Int64)? + rewrite_placeholder(pattern.as_mut(), expr.as_ref(), schema)?; } Expr::Placeholder(_) => { has_placeholder = true; } - // These expressions either cannot contain placeholders or - // do not constrain the type of the placeholder. - Expr::Alias(_) - | Expr::Column(_) - | Expr::ScalarVariable(_, _) - | Expr::Literal(_) - | Expr::IsNotNull(_) - | Expr::IsNull(_) - | Expr::IsUnknown(_) - | Expr::ScalarFunction(_) - | Expr::AggregateFunction(_) - | Expr::WindowFunction(_) - | Expr::Exists(_) - | Expr::ScalarSubquery(_) - | Expr::Wildcard { .. } - | Expr::GroupingSet(_) - | Expr::OuterReferenceColumn(_, _) - | Expr::Unnest(_) => {} + _ => {} } Ok(Transformed::yes(expr)) }) @@ -2243,32 +2163,29 @@ impl NormalizeEq for Expr { _ => false, } } - ( - Expr::WindowFunction(WindowFunction { + (Expr::WindowFunction(left), Expr::WindowFunction(other)) => { + let WindowFunction { fun: self_fun, - params: self_params, - }), - Expr::WindowFunction(WindowFunction { + params: + WindowFunctionParams { + args: self_args, + window_frame: self_window_frame, + partition_by: self_partition_by, + order_by: self_order_by, + null_treatment: self_null_treatment, + }, + } = left.as_ref(); + let WindowFunction { fun: other_fun, - params: other_params, - }), - ) => { - let ( - WindowFunctionParams { - args: self_args, - window_frame: self_window_frame, - partition_by: self_partition_by, - order_by: self_order_by, - null_treatment: self_null_treatment, - }, - WindowFunctionParams { - args: other_args, - window_frame: other_window_frame, - partition_by: other_partition_by, - order_by: other_order_by, - null_treatment: other_null_treatment, - }, - ) = (self_params, other_params); + params: + WindowFunctionParams { + args: other_args, + window_frame: other_window_frame, + partition_by: other_partition_by, + order_by: other_order_by, + null_treatment: other_null_treatment, + }, + } = other.as_ref(); self_fun.name() == other_fun.name() && self_window_frame == other_window_frame @@ -2434,7 +2351,7 @@ impl HashNode for Expr { data_type.hash(state); name.hash(state); } - Expr::Literal(scalar_value) => { + Expr::Literal(scalar_value, _) => { scalar_value.hash(state); } Expr::BinaryExpr(BinaryExpr { @@ -2513,14 +2430,18 @@ impl HashNode for Expr { distinct.hash(state); null_treatment.hash(state); } - Expr::WindowFunction(WindowFunction { fun, params }) => { - let WindowFunctionParams { - args: _args, - partition_by: _, - order_by: _, - window_frame, - null_treatment, - } = params; + Expr::WindowFunction(window_fun) => { + let WindowFunction { + fun, + params: + WindowFunctionParams { + args: _args, + partition_by: _, + order_by: _, + window_frame, + null_treatment, + }, + } = window_fun.as_ref(); fun.hash(state); window_frame.hash(state); null_treatment.hash(state); @@ -2571,10 +2492,20 @@ impl HashNode for Expr { } } -fn rewrite_placeholder_type(expr: &mut Expr, dt: &DataType) -> Result<()> { +fn rewrite_placeholder(expr: &mut Expr, other: &Expr, schema: &DFSchema) -> Result<()> { if let Expr::Placeholder(Placeholder { id: _, data_type }) = expr { if data_type.is_none() { - *data_type = Some(dt.clone()); + let other_dt = other.get_type(schema); + match other_dt { + Err(e) => { + Err(e.context(format!( + "Can not find type of {other} needed to infer type of {expr}" + )))?; + } + Ok(dt) => { + *data_type = Some(dt); + } + } }; } Ok(()) @@ -2613,7 +2544,7 @@ impl Display for SchemaDisplay<'_> { // TODO: remove the next line after `Expr::Wildcard` is removed #[expect(deprecated)] Expr::Column(_) - | Expr::Literal(_) + | Expr::Literal(_, _) | Expr::ScalarVariable(..) | Expr::OuterReferenceColumn(..) | Expr::Placeholder(_) @@ -2624,7 +2555,7 @@ impl Display for SchemaDisplay<'_> { write!(f, "{name}") } Err(e) => { - write!(f, "got error from schema_name {}", e) + write!(f, "got error from schema_name {e}") } } } @@ -2775,7 +2706,7 @@ impl Display for SchemaDisplay<'_> { write!(f, "{name}") } Err(e) => { - write!(f, "got error from schema_name {}", e) + write!(f, "got error from schema_name {e}") } } } @@ -2806,52 +2737,62 @@ impl Display for SchemaDisplay<'_> { Ok(()) } - Expr::WindowFunction(WindowFunction { fun, params }) => match fun { - WindowFunctionDefinition::AggregateUDF(fun) => { - match fun.window_function_schema_name(params) { - Ok(name) => { - write!(f, "{name}") - } - Err(e) => { - write!(f, "got error from window_function_schema_name {}", e) + Expr::WindowFunction(window_fun) => { + let WindowFunction { fun, params } = window_fun.as_ref(); + match fun { + WindowFunctionDefinition::AggregateUDF(fun) => { + match fun.window_function_schema_name(params) { + Ok(name) => { + write!(f, "{name}") + } + Err(e) => { + write!( + f, + "got error from window_function_schema_name {e}" + ) + } } } - } - _ => { - let WindowFunctionParams { - args, - partition_by, - order_by, - window_frame, - null_treatment, - } = params; + _ => { + let WindowFunctionParams { + args, + partition_by, + order_by, + window_frame, + null_treatment, + } = params; - write!( - f, - "{}({})", - fun, - schema_name_from_exprs_comma_separated_without_space(args)? - )?; - - if let Some(null_treatment) = null_treatment { - write!(f, " {}", null_treatment)?; - } - - if !partition_by.is_empty() { write!( f, - " PARTITION BY [{}]", - schema_name_from_exprs(partition_by)? + "{}({})", + fun, + schema_name_from_exprs_comma_separated_without_space(args)? )?; - } - if !order_by.is_empty() { - write!(f, " ORDER BY [{}]", schema_name_from_sorts(order_by)?)?; - }; + if let Some(null_treatment) = null_treatment { + write!(f, " {null_treatment}")?; + } - write!(f, " {window_frame}") + if !partition_by.is_empty() { + write!( + f, + " PARTITION BY [{}]", + schema_name_from_exprs(partition_by)? + )?; + } + + if !order_by.is_empty() { + write!( + f, + " ORDER BY [{}]", + schema_name_from_sorts(order_by)? + )?; + }; + + write!(f, " {window_frame}") + } } - }, + } } } } @@ -2862,7 +2803,7 @@ struct SqlDisplay<'a>(&'a Expr); impl Display for SqlDisplay<'_> { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { match self.0 { - Expr::Literal(scalar) => scalar.fmt(f), + Expr::Literal(scalar, _) => scalar.fmt(f), Expr::Alias(Alias { name, .. }) => write!(f, "{name}"), Expr::Between(Between { expr, @@ -3028,7 +2969,7 @@ impl Display for SqlDisplay<'_> { write!(f, "{name}") } Err(e) => { - write!(f, "got error from schema_name {}", e) + write!(f, "got error from schema_name {e}") } } } @@ -3129,7 +3070,12 @@ impl Display for Expr { write!(f, "{OUTER_REFERENCE_COLUMN_PREFIX}({c})") } Expr::ScalarVariable(_, var_names) => write!(f, "{}", var_names.join(".")), - Expr::Literal(v) => write!(f, "{v:?}"), + Expr::Literal(v, metadata) => { + match metadata.as_ref().map(|m| m.is_empty()).unwrap_or(true) { + false => write!(f, "{v:?} {:?}", metadata.as_ref().unwrap()), + true => write!(f, "{v:?}"), + } + } Expr::Case(case) => { write!(f, "CASE ")?; if let Some(e) = &case.expr { @@ -3186,54 +3132,60 @@ impl Display for Expr { // Expr::ScalarFunction(ScalarFunction { func, args }) => { // write!(f, "{}", func.display_name(args).unwrap()) // } - Expr::WindowFunction(WindowFunction { fun, params }) => match fun { - WindowFunctionDefinition::AggregateUDF(fun) => { - match fun.window_function_display_name(params) { - Ok(name) => { - write!(f, "{}", name) - } - Err(e) => { - write!(f, "got error from window_function_display_name {}", e) + Expr::WindowFunction(window_fun) => { + let WindowFunction { fun, params } = window_fun.as_ref(); + match fun { + WindowFunctionDefinition::AggregateUDF(fun) => { + match fun.window_function_display_name(params) { + Ok(name) => { + write!(f, "{name}") + } + Err(e) => { + write!( + f, + "got error from window_function_display_name {e}" + ) + } } } - } - WindowFunctionDefinition::WindowUDF(fun) => { - let WindowFunctionParams { - args, - partition_by, - order_by, - window_frame, - null_treatment, - } = params; - - fmt_function(f, &fun.to_string(), false, args, true)?; - - if let Some(nt) = null_treatment { - write!(f, "{}", nt)?; - } + WindowFunctionDefinition::WindowUDF(fun) => { + let WindowFunctionParams { + args, + partition_by, + order_by, + window_frame, + null_treatment, + } = params; + + fmt_function(f, &fun.to_string(), false, args, true)?; + + if let Some(nt) = null_treatment { + write!(f, "{nt}")?; + } - if !partition_by.is_empty() { - write!(f, " PARTITION BY [{}]", expr_vec_fmt!(partition_by))?; - } - if !order_by.is_empty() { - write!(f, " ORDER BY [{}]", expr_vec_fmt!(order_by))?; + if !partition_by.is_empty() { + write!(f, " PARTITION BY [{}]", expr_vec_fmt!(partition_by))?; + } + if !order_by.is_empty() { + write!(f, " ORDER BY [{}]", expr_vec_fmt!(order_by))?; + } + write!( + f, + " {} BETWEEN {} AND {}", + window_frame.units, + window_frame.start_bound, + window_frame.end_bound + ) } - write!( - f, - " {} BETWEEN {} AND {}", - window_frame.units, - window_frame.start_bound, - window_frame.end_bound - ) } - }, + } Expr::AggregateFunction(AggregateFunction { func, params }) => { match func.display_name(params) { Ok(name) => { - write!(f, "{}", name) + write!(f, "{name}") } Err(e) => { - write!(f, "got error from display_name {}", e) + write!(f, "got error from display_name {e}") } } } @@ -3372,61 +3324,6 @@ mod test { use sqlparser::ast::{Ident, IdentWithAlias}; use std::any::Any; - #[test] - fn infer_placeholder_in_clause() { - // SELECT * FROM employees WHERE department_id IN ($1, $2, $3); - let column = col("department_id"); - let param_placeholders = vec![ - Expr::Placeholder(Placeholder { - id: "$1".to_string(), - data_type: None, - }), - Expr::Placeholder(Placeholder { - id: "$2".to_string(), - data_type: None, - }), - Expr::Placeholder(Placeholder { - id: "$3".to_string(), - data_type: None, - }), - ]; - let in_list = Expr::InList(InList { - expr: Box::new(column), - list: param_placeholders, - negated: false, - }); - - let schema = Arc::new(Schema::new(vec![ - Field::new("name", DataType::Utf8, true), - Field::new("department_id", DataType::Int32, true), - ])); - let df_schema = DFSchema::try_from(schema).unwrap(); - - let (inferred_expr, contains_placeholder) = - in_list.infer_placeholder_types(&df_schema).unwrap(); - - assert!(contains_placeholder); - - match inferred_expr { - Expr::InList(in_list) => { - for expr in in_list.list { - match expr { - Expr::Placeholder(placeholder) => { - assert_eq!( - placeholder.data_type, - Some(DataType::Int32), - "Placeholder {} should infer Int32", - placeholder.id - ); - } - _ => panic!("Expected Placeholder expression"), - } - } - } - _ => panic!("Expected InList expression"), - } - } - #[test] fn infer_placeholder_in_clause_with_placeholder_expr() { // SELECT * FROM employees WHERE $1 IN (1, 2, 3); @@ -3437,9 +3334,9 @@ mod test { let in_list = Expr::InList(InList { expr: Box::new(placeholder_expr), list: vec![ - Expr::Literal(ScalarValue::Int32(Some(1))), - Expr::Literal(ScalarValue::Int32(Some(2))), - Expr::Literal(ScalarValue::Int32(Some(3))), + Expr::Literal(ScalarValue::Int32(Some(1)), None), + Expr::Literal(ScalarValue::Int32(Some(2)), None), + Expr::Literal(ScalarValue::Int32(Some(3)), None), ], negated: false, }); @@ -3491,7 +3388,7 @@ mod test { let subquery_filter = Expr::BinaryExpr(BinaryExpr { left: Box::new(col("B")), op: Operator::Gt, - right: Box::new(Expr::Literal(ScalarValue::Int32(Some(3)))), + right: Box::new(Expr::Literal(ScalarValue::Int32(Some(3)), None)), }); let subquery_scan = LogicalPlan::TableScan(TableScan { @@ -3553,18 +3450,18 @@ mod test { } #[test] - fn infer_placeholder_like_and_similar_to() { - // name LIKE $1 + fn infer_placeholder_like_and_similar_to_expr_position() { + // $1 LIKE 'pattern%' let schema = Arc::new(Schema::new(vec![Field::new("name", DataType::Utf8, true)])); let df_schema = DFSchema::try_from(schema).unwrap(); let like = Like { - expr: Box::new(col("name")), - pattern: Box::new(Expr::Placeholder(Placeholder { + expr: Box::new(Expr::Placeholder(Placeholder { id: "$1".to_string(), data_type: None, })), + pattern: Box::new(lit("pattern%")), negated: false, case_insensitive: false, escape_char: None, @@ -3574,21 +3471,26 @@ mod test { let (inferred_expr, _) = expr.infer_placeholder_types(&df_schema).unwrap(); match inferred_expr { - Expr::Like(like) => match *like.pattern { + Expr::Like(like) => match *like.expr { Expr::Placeholder(placeholder) => { - assert_eq!(placeholder.data_type, Some(DataType::Utf8)); + assert_eq!( + placeholder.data_type, + Some(DataType::Utf8), + "Placeholder {} should infer Utf8", + placeholder.id + ); } _ => panic!("Expected Placeholder"), }, _ => panic!("Expected Like"), } - // name SIMILAR TO $1 + // $1 SIMILAR TO 'pattern.*' let expr = Expr::SimilarTo(like); let (inferred_expr, _) = expr.infer_placeholder_types(&df_schema).unwrap(); match inferred_expr { - Expr::SimilarTo(like) => match *like.pattern { + Expr::SimilarTo(like) => match *like.expr { Expr::Placeholder(placeholder) => { assert_eq!( placeholder.data_type, @@ -3604,18 +3506,73 @@ mod test { } #[test] - fn infer_placeholder_like_and_similar_to_expr_position() { - // $1 LIKE 'pattern%' + fn infer_placeholder_in_clause() { + // SELECT * FROM employees WHERE department_id IN ($1, $2, $3); + let column = col("department_id"); + let param_placeholders = vec![ + Expr::Placeholder(Placeholder { + id: "$1".to_string(), + data_type: None, + }), + Expr::Placeholder(Placeholder { + id: "$2".to_string(), + data_type: None, + }), + Expr::Placeholder(Placeholder { + id: "$3".to_string(), + data_type: None, + }), + ]; + let in_list = Expr::InList(InList { + expr: Box::new(column), + list: param_placeholders, + negated: false, + }); + + let schema = Arc::new(Schema::new(vec![ + Field::new("name", DataType::Utf8, true), + Field::new("department_id", DataType::Int32, true), + ])); + let df_schema = DFSchema::try_from(schema).unwrap(); + + let (inferred_expr, contains_placeholder) = + in_list.infer_placeholder_types(&df_schema).unwrap(); + + assert!(contains_placeholder); + + match inferred_expr { + Expr::InList(in_list) => { + for expr in in_list.list { + match expr { + Expr::Placeholder(placeholder) => { + assert_eq!( + placeholder.data_type, + Some(DataType::Int32), + "Placeholder {} should infer Int32", + placeholder.id + ); + } + _ => panic!("Expected Placeholder expression"), + } + } + } + _ => panic!("Expected InList expression"), + } + } + + #[test] + fn infer_placeholder_like_and_similar_to() { + // name LIKE $1 let schema = Arc::new(Schema::new(vec![Field::new("name", DataType::Utf8, true)])); let df_schema = DFSchema::try_from(schema).unwrap(); let like = Like { - expr: Box::new(Expr::Placeholder(Placeholder { + expr: Box::new(col("name")), + pattern: Box::new(Expr::Placeholder(Placeholder { id: "$1".to_string(), data_type: None, })), - pattern: Box::new(lit("pattern%")), negated: false, case_insensitive: false, escape_char: None, @@ -3625,26 +3582,21 @@ mod test { let (inferred_expr, _) = expr.infer_placeholder_types(&df_schema).unwrap(); match inferred_expr { - Expr::Like(like) => match *like.expr { + Expr::Like(like) => match *like.pattern { Expr::Placeholder(placeholder) => { - assert_eq!( - placeholder.data_type, - Some(DataType::Utf8), - "Placeholder {} should infer Utf8", - placeholder.id - ); + assert_eq!(placeholder.data_type, Some(DataType::Utf8)); } _ => panic!("Expected Placeholder"), }, _ => panic!("Expected Like"), } - // $1 SIMILAR TO 'pattern.*' + // name SIMILAR TO $1 let expr = Expr::SimilarTo(like); let (inferred_expr, _) = expr.infer_placeholder_types(&df_schema).unwrap(); match inferred_expr { - Expr::SimilarTo(like) => match *like.expr { + Expr::SimilarTo(like) => match *like.pattern { Expr::Placeholder(placeholder) => { assert_eq!( placeholder.data_type, @@ -3676,7 +3628,7 @@ mod test { #[allow(deprecated)] fn format_cast() -> Result<()> { let expr = Expr::Cast(Cast { - expr: Box::new(Expr::Literal(ScalarValue::Float32(Some(1.23)))), + expr: Box::new(Expr::Literal(ScalarValue::Float32(Some(1.23)), None)), data_type: DataType::Utf8, }); let expected_canonical = "CAST(Float32(1.23) AS Utf8)"; @@ -4913,4 +4865,19 @@ mod test { _ => panic!("Expected Negative expression"), } } + + #[test] + fn test_size_of_expr() { + // because Expr is such a widely used struct in DataFusion + // it is important to keep its size as small as possible + // + // If this test fails when you change `Expr`, please try + // `Box`ing the fields to make `Expr` smaller + // See https://github.com/apache/datafusion/issues/16199 for details + assert_eq!(size_of::(), 144); + assert_eq!(size_of::(), 64); + assert_eq!(size_of::(), 24); // 3 ptrs + assert_eq!(size_of::>(), 24); + assert_eq!(size_of::>(), 8); + } } diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 966aba7d1195..e8885ed6b724 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -37,7 +37,7 @@ use crate::{ use arrow::compute::kernels::cast_utils::{ parse_interval_day_time, parse_interval_month_day_nano, parse_interval_year_month, }; -use arrow::datatypes::{DataType, Field}; +use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion_common::{plan_err, Column, Result, ScalarValue, Spans, TableReference}; use datafusion_functions_window_common::field::WindowUDFFieldArgs; use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; @@ -492,6 +492,7 @@ pub fn create_udaf( .into_iter() .enumerate() .map(|(i, t)| Field::new(format!("{i}"), t, true)) + .map(Arc::new) .collect::>(); AggregateUDF::from(SimpleAggregateUDF::new( name, @@ -510,7 +511,7 @@ pub struct SimpleAggregateUDF { signature: Signature, return_type: DataType, accumulator: AccumulatorFactoryFunction, - state_fields: Vec, + state_fields: Vec, } impl Debug for SimpleAggregateUDF { @@ -533,7 +534,7 @@ impl SimpleAggregateUDF { return_type: DataType, volatility: Volatility, accumulator: AccumulatorFactoryFunction, - state_fields: Vec, + state_fields: Vec, ) -> Self { let name = name.into(); let signature = Signature::exact(input_type, volatility); @@ -553,7 +554,7 @@ impl SimpleAggregateUDF { signature: Signature, return_type: DataType, accumulator: AccumulatorFactoryFunction, - state_fields: Vec, + state_fields: Vec, ) -> Self { let name = name.into(); Self { @@ -590,7 +591,7 @@ impl AggregateUDFImpl for SimpleAggregateUDF { (self.accumulator)(acc_args) } - fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { Ok(self.state_fields.clone()) } } @@ -678,28 +679,28 @@ impl WindowUDFImpl for SimpleWindowUDF { (self.partition_evaluator_factory)() } - fn field(&self, field_args: WindowUDFFieldArgs) -> Result { - Ok(Field::new( + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + Ok(Arc::new(Field::new( field_args.name(), self.return_type.clone(), true, - )) + ))) } } pub fn interval_year_month_lit(value: &str) -> Expr { let interval = parse_interval_year_month(value).ok(); - Expr::Literal(ScalarValue::IntervalYearMonth(interval)) + Expr::Literal(ScalarValue::IntervalYearMonth(interval), None) } pub fn interval_datetime_lit(value: &str) -> Expr { let interval = parse_interval_day_time(value).ok(); - Expr::Literal(ScalarValue::IntervalDayTime(interval)) + Expr::Literal(ScalarValue::IntervalDayTime(interval), None) } pub fn interval_month_day_nano_lit(value: &str) -> Expr { let interval = parse_interval_month_day_nano(value).ok(); - Expr::Literal(ScalarValue::IntervalMonthDayNano(interval)) + Expr::Literal(ScalarValue::IntervalMonthDayNano(interval), None) } /// Extensions for configuring [`Expr::AggregateFunction`] or [`Expr::WindowFunction`] @@ -831,14 +832,14 @@ impl ExprFuncBuilder { params: WindowFunctionParams { args, .. }, }) => { let has_order_by = order_by.as_ref().map(|o| !o.is_empty()); - Expr::WindowFunction(WindowFunction { + Expr::from(WindowFunction { fun, params: WindowFunctionParams { args, partition_by: partition_by.unwrap_or_default(), order_by: order_by.unwrap_or_default(), window_frame: window_frame - .unwrap_or(WindowFrame::new(has_order_by)), + .unwrap_or_else(|| WindowFrame::new(has_order_by)), null_treatment, }, }) @@ -895,7 +896,7 @@ impl ExprFunctionExt for Expr { ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf))) } Expr::WindowFunction(udwf) => { - ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf))) + ExprFuncBuilder::new(Some(ExprFuncKind::Window(*udwf))) } _ => ExprFuncBuilder::new(None), }; @@ -935,7 +936,7 @@ impl ExprFunctionExt for Expr { ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf))) } Expr::WindowFunction(udwf) => { - ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf))) + ExprFuncBuilder::new(Some(ExprFuncKind::Window(*udwf))) } _ => ExprFuncBuilder::new(None), }; @@ -948,7 +949,7 @@ impl ExprFunctionExt for Expr { fn partition_by(self, partition_by: Vec) -> ExprFuncBuilder { match self { Expr::WindowFunction(udwf) => { - let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf))); + let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Window(*udwf))); builder.partition_by = Some(partition_by); builder } @@ -959,7 +960,7 @@ impl ExprFunctionExt for Expr { fn window_frame(self, window_frame: WindowFrame) -> ExprFuncBuilder { match self { Expr::WindowFunction(udwf) => { - let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf))); + let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Window(*udwf))); builder.window_frame = Some(window_frame); builder } diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index 90dcbce46b01..f80b8e5a7759 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -354,6 +354,7 @@ mod test { use std::ops::Add; use super::*; + use crate::literal::lit_with_metadata; use crate::{col, lit, Cast}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::tree_node::TreeNodeRewriter; @@ -383,13 +384,17 @@ mod test { // rewrites all "foo" string literals to "bar" let transformer = |expr: Expr| -> Result> { match expr { - Expr::Literal(ScalarValue::Utf8(Some(utf8_val))) => { + Expr::Literal(ScalarValue::Utf8(Some(utf8_val)), metadata) => { let utf8_val = if utf8_val == "foo" { "bar".to_string() } else { utf8_val }; - Ok(Transformed::yes(lit(utf8_val))) + Ok(Transformed::yes(lit_with_metadata( + utf8_val, + metadata + .map(|m| m.into_iter().collect::>()), + ))) } // otherwise, return None _ => Ok(Transformed::no(expr)), diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index a349c83a4934..1973a00a67df 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -22,12 +22,12 @@ use crate::expr::{ WindowFunctionParams, }; use crate::type_coercion::functions::{ - data_types_with_aggregate_udf, data_types_with_scalar_udf, data_types_with_window_udf, + data_types_with_scalar_udf, fields_with_aggregate_udf, fields_with_window_udf, }; -use crate::udf::ReturnTypeArgs; +use crate::udf::ReturnFieldArgs; use crate::{utils, LogicalPlan, Projection, Subquery, WindowFunctionDefinition}; use arrow::compute::can_cast_types; -use arrow::datatypes::{DataType, Field}; +use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion_common::{ not_impl_err, plan_datafusion_err, plan_err, Column, DataFusionError, ExprSchema, Result, Spans, TableReference, @@ -115,7 +115,7 @@ impl ExprSchemable for Expr { Expr::Column(c) => Ok(schema.data_type(c)?.clone()), Expr::OuterReferenceColumn(ty, _) => Ok(ty.clone()), Expr::ScalarVariable(ty, _) => Ok(ty.clone()), - Expr::Literal(l) => Ok(l.data_type()), + Expr::Literal(l, _) => Ok(l.data_type()), Expr::Case(case) => { for (_, then_expr) in &case.when_then_expr { let then_type = then_expr.get_type(schema)?; @@ -158,12 +158,16 @@ impl ExprSchemable for Expr { func, params: AggregateFunctionParams { args, .. }, }) => { - let data_types = args + let fields = args .iter() - .map(|e| e.get_type(schema)) + .map(|e| e.to_field(schema).map(|(_, f)| f)) .collect::>>()?; - let new_types = data_types_with_aggregate_udf(&data_types, func) + let new_fields = fields_with_aggregate_udf(&fields, func) .map_err(|err| { + let data_types = fields + .iter() + .map(|f| f.data_type().clone()) + .collect::>(); plan_datafusion_err!( "{} {}", match err { @@ -176,8 +180,10 @@ impl ExprSchemable for Expr { &data_types ) ) - })?; - Ok(func.return_type(&new_types)?) + })? + .into_iter() + .collect::>(); + Ok(func.return_field(&new_fields)?.data_type().clone()) } Expr::Not(_) | Expr::IsNull(_) @@ -272,7 +278,7 @@ impl ExprSchemable for Expr { Expr::Column(c) => input_schema.nullable(c), Expr::OuterReferenceColumn(_, _) => Ok(true), - Expr::Literal(value) => Ok(value.is_null()), + Expr::Literal(value, _) => Ok(value.is_null()), Expr::Case(case) => { // This expression is nullable if any of the input expressions are nullable let then_nullable = case @@ -341,21 +347,8 @@ impl ExprSchemable for Expr { } fn metadata(&self, schema: &dyn ExprSchema) -> Result> { - match self { - Expr::Column(c) => Ok(schema.metadata(c)?.clone()), - Expr::Alias(Alias { expr, metadata, .. }) => { - let mut ret = expr.metadata(schema)?; - if let Some(metadata) = metadata { - if !metadata.is_empty() { - ret.extend(metadata.clone()); - return Ok(ret); - } - } - Ok(ret) - } - Expr::Cast(Cast { expr, .. }) => expr.metadata(schema), - _ => Ok(HashMap::new()), - } + self.to_field(schema) + .map(|(_, field)| field.metadata().clone()) } /// Returns the datatype and nullability of the expression based on [ExprSchema]. @@ -372,23 +365,73 @@ impl ExprSchemable for Expr { &self, schema: &dyn ExprSchema, ) -> Result<(DataType, bool)> { - match self { - Expr::Alias(Alias { expr, name, .. }) => match &**expr { - Expr::Placeholder(Placeholder { data_type, .. }) => match &data_type { - None => schema - .data_type_and_nullable(&Column::from_name(name)) - .map(|(d, n)| (d.clone(), n)), - Some(dt) => Ok((dt.clone(), expr.nullable(schema)?)), - }, - _ => expr.data_type_and_nullable(schema), - }, - Expr::Negative(expr) => expr.data_type_and_nullable(schema), - Expr::Column(c) => schema - .data_type_and_nullable(c) - .map(|(d, n)| (d.clone(), n)), - Expr::OuterReferenceColumn(ty, _) => Ok((ty.clone(), true)), - Expr::ScalarVariable(ty, _) => Ok((ty.clone(), true)), - Expr::Literal(l) => Ok((l.data_type(), l.is_null())), + let field = self.to_field(schema)?.1; + + Ok((field.data_type().clone(), field.is_nullable())) + } + + /// Returns a [arrow::datatypes::Field] compatible with this expression. + /// + /// So for example, a projected expression `col(c1) + col(c2)` is + /// placed in an output field **named** col("c1 + c2") + fn to_field( + &self, + schema: &dyn ExprSchema, + ) -> Result<(Option, Arc)> { + let (relation, schema_name) = self.qualified_name(); + #[allow(deprecated)] + let field = match self { + Expr::Alias(Alias { + expr, + name, + metadata, + .. + }) => { + let field = match &**expr { + Expr::Placeholder(Placeholder { data_type, .. }) => { + match &data_type { + None => schema + .data_type_and_nullable(&Column::from_name(name)) + .map(|(d, n)| Field::new(&schema_name, d.clone(), n)), + Some(dt) => Ok(Field::new( + &schema_name, + dt.clone(), + expr.nullable(schema)?, + )), + } + } + _ => expr.to_field(schema).map(|(_, f)| f.as_ref().clone()), + }?; + + let mut combined_metadata = expr.metadata(schema)?; + if let Some(metadata) = metadata { + if !metadata.is_empty() { + combined_metadata.extend(metadata.clone()); + } + } + + Ok(Arc::new(field.with_metadata(combined_metadata))) + } + Expr::Negative(expr) => expr.to_field(schema).map(|(_, f)| f), + Expr::Column(c) => schema.field_from_column(c).map(|f| Arc::new(f.clone())), + Expr::OuterReferenceColumn(ty, _) => { + Ok(Arc::new(Field::new(&schema_name, ty.clone(), true))) + } + Expr::ScalarVariable(ty, _) => { + Ok(Arc::new(Field::new(&schema_name, ty.clone(), true))) + } + Expr::Literal(l, metadata) => { + let mut field = Field::new(&schema_name, l.data_type(), l.is_null()); + if let Some(metadata) = metadata { + field = field.with_metadata( + metadata + .iter() + .map(|(k, v)| (k.clone(), v.clone())) + .collect(), + ); + } + Ok(Arc::new(field)) + } Expr::IsNull(_) | Expr::IsNotNull(_) | Expr::IsTrue(_) @@ -397,11 +440,12 @@ impl ExprSchemable for Expr { | Expr::IsNotTrue(_) | Expr::IsNotFalse(_) | Expr::IsNotUnknown(_) - | Expr::Exists { .. } => Ok((DataType::Boolean, false)), - Expr::ScalarSubquery(subquery) => Ok(( - subquery.subquery.schema().field(0).data_type().clone(), - subquery.subquery.schema().field(0).is_nullable(), - )), + | Expr::Exists { .. } => { + Ok(Arc::new(Field::new(&schema_name, DataType::Boolean, false))) + } + Expr::ScalarSubquery(subquery) => { + Ok(Arc::new(subquery.subquery.schema().field(0).clone())) + } Expr::BinaryExpr(BinaryExpr { ref left, ref right, @@ -412,17 +456,63 @@ impl ExprSchemable for Expr { let mut coercer = BinaryTypeCoercer::new(&lhs_type, op, &rhs_type); coercer.set_lhs_spans(left.spans().cloned().unwrap_or_default()); coercer.set_rhs_spans(right.spans().cloned().unwrap_or_default()); - Ok((coercer.get_result_type()?, lhs_nullable || rhs_nullable)) + Ok(Arc::new(Field::new( + &schema_name, + coercer.get_result_type()?, + lhs_nullable || rhs_nullable, + ))) } Expr::WindowFunction(window_function) => { - self.data_type_and_nullable_with_window_function(schema, window_function) + let (dt, nullable) = self.data_type_and_nullable_with_window_function( + schema, + window_function, + )?; + Ok(Arc::new(Field::new(&schema_name, dt, nullable))) + } + Expr::AggregateFunction(aggregate_function) => { + let AggregateFunction { + func, + params: AggregateFunctionParams { args, .. }, + .. + } = aggregate_function; + + let fields = args + .iter() + .map(|e| e.to_field(schema).map(|(_, f)| f)) + .collect::>>()?; + // Verify that function is invoked with correct number and type of arguments as defined in `TypeSignature` + let new_fields = fields_with_aggregate_udf(&fields, func) + .map_err(|err| { + let arg_types = fields + .iter() + .map(|f| f.data_type()) + .cloned() + .collect::>(); + plan_datafusion_err!( + "{} {}", + match err { + DataFusionError::Plan(msg) => msg, + err => err.to_string(), + }, + utils::generate_signature_error_msg( + func.name(), + func.signature().clone(), + &arg_types, + ) + ) + })? + .into_iter() + .collect::>(); + + func.return_field(&new_fields) } Expr::ScalarFunction(ScalarFunction { func, args }) => { - let (arg_types, nullables): (Vec, Vec) = args + let (arg_types, fields): (Vec, Vec>) = args .iter() - .map(|e| e.data_type_and_nullable(schema)) + .map(|e| e.to_field(schema).map(|(_, f)| f)) .collect::>>()? .into_iter() + .map(|f| (f.data_type().clone(), f)) .unzip(); // Verify that function is invoked with correct number and type of arguments as defined in `TypeSignature` let new_data_types = data_types_with_scalar_udf(&arg_types, func) @@ -440,42 +530,54 @@ impl ExprSchemable for Expr { ) ) })?; + let new_fields = fields + .into_iter() + .zip(new_data_types) + .map(|(f, d)| f.as_ref().clone().with_data_type(d)) + .map(Arc::new) + .collect::>(); let arguments = args .iter() .map(|e| match e { - Expr::Literal(sv) => Some(sv), + Expr::Literal(sv, _) => Some(sv), _ => None, }) .collect::>(); - let args = ReturnTypeArgs { - arg_types: &new_data_types, + let args = ReturnFieldArgs { + arg_fields: &new_fields, scalar_arguments: &arguments, - nullables: &nullables, }; - let (return_type, nullable) = - func.return_type_from_args(args)?.into_parts(); - Ok((return_type, nullable)) + func.return_field_from_args(args) } - _ => Ok((self.get_type(schema)?, self.nullable(schema)?)), - } - } - - /// Returns a [arrow::datatypes::Field] compatible with this expression. - /// - /// So for example, a projected expression `col(c1) + col(c2)` is - /// placed in an output field **named** col("c1 + c2") - fn to_field( - &self, - input_schema: &dyn ExprSchema, - ) -> Result<(Option, Arc)> { - let (relation, schema_name) = self.qualified_name(); - let (data_type, nullable) = self.data_type_and_nullable(input_schema)?; - let field = Field::new(schema_name, data_type, nullable) - .with_metadata(self.metadata(input_schema)?) - .into(); - Ok((relation, field)) + // _ => Ok((self.get_type(schema)?, self.nullable(schema)?)), + Expr::Cast(Cast { expr, data_type }) => expr + .to_field(schema) + .map(|(_, f)| f.as_ref().clone().with_data_type(data_type.clone())) + .map(Arc::new), + Expr::Like(_) + | Expr::SimilarTo(_) + | Expr::Not(_) + | Expr::Between(_) + | Expr::Case(_) + | Expr::TryCast(_) + | Expr::InList(_) + | Expr::InSubquery(_) + | Expr::Wildcard { .. } + | Expr::GroupingSet(_) + | Expr::Placeholder(_) + | Expr::Unnest(_) => Ok(Arc::new(Field::new( + &schema_name, + self.get_type(schema)?, + self.nullable(schema)?, + ))), + }?; + + Ok(( + relation, + Arc::new(field.as_ref().clone().with_name(schema_name)), + )) } /// Wraps this expression in a cast to a target [arrow::datatypes::DataType]. @@ -528,13 +630,18 @@ impl Expr { .. } = window_function; - let data_types = args + let fields = args .iter() - .map(|e| e.get_type(schema)) + .map(|e| e.to_field(schema).map(|(_, f)| f)) .collect::>>()?; match fun { WindowFunctionDefinition::AggregateUDF(udaf) => { - let new_types = data_types_with_aggregate_udf(&data_types, udaf) + let data_types = fields + .iter() + .map(|f| f.data_type()) + .cloned() + .collect::>(); + let new_fields = fields_with_aggregate_udf(&fields, udaf) .map_err(|err| { plan_datafusion_err!( "{} {}", @@ -548,16 +655,22 @@ impl Expr { &data_types ) ) - })?; + })? + .into_iter() + .collect::>(); - let return_type = udaf.return_type(&new_types)?; - let nullable = udaf.is_nullable(); + let return_field = udaf.return_field(&new_fields)?; - Ok((return_type, nullable)) + Ok((return_field.data_type().clone(), return_field.is_nullable())) } WindowFunctionDefinition::WindowUDF(udwf) => { - let new_types = - data_types_with_window_udf(&data_types, udwf).map_err(|err| { + let data_types = fields + .iter() + .map(|f| f.data_type()) + .cloned() + .collect::>(); + let new_fields = fields_with_window_udf(&fields, udwf) + .map_err(|err| { plan_datafusion_err!( "{} {}", match err { @@ -570,9 +683,11 @@ impl Expr { &data_types ) ) - })?; + })? + .into_iter() + .collect::>(); let (_, function_name) = self.qualified_name(); - let field_args = WindowUDFFieldArgs::new(&new_types, &function_name); + let field_args = WindowUDFFieldArgs::new(&new_fields, &function_name); udwf.field(field_args) .map(|field| (field.data_type().clone(), field.is_nullable())) @@ -762,29 +877,25 @@ mod tests { #[derive(Debug)] struct MockExprSchema { - nullable: bool, - data_type: DataType, + field: Field, error_on_nullable: bool, - metadata: HashMap, } impl MockExprSchema { fn new() -> Self { Self { - nullable: false, - data_type: DataType::Null, + field: Field::new("mock_field", DataType::Null, false), error_on_nullable: false, - metadata: HashMap::new(), } } fn with_nullable(mut self, nullable: bool) -> Self { - self.nullable = nullable; + self.field = self.field.with_nullable(nullable); self } fn with_data_type(mut self, data_type: DataType) -> Self { - self.data_type = data_type; + self.field = self.field.with_data_type(data_type); self } @@ -794,7 +905,7 @@ mod tests { } fn with_metadata(mut self, metadata: HashMap) -> Self { - self.metadata = metadata; + self.field = self.field.with_metadata(metadata); self } } @@ -804,20 +915,12 @@ mod tests { if self.error_on_nullable { internal_err!("nullable error") } else { - Ok(self.nullable) + Ok(self.field.is_nullable()) } } - fn data_type(&self, _col: &Column) -> Result<&DataType> { - Ok(&self.data_type) - } - - fn metadata(&self, _col: &Column) -> Result<&HashMap> { - Ok(&self.metadata) - } - - fn data_type_and_nullable(&self, col: &Column) -> Result<(&DataType, bool)> { - Ok((self.data_type(col)?, self.nullable(col)?)) + fn field_from_column(&self, _col: &Column) -> Result<&Field> { + Ok(&self.field) } } } diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index d3cc881af361..1f44f755b214 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -94,7 +94,9 @@ pub use function::{ AccumulatorFactoryFunction, PartitionEvaluatorFactory, ReturnTypeFunction, ScalarFunctionImplementation, StateTypeFunction, }; -pub use literal::{lit, lit_timestamp_nano, Literal, TimestampLiteral}; +pub use literal::{ + lit, lit_timestamp_nano, lit_with_metadata, Literal, TimestampLiteral, +}; pub use logical_plan::*; pub use partition_evaluator::PartitionEvaluator; pub use sqlparser; @@ -104,8 +106,7 @@ pub use udaf::{ SetMonotonicity, StatisticsArgs, }; pub use udf::{ - scalar_doc_sections, ReturnInfo, ReturnTypeArgs, ScalarFunctionArgs, ScalarUDF, - ScalarUDFImpl, + scalar_doc_sections, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, }; pub use udwf::{window_doc_sections, ReversedUDWF, WindowUDF, WindowUDFImpl}; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; diff --git a/datafusion/expr/src/literal.rs b/datafusion/expr/src/literal.rs index 90ba5a9a693c..48e058b8b7b1 100644 --- a/datafusion/expr/src/literal.rs +++ b/datafusion/expr/src/literal.rs @@ -19,12 +19,37 @@ use crate::Expr; use datafusion_common::ScalarValue; +use std::collections::HashMap; /// Create a literal expression pub fn lit(n: T) -> Expr { n.lit() } +pub fn lit_with_metadata( + n: T, + metadata: impl Into>>, +) -> Expr { + let metadata = metadata.into(); + let Some(metadata) = metadata else { + return n.lit(); + }; + + let Expr::Literal(sv, prior_metadata) = n.lit() else { + unreachable!(); + }; + + let new_metadata = match prior_metadata { + Some(mut prior) => { + prior.extend(metadata); + prior + } + None => metadata.into_iter().collect(), + }; + + Expr::Literal(sv, Some(new_metadata)) +} + /// Create a literal timestamp expression pub fn lit_timestamp_nano(n: T) -> Expr { n.lit_timestamp_nano() @@ -43,37 +68,37 @@ pub trait TimestampLiteral { impl Literal for &str { fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::from(*self)) + Expr::Literal(ScalarValue::from(*self), None) } } impl Literal for String { fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::from(self.as_ref())) + Expr::Literal(ScalarValue::from(self.as_ref()), None) } } impl Literal for &String { fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::from(self.as_ref())) + Expr::Literal(ScalarValue::from(self.as_ref()), None) } } impl Literal for Vec { fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::Binary(Some((*self).to_owned()))) + Expr::Literal(ScalarValue::Binary(Some((*self).to_owned())), None) } } impl Literal for &[u8] { fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::Binary(Some((*self).to_owned()))) + Expr::Literal(ScalarValue::Binary(Some((*self).to_owned())), None) } } impl Literal for ScalarValue { fn lit(&self) -> Expr { - Expr::Literal(self.clone()) + Expr::Literal(self.clone(), None) } } @@ -82,7 +107,7 @@ macro_rules! make_literal { #[doc = $DOC] impl Literal for $TYPE { fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::$SCALAR(Some(self.clone()))) + Expr::Literal(ScalarValue::$SCALAR(Some(self.clone())), None) } } }; @@ -93,7 +118,7 @@ macro_rules! make_nonzero_literal { #[doc = $DOC] impl Literal for $TYPE { fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::$SCALAR(Some(self.get()))) + Expr::Literal(ScalarValue::$SCALAR(Some(self.get())), None) } } }; @@ -104,10 +129,10 @@ macro_rules! make_timestamp_literal { #[doc = $DOC] impl TimestampLiteral for $TYPE { fn lit_timestamp_nano(&self) -> Expr { - Expr::Literal(ScalarValue::TimestampNanosecond( - Some((self.clone()).into()), + Expr::Literal( + ScalarValue::TimestampNanosecond(Some((self.clone()).into()), None), None, - )) + ) } } }; diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 64931df5a83f..533e81e64f29 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -341,8 +341,11 @@ impl LogicalPlanBuilder { // wrap cast if data type is not same as common type. for row in &mut values { for (j, field_type) in fields.iter().map(|f| f.data_type()).enumerate() { - if let Expr::Literal(ScalarValue::Null) = row[j] { - row[j] = Expr::Literal(ScalarValue::try_from(field_type)?); + if let Expr::Literal(ScalarValue::Null, metadata) = &row[j] { + row[j] = Expr::Literal( + ScalarValue::try_from(field_type)?, + metadata.clone(), + ); } else { row[j] = std::mem::take(&mut row[j]).cast_to(field_type, schema)?; } @@ -501,6 +504,21 @@ impl LogicalPlanBuilder { if table_scan.filters.is_empty() { if let Some(p) = table_scan.source.get_logical_plan() { let sub_plan = p.into_owned(); + + if let Some(proj) = table_scan.projection { + let projection_exprs = proj + .into_iter() + .map(|i| { + Expr::Column(Column::from( + sub_plan.schema().qualified_field(i), + )) + }) + .collect::>(); + return Self::new(sub_plan) + .project(projection_exprs)? + .alias(table_scan.table_name); + } + // Ensures that the reference to the inlined table remains the // same, meaning we don't have to change any of the parent nodes // that reference this table. @@ -586,7 +604,7 @@ impl LogicalPlanBuilder { /// Apply a filter which is used for a having clause pub fn having(self, expr: impl Into) -> Result { let expr = normalize_col(expr.into(), &self.plan)?; - Filter::try_new_with_having(expr, self.plan) + Filter::try_new(expr, self.plan) .map(LogicalPlan::Filter) .map(Self::from) } @@ -1117,19 +1135,29 @@ impl LogicalPlanBuilder { .collect::>()?; let on: Vec<(_, _)> = left_keys.into_iter().zip(right_keys).collect(); - let join_schema = - build_join_schema(self.plan.schema(), right.schema(), &join_type)?; let mut join_on: Vec<(Expr, Expr)> = vec![]; let mut filters: Option = None; for (l, r) in &on { if self.plan.schema().has_column(l) && right.schema().has_column(r) - && can_hash(self.plan.schema().field_from_column(l)?.data_type()) + && can_hash( + datafusion_common::ExprSchema::field_from_column( + self.plan.schema(), + l, + )? + .data_type(), + ) { join_on.push((Expr::Column(l.clone()), Expr::Column(r.clone()))); } else if self.plan.schema().has_column(l) && right.schema().has_column(r) - && can_hash(self.plan.schema().field_from_column(r)?.data_type()) + && can_hash( + datafusion_common::ExprSchema::field_from_column( + self.plan.schema(), + r, + )? + .data_type(), + ) { join_on.push((Expr::Column(r.clone()), Expr::Column(l.clone()))); } else { @@ -1151,33 +1179,33 @@ impl LogicalPlanBuilder { DataFusionError::Internal("filters should not be None here".to_string()) })?) } else { - Ok(Self::new(LogicalPlan::Join(Join { - left: self.plan, - right: Arc::new(right), - on: join_on, - filter: filters, + let join = Join::try_new( + self.plan, + Arc::new(right), + join_on, + filters, join_type, - join_constraint: JoinConstraint::Using, - schema: DFSchemaRef::new(join_schema), - null_equals_null: false, - }))) + JoinConstraint::Using, + false, + )?; + + Ok(Self::new(LogicalPlan::Join(join))) } } /// Apply a cross join pub fn cross_join(self, right: LogicalPlan) -> Result { - let join_schema = - build_join_schema(self.plan.schema(), right.schema(), &JoinType::Inner)?; - Ok(Self::new(LogicalPlan::Join(Join { - left: self.plan, - right: Arc::new(right), - on: vec![], - filter: None, - join_type: JoinType::Inner, - join_constraint: JoinConstraint::On, - null_equals_null: false, - schema: DFSchemaRef::new(join_schema), - }))) + let join = Join::try_new( + self.plan, + Arc::new(right), + vec![], + None, + JoinType::Inner, + JoinConstraint::On, + false, + )?; + + Ok(Self::new(LogicalPlan::Join(join))) } /// Repartition @@ -1338,7 +1366,7 @@ impl LogicalPlanBuilder { /// to columns from the existing input. `r`, the second element of the tuple, /// must only refer to columns from the right input. /// - /// `filter` contains any other other filter expression to apply during the + /// `filter` contains any other filter expression to apply during the /// join. Note that `equi_exprs` predicates are evaluated more efficiently /// than the filter expressions, so they are preferred. pub fn join_with_expr_keys( @@ -1388,19 +1416,17 @@ impl LogicalPlanBuilder { }) .collect::>>()?; - let join_schema = - build_join_schema(self.plan.schema(), right.schema(), &join_type)?; - - Ok(Self::new(LogicalPlan::Join(Join { - left: self.plan, - right: Arc::new(right), - on: join_key_pairs, + let join = Join::try_new( + self.plan, + Arc::new(right), + join_key_pairs, filter, join_type, - join_constraint: JoinConstraint::On, - schema: DFSchemaRef::new(join_schema), - null_equals_null: false, - }))) + JoinConstraint::On, + false, + )?; + + Ok(Self::new(LogicalPlan::Join(join))) } /// Unnest the given column. @@ -1490,7 +1516,7 @@ pub fn change_redundant_column(fields: &Fields) -> Vec { // Loop until we find a name that hasn't been used while seen.contains(&new_name) { *count += 1; - new_name = format!("{}:{}", base_name, count); + new_name = format!("{base_name}:{count}"); } seen.insert(new_name.clone()); @@ -1603,12 +1629,19 @@ pub fn build_join_schema( join_type, left.fields().len(), ); - let metadata = left + + let (schema1, schema2) = match join_type { + JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => (left, right), + _ => (right, left), + }; + + let metadata = schema1 .metadata() .clone() .into_iter() - .chain(right.metadata().clone()) + .chain(schema2.metadata().clone()) .collect(); + let dfschema = DFSchema::new_with_metadata(qualified_fields, metadata)?; dfschema.with_functional_dependencies(func_dependencies) } @@ -2237,6 +2270,7 @@ mod tests { use crate::test::function_stub::sum; use datafusion_common::{Constraint, RecursionUnnestOption, SchemaError}; + use insta::assert_snapshot; #[test] fn plan_builder_simple() -> Result<()> { @@ -2246,11 +2280,11 @@ mod tests { .project(vec![col("id")])? .build()?; - let expected = "Projection: employee_csv.id\ - \n Filter: employee_csv.state = Utf8(\"CO\")\ - \n TableScan: employee_csv projection=[id, state]"; - - assert_eq!(expected, format!("{plan}")); + assert_snapshot!(plan, @r#" + Projection: employee_csv.id + Filter: employee_csv.state = Utf8("CO") + TableScan: employee_csv projection=[id, state] + "#); Ok(()) } @@ -2262,12 +2296,7 @@ mod tests { let plan = LogicalPlanBuilder::scan("employee_csv", table_source(&schema), projection) .unwrap(); - let expected = DFSchema::try_from_qualified_schema( - TableReference::bare("employee_csv"), - &schema, - ) - .unwrap(); - assert_eq!(&expected, plan.schema().as_ref()); + assert_snapshot!(plan.schema().as_ref(), @"fields:[employee_csv.id, employee_csv.first_name, employee_csv.last_name, employee_csv.state, employee_csv.salary], metadata:{}"); // Note scan of "EMPLOYEE_CSV" is treated as a SQL identifier // (and thus normalized to "employee"csv") as well @@ -2275,7 +2304,7 @@ mod tests { let plan = LogicalPlanBuilder::scan("EMPLOYEE_CSV", table_source(&schema), projection) .unwrap(); - assert_eq!(&expected, plan.schema().as_ref()); + assert_snapshot!(plan.schema().as_ref(), @"fields:[employee_csv.id, employee_csv.first_name, employee_csv.last_name, employee_csv.state, employee_csv.salary], metadata:{}"); } #[test] @@ -2284,9 +2313,9 @@ mod tests { let projection = None; let err = LogicalPlanBuilder::scan("", table_source(&schema), projection).unwrap_err(); - assert_eq!( + assert_snapshot!( err.strip_backtrace(), - "Error during planning: table_name cannot be empty" + @"Error during planning: table_name cannot be empty" ); } @@ -2300,10 +2329,10 @@ mod tests { ])? .build()?; - let expected = "Sort: employee_csv.state ASC NULLS FIRST, employee_csv.salary DESC NULLS LAST\ - \n TableScan: employee_csv projection=[state, salary]"; - - assert_eq!(expected, format!("{plan}")); + assert_snapshot!(plan, @r" + Sort: employee_csv.state ASC NULLS FIRST, employee_csv.salary DESC NULLS LAST + TableScan: employee_csv projection=[state, salary] + "); Ok(()) } @@ -2320,15 +2349,15 @@ mod tests { .union(plan.build()?)? .build()?; - let expected = "Union\ - \n Union\ - \n Union\ - \n TableScan: employee_csv projection=[state, salary]\ - \n TableScan: employee_csv projection=[state, salary]\ - \n TableScan: employee_csv projection=[state, salary]\ - \n TableScan: employee_csv projection=[state, salary]"; - - assert_eq!(expected, format!("{plan}")); + assert_snapshot!(plan, @r" + Union + Union + Union + TableScan: employee_csv projection=[state, salary] + TableScan: employee_csv projection=[state, salary] + TableScan: employee_csv projection=[state, salary] + TableScan: employee_csv projection=[state, salary] + "); Ok(()) } @@ -2345,19 +2374,18 @@ mod tests { .union_distinct(plan.build()?)? .build()?; - let expected = "\ - Distinct:\ - \n Union\ - \n Distinct:\ - \n Union\ - \n Distinct:\ - \n Union\ - \n TableScan: employee_csv projection=[state, salary]\ - \n TableScan: employee_csv projection=[state, salary]\ - \n TableScan: employee_csv projection=[state, salary]\ - \n TableScan: employee_csv projection=[state, salary]"; - - assert_eq!(expected, format!("{plan}")); + assert_snapshot!(plan, @r" + Distinct: + Union + Distinct: + Union + Distinct: + Union + TableScan: employee_csv projection=[state, salary] + TableScan: employee_csv projection=[state, salary] + TableScan: employee_csv projection=[state, salary] + TableScan: employee_csv projection=[state, salary] + "); Ok(()) } @@ -2371,13 +2399,12 @@ mod tests { .distinct()? .build()?; - let expected = "\ - Distinct:\ - \n Projection: employee_csv.id\ - \n Filter: employee_csv.state = Utf8(\"CO\")\ - \n TableScan: employee_csv projection=[id, state]"; - - assert_eq!(expected, format!("{plan}")); + assert_snapshot!(plan, @r#" + Distinct: + Projection: employee_csv.id + Filter: employee_csv.state = Utf8("CO") + TableScan: employee_csv projection=[id, state] + "#); Ok(()) } @@ -2397,14 +2424,15 @@ mod tests { .filter(exists(Arc::new(subquery)))? .build()?; - let expected = "Filter: EXISTS ()\ - \n Subquery:\ - \n Filter: foo.a = bar.a\ - \n Projection: foo.a\ - \n TableScan: foo\ - \n Projection: bar.a\ - \n TableScan: bar"; - assert_eq!(expected, format!("{outer_query}")); + assert_snapshot!(outer_query, @r" + Filter: EXISTS () + Subquery: + Filter: foo.a = bar.a + Projection: foo.a + TableScan: foo + Projection: bar.a + TableScan: bar + "); Ok(()) } @@ -2425,14 +2453,15 @@ mod tests { .filter(in_subquery(col("a"), Arc::new(subquery)))? .build()?; - let expected = "Filter: bar.a IN ()\ - \n Subquery:\ - \n Filter: foo.a = bar.a\ - \n Projection: foo.a\ - \n TableScan: foo\ - \n Projection: bar.a\ - \n TableScan: bar"; - assert_eq!(expected, format!("{outer_query}")); + assert_snapshot!(outer_query, @r" + Filter: bar.a IN () + Subquery: + Filter: foo.a = bar.a + Projection: foo.a + TableScan: foo + Projection: bar.a + TableScan: bar + "); Ok(()) } @@ -2452,13 +2481,14 @@ mod tests { .project(vec![scalar_subquery(Arc::new(subquery))])? .build()?; - let expected = "Projection: ()\ - \n Subquery:\ - \n Filter: foo.a = bar.a\ - \n Projection: foo.b\ - \n TableScan: foo\ - \n TableScan: bar"; - assert_eq!(expected, format!("{outer_query}")); + assert_snapshot!(outer_query, @r" + Projection: () + Subquery: + Filter: foo.a = bar.a + Projection: foo.b + TableScan: foo + TableScan: bar + "); Ok(()) } @@ -2552,13 +2582,11 @@ mod tests { let plan2 = table_scan(TableReference::none(), &employee_schema(), Some(vec![3, 4]))?; - let expected = "Error during planning: INTERSECT/EXCEPT query must have the same number of columns. \ - Left is 1 and right is 2."; let err_msg1 = LogicalPlanBuilder::intersect(plan1.build()?, plan2.build()?, true) .unwrap_err(); - assert_eq!(err_msg1.strip_backtrace(), expected); + assert_snapshot!(err_msg1.strip_backtrace(), @"Error during planning: INTERSECT/EXCEPT query must have the same number of columns. Left is 1 and right is 2."); Ok(()) } @@ -2569,19 +2597,29 @@ mod tests { let err = nested_table_scan("test_table")? .unnest_column("scalar") .unwrap_err(); - assert!(err - .to_string() - .starts_with("Internal error: trying to unnest on invalid data type UInt32")); + + let DataFusionError::Internal(desc) = err else { + return plan_err!("Plan should have returned an DataFusionError::Internal"); + }; + + let desc = desc + .split(DataFusionError::BACK_TRACE_SEP) + .collect::>() + .first() + .unwrap_or(&"") + .to_string(); + + assert_snapshot!(desc, @"trying to unnest on invalid data type UInt32"); // Unnesting the strings list. let plan = nested_table_scan("test_table")? .unnest_column("strings")? .build()?; - let expected = "\ - Unnest: lists[test_table.strings|depth=1] structs[]\ - \n TableScan: test_table"; - assert_eq!(expected, format!("{plan}")); + assert_snapshot!(plan, @r" + Unnest: lists[test_table.strings|depth=1] structs[] + TableScan: test_table + "); // Check unnested field is a scalar let field = plan.schema().field_with_name(None, "strings").unwrap(); @@ -2592,16 +2630,16 @@ mod tests { .unnest_column("struct_singular")? .build()?; - let expected = "\ - Unnest: lists[] structs[test_table.struct_singular]\ - \n TableScan: test_table"; - assert_eq!(expected, format!("{plan}")); + assert_snapshot!(plan, @r" + Unnest: lists[] structs[test_table.struct_singular] + TableScan: test_table + "); for field_name in &["a", "b"] { // Check unnested struct field is a scalar let field = plan .schema() - .field_with_name(None, &format!("struct_singular.{}", field_name)) + .field_with_name(None, &format!("struct_singular.{field_name}")) .unwrap(); assert_eq!(&DataType::UInt32, field.data_type()); } @@ -2613,12 +2651,12 @@ mod tests { .unnest_column("struct_singular")? .build()?; - let expected = "\ - Unnest: lists[] structs[test_table.struct_singular]\ - \n Unnest: lists[test_table.structs|depth=1] structs[]\ - \n Unnest: lists[test_table.strings|depth=1] structs[]\ - \n TableScan: test_table"; - assert_eq!(expected, format!("{plan}")); + assert_snapshot!(plan, @r" + Unnest: lists[] structs[test_table.struct_singular] + Unnest: lists[test_table.structs|depth=1] structs[] + Unnest: lists[test_table.strings|depth=1] structs[] + TableScan: test_table + "); // Check unnested struct list field should be a struct. let field = plan.schema().field_with_name(None, "structs").unwrap(); @@ -2634,10 +2672,10 @@ mod tests { .unnest_columns_with_options(cols, UnnestOptions::default())? .build()?; - let expected = "\ - Unnest: lists[test_table.strings|depth=1, test_table.structs|depth=1] structs[test_table.struct_singular]\ - \n TableScan: test_table"; - assert_eq!(expected, format!("{plan}")); + assert_snapshot!(plan, @r" + Unnest: lists[test_table.strings|depth=1, test_table.structs|depth=1] structs[test_table.struct_singular] + TableScan: test_table + "); // Unnesting missing column should fail. let plan = nested_table_scan("test_table")?.unnest_column("missing"); @@ -2661,10 +2699,10 @@ mod tests { )? .build()?; - let expected = "\ - Unnest: lists[test_table.stringss|depth=1, test_table.stringss|depth=2] structs[test_table.struct_singular]\ - \n TableScan: test_table"; - assert_eq!(expected, format!("{plan}")); + assert_snapshot!(plan, @r" + Unnest: lists[test_table.stringss|depth=1, test_table.stringss|depth=2] structs[test_table.struct_singular] + TableScan: test_table + "); // Check output columns has correct type let field = plan @@ -2684,7 +2722,7 @@ mod tests { for field_name in &["a", "b"] { let field = plan .schema() - .field_with_name(None, &format!("struct_singular.{}", field_name)) + .field_with_name(None, &format!("struct_singular.{field_name}")) .unwrap(); assert_eq!(&DataType::UInt32, field.data_type()); } @@ -2736,10 +2774,24 @@ mod tests { let join = LogicalPlanBuilder::from(left).cross_join(right)?.build()?; - let _ = LogicalPlanBuilder::from(join.clone()) + let plan = LogicalPlanBuilder::from(join.clone()) .union(join)? .build()?; + assert_snapshot!(plan, @r" + Union + Cross Join: + SubqueryAlias: left + Values: (Int32(1)) + SubqueryAlias: right + Values: (Int32(1)) + Cross Join: + SubqueryAlias: left + Values: (Int32(1)) + SubqueryAlias: right + Values: (Int32(1)) + "); + Ok(()) } @@ -2799,10 +2851,10 @@ mod tests { .aggregate(vec![col("id")], vec![sum(col("salary"))])? .build()?; - let expected = - "Aggregate: groupBy=[[employee_csv.id]], aggr=[[sum(employee_csv.salary)]]\ - \n TableScan: employee_csv projection=[id, state, salary]"; - assert_eq!(expected, format!("{plan}")); + assert_snapshot!(plan, @r" + Aggregate: groupBy=[[employee_csv.id]], aggr=[[sum(employee_csv.salary)]] + TableScan: employee_csv projection=[id, state, salary] + "); Ok(()) } @@ -2821,10 +2873,37 @@ mod tests { .aggregate(vec![col("id")], vec![sum(col("salary"))])? .build()?; - let expected = - "Aggregate: groupBy=[[employee_csv.id, employee_csv.state, employee_csv.salary]], aggr=[[sum(employee_csv.salary)]]\ - \n TableScan: employee_csv projection=[id, state, salary]"; - assert_eq!(expected, format!("{plan}")); + assert_snapshot!(plan, @r" + Aggregate: groupBy=[[employee_csv.id, employee_csv.state, employee_csv.salary]], aggr=[[sum(employee_csv.salary)]] + TableScan: employee_csv projection=[id, state, salary] + "); + + Ok(()) + } + + #[test] + fn test_join_metadata() -> Result<()> { + let left_schema = DFSchema::new_with_metadata( + vec![(None, Arc::new(Field::new("a", DataType::Int32, false)))], + HashMap::from([("key".to_string(), "left".to_string())]), + )?; + let right_schema = DFSchema::new_with_metadata( + vec![(None, Arc::new(Field::new("b", DataType::Int32, false)))], + HashMap::from([("key".to_string(), "right".to_string())]), + )?; + + let join_schema = + build_join_schema(&left_schema, &right_schema, &JoinType::Left)?; + assert_eq!( + join_schema.metadata(), + &HashMap::from([("key".to_string(), "left".to_string())]) + ); + let join_schema = + build_join_schema(&left_schema, &right_schema, &JoinType::Right)?; + assert_eq!( + join_schema.metadata(), + &HashMap::from([("key".to_string(), "right".to_string())]) + ); Ok(()) } diff --git a/datafusion/expr/src/logical_plan/display.rs b/datafusion/expr/src/logical_plan/display.rs index 14758b61e859..f1e455f46db3 100644 --- a/datafusion/expr/src/logical_plan/display.rs +++ b/datafusion/expr/src/logical_plan/display.rs @@ -341,7 +341,7 @@ impl<'a, 'b> PgJsonVisitor<'a, 'b> { let eclipse = if values.len() > 5 { "..." } else { "" }; - let values_str = format!("{}{}", str_values, eclipse); + let values_str = format!("{str_values}{eclipse}"); json!({ "Node Type": "Values", "Values": values_str @@ -429,7 +429,7 @@ impl<'a, 'b> PgJsonVisitor<'a, 'b> { }) => { let op_str = options .iter() - .map(|(k, v)| format!("{}={}", k, v)) + .map(|(k, v)| format!("{k}={v}")) .collect::>() .join(", "); json!({ @@ -722,13 +722,14 @@ impl<'n> TreeNodeVisitor<'n> for PgJsonVisitor<'_, '_> { #[cfg(test)] mod tests { use arrow::datatypes::{DataType, Field}; + use insta::assert_snapshot; use super::*; #[test] fn test_display_empty_schema() { let schema = Schema::empty(); - assert_eq!("[]", format!("{}", display_schema(&schema))); + assert_snapshot!(display_schema(&schema), @"[]"); } #[test] @@ -738,9 +739,6 @@ mod tests { Field::new("first_name", DataType::Utf8, true), ]); - assert_eq!( - "[id:Int32, first_name:Utf8;N]", - format!("{}", display_schema(&schema)) - ); + assert_snapshot!(display_schema(&schema), @"[id:Int32, first_name:Utf8;N]"); } } diff --git a/datafusion/expr/src/logical_plan/dml.rs b/datafusion/expr/src/logical_plan/dml.rs index d4d50ac4eae4..f3c95e696b4b 100644 --- a/datafusion/expr/src/logical_plan/dml.rs +++ b/datafusion/expr/src/logical_plan/dml.rs @@ -89,8 +89,28 @@ impl Hash for CopyTo { } } -/// The operator that modifies the content of a database (adapted from -/// substrait WriteRel) +/// Modifies the content of a database +/// +/// This operator is used to perform DML operations such as INSERT, DELETE, +/// UPDATE, and CTAS (CREATE TABLE AS SELECT). +/// +/// * `INSERT` - Appends new rows to the existing table. Calls +/// [`TableProvider::insert_into`] +/// +/// * `DELETE` - Removes rows from the table. Currently NOT supported by the +/// [`TableProvider`] trait or builtin sources. +/// +/// * `UPDATE` - Modifies existing rows in the table. Currently NOT supported by +/// the [`TableProvider`] trait or builtin sources. +/// +/// * `CREATE TABLE AS SELECT` - Creates a new table and populates it with data +/// from a query. This is similar to the `INSERT` operation, but it creates a new +/// table instead of modifying an existing one. +/// +/// Note that the structure is adapted from substrait WriteRel) +/// +/// [`TableProvider`]: https://docs.rs/datafusion/latest/datafusion/datasource/trait.TableProvider.html +/// [`TableProvider::insert_into`]: https://docs.rs/datafusion/latest/datafusion/datasource/trait.TableProvider.html#method.insert_into #[derive(Clone)] pub struct DmlStatement { /// The table name @@ -177,11 +197,18 @@ impl PartialOrd for DmlStatement { } } +/// The type of DML operation to perform. +/// +/// See [`DmlStatement`] for more details. #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] pub enum WriteOp { + /// `INSERT INTO` operation Insert(InsertOp), + /// `DELETE` operation Delete, + /// `UPDATE` operation Update, + /// `CREATE TABLE AS SELECT` operation Ctas, } diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index af3deca9b075..321ae103167a 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -630,12 +630,9 @@ impl LogicalPlan { // todo it isn't clear why the schema is not recomputed here Ok(LogicalPlan::Values(Values { schema, values })) } - LogicalPlan::Filter(Filter { - predicate, - input, - having, - }) => Filter::try_new_internal(predicate, input, having) - .map(LogicalPlan::Filter), + LogicalPlan::Filter(Filter { predicate, input }) => { + Filter::try_new(predicate, input).map(LogicalPlan::Filter) + } LogicalPlan::Repartition(_) => Ok(self), LogicalPlan::Window(Window { input, @@ -1308,7 +1305,7 @@ impl LogicalPlan { // Empty group_expr will return Some(1) if group_expr .iter() - .all(|expr| matches!(expr, Expr::Literal(_))) + .all(|expr| matches!(expr, Expr::Literal(_, _))) { Some(1) } else { @@ -1458,7 +1455,7 @@ impl LogicalPlan { let transformed_expr = e.transform_up(|e| { if let Expr::Placeholder(Placeholder { id, .. }) = e { let value = param_values.get_placeholders_with_values(&id)?; - Ok(Transformed::yes(Expr::Literal(value))) + Ok(Transformed::yes(Expr::Literal(value, None))) } else { Ok(Transformed::no(e)) } @@ -1732,7 +1729,7 @@ impl LogicalPlan { LogicalPlan::RecursiveQuery(RecursiveQuery { is_distinct, .. }) => { - write!(f, "RecursiveQuery: is_distinct={}", is_distinct) + write!(f, "RecursiveQuery: is_distinct={is_distinct}") } LogicalPlan::Values(Values { ref values, .. }) => { let str_values: Vec<_> = values @@ -1829,12 +1826,12 @@ impl LogicalPlan { Ok(()) } LogicalPlan::Projection(Projection { ref expr, .. }) => { - write!(f, "Projection: ")?; + write!(f, "Projection:")?; for (i, expr_item) in expr.iter().enumerate() { if i > 0 { - write!(f, ", ")?; + write!(f, ",")?; } - write!(f, "{expr_item}")?; + write!(f, " {expr_item}")?; } Ok(()) } @@ -1975,7 +1972,7 @@ impl LogicalPlan { }; write!( f, - "Limit: skip={}, fetch={}", skip_str,fetch_str, + "Limit: skip={skip_str}, fetch={fetch_str}", ) } LogicalPlan::Subquery(Subquery { .. }) => { @@ -2270,8 +2267,6 @@ pub struct Filter { pub predicate: Expr, /// The incoming logical plan pub input: Arc, - /// The flag to indicate if the filter is a having clause - pub having: bool, } impl Filter { @@ -2280,13 +2275,14 @@ impl Filter { /// Notes: as Aliases have no effect on the output of a filter operator, /// they are removed from the predicate expression. pub fn try_new(predicate: Expr, input: Arc) -> Result { - Self::try_new_internal(predicate, input, false) + Self::try_new_internal(predicate, input) } /// Create a new filter operator for a having clause. /// This is similar to a filter, but its having flag is set to true. + #[deprecated(since = "48.0.0", note = "Use `try_new` instead")] pub fn try_new_with_having(predicate: Expr, input: Arc) -> Result { - Self::try_new_internal(predicate, input, true) + Self::try_new_internal(predicate, input) } fn is_allowed_filter_type(data_type: &DataType) -> bool { @@ -2300,11 +2296,7 @@ impl Filter { } } - fn try_new_internal( - predicate: Expr, - input: Arc, - having: bool, - ) -> Result { + fn try_new_internal(predicate: Expr, input: Arc) -> Result { // Filter predicates must return a boolean value so we try and validate that here. // Note that it is not always possible to resolve the predicate expression during plan // construction (such as with correlated subqueries) so we make a best effort here and @@ -2320,7 +2312,6 @@ impl Filter { Ok(Self { predicate: predicate.unalias_nested().data, input, - having, }) } @@ -2442,18 +2433,23 @@ impl Window { .iter() .enumerate() .filter_map(|(idx, expr)| { - if let Expr::WindowFunction(WindowFunction { + let Expr::WindowFunction(window_fun) = expr else { + return None; + }; + let WindowFunction { fun: WindowFunctionDefinition::WindowUDF(udwf), params: WindowFunctionParams { partition_by, .. }, - }) = expr - { - // When there is no PARTITION BY, row number will be unique - // across the entire table. - if udwf.name() == "row_number" && partition_by.is_empty() { - return Some(idx + input_len); - } + } = window_fun.as_ref() + else { + return None; + }; + // When there is no PARTITION BY, row number will be unique + // across the entire table. + if udwf.name() == "row_number" && partition_by.is_empty() { + Some(idx + input_len) + } else { + None } - None }) .map(|idx| { FunctionalDependence::new(vec![idx], vec![], false) @@ -2713,7 +2709,9 @@ impl Union { { expr.push(Expr::Column(column)); } else { - expr.push(Expr::Literal(ScalarValue::Null).alias(column.name())); + expr.push( + Expr::Literal(ScalarValue::Null, None).alias(column.name()), + ); } } wrapped_inputs.push(Arc::new(LogicalPlan::Projection( @@ -2871,7 +2869,7 @@ impl Union { // Generate unique field name let name = if let Some(count) = name_counts.get_mut(&base_name) { *count += 1; - format!("{}_{}", base_name, count) + format!("{base_name}_{count}") } else { name_counts.insert(base_name.clone(), 0); base_name @@ -3239,7 +3237,7 @@ impl Limit { pub fn get_skip_type(&self) -> Result { match self.skip.as_deref() { Some(expr) => match *expr { - Expr::Literal(ScalarValue::Int64(s)) => { + Expr::Literal(ScalarValue::Int64(s), _) => { // `skip = NULL` is equivalent to `skip = 0` let s = s.unwrap_or(0); if s >= 0 { @@ -3259,14 +3257,16 @@ impl Limit { pub fn get_fetch_type(&self) -> Result { match self.fetch.as_deref() { Some(expr) => match *expr { - Expr::Literal(ScalarValue::Int64(Some(s))) => { + Expr::Literal(ScalarValue::Int64(Some(s)), _) => { if s >= 0 { Ok(FetchType::Literal(Some(s as usize))) } else { plan_err!("LIMIT must be >= 0, '{}' was provided", s) } } - Expr::Literal(ScalarValue::Int64(None)) => Ok(FetchType::Literal(None)), + Expr::Literal(ScalarValue::Int64(None), _) => { + Ok(FetchType::Literal(None)) + } _ => Ok(FetchType::UnsupportedExpr), }, None => Ok(FetchType::Literal(None)), @@ -3668,7 +3668,7 @@ fn calc_func_dependencies_for_project( .unwrap_or(vec![])) } _ => { - let name = format!("{}", expr); + let name = format!("{expr}"); Ok(input_fields .iter() .position(|item| *item == name) @@ -3720,6 +3720,47 @@ pub struct Join { } impl Join { + /// Creates a new Join operator with automatically computed schema. + /// + /// This constructor computes the schema based on the join type and inputs, + /// removing the need to manually specify the schema or call `recompute_schema`. + /// + /// # Arguments + /// + /// * `left` - Left input plan + /// * `right` - Right input plan + /// * `on` - Join condition as a vector of (left_expr, right_expr) pairs + /// * `filter` - Optional filter expression (for non-equijoin conditions) + /// * `join_type` - Type of join (Inner, Left, Right, etc.) + /// * `join_constraint` - Join constraint (On, Using) + /// * `null_equals_null` - Whether NULL = NULL in join comparisons + /// + /// # Returns + /// + /// A new Join operator with the computed schema + pub fn try_new( + left: Arc, + right: Arc, + on: Vec<(Expr, Expr)>, + filter: Option, + join_type: JoinType, + join_constraint: JoinConstraint, + null_equals_null: bool, + ) -> Result { + let join_schema = build_join_schema(left.schema(), right.schema(), &join_type)?; + + Ok(Join { + left, + right, + on, + filter, + join_type, + join_constraint, + schema: Arc::new(join_schema), + null_equals_null, + }) + } + /// Create Join with input which wrapped with projection, this method is used to help create physical join. pub fn try_new_with_project_input( original: &LogicalPlan, @@ -3976,6 +4017,7 @@ mod tests { TransformedResult, TreeNodeRewriter, TreeNodeVisitor, }; use datafusion_common::{not_impl_err, Constraint, ScalarValue}; + use insta::{assert_debug_snapshot, assert_snapshot}; use crate::test::function_stub::count; @@ -4003,13 +4045,13 @@ mod tests { fn test_display_indent() -> Result<()> { let plan = display_plan()?; - let expected = "Projection: employee_csv.id\ - \n Filter: employee_csv.state IN ()\ - \n Subquery:\ - \n TableScan: employee_csv projection=[state]\ - \n TableScan: employee_csv projection=[id, state]"; - - assert_eq!(expected, format!("{}", plan.display_indent())); + assert_snapshot!(plan.display_indent(), @r" + Projection: employee_csv.id + Filter: employee_csv.state IN () + Subquery: + TableScan: employee_csv projection=[state] + TableScan: employee_csv projection=[id, state] + "); Ok(()) } @@ -4017,13 +4059,13 @@ mod tests { fn test_display_indent_schema() -> Result<()> { let plan = display_plan()?; - let expected = "Projection: employee_csv.id [id:Int32]\ - \n Filter: employee_csv.state IN () [id:Int32, state:Utf8]\ - \n Subquery: [state:Utf8]\ - \n TableScan: employee_csv projection=[state] [state:Utf8]\ - \n TableScan: employee_csv projection=[id, state] [id:Int32, state:Utf8]"; - - assert_eq!(expected, format!("{}", plan.display_indent_schema())); + assert_snapshot!(plan.display_indent_schema(), @r" + Projection: employee_csv.id [id:Int32] + Filter: employee_csv.state IN () [id:Int32, state:Utf8] + Subquery: [state:Utf8] + TableScan: employee_csv projection=[state] [state:Utf8] + TableScan: employee_csv projection=[id, state] [id:Int32, state:Utf8] + "); Ok(()) } @@ -4038,12 +4080,12 @@ mod tests { .project(vec![col("id"), exists(plan1).alias("exists")])? .build(); - let expected = "Projection: employee_csv.id, EXISTS () AS exists\ - \n Subquery:\ - \n TableScan: employee_csv projection=[state]\ - \n TableScan: employee_csv projection=[id, state]"; - - assert_eq!(expected, format!("{}", plan?.display_indent())); + assert_snapshot!(plan?.display_indent(), @r" + Projection: employee_csv.id, EXISTS () AS exists + Subquery: + TableScan: employee_csv projection=[state] + TableScan: employee_csv projection=[id, state] + "); Ok(()) } @@ -4051,46 +4093,42 @@ mod tests { fn test_display_graphviz() -> Result<()> { let plan = display_plan()?; - let expected_graphviz = r#" -// Begin DataFusion GraphViz Plan, -// display it online here: https://dreampuf.github.io/GraphvizOnline - -digraph { - subgraph cluster_1 - { - graph[label="LogicalPlan"] - 2[shape=box label="Projection: employee_csv.id"] - 3[shape=box label="Filter: employee_csv.state IN ()"] - 2 -> 3 [arrowhead=none, arrowtail=normal, dir=back] - 4[shape=box label="Subquery:"] - 3 -> 4 [arrowhead=none, arrowtail=normal, dir=back] - 5[shape=box label="TableScan: employee_csv projection=[state]"] - 4 -> 5 [arrowhead=none, arrowtail=normal, dir=back] - 6[shape=box label="TableScan: employee_csv projection=[id, state]"] - 3 -> 6 [arrowhead=none, arrowtail=normal, dir=back] - } - subgraph cluster_7 - { - graph[label="Detailed LogicalPlan"] - 8[shape=box label="Projection: employee_csv.id\nSchema: [id:Int32]"] - 9[shape=box label="Filter: employee_csv.state IN ()\nSchema: [id:Int32, state:Utf8]"] - 8 -> 9 [arrowhead=none, arrowtail=normal, dir=back] - 10[shape=box label="Subquery:\nSchema: [state:Utf8]"] - 9 -> 10 [arrowhead=none, arrowtail=normal, dir=back] - 11[shape=box label="TableScan: employee_csv projection=[state]\nSchema: [state:Utf8]"] - 10 -> 11 [arrowhead=none, arrowtail=normal, dir=back] - 12[shape=box label="TableScan: employee_csv projection=[id, state]\nSchema: [id:Int32, state:Utf8]"] - 9 -> 12 [arrowhead=none, arrowtail=normal, dir=back] - } -} -// End DataFusion GraphViz Plan -"#; - // just test for a few key lines in the output rather than the // whole thing to make test maintenance easier. - let graphviz = format!("{}", plan.display_graphviz()); - - assert_eq!(expected_graphviz, graphviz); + assert_snapshot!(plan.display_graphviz(), @r#" + // Begin DataFusion GraphViz Plan, + // display it online here: https://dreampuf.github.io/GraphvizOnline + + digraph { + subgraph cluster_1 + { + graph[label="LogicalPlan"] + 2[shape=box label="Projection: employee_csv.id"] + 3[shape=box label="Filter: employee_csv.state IN ()"] + 2 -> 3 [arrowhead=none, arrowtail=normal, dir=back] + 4[shape=box label="Subquery:"] + 3 -> 4 [arrowhead=none, arrowtail=normal, dir=back] + 5[shape=box label="TableScan: employee_csv projection=[state]"] + 4 -> 5 [arrowhead=none, arrowtail=normal, dir=back] + 6[shape=box label="TableScan: employee_csv projection=[id, state]"] + 3 -> 6 [arrowhead=none, arrowtail=normal, dir=back] + } + subgraph cluster_7 + { + graph[label="Detailed LogicalPlan"] + 8[shape=box label="Projection: employee_csv.id\nSchema: [id:Int32]"] + 9[shape=box label="Filter: employee_csv.state IN ()\nSchema: [id:Int32, state:Utf8]"] + 8 -> 9 [arrowhead=none, arrowtail=normal, dir=back] + 10[shape=box label="Subquery:\nSchema: [state:Utf8]"] + 9 -> 10 [arrowhead=none, arrowtail=normal, dir=back] + 11[shape=box label="TableScan: employee_csv projection=[state]\nSchema: [state:Utf8]"] + 10 -> 11 [arrowhead=none, arrowtail=normal, dir=back] + 12[shape=box label="TableScan: employee_csv projection=[id, state]\nSchema: [id:Int32, state:Utf8]"] + 9 -> 12 [arrowhead=none, arrowtail=normal, dir=back] + } + } + // End DataFusion GraphViz Plan + "#); Ok(()) } @@ -4098,60 +4136,58 @@ digraph { fn test_display_pg_json() -> Result<()> { let plan = display_plan()?; - let expected_pg_json = r#"[ - { - "Plan": { - "Expressions": [ - "employee_csv.id" - ], - "Node Type": "Projection", - "Output": [ - "id" - ], - "Plans": [ - { - "Condition": "employee_csv.state IN ()", - "Node Type": "Filter", - "Output": [ - "id", - "state" - ], - "Plans": [ - { - "Node Type": "Subquery", + assert_snapshot!(plan.display_pg_json(), @r#" + [ + { + "Plan": { + "Expressions": [ + "employee_csv.id" + ], + "Node Type": "Projection", "Output": [ - "state" + "id" ], "Plans": [ { - "Node Type": "TableScan", + "Condition": "employee_csv.state IN ()", + "Node Type": "Filter", "Output": [ + "id", "state" ], - "Plans": [], - "Relation Name": "employee_csv" + "Plans": [ + { + "Node Type": "Subquery", + "Output": [ + "state" + ], + "Plans": [ + { + "Node Type": "TableScan", + "Output": [ + "state" + ], + "Plans": [], + "Relation Name": "employee_csv" + } + ] + }, + { + "Node Type": "TableScan", + "Output": [ + "id", + "state" + ], + "Plans": [], + "Relation Name": "employee_csv" + } + ] } ] - }, - { - "Node Type": "TableScan", - "Output": [ - "id", - "state" - ], - "Plans": [], - "Relation Name": "employee_csv" } - ] - } - ] - } - } -]"#; - - let pg_json = format!("{}", plan.display_pg_json()); - - assert_eq!(expected_pg_json, pg_json); + } + ] + "#); Ok(()) } @@ -4200,17 +4236,16 @@ digraph { let res = plan.visit_with_subqueries(&mut visitor); assert!(res.is_ok()); - assert_eq!( - visitor.strings, - vec![ - "pre_visit Projection", - "pre_visit Filter", - "pre_visit TableScan", - "post_visit TableScan", - "post_visit Filter", - "post_visit Projection", - ] - ); + assert_debug_snapshot!(visitor.strings, @r#" + [ + "pre_visit Projection", + "pre_visit Filter", + "pre_visit TableScan", + "post_visit TableScan", + "post_visit Filter", + "post_visit Projection", + ] + "#); } #[derive(Debug, Default)] @@ -4276,9 +4311,14 @@ digraph { let res = plan.visit_with_subqueries(&mut visitor); assert!(res.is_ok()); - assert_eq!( + assert_debug_snapshot!( visitor.inner.strings, - vec!["pre_visit Projection", "pre_visit Filter"] + @r#" + [ + "pre_visit Projection", + "pre_visit Filter", + ] + "# ); } @@ -4292,14 +4332,16 @@ digraph { let res = plan.visit_with_subqueries(&mut visitor); assert!(res.is_ok()); - assert_eq!( + assert_debug_snapshot!( visitor.inner.strings, - vec![ - "pre_visit Projection", - "pre_visit Filter", - "pre_visit TableScan", - "post_visit TableScan", - ] + @r#" + [ + "pre_visit Projection", + "pre_visit Filter", + "pre_visit TableScan", + "post_visit TableScan", + ] + "# ); } @@ -4341,13 +4383,18 @@ digraph { }; let plan = test_plan(); let res = plan.visit_with_subqueries(&mut visitor).unwrap_err(); - assert_eq!( - "This feature is not implemented: Error in pre_visit", - res.strip_backtrace() + assert_snapshot!( + res.strip_backtrace(), + @"This feature is not implemented: Error in pre_visit" ); - assert_eq!( + assert_debug_snapshot!( visitor.inner.strings, - vec!["pre_visit Projection", "pre_visit Filter"] + @r#" + [ + "pre_visit Projection", + "pre_visit Filter", + ] + "# ); } @@ -4359,18 +4406,20 @@ digraph { }; let plan = test_plan(); let res = plan.visit_with_subqueries(&mut visitor).unwrap_err(); - assert_eq!( - "This feature is not implemented: Error in post_visit", - res.strip_backtrace() + assert_snapshot!( + res.strip_backtrace(), + @"This feature is not implemented: Error in post_visit" ); - assert_eq!( + assert_debug_snapshot!( visitor.inner.strings, - vec![ - "pre_visit Projection", - "pre_visit Filter", - "pre_visit TableScan", - "post_visit TableScan", - ] + @r#" + [ + "pre_visit Projection", + "pre_visit Filter", + "pre_visit TableScan", + "post_visit TableScan", + ] + "# ); } @@ -4385,7 +4434,7 @@ digraph { })), empty_schema, ); - assert_eq!(p.err().unwrap().strip_backtrace(), "Error during planning: Projection has mismatch between number of expressions (1) and number of fields in schema (0)"); + assert_snapshot!(p.unwrap_err().strip_backtrace(), @"Error during planning: Projection has mismatch between number of expressions (1) and number of fields in schema (0)"); Ok(()) } @@ -4505,7 +4554,7 @@ digraph { let col = schema.field_names()[0].clone(); let filter = Filter::try_new( - Expr::Column(col.into()).eq(Expr::Literal(ScalarValue::Int32(Some(1)))), + Expr::Column(col.into()).eq(Expr::Literal(ScalarValue::Int32(Some(1)), None)), scan, ) .unwrap(); @@ -4572,11 +4621,12 @@ digraph { .data() .unwrap(); - let expected = "Explain\ - \n Filter: foo = Boolean(true)\ - \n TableScan: ?table?"; let actual = format!("{}", plan.display_indent()); - assert_eq!(expected.to_string(), actual) + assert_snapshot!(actual, @r" + Explain + Filter: foo = Boolean(true) + TableScan: ?table? + ") } #[test] @@ -4631,12 +4681,14 @@ digraph { skip: None, fetch: Some(Box::new(Expr::Literal( ScalarValue::new_ten(&DataType::UInt32).unwrap(), + None, ))), input: Arc::clone(&input), }), LogicalPlan::Limit(Limit { skip: Some(Box::new(Expr::Literal( ScalarValue::new_ten(&DataType::UInt32).unwrap(), + None, ))), fetch: None, input: Arc::clone(&input), @@ -4644,9 +4696,11 @@ digraph { LogicalPlan::Limit(Limit { skip: Some(Box::new(Expr::Literal( ScalarValue::new_one(&DataType::UInt32).unwrap(), + None, ))), fetch: Some(Box::new(Expr::Literal( ScalarValue::new_ten(&DataType::UInt32).unwrap(), + None, ))), input, }), @@ -4964,4 +5018,374 @@ digraph { Ok(()) } + + #[test] + fn test_join_try_new() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ]); + + let left_scan = table_scan(Some("t1"), &schema, None)?.build()?; + + let right_scan = table_scan(Some("t2"), &schema, None)?.build()?; + + let join_types = vec![ + JoinType::Inner, + JoinType::Left, + JoinType::Right, + JoinType::Full, + JoinType::LeftSemi, + JoinType::LeftAnti, + JoinType::RightSemi, + JoinType::RightAnti, + JoinType::LeftMark, + ]; + + for join_type in join_types { + let join = Join::try_new( + Arc::new(left_scan.clone()), + Arc::new(right_scan.clone()), + vec![(col("t1.a"), col("t2.a"))], + Some(col("t1.b").gt(col("t2.b"))), + join_type, + JoinConstraint::On, + false, + )?; + + match join_type { + JoinType::LeftSemi | JoinType::LeftAnti => { + assert_eq!(join.schema.fields().len(), 2); + + let fields = join.schema.fields(); + assert_eq!( + fields[0].name(), + "a", + "First field should be 'a' from left table" + ); + assert_eq!( + fields[1].name(), + "b", + "Second field should be 'b' from left table" + ); + } + JoinType::RightSemi | JoinType::RightAnti => { + assert_eq!(join.schema.fields().len(), 2); + + let fields = join.schema.fields(); + assert_eq!( + fields[0].name(), + "a", + "First field should be 'a' from right table" + ); + assert_eq!( + fields[1].name(), + "b", + "Second field should be 'b' from right table" + ); + } + JoinType::LeftMark => { + assert_eq!(join.schema.fields().len(), 3); + + let fields = join.schema.fields(); + assert_eq!( + fields[0].name(), + "a", + "First field should be 'a' from left table" + ); + assert_eq!( + fields[1].name(), + "b", + "Second field should be 'b' from left table" + ); + assert_eq!( + fields[2].name(), + "mark", + "Third field should be the mark column" + ); + + assert!(!fields[0].is_nullable()); + assert!(!fields[1].is_nullable()); + assert!(!fields[2].is_nullable()); + } + _ => { + assert_eq!(join.schema.fields().len(), 4); + + let fields = join.schema.fields(); + assert_eq!( + fields[0].name(), + "a", + "First field should be 'a' from left table" + ); + assert_eq!( + fields[1].name(), + "b", + "Second field should be 'b' from left table" + ); + assert_eq!( + fields[2].name(), + "a", + "Third field should be 'a' from right table" + ); + assert_eq!( + fields[3].name(), + "b", + "Fourth field should be 'b' from right table" + ); + + if join_type == JoinType::Left { + // Left side fields (first two) shouldn't be nullable + assert!(!fields[0].is_nullable()); + assert!(!fields[1].is_nullable()); + // Right side fields (third and fourth) should be nullable + assert!(fields[2].is_nullable()); + assert!(fields[3].is_nullable()); + } else if join_type == JoinType::Right { + // Left side fields (first two) should be nullable + assert!(fields[0].is_nullable()); + assert!(fields[1].is_nullable()); + // Right side fields (third and fourth) shouldn't be nullable + assert!(!fields[2].is_nullable()); + assert!(!fields[3].is_nullable()); + } else if join_type == JoinType::Full { + assert!(fields[0].is_nullable()); + assert!(fields[1].is_nullable()); + assert!(fields[2].is_nullable()); + assert!(fields[3].is_nullable()); + } + } + } + + assert_eq!(join.on, vec![(col("t1.a"), col("t2.a"))]); + assert_eq!(join.filter, Some(col("t1.b").gt(col("t2.b")))); + assert_eq!(join.join_type, join_type); + assert_eq!(join.join_constraint, JoinConstraint::On); + assert!(!join.null_equals_null); + } + + Ok(()) + } + + #[test] + fn test_join_try_new_with_using_constraint_and_overlapping_columns() -> Result<()> { + let left_schema = Schema::new(vec![ + Field::new("id", DataType::Int32, false), // Common column in both tables + Field::new("name", DataType::Utf8, false), // Unique to left + Field::new("value", DataType::Int32, false), // Common column, different meaning + ]); + + let right_schema = Schema::new(vec![ + Field::new("id", DataType::Int32, false), // Common column in both tables + Field::new("category", DataType::Utf8, false), // Unique to right + Field::new("value", DataType::Float64, true), // Common column, different meaning + ]); + + let left_plan = table_scan(Some("t1"), &left_schema, None)?.build()?; + + let right_plan = table_scan(Some("t2"), &right_schema, None)?.build()?; + + // Test 1: USING constraint with a common column + { + // In the logical plan, both copies of the `id` column are preserved + // The USING constraint is handled later during physical execution, where the common column appears once + let join = Join::try_new( + Arc::new(left_plan.clone()), + Arc::new(right_plan.clone()), + vec![(col("t1.id"), col("t2.id"))], + None, + JoinType::Inner, + JoinConstraint::Using, + false, + )?; + + let fields = join.schema.fields(); + + assert_eq!(fields.len(), 6); + + assert_eq!( + fields[0].name(), + "id", + "First field should be 'id' from left table" + ); + assert_eq!( + fields[1].name(), + "name", + "Second field should be 'name' from left table" + ); + assert_eq!( + fields[2].name(), + "value", + "Third field should be 'value' from left table" + ); + assert_eq!( + fields[3].name(), + "id", + "Fourth field should be 'id' from right table" + ); + assert_eq!( + fields[4].name(), + "category", + "Fifth field should be 'category' from right table" + ); + assert_eq!( + fields[5].name(), + "value", + "Sixth field should be 'value' from right table" + ); + + assert_eq!(join.join_constraint, JoinConstraint::Using); + } + + // Test 2: Complex join condition with expressions + { + // Complex condition: join on id equality AND where left.value < right.value + let join = Join::try_new( + Arc::new(left_plan.clone()), + Arc::new(right_plan.clone()), + vec![(col("t1.id"), col("t2.id"))], // Equijoin condition + Some(col("t1.value").lt(col("t2.value"))), // Non-equi filter condition + JoinType::Inner, + JoinConstraint::On, + false, + )?; + + let fields = join.schema.fields(); + assert_eq!(fields.len(), 6); + + assert_eq!( + fields[0].name(), + "id", + "First field should be 'id' from left table" + ); + assert_eq!( + fields[1].name(), + "name", + "Second field should be 'name' from left table" + ); + assert_eq!( + fields[2].name(), + "value", + "Third field should be 'value' from left table" + ); + assert_eq!( + fields[3].name(), + "id", + "Fourth field should be 'id' from right table" + ); + assert_eq!( + fields[4].name(), + "category", + "Fifth field should be 'category' from right table" + ); + assert_eq!( + fields[5].name(), + "value", + "Sixth field should be 'value' from right table" + ); + + assert_eq!(join.filter, Some(col("t1.value").lt(col("t2.value")))); + } + + // Test 3: Join with null equality behavior set to true + { + let join = Join::try_new( + Arc::new(left_plan.clone()), + Arc::new(right_plan.clone()), + vec![(col("t1.id"), col("t2.id"))], + None, + JoinType::Inner, + JoinConstraint::On, + true, + )?; + + assert!(join.null_equals_null); + } + + Ok(()) + } + + #[test] + fn test_join_try_new_schema_validation() -> Result<()> { + let left_schema = Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, false), + Field::new("value", DataType::Float64, true), + ]); + + let right_schema = Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("category", DataType::Utf8, true), + Field::new("code", DataType::Int16, false), + ]); + + let left_plan = table_scan(Some("t1"), &left_schema, None)?.build()?; + + let right_plan = table_scan(Some("t2"), &right_schema, None)?.build()?; + + let join_types = vec![ + JoinType::Inner, + JoinType::Left, + JoinType::Right, + JoinType::Full, + ]; + + for join_type in join_types { + let join = Join::try_new( + Arc::new(left_plan.clone()), + Arc::new(right_plan.clone()), + vec![(col("t1.id"), col("t2.id"))], + Some(col("t1.value").gt(lit(5.0))), + join_type, + JoinConstraint::On, + false, + )?; + + let fields = join.schema.fields(); + assert_eq!(fields.len(), 6, "Expected 6 fields for {join_type:?} join"); + + for (i, field) in fields.iter().enumerate() { + let expected_nullable = match (i, &join_type) { + // Left table fields (indices 0, 1, 2) + (0, JoinType::Right | JoinType::Full) => true, // id becomes nullable in RIGHT/FULL + (1, JoinType::Right | JoinType::Full) => true, // name becomes nullable in RIGHT/FULL + (2, _) => true, // value is already nullable + + // Right table fields (indices 3, 4, 5) + (3, JoinType::Left | JoinType::Full) => true, // id becomes nullable in LEFT/FULL + (4, _) => true, // category is already nullable + (5, JoinType::Left | JoinType::Full) => true, // code becomes nullable in LEFT/FULL + + _ => false, + }; + + assert_eq!( + field.is_nullable(), + expected_nullable, + "Field {} ({}) nullability incorrect for {:?} join", + i, + field.name(), + join_type + ); + } + } + + let using_join = Join::try_new( + Arc::new(left_plan.clone()), + Arc::new(right_plan.clone()), + vec![(col("t1.id"), col("t2.id"))], + None, + JoinType::Inner, + JoinConstraint::Using, + false, + )?; + + assert_eq!( + using_join.schema.fields().len(), + 6, + "USING join should have all fields" + ); + assert_eq!(using_join.join_constraint, JoinConstraint::Using); + + Ok(()) + } } diff --git a/datafusion/expr/src/logical_plan/statement.rs b/datafusion/expr/src/logical_plan/statement.rs index 82acebee3de6..72eb6b39bb47 100644 --- a/datafusion/expr/src/logical_plan/statement.rs +++ b/datafusion/expr/src/logical_plan/statement.rs @@ -110,7 +110,7 @@ impl Statement { Statement::Prepare(Prepare { name, data_types, .. }) => { - write!(f, "Prepare: {name:?} {data_types:?} ") + write!(f, "Prepare: {name:?} {data_types:?}") } Statement::Execute(Execute { name, parameters, .. @@ -123,7 +123,7 @@ impl Statement { ) } Statement::Deallocate(Deallocate { name }) => { - write!(f, "Deallocate: {}", name) + write!(f, "Deallocate: {name}") } } } diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index 7f6e1e025387..2a290e692a7b 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -85,17 +85,9 @@ impl TreeNode for LogicalPlan { schema, }) }), - LogicalPlan::Filter(Filter { - predicate, - input, - having, - }) => input.map_elements(f)?.update_data(|input| { - LogicalPlan::Filter(Filter { - predicate, - input, - having, - }) - }), + LogicalPlan::Filter(Filter { predicate, input }) => input + .map_elements(f)? + .update_data(|input| LogicalPlan::Filter(Filter { predicate, input })), LogicalPlan::Repartition(Repartition { input, partitioning_scheme, @@ -509,17 +501,10 @@ impl LogicalPlan { LogicalPlan::Values(Values { schema, values }) => values .map_elements(f)? .update_data(|values| LogicalPlan::Values(Values { schema, values })), - LogicalPlan::Filter(Filter { - predicate, - input, - having, - }) => f(predicate)?.update_data(|predicate| { - LogicalPlan::Filter(Filter { - predicate, - input, - having, - }) - }), + LogicalPlan::Filter(Filter { predicate, input }) => f(predicate)? + .update_data(|predicate| { + LogicalPlan::Filter(Filter { predicate, input }) + }), LogicalPlan::Repartition(Repartition { input, partitioning_scheme, diff --git a/datafusion/expr/src/planner.rs b/datafusion/expr/src/planner.rs index a2ed0592efdb..4c03f919312e 100644 --- a/datafusion/expr/src/planner.rs +++ b/datafusion/expr/src/planner.rs @@ -312,6 +312,7 @@ pub struct RawWindowExpr { /// Result of planning a raw expr with [`ExprPlanner`] #[derive(Debug, Clone)] +#[allow(clippy::large_enum_variant)] pub enum PlannerResult { /// The raw expression was successfully planned as a new [`Expr`] Planned(Expr), diff --git a/datafusion/expr/src/simplify.rs b/datafusion/expr/src/simplify.rs index 467ce8bf53e2..411dbbdc4034 100644 --- a/datafusion/expr/src/simplify.rs +++ b/datafusion/expr/src/simplify.rs @@ -110,6 +110,7 @@ impl SimplifyInfo for SimplifyContext<'_> { /// Was the expression simplified? #[derive(Debug)] +#[allow(clippy::large_enum_variant)] pub enum ExprSimplifyResult { /// The function call was simplified to an entirely new Expr Simplified(Expr), diff --git a/datafusion/expr/src/test/function_stub.rs b/datafusion/expr/src/test/function_stub.rs index a753f4c376c6..673908a4d7e7 100644 --- a/datafusion/expr/src/test/function_stub.rs +++ b/datafusion/expr/src/test/function_stub.rs @@ -22,7 +22,7 @@ use std::any::Any; use arrow::datatypes::{ - DataType, Field, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, + DataType, FieldRef, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, }; use datafusion_common::{exec_err, not_impl_err, utils::take_function_args, Result}; @@ -175,7 +175,7 @@ impl AggregateUDFImpl for Sum { unreachable!("stub should not have accumulate()") } - fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { unreachable!("stub should not have state_fields()") } @@ -254,7 +254,7 @@ impl AggregateUDFImpl for Count { false } - fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { not_impl_err!("no impl for stub") } @@ -336,7 +336,7 @@ impl AggregateUDFImpl for Min { Ok(DataType::Int64) } - fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { not_impl_err!("no impl for stub") } @@ -421,7 +421,7 @@ impl AggregateUDFImpl for Max { Ok(DataType::Int64) } - fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { not_impl_err!("no impl for stub") } @@ -491,7 +491,7 @@ impl AggregateUDFImpl for Avg { not_impl_err!("no impl for stub") } - fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { not_impl_err!("no impl for stub") } fn aliases(&self) -> &[String] { diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index f20dab7e165f..f953aec5a1e3 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -73,7 +73,7 @@ impl TreeNode for Expr { // Treat OuterReferenceColumn as a leaf expression | Expr::OuterReferenceColumn(_, _) | Expr::ScalarVariable(_, _) - | Expr::Literal(_) + | Expr::Literal(_, _) | Expr::Exists { .. } | Expr::ScalarSubquery(_) | Expr::Wildcard { .. } @@ -92,14 +92,16 @@ impl TreeNode for Expr { (expr, when_then_expr, else_expr).apply_ref_elements(f), Expr::AggregateFunction(AggregateFunction { params: AggregateFunctionParams { args, filter, order_by, ..}, .. }) => (args, filter, order_by).apply_ref_elements(f), - Expr::WindowFunction(WindowFunction { - params : WindowFunctionParams { + Expr::WindowFunction(window_fun) => { + let WindowFunctionParams { args, partition_by, order_by, - ..}, ..}) => { + .. + } = &window_fun.as_ref().params; (args, partition_by, order_by).apply_ref_elements(f) } + Expr::InList(InList { expr, list, .. }) => { (expr, list).apply_ref_elements(f) } @@ -124,7 +126,7 @@ impl TreeNode for Expr { | Expr::Exists { .. } | Expr::ScalarSubquery(_) | Expr::ScalarVariable(_, _) - | Expr::Literal(_) => Transformed::no(self), + | Expr::Literal(_, _) => Transformed::no(self), Expr::Unnest(Unnest { expr, .. }) => expr .map_elements(f)? .update_data(|expr| Expr::Unnest(Unnest { expr })), @@ -230,27 +232,30 @@ impl TreeNode for Expr { ))) })? } - Expr::WindowFunction(WindowFunction { - fun, - params: - WindowFunctionParams { - args, - partition_by, - order_by, - window_frame, - null_treatment, + Expr::WindowFunction(window_fun) => { + let WindowFunction { + fun, + params: + WindowFunctionParams { + args, + partition_by, + order_by, + window_frame, + null_treatment, + }, + } = *window_fun; + (args, partition_by, order_by).map_elements(f)?.update_data( + |(new_args, new_partition_by, new_order_by)| { + Expr::from(WindowFunction::new(fun, new_args)) + .partition_by(new_partition_by) + .order_by(new_order_by) + .window_frame(window_frame) + .null_treatment(null_treatment) + .build() + .unwrap() }, - }) => (args, partition_by, order_by).map_elements(f)?.update_data( - |(new_args, new_partition_by, new_order_by)| { - Expr::WindowFunction(WindowFunction::new(fun, new_args)) - .partition_by(new_partition_by) - .order_by(new_order_by) - .window_frame(window_frame) - .null_treatment(null_treatment) - .build() - .unwrap() - }, - ), + ) + } Expr::AggregateFunction(AggregateFunction { func, params: diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index 3b34718062eb..763a4e6539fd 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -15,19 +15,22 @@ // specific language governing permissions and limitations // under the License. -use super::binary::{binary_numeric_coercion, comparison_coercion}; +use super::binary::binary_numeric_coercion; use crate::{AggregateUDF, ScalarUDF, Signature, TypeSignature, WindowUDF}; +use arrow::datatypes::FieldRef; use arrow::{ compute::can_cast_types, - datatypes::{DataType, Field, TimeUnit}, + datatypes::{DataType, TimeUnit}, }; use datafusion_common::types::LogicalType; -use datafusion_common::utils::{coerced_fixed_size_list_to_list, ListCoercion}; +use datafusion_common::utils::{ + base_type, coerced_fixed_size_list_to_list, ListCoercion, +}; use datafusion_common::{ - exec_err, internal_datafusion_err, internal_err, plan_err, types::NativeType, - utils::list_ndims, Result, + exec_err, internal_err, plan_err, types::NativeType, utils::list_ndims, Result, }; use datafusion_expr_common::signature::ArrayFunctionArgument; +use datafusion_expr_common::type_coercion::binary::type_union_resolution; use datafusion_expr_common::{ signature::{ArrayFunctionSignature, FIXED_SIZE_LIST_WILDCARD, TIMEZONE_WILDCARD}, type_coercion::binary::comparison_coercion_numeric, @@ -75,19 +78,19 @@ pub fn data_types_with_scalar_udf( /// Performs type coercion for aggregate function arguments. /// -/// Returns the data types to which each argument must be coerced to +/// Returns the fields to which each argument must be coerced to /// match `signature`. /// /// For more details on coercion in general, please see the /// [`type_coercion`](crate::type_coercion) module. -pub fn data_types_with_aggregate_udf( - current_types: &[DataType], +pub fn fields_with_aggregate_udf( + current_fields: &[FieldRef], func: &AggregateUDF, -) -> Result> { +) -> Result> { let signature = func.signature(); let type_signature = &signature.type_signature; - if current_types.is_empty() && type_signature != &TypeSignature::UserDefined { + if current_fields.is_empty() && type_signature != &TypeSignature::UserDefined { if type_signature.supports_zero_argument() { return Ok(vec![]); } else if type_signature.used_to_support_zero_arguments() { @@ -97,17 +100,32 @@ pub fn data_types_with_aggregate_udf( return plan_err!("'{}' does not support zero arguments", func.name()); } } + let current_types = current_fields + .iter() + .map(|f| f.data_type()) + .cloned() + .collect::>(); let valid_types = - get_valid_types_with_aggregate_udf(type_signature, current_types, func)?; + get_valid_types_with_aggregate_udf(type_signature, ¤t_types, func)?; if valid_types .iter() - .any(|data_type| data_type == current_types) + .any(|data_type| data_type == ¤t_types) { - return Ok(current_types.to_vec()); + return Ok(current_fields.to_vec()); } - try_coerce_types(func.name(), valid_types, current_types, type_signature) + let updated_types = + try_coerce_types(func.name(), valid_types, ¤t_types, type_signature)?; + + Ok(current_fields + .iter() + .zip(updated_types) + .map(|(current_field, new_type)| { + current_field.as_ref().clone().with_data_type(new_type) + }) + .map(Arc::new) + .collect()) } /// Performs type coercion for window function arguments. @@ -117,14 +135,14 @@ pub fn data_types_with_aggregate_udf( /// /// For more details on coercion in general, please see the /// [`type_coercion`](crate::type_coercion) module. -pub fn data_types_with_window_udf( - current_types: &[DataType], +pub fn fields_with_window_udf( + current_fields: &[FieldRef], func: &WindowUDF, -) -> Result> { +) -> Result> { let signature = func.signature(); let type_signature = &signature.type_signature; - if current_types.is_empty() && type_signature != &TypeSignature::UserDefined { + if current_fields.is_empty() && type_signature != &TypeSignature::UserDefined { if type_signature.supports_zero_argument() { return Ok(vec![]); } else if type_signature.used_to_support_zero_arguments() { @@ -135,16 +153,31 @@ pub fn data_types_with_window_udf( } } + let current_types = current_fields + .iter() + .map(|f| f.data_type()) + .cloned() + .collect::>(); let valid_types = - get_valid_types_with_window_udf(type_signature, current_types, func)?; + get_valid_types_with_window_udf(type_signature, ¤t_types, func)?; if valid_types .iter() - .any(|data_type| data_type == current_types) + .any(|data_type| data_type == ¤t_types) { - return Ok(current_types.to_vec()); + return Ok(current_fields.to_vec()); } - try_coerce_types(func.name(), valid_types, current_types, type_signature) + let updated_types = + try_coerce_types(func.name(), valid_types, ¤t_types, type_signature)?; + + Ok(current_fields + .iter() + .zip(updated_types) + .map(|(current_field, new_type)| { + current_field.as_ref().clone().with_data_type(new_type) + }) + .map(Arc::new) + .collect()) } /// Performs type coercion for function arguments. @@ -364,98 +397,67 @@ fn get_valid_types( return Ok(vec![vec![]]); } - let array_idx = arguments.iter().enumerate().find_map(|(idx, arg)| { - if *arg == ArrayFunctionArgument::Array { - Some(idx) - } else { - None - } - }); - let Some(array_idx) = array_idx else { - return Err(internal_datafusion_err!("Function '{function_name}' expected at least one argument array argument")); - }; - let Some(array_type) = array(¤t_types[array_idx]) else { - return Ok(vec![vec![]]); - }; - - // We need to find the coerced base type, mainly for cases like: - // `array_append(List(null), i64)` -> `List(i64)` - let mut new_base_type = datafusion_common::utils::base_type(&array_type); - for (current_type, argument_type) in current_types.iter().zip(arguments.iter()) { - match argument_type { - ArrayFunctionArgument::Element | ArrayFunctionArgument::Array => { - new_base_type = - coerce_array_types(function_name, current_type, &new_base_type)?; + let mut large_list = false; + let mut fixed_size = array_coercion != Some(&ListCoercion::FixedSizedListToList); + let mut list_sizes = Vec::with_capacity(arguments.len()); + let mut element_types = Vec::with_capacity(arguments.len()); + for (argument, current_type) in arguments.iter().zip(current_types.iter()) { + match argument { + ArrayFunctionArgument::Index | ArrayFunctionArgument::String => (), + ArrayFunctionArgument::Element => { + element_types.push(current_type.clone()) } - ArrayFunctionArgument::Index | ArrayFunctionArgument::String => {} + ArrayFunctionArgument::Array => match current_type { + DataType::Null => element_types.push(DataType::Null), + DataType::List(field) => { + element_types.push(field.data_type().clone()); + fixed_size = false; + } + DataType::LargeList(field) => { + element_types.push(field.data_type().clone()); + large_list = true; + fixed_size = false; + } + DataType::FixedSizeList(field, size) => { + element_types.push(field.data_type().clone()); + list_sizes.push(*size) + } + arg_type => { + plan_err!("{function_name} does not support type {arg_type}")? + } + }, } } - let new_array_type = datafusion_common::utils::coerced_type_with_base_type_only( - &array_type, - &new_base_type, - array_coercion, - ); - let new_elem_type = match new_array_type { - DataType::List(ref field) - | DataType::LargeList(ref field) - | DataType::FixedSizeList(ref field, _) => field.data_type(), - _ => return Ok(vec![vec![]]), + let Some(element_type) = type_union_resolution(&element_types) else { + return Ok(vec![vec![]]); }; - let mut valid_types = Vec::with_capacity(arguments.len()); - for (current_type, argument_type) in current_types.iter().zip(arguments.iter()) { - let valid_type = match argument_type { - ArrayFunctionArgument::Element => new_elem_type.clone(), + if !fixed_size { + list_sizes.clear() + } + + let mut list_sizes = list_sizes.into_iter(); + let valid_types = arguments.iter().zip(current_types.iter()).map( + |(argument_type, current_type)| match argument_type { ArrayFunctionArgument::Index => DataType::Int64, ArrayFunctionArgument::String => DataType::Utf8, + ArrayFunctionArgument::Element => element_type.clone(), ArrayFunctionArgument::Array => { - let Some(current_type) = array(current_type) else { - return Ok(vec![vec![]]); - }; - let new_type = - datafusion_common::utils::coerced_type_with_base_type_only( - ¤t_type, - &new_base_type, - array_coercion, - ); - // All array arguments must be coercible to the same type - if new_type != new_array_type { - return Ok(vec![vec![]]); + if current_type.is_null() { + DataType::Null + } else if large_list { + DataType::new_large_list(element_type.clone(), true) + } else if let Some(size) = list_sizes.next() { + DataType::new_fixed_size_list(element_type.clone(), size, true) + } else { + DataType::new_list(element_type.clone(), true) } - new_type } - }; - valid_types.push(valid_type); - } - - Ok(vec![valid_types]) - } - - fn array(array_type: &DataType) -> Option { - match array_type { - DataType::List(_) | DataType::LargeList(_) => Some(array_type.clone()), - DataType::FixedSizeList(field, _) => Some(DataType::List(Arc::clone(field))), - DataType::Null => Some(DataType::List(Arc::new(Field::new_list_field( - DataType::Int64, - true, - )))), - _ => None, - } - } + }, + ); - fn coerce_array_types( - function_name: &str, - current_type: &DataType, - base_type: &DataType, - ) -> Result { - let current_base_type = datafusion_common::utils::base_type(current_type); - let new_base_type = comparison_coercion(base_type, ¤t_base_type); - new_base_type.ok_or_else(|| { - internal_datafusion_err!( - "Function '{function_name}' does not support coercion from {base_type:?} to {current_base_type:?}" - ) - }) + Ok(vec![valid_types.collect()]) } fn recursive_array(array_type: &DataType) -> Option { @@ -800,7 +802,7 @@ pub fn can_coerce_from(type_into: &DataType, type_from: &DataType) -> bool { /// /// Expect uni-directional coercion, for example, i32 is coerced to i64, but i64 is not coerced to i32. /// -/// Unlike [comparison_coercion], the coerced type is usually `wider` for lossless conversion. +/// Unlike [crate::binary::comparison_coercion], the coerced type is usually `wider` for lossless conversion. fn coerced_from<'a>( type_into: &'a DataType, type_from: &'a DataType, @@ -867,7 +869,7 @@ fn coerced_from<'a>( // Only accept list and largelist with the same number of dimensions unless the type is Null. // List or LargeList with different dimensions should be handled in TypeSignature or other places before this (List(_) | LargeList(_), _) - if datafusion_common::utils::base_type(type_from).eq(&Null) + if base_type(type_from).is_null() || list_ndims(type_from) == list_ndims(type_into) => { Some(type_into.clone()) @@ -906,7 +908,6 @@ fn coerced_from<'a>( #[cfg(test)] mod tests { - use crate::Volatility; use super::*; @@ -1193,4 +1194,155 @@ mod tests { Some(type_into.clone()) ); } + + #[test] + fn test_get_valid_types_array_and_array() -> Result<()> { + let function = "array_and_array"; + let signature = Signature::arrays( + 2, + Some(ListCoercion::FixedSizedListToList), + Volatility::Immutable, + ); + + let data_types = vec![ + DataType::new_list(DataType::Int32, true), + DataType::new_large_list(DataType::Float64, true), + ]; + assert_eq!( + get_valid_types(function, &signature.type_signature, &data_types)?, + vec![vec![ + DataType::new_large_list(DataType::Float64, true), + DataType::new_large_list(DataType::Float64, true), + ]] + ); + + let data_types = vec![ + DataType::new_fixed_size_list(DataType::Int64, 3, true), + DataType::new_fixed_size_list(DataType::Int32, 5, true), + ]; + assert_eq!( + get_valid_types(function, &signature.type_signature, &data_types)?, + vec![vec![ + DataType::new_list(DataType::Int64, true), + DataType::new_list(DataType::Int64, true), + ]] + ); + + let data_types = vec![ + DataType::new_fixed_size_list(DataType::Null, 3, true), + DataType::new_large_list(DataType::Utf8, true), + ]; + assert_eq!( + get_valid_types(function, &signature.type_signature, &data_types)?, + vec![vec![ + DataType::new_large_list(DataType::Utf8, true), + DataType::new_large_list(DataType::Utf8, true), + ]] + ); + + Ok(()) + } + + #[test] + fn test_get_valid_types_array_and_element() -> Result<()> { + let function = "array_and_element"; + let signature = Signature::array_and_element(Volatility::Immutable); + + let data_types = + vec![DataType::new_list(DataType::Int32, true), DataType::Float64]; + assert_eq!( + get_valid_types(function, &signature.type_signature, &data_types)?, + vec![vec![ + DataType::new_list(DataType::Float64, true), + DataType::Float64, + ]] + ); + + let data_types = vec![ + DataType::new_large_list(DataType::Int32, true), + DataType::Null, + ]; + assert_eq!( + get_valid_types(function, &signature.type_signature, &data_types)?, + vec![vec![ + DataType::new_large_list(DataType::Int32, true), + DataType::Int32, + ]] + ); + + let data_types = vec![ + DataType::new_fixed_size_list(DataType::Null, 3, true), + DataType::Utf8, + ]; + assert_eq!( + get_valid_types(function, &signature.type_signature, &data_types)?, + vec![vec![ + DataType::new_list(DataType::Utf8, true), + DataType::Utf8, + ]] + ); + + Ok(()) + } + + #[test] + fn test_get_valid_types_element_and_array() -> Result<()> { + let function = "element_and_array"; + let signature = Signature::element_and_array(Volatility::Immutable); + + let data_types = vec![ + DataType::new_large_list(DataType::Null, false), + DataType::new_list(DataType::new_list(DataType::Int64, true), true), + ]; + assert_eq!( + get_valid_types(function, &signature.type_signature, &data_types)?, + vec![vec![ + DataType::new_large_list(DataType::Int64, true), + DataType::new_list(DataType::new_large_list(DataType::Int64, true), true), + ]] + ); + + Ok(()) + } + + #[test] + fn test_get_valid_types_fixed_size_arrays() -> Result<()> { + let function = "fixed_size_arrays"; + let signature = Signature::arrays(2, None, Volatility::Immutable); + + let data_types = vec![ + DataType::new_fixed_size_list(DataType::Int64, 3, true), + DataType::new_fixed_size_list(DataType::Int32, 5, true), + ]; + assert_eq!( + get_valid_types(function, &signature.type_signature, &data_types)?, + vec![vec![ + DataType::new_fixed_size_list(DataType::Int64, 3, true), + DataType::new_fixed_size_list(DataType::Int64, 5, true), + ]] + ); + + let data_types = vec![ + DataType::new_fixed_size_list(DataType::Int64, 3, true), + DataType::new_list(DataType::Int32, true), + ]; + assert_eq!( + get_valid_types(function, &signature.type_signature, &data_types)?, + vec![vec![ + DataType::new_list(DataType::Int64, true), + DataType::new_list(DataType::Int64, true), + ]] + ); + + let data_types = vec![ + DataType::new_fixed_size_list(DataType::Utf8, 3, true), + DataType::new_list(DataType::new_list(DataType::Int32, true), true), + ]; + assert_eq!( + get_valid_types(function, &signature.type_signature, &data_types)?, + vec![vec![]] + ); + + Ok(()) + } } diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index b75e8fd3cd3c..d1bf45ce2fe8 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -24,7 +24,7 @@ use std::hash::{DefaultHasher, Hash, Hasher}; use std::sync::Arc; use std::vec; -use arrow::datatypes::{DataType, Field}; +use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue, Statistics}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; @@ -224,6 +224,13 @@ impl AggregateUDF { self.inner.return_type(args) } + /// Return the field of the function given its input fields + /// + /// See [`AggregateUDFImpl::return_field`] for more details. + pub fn return_field(&self, args: &[FieldRef]) -> Result { + self.inner.return_field(args) + } + /// Return an accumulator the given aggregate, given its return datatype pub fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { self.inner.accumulator(acc_args) @@ -234,7 +241,7 @@ impl AggregateUDF { /// for more details. /// /// This is used to support multi-phase aggregations - pub fn state_fields(&self, args: StateFieldsArgs) -> Result> { + pub fn state_fields(&self, args: StateFieldsArgs) -> Result> { self.inner.state_fields(args) } @@ -315,6 +322,16 @@ impl AggregateUDF { self.inner.default_value(data_type) } + /// See [`AggregateUDFImpl::supports_null_handling_clause`] for more details. + pub fn supports_null_handling_clause(&self) -> bool { + self.inner.supports_null_handling_clause() + } + + /// See [`AggregateUDFImpl::is_ordered_set_aggregate`] for more details. + pub fn is_ordered_set_aggregate(&self) -> bool { + self.inner.is_ordered_set_aggregate() + } + /// Returns the documentation for this Aggregate UDF. /// /// Documentation can be accessed programmatically as well as @@ -346,8 +363,8 @@ where /// # Basic Example /// ``` /// # use std::any::Any; -/// # use std::sync::LazyLock; -/// # use arrow::datatypes::DataType; +/// # use std::sync::{Arc, LazyLock}; +/// # use arrow::datatypes::{DataType, FieldRef}; /// # use datafusion_common::{DataFusionError, plan_err, Result}; /// # use datafusion_expr::{col, ColumnarValue, Signature, Volatility, Expr, Documentation}; /// # use datafusion_expr::{AggregateUDFImpl, AggregateUDF, Accumulator, function::{AccumulatorArgs, StateFieldsArgs}}; @@ -391,10 +408,10 @@ where /// } /// // This is the accumulator factory; DataFusion uses it to create new accumulators. /// fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { unimplemented!() } -/// fn state_fields(&self, args: StateFieldsArgs) -> Result> { +/// fn state_fields(&self, args: StateFieldsArgs) -> Result> { /// Ok(vec![ -/// Field::new("value", args.return_type.clone(), true), -/// Field::new("ordering", DataType::UInt32, true) +/// Arc::new(args.return_field.as_ref().clone().with_name("value")), +/// Arc::new(Field::new("ordering", DataType::UInt32, true)) /// ]) /// } /// fn documentation(&self) -> Option<&Documentation> { @@ -432,6 +449,14 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { null_treatment, } = params; + // exclude the first function argument(= column) in ordered set aggregate function, + // because it is duplicated with the WITHIN GROUP clause in schema name. + let args = if self.is_ordered_set_aggregate() { + &args[1..] + } else { + &args[..] + }; + let mut schema_name = String::new(); schema_name.write_fmt(format_args!( @@ -442,7 +467,7 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { ))?; if let Some(null_treatment) = null_treatment { - schema_name.write_fmt(format_args!(" {}", null_treatment))?; + schema_name.write_fmt(format_args!(" {null_treatment}"))?; } if let Some(filter) = filter { @@ -450,8 +475,14 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { }; if let Some(order_by) = order_by { + let clause = match self.is_ordered_set_aggregate() { + true => "WITHIN GROUP", + false => "ORDER BY", + }; + schema_name.write_fmt(format_args!( - " ORDER BY [{}]", + " {} [{}]", + clause, schema_name_from_sorts(order_by)? ))?; }; @@ -481,7 +512,7 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { ))?; if let Some(null_treatment) = null_treatment { - schema_name.write_fmt(format_args!(" {}", null_treatment))?; + schema_name.write_fmt(format_args!(" {null_treatment}"))?; } if let Some(filter) = filter { @@ -525,7 +556,7 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { ))?; if let Some(null_treatment) = null_treatment { - schema_name.write_fmt(format_args!(" {}", null_treatment))?; + schema_name.write_fmt(format_args!(" {null_treatment}"))?; } if !partition_by.is_empty() { @@ -572,7 +603,7 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { ))?; if let Some(nt) = null_treatment { - display_name.write_fmt(format_args!(" {}", nt))?; + display_name.write_fmt(format_args!(" {nt}"))?; } if let Some(fe) = filter { display_name.write_fmt(format_args!(" FILTER (WHERE {fe})"))?; @@ -619,7 +650,7 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { ))?; if let Some(null_treatment) = null_treatment { - display_name.write_fmt(format_args!(" {}", null_treatment))?; + display_name.write_fmt(format_args!(" {null_treatment}"))?; } if !partition_by.is_empty() { @@ -650,6 +681,35 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { /// the arguments fn return_type(&self, arg_types: &[DataType]) -> Result; + /// What type will be returned by this function, given the arguments? + /// + /// By default, this function calls [`Self::return_type`] with the + /// types of each argument. + /// + /// # Notes + /// + /// Most UDFs should implement [`Self::return_type`] and not this + /// function as the output type for most functions only depends on the types + /// of their inputs (e.g. `sum(f64)` is always `f64`). + /// + /// This function can be used for more advanced cases such as: + /// + /// 1. specifying nullability + /// 2. return types based on the **values** of the arguments (rather than + /// their **types**. + /// 3. return types based on metadata within the fields of the inputs + fn return_field(&self, arg_fields: &[FieldRef]) -> Result { + let arg_types: Vec<_> = + arg_fields.iter().map(|f| f.data_type()).cloned().collect(); + let data_type = self.return_type(&arg_types)?; + + Ok(Arc::new(Field::new( + self.name(), + data_type, + self.is_nullable(), + ))) + } + /// Whether the aggregate function is nullable. /// /// Nullable means that the function could return `null` for any inputs. @@ -688,15 +748,16 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { /// The name of the fields must be unique within the query and thus should /// be derived from `name`. See [`format_state_name`] for a utility function /// to generate a unique name. - fn state_fields(&self, args: StateFieldsArgs) -> Result> { - let fields = vec![Field::new( - format_state_name(args.name, "value"), - args.return_type.clone(), - true, - )]; + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + let fields = vec![args + .return_field + .as_ref() + .clone() + .with_name(format_state_name(args.name, "value"))]; Ok(fields .into_iter() + .map(Arc::new) .chain(args.ordering_fields.to_vec()) .collect()) } @@ -891,6 +952,18 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { ScalarValue::try_from(data_type) } + /// If this function supports `[IGNORE NULLS | RESPECT NULLS]` clause, return true + /// If the function does not, return false + fn supports_null_handling_clause(&self) -> bool { + true + } + + /// If this function is ordered-set aggregate function, return true + /// If the function is not, return false + fn is_ordered_set_aggregate(&self) -> bool { + false + } + /// Returns the documentation for this Aggregate UDF. /// /// Documentation can be accessed programmatically as well as @@ -978,7 +1051,7 @@ impl AggregateUDFImpl for AliasedAggregateUDFImpl { &self.aliases } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { self.inner.state_fields(args) } @@ -1111,7 +1184,7 @@ pub enum SetMonotonicity { #[cfg(test)] mod test { use crate::{AggregateUDF, AggregateUDFImpl}; - use arrow::datatypes::{DataType, Field}; + use arrow::datatypes::{DataType, FieldRef}; use datafusion_common::Result; use datafusion_expr_common::accumulator::Accumulator; use datafusion_expr_common::signature::{Signature, Volatility}; @@ -1157,7 +1230,7 @@ mod test { ) -> Result> { unimplemented!() } - fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { unimplemented!() } } @@ -1197,7 +1270,7 @@ mod test { ) -> Result> { unimplemented!() } - fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { unimplemented!() } } diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 9b2400774a3d..816929a1fba1 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -21,7 +21,7 @@ use crate::expr::schema_name_from_exprs_comma_separated_without_space; use crate::simplify::{ExprSimplifyResult, SimplifyInfo}; use crate::sort_properties::{ExprProperties, SortProperties}; use crate::{ColumnarValue, Documentation, Expr, Signature}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion_common::{not_impl_err, ExprSchema, Result, ScalarValue}; use datafusion_expr_common::interval_arithmetic::Interval; use std::any::Any; @@ -34,7 +34,7 @@ use std::sync::Arc; /// /// A scalar function produces a single row output for each row of input. This /// struct contains the information DataFusion needs to plan and invoke -/// functions you supply such name, type signature, return type, and actual +/// functions you supply such as name, type signature, return type, and actual /// implementation. /// /// 1. For simple use cases, use [`create_udf`] (examples in [`simple_udf.rs`]). @@ -42,11 +42,11 @@ use std::sync::Arc; /// 2. For advanced use cases, use [`ScalarUDFImpl`] which provides full API /// access (examples in [`advanced_udf.rs`]). /// -/// See [`Self::call`] to invoke a `ScalarUDF` with arguments. +/// See [`Self::call`] to create an `Expr` which invokes a `ScalarUDF` with arguments. /// /// # API Note /// -/// This is a separate struct from `ScalarUDFImpl` to maintain backwards +/// This is a separate struct from [`ScalarUDFImpl`] to maintain backwards /// compatibility with the older API. /// /// [`create_udf`]: crate::expr_fn::create_udf @@ -170,7 +170,7 @@ impl ScalarUDF { /// /// # Notes /// - /// If a function implement [`ScalarUDFImpl::return_type_from_args`], + /// If a function implement [`ScalarUDFImpl::return_field_from_args`], /// its [`ScalarUDFImpl::return_type`] should raise an error. /// /// See [`ScalarUDFImpl::return_type`] for more details. @@ -180,9 +180,9 @@ impl ScalarUDF { /// Return the datatype this function returns given the input argument types. /// - /// See [`ScalarUDFImpl::return_type_from_args`] for more details. - pub fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { - self.inner.return_type_from_args(args) + /// See [`ScalarUDFImpl::return_field_from_args`] for more details. + pub fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + self.inner.return_field_from_args(args) } /// Do the function rewrite @@ -293,14 +293,25 @@ where /// Arguments passed to [`ScalarUDFImpl::invoke_with_args`] when invoking a /// scalar function. -pub struct ScalarFunctionArgs<'a> { +pub struct ScalarFunctionArgs { /// The evaluated arguments to the function pub args: Vec, + /// Field associated with each arg, if it exists + pub arg_fields: Vec, /// The number of rows in record batch being evaluated pub number_rows: usize, - /// The return type of the scalar function returned (from `return_type` or `return_type_from_args`) - /// when creating the physical expression from the logical expression - pub return_type: &'a DataType, + /// The return field of the scalar function returned (from `return_type` + /// or `return_field_from_args`) when creating the physical expression + /// from the logical expression + pub return_field: FieldRef, +} + +impl ScalarFunctionArgs { + /// The return type of the function. See [`Self::return_field`] for more + /// details. + pub fn return_type(&self) -> &DataType { + self.return_field.data_type() + } } /// Information about arguments passed to the function @@ -309,64 +320,18 @@ pub struct ScalarFunctionArgs<'a> { /// such as the type of the arguments, any scalar arguments and if the /// arguments can (ever) be null /// -/// See [`ScalarUDFImpl::return_type_from_args`] for more information +/// See [`ScalarUDFImpl::return_field_from_args`] for more information #[derive(Debug)] -pub struct ReturnTypeArgs<'a> { +pub struct ReturnFieldArgs<'a> { /// The data types of the arguments to the function - pub arg_types: &'a [DataType], - /// Is argument `i` to the function a scalar (constant) + pub arg_fields: &'a [FieldRef], + /// Is argument `i` to the function a scalar (constant)? /// - /// If argument `i` is not a scalar, it will be None + /// If the argument `i` is not a scalar, it will be None /// /// For example, if a function is called like `my_function(column_a, 5)` /// this field will be `[None, Some(ScalarValue::Int32(Some(5)))]` pub scalar_arguments: &'a [Option<&'a ScalarValue>], - /// Can argument `i` (ever) null? - pub nullables: &'a [bool], -} - -/// Return metadata for this function. -/// -/// See [`ScalarUDFImpl::return_type_from_args`] for more information -#[derive(Debug)] -pub struct ReturnInfo { - return_type: DataType, - nullable: bool, -} - -impl ReturnInfo { - pub fn new(return_type: DataType, nullable: bool) -> Self { - Self { - return_type, - nullable, - } - } - - pub fn new_nullable(return_type: DataType) -> Self { - Self { - return_type, - nullable: true, - } - } - - pub fn new_non_nullable(return_type: DataType) -> Self { - Self { - return_type, - nullable: false, - } - } - - pub fn return_type(&self) -> &DataType { - &self.return_type - } - - pub fn nullable(&self) -> bool { - self.nullable - } - - pub fn into_parts(self) -> (DataType, bool) { - (self.return_type, self.nullable) - } } /// Trait for implementing user defined scalar functions. @@ -480,7 +445,7 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { /// /// # Notes /// - /// If you provide an implementation for [`Self::return_type_from_args`], + /// If you provide an implementation for [`Self::return_field_from_args`], /// DataFusion will not call `return_type` (this function). In such cases /// is recommended to return [`DataFusionError::Internal`]. /// @@ -494,9 +459,10 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { /// /// # Notes /// - /// Most UDFs should implement [`Self::return_type`] and not this - /// function as the output type for most functions only depends on the types - /// of their inputs (e.g. `sqrt(f32)` is always `f32`). + /// For the majority of UDFs, implementing [`Self::return_type`] is sufficient, + /// as the result type is typically a deterministic function of the input types + /// (e.g., `sqrt(f32)` consistently yields `f32`). Implementing this method directly + /// is generally unnecessary unless the return type depends on runtime values. /// /// This function can be used for more advanced cases such as: /// @@ -504,6 +470,27 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { /// 2. return types based on the **values** of the arguments (rather than /// their **types**. /// + /// # Example creating `Field` + /// + /// Note the name of the [`Field`] is ignored, except for structured types such as + /// `DataType::Struct`. + /// + /// ```rust + /// # use std::sync::Arc; + /// # use arrow::datatypes::{DataType, Field, FieldRef}; + /// # use datafusion_common::Result; + /// # use datafusion_expr::ReturnFieldArgs; + /// # struct Example{} + /// # impl Example { + /// fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + /// // report output is only nullable if any one of the arguments are nullable + /// let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); + /// let field = Arc::new(Field::new("ignored_name", DataType::Int32, true)); + /// Ok(field) + /// } + /// # } + /// ``` + /// /// # Output Type based on Values /// /// For example, the following two function calls get the same argument @@ -518,14 +505,20 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { /// This function **must** consistently return the same type for the same /// logical input even if the input is simplified (e.g. it must return the same /// value for `('foo' | 'bar')` as it does for ('foobar'). - fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { - let return_type = self.return_type(args.arg_types)?; - Ok(ReturnInfo::new_nullable(return_type)) + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let data_types = args + .arg_fields + .iter() + .map(|f| f.data_type()) + .cloned() + .collect::>(); + let return_type = self.return_type(&data_types)?; + Ok(Arc::new(Field::new(self.name(), return_type, true))) } #[deprecated( since = "45.0.0", - note = "Use `return_type_from_args` instead. if you use `is_nullable` that returns non-nullable with `return_type`, you would need to switch to `return_type_from_args`, you might have error" + note = "Use `return_field_from_args` instead. if you use `is_nullable` that returns non-nullable with `return_type`, you would need to switch to `return_field_from_args`, you might have error" )] fn is_nullable(&self, _args: &[Expr], _schema: &dyn ExprSchema) -> bool { true @@ -584,13 +577,15 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { } /// Returns true if some of this `exprs` subexpressions may not be evaluated - /// and thus any side effects (like divide by zero) may not be encountered - /// Setting this to true prevents certain optimizations such as common subexpression elimination + /// and thus any side effects (like divide by zero) may not be encountered. + /// + /// Setting this to true prevents certain optimizations such as common + /// subexpression elimination fn short_circuits(&self) -> bool { false } - /// Computes the output interval for a [`ScalarUDFImpl`], given the input + /// Computes the output [`Interval`] for a [`ScalarUDFImpl`], given the input /// intervals. /// /// # Parameters @@ -606,9 +601,11 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { Interval::make_unbounded(&DataType::Null) } - /// Updates bounds for child expressions, given a known interval for this - /// function. This is used to propagate constraints down through an expression - /// tree. + /// Updates bounds for child expressions, given a known [`Interval`]s for this + /// function. + /// + /// This function is used to propagate constraints down through an + /// expression tree. /// /// # Parameters /// @@ -657,20 +654,25 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { } } - /// Whether the function preserves lexicographical ordering based on the input ordering + /// Returns true if the function preserves lexicographical ordering based on + /// the input ordering. + /// + /// For example, `concat(a || b)` preserves lexicographical ordering, but `abs(a)` does not. fn preserves_lex_ordering(&self, _inputs: &[ExprProperties]) -> Result { Ok(false) } /// Coerce arguments of a function call to types that the function can evaluate. /// - /// This function is only called if [`ScalarUDFImpl::signature`] returns [`crate::TypeSignature::UserDefined`]. Most - /// UDFs should return one of the other variants of `TypeSignature` which handle common - /// cases + /// This function is only called if [`ScalarUDFImpl::signature`] returns + /// [`crate::TypeSignature::UserDefined`]. Most UDFs should return one of + /// the other variants of [`TypeSignature`] which handle common cases. /// /// See the [type coercion module](crate::type_coercion) /// documentation for more details on type coercion /// + /// [`TypeSignature`]: crate::TypeSignature + /// /// For example, if your function requires a floating point arguments, but the user calls /// it like `my_func(1::int)` (i.e. with `1` as an integer), coerce_types can return `[DataType::Float64]` /// to ensure the argument is converted to `1::double` @@ -714,8 +716,8 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { /// Returns the documentation for this Scalar UDF. /// - /// Documentation can be accessed programmatically as well as - /// generating publicly facing documentation. + /// Documentation can be accessed programmatically as well as generating + /// publicly facing documentation. fn documentation(&self) -> Option<&Documentation> { None } @@ -765,18 +767,18 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl { self.inner.return_type(arg_types) } - fn aliases(&self) -> &[String] { - &self.aliases - } - - fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { - self.inner.return_type_from_args(args) + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + self.inner.return_field_from_args(args) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { self.inner.invoke_with_args(args) } + fn aliases(&self) -> &[String] { + &self.aliases + } + fn simplify( &self, args: Vec, diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs index 4da63d7955f5..155de232285e 100644 --- a/datafusion/expr/src/udwf.rs +++ b/datafusion/expr/src/udwf.rs @@ -26,7 +26,7 @@ use std::{ sync::Arc, }; -use arrow::datatypes::{DataType, Field}; +use arrow::datatypes::{DataType, FieldRef}; use crate::expr::WindowFunction; use crate::{ @@ -133,7 +133,7 @@ impl WindowUDF { pub fn call(&self, args: Vec) -> Expr { let fun = crate::WindowFunctionDefinition::WindowUDF(Arc::new(self.clone())); - Expr::WindowFunction(WindowFunction::new(fun, args)) + Expr::from(WindowFunction::new(fun, args)) } /// Returns this function's name @@ -179,7 +179,7 @@ impl WindowUDF { /// Returns the field of the final result of evaluating this window function. /// /// See [`WindowUDFImpl::field`] for more details. - pub fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + pub fn field(&self, field_args: WindowUDFFieldArgs) -> Result { self.inner.field(field_args) } @@ -236,7 +236,7 @@ where /// ``` /// # use std::any::Any; /// # use std::sync::LazyLock; -/// # use arrow::datatypes::{DataType, Field}; +/// # use arrow::datatypes::{DataType, Field, FieldRef}; /// # use datafusion_common::{DataFusionError, plan_err, Result}; /// # use datafusion_expr::{col, Signature, Volatility, PartitionEvaluator, WindowFrame, ExprFunctionExt, Documentation}; /// # use datafusion_expr::{WindowUDFImpl, WindowUDF}; @@ -279,9 +279,9 @@ where /// ) -> Result> { /// unimplemented!() /// } -/// fn field(&self, field_args: WindowUDFFieldArgs) -> Result { -/// if let Some(DataType::Int32) = field_args.get_input_type(0) { -/// Ok(Field::new(field_args.name(), DataType::Int32, false)) +/// fn field(&self, field_args: WindowUDFFieldArgs) -> Result { +/// if let Some(DataType::Int32) = field_args.get_input_field(0).map(|f| f.data_type().clone()) { +/// Ok(Field::new(field_args.name(), DataType::Int32, false).into()) /// } else { /// plan_err!("smooth_it only accepts Int32 arguments") /// } @@ -386,12 +386,12 @@ pub trait WindowUDFImpl: Debug + Send + Sync { hasher.finish() } - /// The [`Field`] of the final result of evaluating this window function. + /// The [`FieldRef`] of the final result of evaluating this window function. /// /// Call `field_args.name()` to get the fully qualified name for defining - /// the [`Field`]. For a complete example see the implementation in the + /// the [`FieldRef`]. For a complete example see the implementation in the /// [Basic Example](WindowUDFImpl#basic-example) section. - fn field(&self, field_args: WindowUDFFieldArgs) -> Result; + fn field(&self, field_args: WindowUDFFieldArgs) -> Result; /// Allows the window UDF to define a custom result ordering. /// @@ -537,7 +537,7 @@ impl WindowUDFImpl for AliasedWindowUDFImpl { hasher.finish() } - fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { self.inner.field(field_args) } @@ -588,7 +588,7 @@ pub mod window_doc_sections { #[cfg(test)] mod test { use crate::{PartitionEvaluator, WindowUDF, WindowUDFImpl}; - use arrow::datatypes::{DataType, Field}; + use arrow::datatypes::{DataType, FieldRef}; use datafusion_common::Result; use datafusion_expr_common::signature::{Signature, Volatility}; use datafusion_functions_window_common::field::WindowUDFFieldArgs; @@ -630,7 +630,7 @@ mod test { ) -> Result> { unimplemented!() } - fn field(&self, _field_args: WindowUDFFieldArgs) -> Result { + fn field(&self, _field_args: WindowUDFFieldArgs) -> Result { unimplemented!() } } @@ -669,7 +669,7 @@ mod test { ) -> Result> { unimplemented!() } - fn field(&self, _field_args: WindowUDFFieldArgs) -> Result { + fn field(&self, _field_args: WindowUDFFieldArgs) -> Result { unimplemented!() } } diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 552ce1502d46..b7851e530099 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -21,7 +21,7 @@ use std::cmp::Ordering; use std::collections::{BTreeSet, HashSet}; use std::sync::Arc; -use crate::expr::{Alias, Sort, WildcardOptions, WindowFunction, WindowFunctionParams}; +use crate::expr::{Alias, Sort, WildcardOptions, WindowFunctionParams}; use crate::expr_rewriter::strip_outer_reference; use crate::{ and, BinaryExpr, Expr, ExprSchemable, Filter, GroupingSet, LogicalPlan, Operator, @@ -276,7 +276,7 @@ pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet) -> Result<()> { Expr::Unnest(_) | Expr::ScalarVariable(_, _) | Expr::Alias(_) - | Expr::Literal(_) + | Expr::Literal(_, _) | Expr::BinaryExpr { .. } | Expr::Like { .. } | Expr::SimilarTo { .. } @@ -579,7 +579,8 @@ pub fn group_window_expr_by_sort_keys( ) -> Result)>> { let mut result = vec![]; window_expr.into_iter().try_for_each(|expr| match &expr { - Expr::WindowFunction( WindowFunction{ params: WindowFunctionParams { partition_by, order_by, ..}, .. }) => { + Expr::WindowFunction(window_fun) => { + let WindowFunctionParams{ partition_by, order_by, ..} = &window_fun.as_ref().params; let sort_key = generate_sort_key(partition_by, order_by)?; if let Some((_, values)) = result.iter_mut().find( |group: &&mut (WindowSortKey, Vec)| matches!(group, (key, _) if *key == sort_key), @@ -784,7 +785,7 @@ pub(crate) fn find_column_indexes_referenced_by_expr( indexes.push(idx); } } - Expr::Literal(_) => { + Expr::Literal(_, _) => { indexes.push(usize::MAX); } _ => {} @@ -1263,9 +1264,11 @@ pub fn collect_subquery_cols( mod tests { use super::*; use crate::{ - col, cube, expr_vec_fmt, grouping_set, lit, rollup, - test::function_stub::max_udaf, test::function_stub::min_udaf, - test::function_stub::sum_udaf, Cast, ExprFunctionExt, WindowFunctionDefinition, + col, cube, + expr::WindowFunction, + expr_vec_fmt, grouping_set, lit, rollup, + test::function_stub::{max_udaf, min_udaf, sum_udaf}, + Cast, ExprFunctionExt, WindowFunctionDefinition, }; use arrow::datatypes::{UnionFields, UnionMode}; @@ -1279,19 +1282,19 @@ mod tests { #[test] fn test_group_window_expr_by_sort_keys_empty_window() -> Result<()> { - let max1 = Expr::WindowFunction(WindowFunction::new( + let max1 = Expr::from(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], )); - let max2 = Expr::WindowFunction(WindowFunction::new( + let max2 = Expr::from(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], )); - let min3 = Expr::WindowFunction(WindowFunction::new( + let min3 = Expr::from(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(min_udaf()), vec![col("name")], )); - let sum4 = Expr::WindowFunction(WindowFunction::new( + let sum4 = Expr::from(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(sum_udaf()), vec![col("age")], )); @@ -1309,25 +1312,25 @@ mod tests { let age_asc = Sort::new(col("age"), true, true); let name_desc = Sort::new(col("name"), false, true); let created_at_desc = Sort::new(col("created_at"), false, true); - let max1 = Expr::WindowFunction(WindowFunction::new( + let max1 = Expr::from(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], )) .order_by(vec![age_asc.clone(), name_desc.clone()]) .build() .unwrap(); - let max2 = Expr::WindowFunction(WindowFunction::new( + let max2 = Expr::from(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], )); - let min3 = Expr::WindowFunction(WindowFunction::new( + let min3 = Expr::from(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(min_udaf()), vec![col("name")], )) .order_by(vec![age_asc.clone(), name_desc.clone()]) .build() .unwrap(); - let sum4 = Expr::WindowFunction(WindowFunction::new( + let sum4 = Expr::from(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(sum_udaf()), vec![col("age")], )) diff --git a/datafusion/expr/src/window_state.rs b/datafusion/expr/src/window_state.rs index f1d0ead23ab1..4c37cc6b6013 100644 --- a/datafusion/expr/src/window_state.rs +++ b/datafusion/expr/src/window_state.rs @@ -193,11 +193,7 @@ impl WindowFrameContext { // UNBOUNDED PRECEDING WindowFrameBound::Preceding(ScalarValue::UInt64(None)) => 0, WindowFrameBound::Preceding(ScalarValue::UInt64(Some(n))) => { - if idx >= n as usize { - idx - n as usize - } else { - 0 - } + idx.saturating_sub(n as usize) } WindowFrameBound::CurrentRow => idx, // UNBOUNDED FOLLOWING @@ -602,11 +598,7 @@ impl WindowFrameStateGroups { // Find the group index of the frame boundary: let group_idx = if SEARCH_SIDE { - if self.current_group_idx > delta { - self.current_group_idx - delta - } else { - 0 - } + self.current_group_idx.saturating_sub(delta) } else { self.current_group_idx + delta }; diff --git a/datafusion/ffi/Cargo.toml b/datafusion/ffi/Cargo.toml index 29f40df51444..a8335769ec29 100644 --- a/datafusion/ffi/Cargo.toml +++ b/datafusion/ffi/Cargo.toml @@ -44,7 +44,9 @@ arrow-schema = { workspace = true } async-ffi = { version = "0.5.0", features = ["abi_stable"] } async-trait = { workspace = true } datafusion = { workspace = true, default-features = false } +datafusion-functions-aggregate-common = { workspace = true } datafusion-proto = { workspace = true } +datafusion-proto-common = { workspace = true } futures = { workspace = true } log = { workspace = true } prost = { workspace = true } @@ -56,3 +58,4 @@ doc-comment = { workspace = true } [features] integration-tests = [] +tarpaulin_include = [] # Exists only to prevent warnings on stable and still have accurate coverage diff --git a/datafusion/ffi/src/arrow_wrappers.rs b/datafusion/ffi/src/arrow_wrappers.rs index a18e6df59bf1..7b3751dcae82 100644 --- a/datafusion/ffi/src/arrow_wrappers.rs +++ b/datafusion/ffi/src/arrow_wrappers.rs @@ -21,7 +21,8 @@ use abi_stable::StableAbi; use arrow::{ array::{make_array, ArrayRef}, datatypes::{Schema, SchemaRef}, - ffi::{from_ffi, FFI_ArrowArray, FFI_ArrowSchema}, + error::ArrowError, + ffi::{from_ffi, to_ffi, FFI_ArrowArray, FFI_ArrowSchema}, }; use log::error; @@ -36,7 +37,7 @@ impl From for WrappedSchema { let ffi_schema = match FFI_ArrowSchema::try_from(value.as_ref()) { Ok(s) => s, Err(e) => { - error!("Unable to convert DataFusion Schema to FFI_ArrowSchema in FFI_PlanProperties. {}", e); + error!("Unable to convert DataFusion Schema to FFI_ArrowSchema in FFI_PlanProperties. {e}"); FFI_ArrowSchema::empty() } }; @@ -44,16 +45,19 @@ impl From for WrappedSchema { WrappedSchema(ffi_schema) } } +/// Some functions are expected to always succeed, like getting the schema from a TableProvider. +/// Since going through the FFI always has the potential to fail, we need to catch these errors, +/// give the user a warning, and return some kind of result. In this case we default to an +/// empty schema. +#[cfg(not(tarpaulin_include))] +fn catch_df_schema_error(e: ArrowError) -> Schema { + error!("Unable to convert from FFI_ArrowSchema to DataFusion Schema in FFI_PlanProperties. {e}"); + Schema::empty() +} impl From for SchemaRef { fn from(value: WrappedSchema) -> Self { - let schema = match Schema::try_from(&value.0) { - Ok(s) => s, - Err(e) => { - error!("Unable to convert from FFI_ArrowSchema to DataFusion Schema in FFI_PlanProperties. {}", e); - Schema::empty() - } - }; + let schema = Schema::try_from(&value.0).unwrap_or_else(catch_df_schema_error); Arc::new(schema) } } @@ -71,7 +75,7 @@ pub struct WrappedArray { } impl TryFrom for ArrayRef { - type Error = arrow::error::ArrowError; + type Error = ArrowError; fn try_from(value: WrappedArray) -> Result { let data = unsafe { from_ffi(value.array, &value.schema.0)? }; @@ -79,3 +83,14 @@ impl TryFrom for ArrayRef { Ok(make_array(data)) } } + +impl TryFrom<&ArrayRef> for WrappedArray { + type Error = ArrowError; + + fn try_from(array: &ArrayRef) -> Result { + let (array, schema) = to_ffi(&array.to_data())?; + let schema = WrappedSchema(schema); + + Ok(WrappedArray { array, schema }) + } +} diff --git a/datafusion/ffi/src/lib.rs b/datafusion/ffi/src/lib.rs index d877e182a1d8..ff641e8315c7 100644 --- a/datafusion/ffi/src/lib.rs +++ b/datafusion/ffi/src/lib.rs @@ -34,8 +34,10 @@ pub mod schema_provider; pub mod session_config; pub mod table_provider; pub mod table_source; +pub mod udaf; pub mod udf; pub mod udtf; +pub mod udwf; pub mod util; pub mod volatility; diff --git a/datafusion/ffi/src/plan_properties.rs b/datafusion/ffi/src/plan_properties.rs index 3592c16b8fab..587e667a4775 100644 --- a/datafusion/ffi/src/plan_properties.rs +++ b/datafusion/ffi/src/plan_properties.rs @@ -300,7 +300,10 @@ impl From for EmissionType { #[cfg(test)] mod tests { - use datafusion::physical_plan::Partitioning; + use datafusion::{ + physical_expr::{LexOrdering, PhysicalSortExpr}, + physical_plan::Partitioning, + }; use super::*; @@ -311,8 +314,13 @@ mod tests { Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, false)])); let original_props = PlanProperties::new( - EquivalenceProperties::new(schema), - Partitioning::UnknownPartitioning(3), + EquivalenceProperties::new(Arc::clone(&schema)).with_reorder( + LexOrdering::new(vec![PhysicalSortExpr { + expr: datafusion::physical_plan::expressions::col("a", &schema)?, + options: Default::default(), + }]), + ), + Partitioning::RoundRobinBatch(3), EmissionType::Incremental, Boundedness::Bounded, ); @@ -321,7 +329,7 @@ mod tests { let foreign_props: PlanProperties = local_props_ptr.try_into()?; - assert!(format!("{:?}", foreign_props) == format!("{:?}", original_props)); + assert_eq!(format!("{foreign_props:?}"), format!("{original_props:?}")); Ok(()) } diff --git a/datafusion/ffi/src/record_batch_stream.rs b/datafusion/ffi/src/record_batch_stream.rs index 939c4050028c..78d65a816fcc 100644 --- a/datafusion/ffi/src/record_batch_stream.rs +++ b/datafusion/ffi/src/record_batch_stream.rs @@ -196,3 +196,49 @@ impl Stream for FFI_RecordBatchStream { } } } + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion::{ + common::record_batch, error::Result, execution::SendableRecordBatchStream, + test_util::bounded_stream, + }; + + use super::FFI_RecordBatchStream; + use futures::StreamExt; + + #[tokio::test] + async fn test_round_trip_record_batch_stream() -> Result<()> { + let record_batch = record_batch!( + ("a", Int32, vec![1, 2, 3]), + ("b", Float64, vec![Some(4.0), None, Some(5.0)]) + )?; + let original_rbs = bounded_stream(record_batch.clone(), 1); + + let ffi_rbs: FFI_RecordBatchStream = original_rbs.into(); + let mut ffi_rbs: SendableRecordBatchStream = Box::pin(ffi_rbs); + + let schema = ffi_rbs.schema(); + assert_eq!( + schema, + Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Float64, true) + ])) + ); + + let batch = ffi_rbs.next().await; + assert!(batch.is_some()); + assert!(batch.as_ref().unwrap().is_ok()); + assert_eq!(batch.unwrap().unwrap(), record_batch); + + // There should only be one batch + let no_batch = ffi_rbs.next().await; + assert!(no_batch.is_none()); + + Ok(()) + } +} diff --git a/datafusion/ffi/src/tests/async_provider.rs b/datafusion/ffi/src/tests/async_provider.rs index cf05d596308f..60434a7dda12 100644 --- a/datafusion/ffi/src/tests/async_provider.rs +++ b/datafusion/ffi/src/tests/async_provider.rs @@ -260,7 +260,7 @@ impl Stream for AsyncTestRecordBatchStream { if let Err(e) = this.batch_request.try_send(true) { return std::task::Poll::Ready(Some(Err(DataFusionError::Execution( - format!("Unable to send batch request, {}", e), + format!("Unable to send batch request, {e}"), )))); } @@ -270,7 +270,7 @@ impl Stream for AsyncTestRecordBatchStream { None => std::task::Poll::Ready(None), }, Err(e) => std::task::Poll::Ready(Some(Err(DataFusionError::Execution( - format!("Unable receive record batch: {}", e), + format!("Unable to receive record batch: {e}"), )))), } } diff --git a/datafusion/ffi/src/tests/mod.rs b/datafusion/ffi/src/tests/mod.rs index 7a36ee52bdb4..db596f51fcd9 100644 --- a/datafusion/ffi/src/tests/mod.rs +++ b/datafusion/ffi/src/tests/mod.rs @@ -29,6 +29,10 @@ use catalog::create_catalog_provider; use crate::{catalog_provider::FFI_CatalogProvider, udtf::FFI_TableFunction}; +use crate::udaf::FFI_AggregateUDF; + +use crate::udwf::FFI_WindowUDF; + use super::{table_provider::FFI_TableProvider, udf::FFI_ScalarUDF}; use arrow::array::RecordBatch; use async_provider::create_async_table_provider; @@ -37,7 +41,10 @@ use datafusion::{ common::record_batch, }; use sync_provider::create_sync_table_provider; -use udf_udaf_udwf::{create_ffi_abs_func, create_ffi_random_func, create_ffi_table_func}; +use udf_udaf_udwf::{ + create_ffi_abs_func, create_ffi_random_func, create_ffi_rank_func, + create_ffi_stddev_func, create_ffi_sum_func, create_ffi_table_func, +}; mod async_provider; pub mod catalog; @@ -65,6 +72,14 @@ pub struct ForeignLibraryModule { pub create_table_function: extern "C" fn() -> FFI_TableFunction, + /// Create an aggregate UDAF using sum + pub create_sum_udaf: extern "C" fn() -> FFI_AggregateUDF, + + /// Createa grouping UDAF using stddev + pub create_stddev_udaf: extern "C" fn() -> FFI_AggregateUDF, + + pub create_rank_udwf: extern "C" fn() -> FFI_WindowUDF, + pub version: extern "C" fn() -> u64, } @@ -112,6 +127,9 @@ pub fn get_foreign_library_module() -> ForeignLibraryModuleRef { create_scalar_udf: create_ffi_abs_func, create_nullary_udf: create_ffi_random_func, create_table_function: create_ffi_table_func, + create_sum_udaf: create_ffi_sum_func, + create_stddev_udaf: create_ffi_stddev_func, + create_rank_udwf: create_ffi_rank_func, version: super::version, } .leak_into_prefix() diff --git a/datafusion/ffi/src/tests/udf_udaf_udwf.rs b/datafusion/ffi/src/tests/udf_udaf_udwf.rs index c3cb1bcc3533..55e31ef3ab77 100644 --- a/datafusion/ffi/src/tests/udf_udaf_udwf.rs +++ b/datafusion/ffi/src/tests/udf_udaf_udwf.rs @@ -15,12 +15,17 @@ // specific language governing permissions and limitations // under the License. -use crate::{udf::FFI_ScalarUDF, udtf::FFI_TableFunction}; +use crate::{ + udaf::FFI_AggregateUDF, udf::FFI_ScalarUDF, udtf::FFI_TableFunction, + udwf::FFI_WindowUDF, +}; use datafusion::{ catalog::TableFunctionImpl, functions::math::{abs::AbsFunc, random::RandomFunc}, + functions_aggregate::{stddev::Stddev, sum::Sum}, functions_table::generate_series::RangeFunc, - logical_expr::ScalarUDF, + functions_window::rank::Rank, + logical_expr::{AggregateUDF, ScalarUDF, WindowUDF}, }; use std::sync::Arc; @@ -42,3 +47,27 @@ pub(crate) extern "C" fn create_ffi_table_func() -> FFI_TableFunction { FFI_TableFunction::new(udtf, None) } + +pub(crate) extern "C" fn create_ffi_sum_func() -> FFI_AggregateUDF { + let udaf: Arc = Arc::new(Sum::new().into()); + + udaf.into() +} + +pub(crate) extern "C" fn create_ffi_stddev_func() -> FFI_AggregateUDF { + let udaf: Arc = Arc::new(Stddev::new().into()); + + udaf.into() +} + +pub(crate) extern "C" fn create_ffi_rank_func() -> FFI_WindowUDF { + let udwf: Arc = Arc::new( + Rank::new( + "rank_demo".to_string(), + datafusion::functions_window::rank::RankType::Basic, + ) + .into(), + ); + + udwf.into() +} diff --git a/datafusion/ffi/src/udaf/accumulator.rs b/datafusion/ffi/src/udaf/accumulator.rs new file mode 100644 index 000000000000..80b872159f48 --- /dev/null +++ b/datafusion/ffi/src/udaf/accumulator.rs @@ -0,0 +1,366 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{ffi::c_void, ops::Deref}; + +use abi_stable::{ + std_types::{RResult, RString, RVec}, + StableAbi, +}; +use arrow::{array::ArrayRef, error::ArrowError}; +use datafusion::{ + error::{DataFusionError, Result}, + logical_expr::Accumulator, + scalar::ScalarValue, +}; +use prost::Message; + +use crate::{arrow_wrappers::WrappedArray, df_result, rresult, rresult_return}; + +/// A stable struct for sharing [`Accumulator`] across FFI boundaries. +/// For an explanation of each field, see the corresponding function +/// defined in [`Accumulator`]. +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub struct FFI_Accumulator { + pub update_batch: unsafe extern "C" fn( + accumulator: &mut Self, + values: RVec, + ) -> RResult<(), RString>, + + // Evaluate and return a ScalarValues as protobuf bytes + pub evaluate: + unsafe extern "C" fn(accumulator: &mut Self) -> RResult, RString>, + + pub size: unsafe extern "C" fn(accumulator: &Self) -> usize, + + pub state: + unsafe extern "C" fn(accumulator: &mut Self) -> RResult>, RString>, + + pub merge_batch: unsafe extern "C" fn( + accumulator: &mut Self, + states: RVec, + ) -> RResult<(), RString>, + + pub retract_batch: unsafe extern "C" fn( + accumulator: &mut Self, + values: RVec, + ) -> RResult<(), RString>, + + pub supports_retract_batch: bool, + + /// Release the memory of the private data when it is no longer being used. + pub release: unsafe extern "C" fn(accumulator: &mut Self), + + /// Internal data. This is only to be accessed by the provider of the accumulator. + /// A [`ForeignAccumulator`] should never attempt to access this data. + pub private_data: *mut c_void, +} + +unsafe impl Send for FFI_Accumulator {} +unsafe impl Sync for FFI_Accumulator {} + +pub struct AccumulatorPrivateData { + pub accumulator: Box, +} + +impl FFI_Accumulator { + #[inline] + unsafe fn inner_mut(&mut self) -> &mut Box { + let private_data = self.private_data as *mut AccumulatorPrivateData; + &mut (*private_data).accumulator + } + + #[inline] + unsafe fn inner(&self) -> &dyn Accumulator { + let private_data = self.private_data as *const AccumulatorPrivateData; + (*private_data).accumulator.deref() + } +} + +unsafe extern "C" fn update_batch_fn_wrapper( + accumulator: &mut FFI_Accumulator, + values: RVec, +) -> RResult<(), RString> { + let accumulator = accumulator.inner_mut(); + + let values_arrays = values + .into_iter() + .map(|v| v.try_into().map_err(DataFusionError::from)) + .collect::>>(); + let values_arrays = rresult_return!(values_arrays); + + rresult!(accumulator.update_batch(&values_arrays)) +} + +unsafe extern "C" fn evaluate_fn_wrapper( + accumulator: &mut FFI_Accumulator, +) -> RResult, RString> { + let accumulator = accumulator.inner_mut(); + + let scalar_result = rresult_return!(accumulator.evaluate()); + let proto_result: datafusion_proto::protobuf::ScalarValue = + rresult_return!((&scalar_result).try_into()); + + RResult::ROk(proto_result.encode_to_vec().into()) +} + +unsafe extern "C" fn size_fn_wrapper(accumulator: &FFI_Accumulator) -> usize { + accumulator.inner().size() +} + +unsafe extern "C" fn state_fn_wrapper( + accumulator: &mut FFI_Accumulator, +) -> RResult>, RString> { + let accumulator = accumulator.inner_mut(); + + let state = rresult_return!(accumulator.state()); + let state = state + .into_iter() + .map(|state_val| { + datafusion_proto::protobuf::ScalarValue::try_from(&state_val) + .map_err(DataFusionError::from) + .map(|v| RVec::from(v.encode_to_vec())) + }) + .collect::>>() + .map(|state_vec| state_vec.into()); + + rresult!(state) +} + +unsafe extern "C" fn merge_batch_fn_wrapper( + accumulator: &mut FFI_Accumulator, + states: RVec, +) -> RResult<(), RString> { + let accumulator = accumulator.inner_mut(); + + let states = rresult_return!(states + .into_iter() + .map(|state| ArrayRef::try_from(state).map_err(DataFusionError::from)) + .collect::>>()); + + rresult!(accumulator.merge_batch(&states)) +} + +unsafe extern "C" fn retract_batch_fn_wrapper( + accumulator: &mut FFI_Accumulator, + values: RVec, +) -> RResult<(), RString> { + let accumulator = accumulator.inner_mut(); + + let values_arrays = values + .into_iter() + .map(|v| v.try_into().map_err(DataFusionError::from)) + .collect::>>(); + let values_arrays = rresult_return!(values_arrays); + + rresult!(accumulator.retract_batch(&values_arrays)) +} + +unsafe extern "C" fn release_fn_wrapper(accumulator: &mut FFI_Accumulator) { + let private_data = + Box::from_raw(accumulator.private_data as *mut AccumulatorPrivateData); + drop(private_data); +} + +impl From> for FFI_Accumulator { + fn from(accumulator: Box) -> Self { + let supports_retract_batch = accumulator.supports_retract_batch(); + let private_data = AccumulatorPrivateData { accumulator }; + + Self { + update_batch: update_batch_fn_wrapper, + evaluate: evaluate_fn_wrapper, + size: size_fn_wrapper, + state: state_fn_wrapper, + merge_batch: merge_batch_fn_wrapper, + retract_batch: retract_batch_fn_wrapper, + supports_retract_batch, + release: release_fn_wrapper, + private_data: Box::into_raw(Box::new(private_data)) as *mut c_void, + } + } +} + +impl Drop for FFI_Accumulator { + fn drop(&mut self) { + unsafe { (self.release)(self) } + } +} + +/// This struct is used to access an UDF provided by a foreign +/// library across a FFI boundary. +/// +/// The ForeignAccumulator is to be used by the caller of the UDF, so it has +/// no knowledge or access to the private data. All interaction with the UDF +/// must occur through the functions defined in FFI_Accumulator. +#[derive(Debug)] +pub struct ForeignAccumulator { + accumulator: FFI_Accumulator, +} + +unsafe impl Send for ForeignAccumulator {} +unsafe impl Sync for ForeignAccumulator {} + +impl From for ForeignAccumulator { + fn from(accumulator: FFI_Accumulator) -> Self { + Self { accumulator } + } +} + +impl Accumulator for ForeignAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + unsafe { + let values = values + .iter() + .map(WrappedArray::try_from) + .collect::, ArrowError>>()?; + df_result!((self.accumulator.update_batch)( + &mut self.accumulator, + values.into() + )) + } + } + + fn evaluate(&mut self) -> Result { + unsafe { + let scalar_bytes = + df_result!((self.accumulator.evaluate)(&mut self.accumulator))?; + + let proto_scalar = + datafusion_proto::protobuf::ScalarValue::decode(scalar_bytes.as_ref()) + .map_err(|e| DataFusionError::External(Box::new(e)))?; + + ScalarValue::try_from(&proto_scalar).map_err(DataFusionError::from) + } + } + + fn size(&self) -> usize { + unsafe { (self.accumulator.size)(&self.accumulator) } + } + + fn state(&mut self) -> Result> { + unsafe { + let state_protos = + df_result!((self.accumulator.state)(&mut self.accumulator))?; + + state_protos + .into_iter() + .map(|proto_bytes| { + datafusion_proto::protobuf::ScalarValue::decode(proto_bytes.as_ref()) + .map_err(|e| DataFusionError::External(Box::new(e))) + .and_then(|proto_value| { + ScalarValue::try_from(&proto_value) + .map_err(DataFusionError::from) + }) + }) + .collect::>>() + } + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + unsafe { + let states = states + .iter() + .map(WrappedArray::try_from) + .collect::, ArrowError>>()?; + df_result!((self.accumulator.merge_batch)( + &mut self.accumulator, + states.into() + )) + } + } + + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + unsafe { + let values = values + .iter() + .map(WrappedArray::try_from) + .collect::, ArrowError>>()?; + df_result!((self.accumulator.retract_batch)( + &mut self.accumulator, + values.into() + )) + } + } + + fn supports_retract_batch(&self) -> bool { + self.accumulator.supports_retract_batch + } +} + +#[cfg(test)] +mod tests { + use arrow::array::{make_array, Array}; + use datafusion::{ + common::create_array, error::Result, + functions_aggregate::average::AvgAccumulator, logical_expr::Accumulator, + scalar::ScalarValue, + }; + + use super::{FFI_Accumulator, ForeignAccumulator}; + + #[test] + fn test_foreign_avg_accumulator() -> Result<()> { + let original_accum = AvgAccumulator::default(); + let original_size = original_accum.size(); + let original_supports_retract = original_accum.supports_retract_batch(); + + let boxed_accum: Box = Box::new(original_accum); + let ffi_accum: FFI_Accumulator = boxed_accum.into(); + let mut foreign_accum: ForeignAccumulator = ffi_accum.into(); + + // Send in an array to average. There are 5 values and it should average to 30.0 + let values = create_array!(Float64, vec![10., 20., 30., 40., 50.]); + foreign_accum.update_batch(&[values])?; + + let avg = foreign_accum.evaluate()?; + assert_eq!(avg, ScalarValue::Float64(Some(30.0))); + + let state = foreign_accum.state()?; + assert_eq!(state.len(), 2); + assert_eq!(state[0], ScalarValue::UInt64(Some(5))); + assert_eq!(state[1], ScalarValue::Float64(Some(150.0))); + + // To verify merging batches works, create a second state to add in + // This should cause our average to go down to 25.0 + let second_states = vec![ + make_array(create_array!(UInt64, vec![1]).to_data()), + make_array(create_array!(Float64, vec![0.0]).to_data()), + ]; + + foreign_accum.merge_batch(&second_states)?; + let avg = foreign_accum.evaluate()?; + assert_eq!(avg, ScalarValue::Float64(Some(25.0))); + + // If we remove a batch that is equivalent to the state we added + // we should go back to our original value of 30.0 + let values = create_array!(Float64, vec![0.0]); + foreign_accum.retract_batch(&[values])?; + let avg = foreign_accum.evaluate()?; + assert_eq!(avg, ScalarValue::Float64(Some(30.0))); + + assert_eq!(original_size, foreign_accum.size()); + assert_eq!( + original_supports_retract, + foreign_accum.supports_retract_batch() + ); + + Ok(()) + } +} diff --git a/datafusion/ffi/src/udaf/accumulator_args.rs b/datafusion/ffi/src/udaf/accumulator_args.rs new file mode 100644 index 000000000000..699af1d5c5e0 --- /dev/null +++ b/datafusion/ffi/src/udaf/accumulator_args.rs @@ -0,0 +1,198 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use crate::arrow_wrappers::WrappedSchema; +use abi_stable::{ + std_types::{RString, RVec}, + StableAbi, +}; +use arrow::{datatypes::Schema, ffi::FFI_ArrowSchema}; +use arrow_schema::FieldRef; +use datafusion::{ + error::DataFusionError, logical_expr::function::AccumulatorArgs, + physical_expr::LexOrdering, physical_plan::PhysicalExpr, prelude::SessionContext, +}; +use datafusion_proto::{ + physical_plan::{ + from_proto::{parse_physical_exprs, parse_physical_sort_exprs}, + to_proto::{serialize_physical_exprs, serialize_physical_sort_exprs}, + DefaultPhysicalExtensionCodec, + }, + protobuf::PhysicalAggregateExprNode, +}; +use prost::Message; + +/// A stable struct for sharing [`AccumulatorArgs`] across FFI boundaries. +/// For an explanation of each field, see the corresponding field +/// defined in [`AccumulatorArgs`]. +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub struct FFI_AccumulatorArgs { + return_field: WrappedSchema, + schema: WrappedSchema, + is_reversed: bool, + name: RString, + physical_expr_def: RVec, +} + +impl TryFrom> for FFI_AccumulatorArgs { + type Error = DataFusionError; + + fn try_from(args: AccumulatorArgs) -> Result { + let return_field = + WrappedSchema(FFI_ArrowSchema::try_from(args.return_field.as_ref())?); + let schema = WrappedSchema(FFI_ArrowSchema::try_from(args.schema)?); + + let codec = DefaultPhysicalExtensionCodec {}; + let ordering_req = + serialize_physical_sort_exprs(args.ordering_req.to_owned(), &codec)?; + + let expr = serialize_physical_exprs(args.exprs, &codec)?; + + let physical_expr_def = PhysicalAggregateExprNode { + expr, + ordering_req, + distinct: args.is_distinct, + ignore_nulls: args.ignore_nulls, + fun_definition: None, + aggregate_function: None, + }; + let physical_expr_def = physical_expr_def.encode_to_vec().into(); + + Ok(Self { + return_field, + schema, + is_reversed: args.is_reversed, + name: args.name.into(), + physical_expr_def, + }) + } +} + +/// This struct mirrors AccumulatorArgs except that it contains owned data. +/// It is necessary to create this struct so that we can parse the protobuf +/// data across the FFI boundary and turn it into owned data that +/// AccumulatorArgs can then reference. +pub struct ForeignAccumulatorArgs { + pub return_field: FieldRef, + pub schema: Schema, + pub ignore_nulls: bool, + pub ordering_req: LexOrdering, + pub is_reversed: bool, + pub name: String, + pub is_distinct: bool, + pub exprs: Vec>, +} + +impl TryFrom for ForeignAccumulatorArgs { + type Error = DataFusionError; + + fn try_from(value: FFI_AccumulatorArgs) -> Result { + let proto_def = + PhysicalAggregateExprNode::decode(value.physical_expr_def.as_ref()) + .map_err(|e| DataFusionError::Execution(e.to_string()))?; + + let return_field = Arc::new((&value.return_field.0).try_into()?); + let schema = Schema::try_from(&value.schema.0)?; + + let default_ctx = SessionContext::new(); + let codex = DefaultPhysicalExtensionCodec {}; + + // let proto_ordering_req = + // rresult_return!(PhysicalSortExprNodeCollection::decode(ordering_req.as_ref())); + let ordering_req = parse_physical_sort_exprs( + &proto_def.ordering_req, + &default_ctx, + &schema, + &codex, + )?; + + let exprs = parse_physical_exprs(&proto_def.expr, &default_ctx, &schema, &codex)?; + + Ok(Self { + return_field, + schema, + ignore_nulls: proto_def.ignore_nulls, + ordering_req, + is_reversed: value.is_reversed, + name: value.name.to_string(), + is_distinct: proto_def.distinct, + exprs, + }) + } +} + +impl<'a> From<&'a ForeignAccumulatorArgs> for AccumulatorArgs<'a> { + fn from(value: &'a ForeignAccumulatorArgs) -> Self { + Self { + return_field: Arc::clone(&value.return_field), + schema: &value.schema, + ignore_nulls: value.ignore_nulls, + ordering_req: &value.ordering_req, + is_reversed: value.is_reversed, + name: value.name.as_str(), + is_distinct: value.is_distinct, + exprs: &value.exprs, + } + } +} + +#[cfg(test)] +mod tests { + use super::{FFI_AccumulatorArgs, ForeignAccumulatorArgs}; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion::{ + error::Result, + logical_expr::function::AccumulatorArgs, + physical_expr::{LexOrdering, PhysicalSortExpr}, + physical_plan::expressions::col, + }; + + #[test] + fn test_round_trip_accumulator_args() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + let orig_args = AccumulatorArgs { + return_field: Field::new("f", DataType::Float64, true).into(), + schema: &schema, + ignore_nulls: false, + ordering_req: &LexOrdering::new(vec![PhysicalSortExpr { + expr: col("a", &schema)?, + options: Default::default(), + }]), + is_reversed: false, + name: "round_trip", + is_distinct: true, + exprs: &[col("a", &schema)?], + }; + let orig_str = format!("{orig_args:?}"); + + let ffi_args: FFI_AccumulatorArgs = orig_args.try_into()?; + let foreign_args: ForeignAccumulatorArgs = ffi_args.try_into()?; + let round_trip_args: AccumulatorArgs = (&foreign_args).into(); + + let round_trip_str = format!("{round_trip_args:?}"); + + // Since AccumulatorArgs doesn't implement Eq, simply compare + // the debug strings. + assert_eq!(orig_str, round_trip_str); + + Ok(()) + } +} diff --git a/datafusion/ffi/src/udaf/groups_accumulator.rs b/datafusion/ffi/src/udaf/groups_accumulator.rs new file mode 100644 index 000000000000..58a18c69db7c --- /dev/null +++ b/datafusion/ffi/src/udaf/groups_accumulator.rs @@ -0,0 +1,513 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{ffi::c_void, ops::Deref, sync::Arc}; + +use crate::{ + arrow_wrappers::{WrappedArray, WrappedSchema}, + df_result, rresult, rresult_return, +}; +use abi_stable::{ + std_types::{ROption, RResult, RString, RVec}, + StableAbi, +}; +use arrow::{ + array::{Array, ArrayRef, BooleanArray}, + error::ArrowError, + ffi::to_ffi, +}; +use datafusion::{ + error::{DataFusionError, Result}, + logical_expr::{EmitTo, GroupsAccumulator}, +}; + +/// A stable struct for sharing [`GroupsAccumulator`] across FFI boundaries. +/// For an explanation of each field, see the corresponding function +/// defined in [`GroupsAccumulator`]. +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub struct FFI_GroupsAccumulator { + pub update_batch: unsafe extern "C" fn( + accumulator: &mut Self, + values: RVec, + group_indices: RVec, + opt_filter: ROption, + total_num_groups: usize, + ) -> RResult<(), RString>, + + // Evaluate and return a ScalarValues as protobuf bytes + pub evaluate: unsafe extern "C" fn( + accumulator: &mut Self, + emit_to: FFI_EmitTo, + ) -> RResult, + + pub size: unsafe extern "C" fn(accumulator: &Self) -> usize, + + pub state: unsafe extern "C" fn( + accumulator: &mut Self, + emit_to: FFI_EmitTo, + ) -> RResult, RString>, + + pub merge_batch: unsafe extern "C" fn( + accumulator: &mut Self, + values: RVec, + group_indices: RVec, + opt_filter: ROption, + total_num_groups: usize, + ) -> RResult<(), RString>, + + pub convert_to_state: unsafe extern "C" fn( + accumulator: &Self, + values: RVec, + opt_filter: ROption, + ) + -> RResult, RString>, + + pub supports_convert_to_state: bool, + + /// Release the memory of the private data when it is no longer being used. + pub release: unsafe extern "C" fn(accumulator: &mut Self), + + /// Internal data. This is only to be accessed by the provider of the accumulator. + /// A [`ForeignGroupsAccumulator`] should never attempt to access this data. + pub private_data: *mut c_void, +} + +unsafe impl Send for FFI_GroupsAccumulator {} +unsafe impl Sync for FFI_GroupsAccumulator {} + +pub struct GroupsAccumulatorPrivateData { + pub accumulator: Box, +} + +impl FFI_GroupsAccumulator { + #[inline] + unsafe fn inner_mut(&mut self) -> &mut Box { + let private_data = self.private_data as *mut GroupsAccumulatorPrivateData; + &mut (*private_data).accumulator + } + + #[inline] + unsafe fn inner(&self) -> &dyn GroupsAccumulator { + let private_data = self.private_data as *const GroupsAccumulatorPrivateData; + (*private_data).accumulator.deref() + } +} + +fn process_values(values: RVec) -> Result>> { + values + .into_iter() + .map(|v| v.try_into().map_err(DataFusionError::from)) + .collect::>>() +} + +/// Convert C-typed opt_filter into the internal type. +fn process_opt_filter(opt_filter: ROption) -> Result> { + opt_filter + .into_option() + .map(|filter| { + ArrayRef::try_from(filter) + .map_err(DataFusionError::from) + .map(|arr| BooleanArray::from(arr.into_data())) + }) + .transpose() +} + +unsafe extern "C" fn update_batch_fn_wrapper( + accumulator: &mut FFI_GroupsAccumulator, + values: RVec, + group_indices: RVec, + opt_filter: ROption, + total_num_groups: usize, +) -> RResult<(), RString> { + let accumulator = accumulator.inner_mut(); + let values = rresult_return!(process_values(values)); + let group_indices: Vec = group_indices.into_iter().collect(); + let opt_filter = rresult_return!(process_opt_filter(opt_filter)); + + rresult!(accumulator.update_batch( + &values, + &group_indices, + opt_filter.as_ref(), + total_num_groups + )) +} + +unsafe extern "C" fn evaluate_fn_wrapper( + accumulator: &mut FFI_GroupsAccumulator, + emit_to: FFI_EmitTo, +) -> RResult { + let accumulator = accumulator.inner_mut(); + + let result = rresult_return!(accumulator.evaluate(emit_to.into())); + + rresult!(WrappedArray::try_from(&result)) +} + +unsafe extern "C" fn size_fn_wrapper(accumulator: &FFI_GroupsAccumulator) -> usize { + let accumulator = accumulator.inner(); + accumulator.size() +} + +unsafe extern "C" fn state_fn_wrapper( + accumulator: &mut FFI_GroupsAccumulator, + emit_to: FFI_EmitTo, +) -> RResult, RString> { + let accumulator = accumulator.inner_mut(); + + let state = rresult_return!(accumulator.state(emit_to.into())); + rresult!(state + .into_iter() + .map(|arr| WrappedArray::try_from(&arr).map_err(DataFusionError::from)) + .collect::>>()) +} + +unsafe extern "C" fn merge_batch_fn_wrapper( + accumulator: &mut FFI_GroupsAccumulator, + values: RVec, + group_indices: RVec, + opt_filter: ROption, + total_num_groups: usize, +) -> RResult<(), RString> { + let accumulator = accumulator.inner_mut(); + let values = rresult_return!(process_values(values)); + let group_indices: Vec = group_indices.into_iter().collect(); + let opt_filter = rresult_return!(process_opt_filter(opt_filter)); + + rresult!(accumulator.merge_batch( + &values, + &group_indices, + opt_filter.as_ref(), + total_num_groups + )) +} + +unsafe extern "C" fn convert_to_state_fn_wrapper( + accumulator: &FFI_GroupsAccumulator, + values: RVec, + opt_filter: ROption, +) -> RResult, RString> { + let accumulator = accumulator.inner(); + let values = rresult_return!(process_values(values)); + let opt_filter = rresult_return!(process_opt_filter(opt_filter)); + let state = + rresult_return!(accumulator.convert_to_state(&values, opt_filter.as_ref())); + + rresult!(state + .iter() + .map(|arr| WrappedArray::try_from(arr).map_err(DataFusionError::from)) + .collect::>>()) +} + +unsafe extern "C" fn release_fn_wrapper(accumulator: &mut FFI_GroupsAccumulator) { + let private_data = + Box::from_raw(accumulator.private_data as *mut GroupsAccumulatorPrivateData); + drop(private_data); +} + +impl From> for FFI_GroupsAccumulator { + fn from(accumulator: Box) -> Self { + let supports_convert_to_state = accumulator.supports_convert_to_state(); + let private_data = GroupsAccumulatorPrivateData { accumulator }; + + Self { + update_batch: update_batch_fn_wrapper, + evaluate: evaluate_fn_wrapper, + size: size_fn_wrapper, + state: state_fn_wrapper, + merge_batch: merge_batch_fn_wrapper, + convert_to_state: convert_to_state_fn_wrapper, + supports_convert_to_state, + + release: release_fn_wrapper, + private_data: Box::into_raw(Box::new(private_data)) as *mut c_void, + } + } +} + +impl Drop for FFI_GroupsAccumulator { + fn drop(&mut self) { + unsafe { (self.release)(self) } + } +} + +/// This struct is used to access an UDF provided by a foreign +/// library across a FFI boundary. +/// +/// The ForeignGroupsAccumulator is to be used by the caller of the UDF, so it has +/// no knowledge or access to the private data. All interaction with the UDF +/// must occur through the functions defined in FFI_GroupsAccumulator. +#[derive(Debug)] +pub struct ForeignGroupsAccumulator { + accumulator: FFI_GroupsAccumulator, +} + +unsafe impl Send for ForeignGroupsAccumulator {} +unsafe impl Sync for ForeignGroupsAccumulator {} + +impl From for ForeignGroupsAccumulator { + fn from(accumulator: FFI_GroupsAccumulator) -> Self { + Self { accumulator } + } +} + +impl GroupsAccumulator for ForeignGroupsAccumulator { + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + unsafe { + let values = values + .iter() + .map(WrappedArray::try_from) + .collect::, ArrowError>>()?; + let group_indices = group_indices.iter().cloned().collect(); + let opt_filter = opt_filter + .map(|bool_array| to_ffi(&bool_array.to_data())) + .transpose()? + .map(|(array, schema)| WrappedArray { + array, + schema: WrappedSchema(schema), + }) + .into(); + + df_result!((self.accumulator.update_batch)( + &mut self.accumulator, + values.into(), + group_indices, + opt_filter, + total_num_groups + )) + } + } + + fn size(&self) -> usize { + unsafe { (self.accumulator.size)(&self.accumulator) } + } + + fn evaluate(&mut self, emit_to: EmitTo) -> Result { + unsafe { + let return_array = df_result!((self.accumulator.evaluate)( + &mut self.accumulator, + emit_to.into() + ))?; + + return_array.try_into().map_err(DataFusionError::from) + } + } + + fn state(&mut self, emit_to: EmitTo) -> Result> { + unsafe { + let returned_arrays = df_result!((self.accumulator.state)( + &mut self.accumulator, + emit_to.into() + ))?; + + returned_arrays + .into_iter() + .map(|wrapped_array| { + wrapped_array.try_into().map_err(DataFusionError::from) + }) + .collect::>>() + } + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + unsafe { + let values = values + .iter() + .map(WrappedArray::try_from) + .collect::, ArrowError>>()?; + let group_indices = group_indices.iter().cloned().collect(); + let opt_filter = opt_filter + .map(|bool_array| to_ffi(&bool_array.to_data())) + .transpose()? + .map(|(array, schema)| WrappedArray { + array, + schema: WrappedSchema(schema), + }) + .into(); + + df_result!((self.accumulator.merge_batch)( + &mut self.accumulator, + values.into(), + group_indices, + opt_filter, + total_num_groups + )) + } + } + + fn convert_to_state( + &self, + values: &[ArrayRef], + opt_filter: Option<&BooleanArray>, + ) -> Result> { + unsafe { + let values = values + .iter() + .map(WrappedArray::try_from) + .collect::, ArrowError>>()?; + + let opt_filter = opt_filter + .map(|bool_array| to_ffi(&bool_array.to_data())) + .transpose()? + .map(|(array, schema)| WrappedArray { + array, + schema: WrappedSchema(schema), + }) + .into(); + + let returned_array = df_result!((self.accumulator.convert_to_state)( + &self.accumulator, + values, + opt_filter + ))?; + + returned_array + .into_iter() + .map(|arr| arr.try_into().map_err(DataFusionError::from)) + .collect() + } + } + + fn supports_convert_to_state(&self) -> bool { + self.accumulator.supports_convert_to_state + } +} + +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub enum FFI_EmitTo { + All, + First(usize), +} + +impl From for FFI_EmitTo { + fn from(value: EmitTo) -> Self { + match value { + EmitTo::All => Self::All, + EmitTo::First(v) => Self::First(v), + } + } +} + +impl From for EmitTo { + fn from(value: FFI_EmitTo) -> Self { + match value { + FFI_EmitTo::All => Self::All, + FFI_EmitTo::First(v) => Self::First(v), + } + } +} + +#[cfg(test)] +mod tests { + use arrow::array::{make_array, Array, BooleanArray}; + use datafusion::{ + common::create_array, + error::Result, + logical_expr::{EmitTo, GroupsAccumulator}, + }; + use datafusion_functions_aggregate_common::aggregate::groups_accumulator::bool_op::BooleanGroupsAccumulator; + + use super::{FFI_EmitTo, FFI_GroupsAccumulator, ForeignGroupsAccumulator}; + + #[test] + fn test_foreign_avg_accumulator() -> Result<()> { + let boxed_accum: Box = + Box::new(BooleanGroupsAccumulator::new(|a, b| a && b, true)); + let ffi_accum: FFI_GroupsAccumulator = boxed_accum.into(); + let mut foreign_accum: ForeignGroupsAccumulator = ffi_accum.into(); + + // Send in an array to evaluate. We want a mean of 30 and standard deviation of 4. + let values = create_array!(Boolean, vec![true, true, true, false, true, true]); + let opt_filter = + create_array!(Boolean, vec![true, true, true, true, false, false]); + foreign_accum.update_batch( + &[values], + &[0, 0, 1, 1, 2, 2], + Some(opt_filter.as_ref()), + 3, + )?; + + let groups_bool = foreign_accum.evaluate(EmitTo::All)?; + let groups_bool = groups_bool.as_any().downcast_ref::().unwrap(); + + assert_eq!( + groups_bool, + create_array!(Boolean, vec![Some(true), Some(false), None]).as_ref() + ); + + let state = foreign_accum.state(EmitTo::All)?; + assert_eq!(state.len(), 1); + + // To verify merging batches works, create a second state to add in + // This should cause our average to go down to 25.0 + let second_states = + vec![make_array(create_array!(Boolean, vec![false]).to_data())]; + + let opt_filter = create_array!(Boolean, vec![true]); + foreign_accum.merge_batch(&second_states, &[0], Some(opt_filter.as_ref()), 1)?; + let groups_bool = foreign_accum.evaluate(EmitTo::All)?; + assert_eq!(groups_bool.len(), 1); + assert_eq!( + groups_bool.as_ref(), + make_array(create_array!(Boolean, vec![false]).to_data()).as_ref() + ); + + let values = create_array!(Boolean, vec![false]); + let opt_filter = create_array!(Boolean, vec![true]); + let groups_bool = + foreign_accum.convert_to_state(&[values], Some(opt_filter.as_ref()))?; + + assert_eq!( + groups_bool[0].as_ref(), + make_array(create_array!(Boolean, vec![false]).to_data()).as_ref() + ); + + Ok(()) + } + + fn test_emit_to_round_trip(value: EmitTo) -> Result<()> { + let ffi_value: FFI_EmitTo = value.into(); + let round_trip_value: EmitTo = ffi_value.into(); + + assert_eq!(value, round_trip_value); + Ok(()) + } + + /// This test ensures all enum values are properly translated + #[test] + fn test_all_emit_to_round_trip() -> Result<()> { + test_emit_to_round_trip(EmitTo::All)?; + test_emit_to_round_trip(EmitTo::First(10))?; + + Ok(()) + } +} diff --git a/datafusion/ffi/src/udaf/mod.rs b/datafusion/ffi/src/udaf/mod.rs new file mode 100644 index 000000000000..2529ed7a06dc --- /dev/null +++ b/datafusion/ffi/src/udaf/mod.rs @@ -0,0 +1,733 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{ffi::c_void, sync::Arc}; + +use abi_stable::{ + std_types::{ROption, RResult, RStr, RString, RVec}, + StableAbi, +}; +use accumulator::{FFI_Accumulator, ForeignAccumulator}; +use accumulator_args::{FFI_AccumulatorArgs, ForeignAccumulatorArgs}; +use arrow::datatypes::{DataType, Field}; +use arrow::ffi::FFI_ArrowSchema; +use arrow_schema::FieldRef; +use datafusion::{ + error::DataFusionError, + logical_expr::{ + function::{AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs}, + type_coercion::functions::fields_with_aggregate_udf, + utils::AggregateOrderSensitivity, + Accumulator, GroupsAccumulator, + }, +}; +use datafusion::{ + error::Result, + logical_expr::{AggregateUDF, AggregateUDFImpl, Signature}, +}; +use datafusion_proto_common::from_proto::parse_proto_fields_to_fields; +use groups_accumulator::{FFI_GroupsAccumulator, ForeignGroupsAccumulator}; + +use crate::util::{rvec_wrapped_to_vec_fieldref, vec_fieldref_to_rvec_wrapped}; +use crate::{ + arrow_wrappers::WrappedSchema, + df_result, rresult, rresult_return, + util::{rvec_wrapped_to_vec_datatype, vec_datatype_to_rvec_wrapped}, + volatility::FFI_Volatility, +}; +use prost::{DecodeError, Message}; + +mod accumulator; +mod accumulator_args; +mod groups_accumulator; + +/// A stable struct for sharing a [`AggregateUDF`] across FFI boundaries. +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub struct FFI_AggregateUDF { + /// FFI equivalent to the `name` of a [`AggregateUDF`] + pub name: RString, + + /// FFI equivalent to the `aliases` of a [`AggregateUDF`] + pub aliases: RVec, + + /// FFI equivalent to the `volatility` of a [`AggregateUDF`] + pub volatility: FFI_Volatility, + + /// Determines the return type of the underlying [`AggregateUDF`] based on the + /// argument types. + pub return_type: unsafe extern "C" fn( + udaf: &Self, + arg_types: RVec, + ) -> RResult, + + /// FFI equivalent to the `is_nullable` of a [`AggregateUDF`] + pub is_nullable: bool, + + /// FFI equivalent to [`AggregateUDF::groups_accumulator_supported`] + pub groups_accumulator_supported: + unsafe extern "C" fn(udaf: &FFI_AggregateUDF, args: FFI_AccumulatorArgs) -> bool, + + /// FFI equivalent to [`AggregateUDF::accumulator`] + pub accumulator: unsafe extern "C" fn( + udaf: &FFI_AggregateUDF, + args: FFI_AccumulatorArgs, + ) -> RResult, + + /// FFI equivalent to [`AggregateUDF::create_sliding_accumulator`] + pub create_sliding_accumulator: + unsafe extern "C" fn( + udaf: &FFI_AggregateUDF, + args: FFI_AccumulatorArgs, + ) -> RResult, + + /// FFI equivalent to [`AggregateUDF::state_fields`] + #[allow(clippy::type_complexity)] + pub state_fields: unsafe extern "C" fn( + udaf: &FFI_AggregateUDF, + name: &RStr, + input_fields: RVec, + return_field: WrappedSchema, + ordering_fields: RVec>, + is_distinct: bool, + ) -> RResult>, RString>, + + /// FFI equivalent to [`AggregateUDF::create_groups_accumulator`] + pub create_groups_accumulator: + unsafe extern "C" fn( + udaf: &FFI_AggregateUDF, + args: FFI_AccumulatorArgs, + ) -> RResult, + + /// FFI equivalent to [`AggregateUDF::with_beneficial_ordering`] + pub with_beneficial_ordering: + unsafe extern "C" fn( + udaf: &FFI_AggregateUDF, + beneficial_ordering: bool, + ) -> RResult, RString>, + + /// FFI equivalent to [`AggregateUDF::order_sensitivity`] + pub order_sensitivity: + unsafe extern "C" fn(udaf: &FFI_AggregateUDF) -> FFI_AggregateOrderSensitivity, + + /// Performs type coersion. To simply this interface, all UDFs are treated as having + /// user defined signatures, which will in turn call coerce_types to be called. This + /// call should be transparent to most users as the internal function performs the + /// appropriate calls on the underlying [`AggregateUDF`] + pub coerce_types: unsafe extern "C" fn( + udf: &Self, + arg_types: RVec, + ) -> RResult, RString>, + + /// Used to create a clone on the provider of the udaf. This should + /// only need to be called by the receiver of the udaf. + pub clone: unsafe extern "C" fn(udaf: &Self) -> Self, + + /// Release the memory of the private data when it is no longer being used. + pub release: unsafe extern "C" fn(udaf: &mut Self), + + /// Internal data. This is only to be accessed by the provider of the udaf. + /// A [`ForeignAggregateUDF`] should never attempt to access this data. + pub private_data: *mut c_void, +} + +unsafe impl Send for FFI_AggregateUDF {} +unsafe impl Sync for FFI_AggregateUDF {} + +pub struct AggregateUDFPrivateData { + pub udaf: Arc, +} + +impl FFI_AggregateUDF { + unsafe fn inner(&self) -> &Arc { + let private_data = self.private_data as *const AggregateUDFPrivateData; + &(*private_data).udaf + } +} + +unsafe extern "C" fn return_type_fn_wrapper( + udaf: &FFI_AggregateUDF, + arg_types: RVec, +) -> RResult { + let udaf = udaf.inner(); + + let arg_types = rresult_return!(rvec_wrapped_to_vec_datatype(&arg_types)); + + let return_type = udaf + .return_type(&arg_types) + .and_then(|v| FFI_ArrowSchema::try_from(v).map_err(DataFusionError::from)) + .map(WrappedSchema); + + rresult!(return_type) +} + +unsafe extern "C" fn accumulator_fn_wrapper( + udaf: &FFI_AggregateUDF, + args: FFI_AccumulatorArgs, +) -> RResult { + let udaf = udaf.inner(); + + let accumulator_args = &rresult_return!(ForeignAccumulatorArgs::try_from(args)); + + rresult!(udaf + .accumulator(accumulator_args.into()) + .map(FFI_Accumulator::from)) +} + +unsafe extern "C" fn create_sliding_accumulator_fn_wrapper( + udaf: &FFI_AggregateUDF, + args: FFI_AccumulatorArgs, +) -> RResult { + let udaf = udaf.inner(); + + let accumulator_args = &rresult_return!(ForeignAccumulatorArgs::try_from(args)); + + rresult!(udaf + .create_sliding_accumulator(accumulator_args.into()) + .map(FFI_Accumulator::from)) +} + +unsafe extern "C" fn create_groups_accumulator_fn_wrapper( + udaf: &FFI_AggregateUDF, + args: FFI_AccumulatorArgs, +) -> RResult { + let udaf = udaf.inner(); + + let accumulator_args = &rresult_return!(ForeignAccumulatorArgs::try_from(args)); + + rresult!(udaf + .create_groups_accumulator(accumulator_args.into()) + .map(FFI_GroupsAccumulator::from)) +} + +unsafe extern "C" fn groups_accumulator_supported_fn_wrapper( + udaf: &FFI_AggregateUDF, + args: FFI_AccumulatorArgs, +) -> bool { + let udaf = udaf.inner(); + + ForeignAccumulatorArgs::try_from(args) + .map(|a| udaf.groups_accumulator_supported((&a).into())) + .unwrap_or_else(|e| { + log::warn!("Unable to parse accumulator args. {e}"); + false + }) +} + +unsafe extern "C" fn with_beneficial_ordering_fn_wrapper( + udaf: &FFI_AggregateUDF, + beneficial_ordering: bool, +) -> RResult, RString> { + let udaf = udaf.inner().as_ref().clone(); + + let result = rresult_return!(udaf.with_beneficial_ordering(beneficial_ordering)); + let result = rresult_return!(result + .map(|func| func.with_beneficial_ordering(beneficial_ordering)) + .transpose()) + .flatten() + .map(|func| FFI_AggregateUDF::from(Arc::new(func))); + + RResult::ROk(result.into()) +} + +unsafe extern "C" fn state_fields_fn_wrapper( + udaf: &FFI_AggregateUDF, + name: &RStr, + input_fields: RVec, + return_field: WrappedSchema, + ordering_fields: RVec>, + is_distinct: bool, +) -> RResult>, RString> { + let udaf = udaf.inner(); + + let input_fields = &rresult_return!(rvec_wrapped_to_vec_fieldref(&input_fields)); + let return_field = rresult_return!(Field::try_from(&return_field.0)).into(); + + let ordering_fields = &rresult_return!(ordering_fields + .into_iter() + .map(|field_bytes| datafusion_proto_common::Field::decode(field_bytes.as_ref())) + .collect::, DecodeError>>()); + + let ordering_fields = &rresult_return!(parse_proto_fields_to_fields(ordering_fields)) + .into_iter() + .map(Arc::new) + .collect::>(); + + let args = StateFieldsArgs { + name: name.as_str(), + input_fields, + return_field, + ordering_fields, + is_distinct, + }; + + let state_fields = rresult_return!(udaf.state_fields(args)); + let state_fields = rresult_return!(state_fields + .iter() + .map(|f| f.as_ref()) + .map(datafusion_proto::protobuf::Field::try_from) + .map(|v| v.map_err(DataFusionError::from)) + .collect::>>()) + .into_iter() + .map(|field| field.encode_to_vec().into()) + .collect(); + + RResult::ROk(state_fields) +} + +unsafe extern "C" fn order_sensitivity_fn_wrapper( + udaf: &FFI_AggregateUDF, +) -> FFI_AggregateOrderSensitivity { + udaf.inner().order_sensitivity().into() +} + +unsafe extern "C" fn coerce_types_fn_wrapper( + udaf: &FFI_AggregateUDF, + arg_types: RVec, +) -> RResult, RString> { + let udaf = udaf.inner(); + + let arg_types = rresult_return!(rvec_wrapped_to_vec_datatype(&arg_types)); + + let arg_fields = arg_types + .iter() + .map(|dt| Field::new("f", dt.clone(), true)) + .map(Arc::new) + .collect::>(); + let return_types = rresult_return!(fields_with_aggregate_udf(&arg_fields, udaf)) + .into_iter() + .map(|f| f.data_type().to_owned()) + .collect::>(); + + rresult!(vec_datatype_to_rvec_wrapped(&return_types)) +} + +unsafe extern "C" fn release_fn_wrapper(udaf: &mut FFI_AggregateUDF) { + let private_data = Box::from_raw(udaf.private_data as *mut AggregateUDFPrivateData); + drop(private_data); +} + +unsafe extern "C" fn clone_fn_wrapper(udaf: &FFI_AggregateUDF) -> FFI_AggregateUDF { + Arc::clone(udaf.inner()).into() +} + +impl Clone for FFI_AggregateUDF { + fn clone(&self) -> Self { + unsafe { (self.clone)(self) } + } +} + +impl From> for FFI_AggregateUDF { + fn from(udaf: Arc) -> Self { + let name = udaf.name().into(); + let aliases = udaf.aliases().iter().map(|a| a.to_owned().into()).collect(); + let is_nullable = udaf.is_nullable(); + let volatility = udaf.signature().volatility.into(); + + let private_data = Box::new(AggregateUDFPrivateData { udaf }); + + Self { + name, + is_nullable, + volatility, + aliases, + return_type: return_type_fn_wrapper, + accumulator: accumulator_fn_wrapper, + create_sliding_accumulator: create_sliding_accumulator_fn_wrapper, + create_groups_accumulator: create_groups_accumulator_fn_wrapper, + groups_accumulator_supported: groups_accumulator_supported_fn_wrapper, + with_beneficial_ordering: with_beneficial_ordering_fn_wrapper, + state_fields: state_fields_fn_wrapper, + order_sensitivity: order_sensitivity_fn_wrapper, + coerce_types: coerce_types_fn_wrapper, + clone: clone_fn_wrapper, + release: release_fn_wrapper, + private_data: Box::into_raw(private_data) as *mut c_void, + } + } +} + +impl Drop for FFI_AggregateUDF { + fn drop(&mut self) { + unsafe { (self.release)(self) } + } +} + +/// This struct is used to access an UDF provided by a foreign +/// library across a FFI boundary. +/// +/// The ForeignAggregateUDF is to be used by the caller of the UDF, so it has +/// no knowledge or access to the private data. All interaction with the UDF +/// must occur through the functions defined in FFI_AggregateUDF. +#[derive(Debug)] +pub struct ForeignAggregateUDF { + signature: Signature, + aliases: Vec, + udaf: FFI_AggregateUDF, +} + +unsafe impl Send for ForeignAggregateUDF {} +unsafe impl Sync for ForeignAggregateUDF {} + +impl TryFrom<&FFI_AggregateUDF> for ForeignAggregateUDF { + type Error = DataFusionError; + + fn try_from(udaf: &FFI_AggregateUDF) -> Result { + let signature = Signature::user_defined((&udaf.volatility).into()); + let aliases = udaf.aliases.iter().map(|s| s.to_string()).collect(); + + Ok(Self { + udaf: udaf.clone(), + signature, + aliases, + }) + } +} + +impl AggregateUDFImpl for ForeignAggregateUDF { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + self.udaf.name.as_str() + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + let arg_types = vec_datatype_to_rvec_wrapped(arg_types)?; + + let result = unsafe { (self.udaf.return_type)(&self.udaf, arg_types) }; + + let result = df_result!(result); + + result.and_then(|r| (&r.0).try_into().map_err(DataFusionError::from)) + } + + fn is_nullable(&self) -> bool { + self.udaf.is_nullable + } + + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + let args = acc_args.try_into()?; + unsafe { + df_result!((self.udaf.accumulator)(&self.udaf, args)).map(|accum| { + Box::new(ForeignAccumulator::from(accum)) as Box + }) + } + } + + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + unsafe { + let name = RStr::from_str(args.name); + let input_fields = vec_fieldref_to_rvec_wrapped(args.input_fields)?; + let return_field = + WrappedSchema(FFI_ArrowSchema::try_from(args.return_field.as_ref())?); + let ordering_fields = args + .ordering_fields + .iter() + .map(|f| f.as_ref()) + .map(datafusion_proto::protobuf::Field::try_from) + .map(|v| v.map_err(DataFusionError::from)) + .collect::>>()? + .into_iter() + .map(|proto_field| proto_field.encode_to_vec().into()) + .collect(); + + let fields = df_result!((self.udaf.state_fields)( + &self.udaf, + &name, + input_fields, + return_field, + ordering_fields, + args.is_distinct + ))?; + let fields = fields + .into_iter() + .map(|field_bytes| { + datafusion_proto_common::Field::decode(field_bytes.as_ref()) + .map_err(|e| DataFusionError::Execution(e.to_string())) + }) + .collect::>>()?; + + parse_proto_fields_to_fields(fields.iter()) + .map(|fields| fields.into_iter().map(Arc::new).collect()) + .map_err(|e| DataFusionError::Execution(e.to_string())) + } + } + + fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { + let args = match FFI_AccumulatorArgs::try_from(args) { + Ok(v) => v, + Err(e) => { + log::warn!("Attempting to convert accumulator arguments: {e}"); + return false; + } + }; + + unsafe { (self.udaf.groups_accumulator_supported)(&self.udaf, args) } + } + + fn create_groups_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { + let args = FFI_AccumulatorArgs::try_from(args)?; + + unsafe { + df_result!((self.udaf.create_groups_accumulator)(&self.udaf, args)).map( + |accum| { + Box::new(ForeignGroupsAccumulator::from(accum)) + as Box + }, + ) + } + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn create_sliding_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { + let args = args.try_into()?; + unsafe { + df_result!((self.udaf.create_sliding_accumulator)(&self.udaf, args)).map( + |accum| Box::new(ForeignAccumulator::from(accum)) as Box, + ) + } + } + + fn with_beneficial_ordering( + self: Arc, + beneficial_ordering: bool, + ) -> Result>> { + unsafe { + let result = df_result!((self.udaf.with_beneficial_ordering)( + &self.udaf, + beneficial_ordering + ))? + .into_option(); + + let result = result + .map(|func| ForeignAggregateUDF::try_from(&func)) + .transpose()?; + + Ok(result.map(|func| Arc::new(func) as Arc)) + } + } + + fn order_sensitivity(&self) -> AggregateOrderSensitivity { + unsafe { (self.udaf.order_sensitivity)(&self.udaf).into() } + } + + fn simplify(&self) -> Option { + None + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + unsafe { + let arg_types = vec_datatype_to_rvec_wrapped(arg_types)?; + let result_types = + df_result!((self.udaf.coerce_types)(&self.udaf, arg_types))?; + Ok(rvec_wrapped_to_vec_datatype(&result_types)?) + } + } +} + +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub enum FFI_AggregateOrderSensitivity { + Insensitive, + HardRequirement, + Beneficial, +} + +impl From for AggregateOrderSensitivity { + fn from(value: FFI_AggregateOrderSensitivity) -> Self { + match value { + FFI_AggregateOrderSensitivity::Insensitive => Self::Insensitive, + FFI_AggregateOrderSensitivity::HardRequirement => Self::HardRequirement, + FFI_AggregateOrderSensitivity::Beneficial => Self::Beneficial, + } + } +} + +impl From for FFI_AggregateOrderSensitivity { + fn from(value: AggregateOrderSensitivity) -> Self { + match value { + AggregateOrderSensitivity::Insensitive => Self::Insensitive, + AggregateOrderSensitivity::HardRequirement => Self::HardRequirement, + AggregateOrderSensitivity::Beneficial => Self::Beneficial, + } + } +} + +#[cfg(test)] +mod tests { + use arrow::datatypes::Schema; + use datafusion::{ + common::create_array, + functions_aggregate::sum::Sum, + physical_expr::{LexOrdering, PhysicalSortExpr}, + physical_plan::expressions::col, + scalar::ScalarValue, + }; + + use super::*; + + fn create_test_foreign_udaf( + original_udaf: impl AggregateUDFImpl + 'static, + ) -> Result { + let original_udaf = Arc::new(AggregateUDF::from(original_udaf)); + + let local_udaf: FFI_AggregateUDF = Arc::clone(&original_udaf).into(); + + let foreign_udaf: ForeignAggregateUDF = (&local_udaf).try_into()?; + Ok(foreign_udaf.into()) + } + + #[test] + fn test_round_trip_udaf() -> Result<()> { + let original_udaf = Sum::new(); + let original_name = original_udaf.name().to_owned(); + let original_udaf = Arc::new(AggregateUDF::from(original_udaf)); + + // Convert to FFI format + let local_udaf: FFI_AggregateUDF = Arc::clone(&original_udaf).into(); + + // Convert back to native format + let foreign_udaf: ForeignAggregateUDF = (&local_udaf).try_into()?; + let foreign_udaf: AggregateUDF = foreign_udaf.into(); + + assert_eq!(original_name, foreign_udaf.name()); + Ok(()) + } + + #[test] + fn test_foreign_udaf_aliases() -> Result<()> { + let foreign_udaf = + create_test_foreign_udaf(Sum::new())?.with_aliases(["my_function"]); + + let return_type = foreign_udaf.return_type(&[DataType::Float64])?; + assert_eq!(return_type, DataType::Float64); + Ok(()) + } + + #[test] + fn test_foreign_udaf_accumulator() -> Result<()> { + let foreign_udaf = create_test_foreign_udaf(Sum::new())?; + + let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]); + let acc_args = AccumulatorArgs { + return_field: Field::new("f", DataType::Float64, true).into(), + schema: &schema, + ignore_nulls: true, + ordering_req: &LexOrdering::new(vec![PhysicalSortExpr { + expr: col("a", &schema)?, + options: Default::default(), + }]), + is_reversed: false, + name: "round_trip", + is_distinct: true, + exprs: &[col("a", &schema)?], + }; + let mut accumulator = foreign_udaf.accumulator(acc_args)?; + let values = create_array!(Float64, vec![10., 20., 30., 40., 50.]); + accumulator.update_batch(&[values])?; + let resultant_value = accumulator.evaluate()?; + assert_eq!(resultant_value, ScalarValue::Float64(Some(150.))); + + Ok(()) + } + + #[test] + fn test_beneficial_ordering() -> Result<()> { + let foreign_udaf = create_test_foreign_udaf( + datafusion::functions_aggregate::first_last::FirstValue::new(), + )?; + + let foreign_udaf = foreign_udaf.with_beneficial_ordering(true)?.unwrap(); + + assert_eq!( + foreign_udaf.order_sensitivity(), + AggregateOrderSensitivity::Beneficial + ); + + let a_field = Arc::new(Field::new("a", DataType::Float64, true)); + let state_fields = foreign_udaf.state_fields(StateFieldsArgs { + name: "a", + input_fields: &[Field::new("f", DataType::Float64, true).into()], + return_field: Field::new("f", DataType::Float64, true).into(), + ordering_fields: &[Arc::clone(&a_field)], + is_distinct: false, + })?; + + assert_eq!(state_fields.len(), 3); + assert_eq!(state_fields[1], a_field); + Ok(()) + } + + #[test] + fn test_sliding_accumulator() -> Result<()> { + let foreign_udaf = create_test_foreign_udaf(Sum::new())?; + + let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]); + let acc_args = AccumulatorArgs { + return_field: Field::new("f", DataType::Float64, true).into(), + schema: &schema, + ignore_nulls: true, + ordering_req: &LexOrdering::new(vec![PhysicalSortExpr { + expr: col("a", &schema)?, + options: Default::default(), + }]), + is_reversed: false, + name: "round_trip", + is_distinct: true, + exprs: &[col("a", &schema)?], + }; + + let mut accumulator = foreign_udaf.create_sliding_accumulator(acc_args)?; + let values = create_array!(Float64, vec![10., 20., 30., 40., 50.]); + accumulator.update_batch(&[values])?; + let resultant_value = accumulator.evaluate()?; + assert_eq!(resultant_value, ScalarValue::Float64(Some(150.))); + + Ok(()) + } + + fn test_round_trip_order_sensitivity(sensitivity: AggregateOrderSensitivity) { + let ffi_sensitivity: FFI_AggregateOrderSensitivity = sensitivity.into(); + let round_trip_sensitivity: AggregateOrderSensitivity = ffi_sensitivity.into(); + + assert_eq!(sensitivity, round_trip_sensitivity); + } + + #[test] + fn test_round_trip_all_order_sensitivities() { + test_round_trip_order_sensitivity(AggregateOrderSensitivity::Insensitive); + test_round_trip_order_sensitivity(AggregateOrderSensitivity::HardRequirement); + test_round_trip_order_sensitivity(AggregateOrderSensitivity::Beneficial); + } +} diff --git a/datafusion/ffi/src/udf/mod.rs b/datafusion/ffi/src/udf/mod.rs index 706b9fabedcb..303acc783b2e 100644 --- a/datafusion/ffi/src/udf/mod.rs +++ b/datafusion/ffi/src/udf/mod.rs @@ -15,23 +15,27 @@ // specific language governing permissions and limitations // under the License. -use std::{ffi::c_void, sync::Arc}; - +use crate::{ + arrow_wrappers::{WrappedArray, WrappedSchema}, + df_result, rresult, rresult_return, + util::{rvec_wrapped_to_vec_datatype, vec_datatype_to_rvec_wrapped}, + volatility::FFI_Volatility, +}; use abi_stable::{ std_types::{RResult, RString, RVec}, StableAbi, }; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use arrow::{ array::ArrayRef, error::ArrowError, ffi::{from_ffi, to_ffi, FFI_ArrowSchema}, }; +use arrow_schema::FieldRef; +use datafusion::logical_expr::ReturnFieldArgs; use datafusion::{ error::DataFusionError, - logical_expr::{ - type_coercion::functions::data_types_with_scalar_udf, ReturnInfo, ReturnTypeArgs, - }, + logical_expr::type_coercion::functions::data_types_with_scalar_udf, }; use datafusion::{ error::Result, @@ -39,19 +43,11 @@ use datafusion::{ ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, }, }; -use return_info::FFI_ReturnInfo; use return_type_args::{ - FFI_ReturnTypeArgs, ForeignReturnTypeArgs, ForeignReturnTypeArgsOwned, -}; - -use crate::{ - arrow_wrappers::{WrappedArray, WrappedSchema}, - df_result, rresult, rresult_return, - util::{rvec_wrapped_to_vec_datatype, vec_datatype_to_rvec_wrapped}, - volatility::FFI_Volatility, + FFI_ReturnFieldArgs, ForeignReturnFieldArgs, ForeignReturnFieldArgsOwned, }; +use std::{ffi::c_void, sync::Arc}; -pub mod return_info; pub mod return_type_args; /// A stable struct for sharing a [`ScalarUDF`] across FFI boundaries. @@ -77,19 +73,21 @@ pub struct FFI_ScalarUDF { /// Determines the return info of the underlying [`ScalarUDF`]. Either this /// or return_type may be implemented on a UDF. - pub return_type_from_args: unsafe extern "C" fn( + pub return_field_from_args: unsafe extern "C" fn( udf: &Self, - args: FFI_ReturnTypeArgs, + args: FFI_ReturnFieldArgs, ) - -> RResult, + -> RResult, /// Execute the underlying [`ScalarUDF`] and return the result as a `FFI_ArrowArray` /// within an AbiStable wrapper. + #[allow(clippy::type_complexity)] pub invoke_with_args: unsafe extern "C" fn( udf: &Self, args: RVec, + arg_fields: RVec, num_rows: usize, - return_type: WrappedSchema, + return_field: WrappedSchema, ) -> RResult, /// See [`ScalarUDFImpl`] for details on short_circuits @@ -140,19 +138,20 @@ unsafe extern "C" fn return_type_fn_wrapper( rresult!(return_type) } -unsafe extern "C" fn return_type_from_args_fn_wrapper( +unsafe extern "C" fn return_field_from_args_fn_wrapper( udf: &FFI_ScalarUDF, - args: FFI_ReturnTypeArgs, -) -> RResult { + args: FFI_ReturnFieldArgs, +) -> RResult { let private_data = udf.private_data as *const ScalarUDFPrivateData; let udf = &(*private_data).udf; - let args: ForeignReturnTypeArgsOwned = rresult_return!((&args).try_into()); - let args_ref: ForeignReturnTypeArgs = (&args).into(); + let args: ForeignReturnFieldArgsOwned = rresult_return!((&args).try_into()); + let args_ref: ForeignReturnFieldArgs = (&args).into(); let return_type = udf - .return_type_from_args((&args_ref).into()) - .and_then(FFI_ReturnInfo::try_from); + .return_field_from_args((&args_ref).into()) + .and_then(|f| FFI_ArrowSchema::try_from(&f).map_err(DataFusionError::from)) + .map(WrappedSchema); rresult!(return_type) } @@ -174,8 +173,9 @@ unsafe extern "C" fn coerce_types_fn_wrapper( unsafe extern "C" fn invoke_with_args_fn_wrapper( udf: &FFI_ScalarUDF, args: RVec, + arg_fields: RVec, number_rows: usize, - return_type: WrappedSchema, + return_field: WrappedSchema, ) -> RResult { let private_data = udf.private_data as *const ScalarUDFPrivateData; let udf = &(*private_data).udf; @@ -189,12 +189,23 @@ unsafe extern "C" fn invoke_with_args_fn_wrapper( .collect::>(); let args = rresult_return!(args); - let return_type = rresult_return!(DataType::try_from(&return_type.0)); + let return_field = rresult_return!(Field::try_from(&return_field.0)).into(); + + let arg_fields = arg_fields + .into_iter() + .map(|wrapped_field| { + Field::try_from(&wrapped_field.0) + .map(Arc::new) + .map_err(DataFusionError::from) + }) + .collect::>>(); + let arg_fields = rresult_return!(arg_fields); let args = ScalarFunctionArgs { args, + arg_fields, number_rows, - return_type: &return_type, + return_field, }; let result = rresult_return!(udf @@ -243,7 +254,7 @@ impl From> for FFI_ScalarUDF { short_circuits, invoke_with_args: invoke_with_args_fn_wrapper, return_type: return_type_fn_wrapper, - return_type_from_args: return_type_from_args_fn_wrapper, + return_field_from_args: return_field_from_args_fn_wrapper, coerce_types: coerce_types_fn_wrapper, clone: clone_fn_wrapper, release: release_fn_wrapper, @@ -316,21 +327,26 @@ impl ScalarUDFImpl for ForeignScalarUDF { result.and_then(|r| (&r.0).try_into().map_err(DataFusionError::from)) } - fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { - let args: FFI_ReturnTypeArgs = args.try_into()?; + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let args: FFI_ReturnFieldArgs = args.try_into()?; - let result = unsafe { (self.udf.return_type_from_args)(&self.udf, args) }; + let result = unsafe { (self.udf.return_field_from_args)(&self.udf, args) }; let result = df_result!(result); - result.and_then(|r| r.try_into()) + result.and_then(|r| { + Field::try_from(&r.0) + .map(Arc::new) + .map_err(DataFusionError::from) + }) } fn invoke_with_args(&self, invoke_args: ScalarFunctionArgs) -> Result { let ScalarFunctionArgs { args, + arg_fields, number_rows, - return_type, + return_field, } = invoke_args; let args = args @@ -347,10 +363,27 @@ impl ScalarUDFImpl for ForeignScalarUDF { .collect::, ArrowError>>()? .into(); - let return_type = WrappedSchema(FFI_ArrowSchema::try_from(return_type)?); + let arg_fields_wrapped = arg_fields + .iter() + .map(FFI_ArrowSchema::try_from) + .collect::, ArrowError>>()?; + + let arg_fields = arg_fields_wrapped + .into_iter() + .map(WrappedSchema) + .collect::>(); + + let return_field = return_field.as_ref().clone(); + let return_field = WrappedSchema(FFI_ArrowSchema::try_from(return_field)?); let result = unsafe { - (self.udf.invoke_with_args)(&self.udf, args, number_rows, return_type) + (self.udf.invoke_with_args)( + &self.udf, + args, + arg_fields, + number_rows, + return_field, + ) }; let result = df_result!(result)?; @@ -389,7 +422,7 @@ mod tests { let foreign_udf: ForeignScalarUDF = (&local_udf).try_into()?; - assert!(original_udf.name() == foreign_udf.name()); + assert_eq!(original_udf.name(), foreign_udf.name()); Ok(()) } diff --git a/datafusion/ffi/src/udf/return_type_args.rs b/datafusion/ffi/src/udf/return_type_args.rs index a0897630e2ea..c437c9537be6 100644 --- a/datafusion/ffi/src/udf/return_type_args.rs +++ b/datafusion/ffi/src/udf/return_type_args.rs @@ -19,33 +19,30 @@ use abi_stable::{ std_types::{ROption, RVec}, StableAbi, }; -use arrow::datatypes::DataType; +use arrow_schema::FieldRef; use datafusion::{ - common::exec_datafusion_err, error::DataFusionError, logical_expr::ReturnTypeArgs, + common::exec_datafusion_err, error::DataFusionError, logical_expr::ReturnFieldArgs, scalar::ScalarValue, }; -use crate::{ - arrow_wrappers::WrappedSchema, - util::{rvec_wrapped_to_vec_datatype, vec_datatype_to_rvec_wrapped}, -}; +use crate::arrow_wrappers::WrappedSchema; +use crate::util::{rvec_wrapped_to_vec_fieldref, vec_fieldref_to_rvec_wrapped}; use prost::Message; -/// A stable struct for sharing a [`ReturnTypeArgs`] across FFI boundaries. +/// A stable struct for sharing a [`ReturnFieldArgs`] across FFI boundaries. #[repr(C)] #[derive(Debug, StableAbi)] #[allow(non_camel_case_types)] -pub struct FFI_ReturnTypeArgs { - arg_types: RVec, +pub struct FFI_ReturnFieldArgs { + arg_fields: RVec, scalar_arguments: RVec>>, - nullables: RVec, } -impl TryFrom> for FFI_ReturnTypeArgs { +impl TryFrom> for FFI_ReturnFieldArgs { type Error = DataFusionError; - fn try_from(value: ReturnTypeArgs) -> Result { - let arg_types = vec_datatype_to_rvec_wrapped(value.arg_types)?; + fn try_from(value: ReturnFieldArgs) -> Result { + let arg_fields = vec_fieldref_to_rvec_wrapped(value.arg_fields)?; let scalar_arguments: Result, Self::Error> = value .scalar_arguments .iter() @@ -62,35 +59,31 @@ impl TryFrom> for FFI_ReturnTypeArgs { .collect(); let scalar_arguments = scalar_arguments?.into_iter().map(ROption::from).collect(); - let nullables = value.nullables.into(); Ok(Self { - arg_types, + arg_fields, scalar_arguments, - nullables, }) } } // TODO(tsaucer) It would be good to find a better way around this, but it // appears a restriction based on the need to have a borrowed ScalarValue -// in the arguments when converted to ReturnTypeArgs -pub struct ForeignReturnTypeArgsOwned { - arg_types: Vec, +// in the arguments when converted to ReturnFieldArgs +pub struct ForeignReturnFieldArgsOwned { + arg_fields: Vec, scalar_arguments: Vec>, - nullables: Vec, } -pub struct ForeignReturnTypeArgs<'a> { - arg_types: &'a [DataType], +pub struct ForeignReturnFieldArgs<'a> { + arg_fields: &'a [FieldRef], scalar_arguments: Vec>, - nullables: &'a [bool], } -impl TryFrom<&FFI_ReturnTypeArgs> for ForeignReturnTypeArgsOwned { +impl TryFrom<&FFI_ReturnFieldArgs> for ForeignReturnFieldArgsOwned { type Error = DataFusionError; - fn try_from(value: &FFI_ReturnTypeArgs) -> Result { - let arg_types = rvec_wrapped_to_vec_datatype(&value.arg_types)?; + fn try_from(value: &FFI_ReturnFieldArgs) -> Result { + let arg_fields = rvec_wrapped_to_vec_fieldref(&value.arg_fields)?; let scalar_arguments: Result, Self::Error> = value .scalar_arguments .iter() @@ -107,36 +100,31 @@ impl TryFrom<&FFI_ReturnTypeArgs> for ForeignReturnTypeArgsOwned { .collect(); let scalar_arguments = scalar_arguments?.into_iter().collect(); - let nullables = value.nullables.iter().cloned().collect(); - Ok(Self { - arg_types, + arg_fields, scalar_arguments, - nullables, }) } } -impl<'a> From<&'a ForeignReturnTypeArgsOwned> for ForeignReturnTypeArgs<'a> { - fn from(value: &'a ForeignReturnTypeArgsOwned) -> Self { +impl<'a> From<&'a ForeignReturnFieldArgsOwned> for ForeignReturnFieldArgs<'a> { + fn from(value: &'a ForeignReturnFieldArgsOwned) -> Self { Self { - arg_types: &value.arg_types, + arg_fields: &value.arg_fields, scalar_arguments: value .scalar_arguments .iter() .map(|opt| opt.as_ref()) .collect(), - nullables: &value.nullables, } } } -impl<'a> From<&'a ForeignReturnTypeArgs<'a>> for ReturnTypeArgs<'a> { - fn from(value: &'a ForeignReturnTypeArgs) -> Self { - ReturnTypeArgs { - arg_types: value.arg_types, +impl<'a> From<&'a ForeignReturnFieldArgs<'a>> for ReturnFieldArgs<'a> { + fn from(value: &'a ForeignReturnFieldArgs) -> Self { + ReturnFieldArgs { + arg_fields: value.arg_fields, scalar_arguments: &value.scalar_arguments, - nullables: value.nullables, } } } diff --git a/datafusion/ffi/src/udtf.rs b/datafusion/ffi/src/udtf.rs index 1e06247546be..ceedec2599a2 100644 --- a/datafusion/ffi/src/udtf.rs +++ b/datafusion/ffi/src/udtf.rs @@ -214,7 +214,7 @@ mod tests { let args = args .iter() .map(|arg| { - if let Expr::Literal(scalar) = arg { + if let Expr::Literal(scalar, _) = arg { Ok(scalar) } else { exec_err!("Expected only literal arguments to table udf") @@ -243,21 +243,21 @@ mod tests { ScalarValue::Utf8(s) => { let s_vec = vec![s.to_owned(); num_rows]; ( - Field::new(format!("field-{}", idx), DataType::Utf8, true), + Field::new(format!("field-{idx}"), DataType::Utf8, true), Arc::new(StringArray::from(s_vec)) as ArrayRef, ) } ScalarValue::UInt64(v) => { let v_vec = vec![v.to_owned(); num_rows]; ( - Field::new(format!("field-{}", idx), DataType::UInt64, true), + Field::new(format!("field-{idx}"), DataType::UInt64, true), Arc::new(UInt64Array::from(v_vec)) as ArrayRef, ) } ScalarValue::Float64(v) => { let v_vec = vec![v.to_owned(); num_rows]; ( - Field::new(format!("field-{}", idx), DataType::Float64, true), + Field::new(format!("field-{idx}"), DataType::Float64, true), Arc::new(Float64Array::from(v_vec)) as ArrayRef, ) } diff --git a/datafusion/ffi/src/udwf/mod.rs b/datafusion/ffi/src/udwf/mod.rs new file mode 100644 index 000000000000..504bf7a411f1 --- /dev/null +++ b/datafusion/ffi/src/udwf/mod.rs @@ -0,0 +1,432 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{ffi::c_void, sync::Arc}; + +use abi_stable::{ + std_types::{ROption, RResult, RString, RVec}, + StableAbi, +}; +use arrow::datatypes::Schema; +use arrow::{ + compute::SortOptions, + datatypes::{DataType, SchemaRef}, +}; +use arrow_schema::{Field, FieldRef}; +use datafusion::{ + error::DataFusionError, + logical_expr::{ + function::WindowUDFFieldArgs, type_coercion::functions::fields_with_window_udf, + PartitionEvaluator, + }, +}; +use datafusion::{ + error::Result, + logical_expr::{Signature, WindowUDF, WindowUDFImpl}, +}; +use partition_evaluator::{FFI_PartitionEvaluator, ForeignPartitionEvaluator}; +use partition_evaluator_args::{ + FFI_PartitionEvaluatorArgs, ForeignPartitionEvaluatorArgs, +}; +mod partition_evaluator; +mod partition_evaluator_args; +mod range; + +use crate::util::{rvec_wrapped_to_vec_fieldref, vec_fieldref_to_rvec_wrapped}; +use crate::{ + arrow_wrappers::WrappedSchema, + df_result, rresult, rresult_return, + util::{rvec_wrapped_to_vec_datatype, vec_datatype_to_rvec_wrapped}, + volatility::FFI_Volatility, +}; + +/// A stable struct for sharing a [`WindowUDF`] across FFI boundaries. +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub struct FFI_WindowUDF { + /// FFI equivalent to the `name` of a [`WindowUDF`] + pub name: RString, + + /// FFI equivalent to the `aliases` of a [`WindowUDF`] + pub aliases: RVec, + + /// FFI equivalent to the `volatility` of a [`WindowUDF`] + pub volatility: FFI_Volatility, + + pub partition_evaluator: + unsafe extern "C" fn( + udwf: &Self, + args: FFI_PartitionEvaluatorArgs, + ) -> RResult, + + pub field: unsafe extern "C" fn( + udwf: &Self, + input_types: RVec, + display_name: RString, + ) -> RResult, + + /// Performs type coersion. To simply this interface, all UDFs are treated as having + /// user defined signatures, which will in turn call coerce_types to be called. This + /// call should be transparent to most users as the internal function performs the + /// appropriate calls on the underlying [`WindowUDF`] + pub coerce_types: unsafe extern "C" fn( + udf: &Self, + arg_types: RVec, + ) -> RResult, RString>, + + pub sort_options: ROption, + + /// Used to create a clone on the provider of the udf. This should + /// only need to be called by the receiver of the udf. + pub clone: unsafe extern "C" fn(udf: &Self) -> Self, + + /// Release the memory of the private data when it is no longer being used. + pub release: unsafe extern "C" fn(udf: &mut Self), + + /// Internal data. This is only to be accessed by the provider of the udf. + /// A [`ForeignWindowUDF`] should never attempt to access this data. + pub private_data: *mut c_void, +} + +unsafe impl Send for FFI_WindowUDF {} +unsafe impl Sync for FFI_WindowUDF {} + +pub struct WindowUDFPrivateData { + pub udf: Arc, +} + +impl FFI_WindowUDF { + unsafe fn inner(&self) -> &Arc { + let private_data = self.private_data as *const WindowUDFPrivateData; + &(*private_data).udf + } +} + +unsafe extern "C" fn partition_evaluator_fn_wrapper( + udwf: &FFI_WindowUDF, + args: FFI_PartitionEvaluatorArgs, +) -> RResult { + let inner = udwf.inner(); + + let args = rresult_return!(ForeignPartitionEvaluatorArgs::try_from(args)); + + let evaluator = rresult_return!(inner.partition_evaluator_factory((&args).into())); + + RResult::ROk(evaluator.into()) +} + +unsafe extern "C" fn field_fn_wrapper( + udwf: &FFI_WindowUDF, + input_fields: RVec, + display_name: RString, +) -> RResult { + let inner = udwf.inner(); + + let input_fields = rresult_return!(rvec_wrapped_to_vec_fieldref(&input_fields)); + + let field = rresult_return!(inner.field(WindowUDFFieldArgs::new( + &input_fields, + display_name.as_str() + ))); + + let schema = Arc::new(Schema::new(vec![field])); + + RResult::ROk(WrappedSchema::from(schema)) +} + +unsafe extern "C" fn coerce_types_fn_wrapper( + udwf: &FFI_WindowUDF, + arg_types: RVec, +) -> RResult, RString> { + let inner = udwf.inner(); + + let arg_fields = rresult_return!(rvec_wrapped_to_vec_datatype(&arg_types)) + .into_iter() + .map(|dt| Field::new("f", dt, false)) + .map(Arc::new) + .collect::>(); + + let return_fields = rresult_return!(fields_with_window_udf(&arg_fields, inner)); + let return_types = return_fields + .into_iter() + .map(|f| f.data_type().to_owned()) + .collect::>(); + + rresult!(vec_datatype_to_rvec_wrapped(&return_types)) +} + +unsafe extern "C" fn release_fn_wrapper(udwf: &mut FFI_WindowUDF) { + let private_data = Box::from_raw(udwf.private_data as *mut WindowUDFPrivateData); + drop(private_data); +} + +unsafe extern "C" fn clone_fn_wrapper(udwf: &FFI_WindowUDF) -> FFI_WindowUDF { + // let private_data = udf.private_data as *const WindowUDFPrivateData; + // let udf_data = &(*private_data); + + // let private_data = Box::new(WindowUDFPrivateData { + // udf: Arc::clone(&udf_data.udf), + // }); + let private_data = Box::new(WindowUDFPrivateData { + udf: Arc::clone(udwf.inner()), + }); + + FFI_WindowUDF { + name: udwf.name.clone(), + aliases: udwf.aliases.clone(), + volatility: udwf.volatility.clone(), + partition_evaluator: partition_evaluator_fn_wrapper, + sort_options: udwf.sort_options.clone(), + coerce_types: coerce_types_fn_wrapper, + field: field_fn_wrapper, + clone: clone_fn_wrapper, + release: release_fn_wrapper, + private_data: Box::into_raw(private_data) as *mut c_void, + } +} + +impl Clone for FFI_WindowUDF { + fn clone(&self) -> Self { + unsafe { (self.clone)(self) } + } +} + +impl From> for FFI_WindowUDF { + fn from(udf: Arc) -> Self { + let name = udf.name().into(); + let aliases = udf.aliases().iter().map(|a| a.to_owned().into()).collect(); + let volatility = udf.signature().volatility.into(); + let sort_options = udf.sort_options().map(|v| (&v).into()).into(); + + let private_data = Box::new(WindowUDFPrivateData { udf }); + + Self { + name, + aliases, + volatility, + partition_evaluator: partition_evaluator_fn_wrapper, + sort_options, + coerce_types: coerce_types_fn_wrapper, + field: field_fn_wrapper, + clone: clone_fn_wrapper, + release: release_fn_wrapper, + private_data: Box::into_raw(private_data) as *mut c_void, + } + } +} + +impl Drop for FFI_WindowUDF { + fn drop(&mut self) { + unsafe { (self.release)(self) } + } +} + +/// This struct is used to access an UDF provided by a foreign +/// library across a FFI boundary. +/// +/// The ForeignWindowUDF is to be used by the caller of the UDF, so it has +/// no knowledge or access to the private data. All interaction with the UDF +/// must occur through the functions defined in FFI_WindowUDF. +#[derive(Debug)] +pub struct ForeignWindowUDF { + name: String, + aliases: Vec, + udf: FFI_WindowUDF, + signature: Signature, +} + +unsafe impl Send for ForeignWindowUDF {} +unsafe impl Sync for ForeignWindowUDF {} + +impl TryFrom<&FFI_WindowUDF> for ForeignWindowUDF { + type Error = DataFusionError; + + fn try_from(udf: &FFI_WindowUDF) -> Result { + let name = udf.name.to_owned().into(); + let signature = Signature::user_defined((&udf.volatility).into()); + + let aliases = udf.aliases.iter().map(|s| s.to_string()).collect(); + + Ok(Self { + name, + udf: udf.clone(), + aliases, + signature, + }) + } +} + +impl WindowUDFImpl for ForeignWindowUDF { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + unsafe { + let arg_types = vec_datatype_to_rvec_wrapped(arg_types)?; + let result_types = df_result!((self.udf.coerce_types)(&self.udf, arg_types))?; + Ok(rvec_wrapped_to_vec_datatype(&result_types)?) + } + } + + fn partition_evaluator( + &self, + args: datafusion::logical_expr::function::PartitionEvaluatorArgs, + ) -> Result> { + let evaluator = unsafe { + let args = FFI_PartitionEvaluatorArgs::try_from(args)?; + (self.udf.partition_evaluator)(&self.udf, args) + }; + + df_result!(evaluator).map(|evaluator| { + Box::new(ForeignPartitionEvaluator::from(evaluator)) + as Box + }) + } + + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + unsafe { + let input_types = vec_fieldref_to_rvec_wrapped(field_args.input_fields())?; + let schema = df_result!((self.udf.field)( + &self.udf, + input_types, + field_args.name().into() + ))?; + let schema: SchemaRef = schema.into(); + + match schema.fields().is_empty() { + true => Err(DataFusionError::Execution( + "Unable to retrieve field in WindowUDF via FFI".to_string(), + )), + false => Ok(schema.field(0).to_owned().into()), + } + } + } + + fn sort_options(&self) -> Option { + let options: Option<&FFI_SortOptions> = self.udf.sort_options.as_ref().into(); + options.map(|s| s.into()) + } +} + +#[repr(C)] +#[derive(Debug, StableAbi, Clone)] +#[allow(non_camel_case_types)] +pub struct FFI_SortOptions { + pub descending: bool, + pub nulls_first: bool, +} + +impl From<&SortOptions> for FFI_SortOptions { + fn from(value: &SortOptions) -> Self { + Self { + descending: value.descending, + nulls_first: value.nulls_first, + } + } +} + +impl From<&FFI_SortOptions> for SortOptions { + fn from(value: &FFI_SortOptions) -> Self { + Self { + descending: value.descending, + nulls_first: value.nulls_first, + } + } +} + +#[cfg(test)] +#[cfg(feature = "integration-tests")] +mod tests { + use crate::tests::create_record_batch; + use crate::udwf::{FFI_WindowUDF, ForeignWindowUDF}; + use arrow::array::{create_array, ArrayRef}; + use datafusion::functions_window::lead_lag::{lag_udwf, WindowShift}; + use datafusion::logical_expr::expr::Sort; + use datafusion::logical_expr::{col, ExprFunctionExt, WindowUDF, WindowUDFImpl}; + use datafusion::prelude::SessionContext; + use std::sync::Arc; + + fn create_test_foreign_udwf( + original_udwf: impl WindowUDFImpl + 'static, + ) -> datafusion::common::Result { + let original_udwf = Arc::new(WindowUDF::from(original_udwf)); + + let local_udwf: FFI_WindowUDF = Arc::clone(&original_udwf).into(); + + let foreign_udwf: ForeignWindowUDF = (&local_udwf).try_into()?; + Ok(foreign_udwf.into()) + } + + #[test] + fn test_round_trip_udwf() -> datafusion::common::Result<()> { + let original_udwf = lag_udwf(); + let original_name = original_udwf.name().to_owned(); + + // Convert to FFI format + let local_udwf: FFI_WindowUDF = Arc::clone(&original_udwf).into(); + + // Convert back to native format + let foreign_udwf: ForeignWindowUDF = (&local_udwf).try_into()?; + let foreign_udwf: WindowUDF = foreign_udwf.into(); + + assert_eq!(original_name, foreign_udwf.name()); + Ok(()) + } + + #[tokio::test] + async fn test_lag_udwf() -> datafusion::common::Result<()> { + let udwf = create_test_foreign_udwf(WindowShift::lag())?; + + let ctx = SessionContext::default(); + let df = ctx.read_batch(create_record_batch(-5, 5))?; + + let df = df.select(vec![ + col("a"), + udwf.call(vec![col("a")]) + .order_by(vec![Sort::new(col("a"), true, true)]) + .build() + .unwrap() + .alias("lag_a"), + ])?; + + df.clone().show().await?; + + let result = df.collect().await?; + let expected = + create_array!(Int32, [None, Some(-5), Some(-4), Some(-3), Some(-2)]) + as ArrayRef; + + assert_eq!(result.len(), 1); + assert_eq!(result[0].column(1), &expected); + + Ok(()) + } +} diff --git a/datafusion/ffi/src/udwf/partition_evaluator.rs b/datafusion/ffi/src/udwf/partition_evaluator.rs new file mode 100644 index 000000000000..995d00cce30e --- /dev/null +++ b/datafusion/ffi/src/udwf/partition_evaluator.rs @@ -0,0 +1,320 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{ffi::c_void, ops::Range}; + +use crate::{arrow_wrappers::WrappedArray, df_result, rresult, rresult_return}; +use abi_stable::{ + std_types::{RResult, RString, RVec}, + StableAbi, +}; +use arrow::{array::ArrayRef, error::ArrowError}; +use datafusion::{ + error::{DataFusionError, Result}, + logical_expr::{window_state::WindowAggState, PartitionEvaluator}, + scalar::ScalarValue, +}; +use prost::Message; + +use super::range::FFI_Range; + +/// A stable struct for sharing [`PartitionEvaluator`] across FFI boundaries. +/// For an explanation of each field, see the corresponding function +/// defined in [`PartitionEvaluator`]. +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub struct FFI_PartitionEvaluator { + pub evaluate_all: unsafe extern "C" fn( + evaluator: &mut Self, + values: RVec, + num_rows: usize, + ) -> RResult, + + pub evaluate: unsafe extern "C" fn( + evaluator: &mut Self, + values: RVec, + range: FFI_Range, + ) -> RResult, RString>, + + pub evaluate_all_with_rank: unsafe extern "C" fn( + evaluator: &Self, + num_rows: usize, + ranks_in_partition: RVec, + ) + -> RResult, + + pub get_range: unsafe extern "C" fn( + evaluator: &Self, + idx: usize, + n_rows: usize, + ) -> RResult, + + pub is_causal: bool, + + pub supports_bounded_execution: bool, + pub uses_window_frame: bool, + pub include_rank: bool, + + /// Release the memory of the private data when it is no longer being used. + pub release: unsafe extern "C" fn(evaluator: &mut Self), + + /// Internal data. This is only to be accessed by the provider of the evaluator. + /// A [`ForeignPartitionEvaluator`] should never attempt to access this data. + pub private_data: *mut c_void, +} + +unsafe impl Send for FFI_PartitionEvaluator {} +unsafe impl Sync for FFI_PartitionEvaluator {} + +pub struct PartitionEvaluatorPrivateData { + pub evaluator: Box, +} + +impl FFI_PartitionEvaluator { + unsafe fn inner_mut(&mut self) -> &mut Box<(dyn PartitionEvaluator + 'static)> { + let private_data = self.private_data as *mut PartitionEvaluatorPrivateData; + &mut (*private_data).evaluator + } + + unsafe fn inner(&self) -> &(dyn PartitionEvaluator + 'static) { + let private_data = self.private_data as *mut PartitionEvaluatorPrivateData; + (*private_data).evaluator.as_ref() + } +} + +unsafe extern "C" fn evaluate_all_fn_wrapper( + evaluator: &mut FFI_PartitionEvaluator, + values: RVec, + num_rows: usize, +) -> RResult { + let inner = evaluator.inner_mut(); + + let values_arrays = values + .into_iter() + .map(|v| v.try_into().map_err(DataFusionError::from)) + .collect::>>(); + let values_arrays = rresult_return!(values_arrays); + + let return_array = inner + .evaluate_all(&values_arrays, num_rows) + .and_then(|array| WrappedArray::try_from(&array).map_err(DataFusionError::from)); + + rresult!(return_array) +} + +unsafe extern "C" fn evaluate_fn_wrapper( + evaluator: &mut FFI_PartitionEvaluator, + values: RVec, + range: FFI_Range, +) -> RResult, RString> { + let inner = evaluator.inner_mut(); + + let values_arrays = values + .into_iter() + .map(|v| v.try_into().map_err(DataFusionError::from)) + .collect::>>(); + let values_arrays = rresult_return!(values_arrays); + + // let return_array = (inner.evaluate(&values_arrays, &range.into())); + // .and_then(|array| WrappedArray::try_from(&array).map_err(DataFusionError::from)); + let scalar_result = rresult_return!(inner.evaluate(&values_arrays, &range.into())); + let proto_result: datafusion_proto::protobuf::ScalarValue = + rresult_return!((&scalar_result).try_into()); + + RResult::ROk(proto_result.encode_to_vec().into()) +} + +unsafe extern "C" fn evaluate_all_with_rank_fn_wrapper( + evaluator: &FFI_PartitionEvaluator, + num_rows: usize, + ranks_in_partition: RVec, +) -> RResult { + let inner = evaluator.inner(); + + let ranks_in_partition = ranks_in_partition + .into_iter() + .map(Range::from) + .collect::>(); + + let return_array = inner + .evaluate_all_with_rank(num_rows, &ranks_in_partition) + .and_then(|array| WrappedArray::try_from(&array).map_err(DataFusionError::from)); + + rresult!(return_array) +} + +unsafe extern "C" fn get_range_fn_wrapper( + evaluator: &FFI_PartitionEvaluator, + idx: usize, + n_rows: usize, +) -> RResult { + let inner = evaluator.inner(); + let range = inner.get_range(idx, n_rows).map(FFI_Range::from); + + rresult!(range) +} + +unsafe extern "C" fn release_fn_wrapper(evaluator: &mut FFI_PartitionEvaluator) { + let private_data = + Box::from_raw(evaluator.private_data as *mut PartitionEvaluatorPrivateData); + drop(private_data); +} + +impl From> for FFI_PartitionEvaluator { + fn from(evaluator: Box) -> Self { + let is_causal = evaluator.is_causal(); + let supports_bounded_execution = evaluator.supports_bounded_execution(); + let include_rank = evaluator.include_rank(); + let uses_window_frame = evaluator.uses_window_frame(); + + let private_data = PartitionEvaluatorPrivateData { evaluator }; + + Self { + evaluate: evaluate_fn_wrapper, + evaluate_all: evaluate_all_fn_wrapper, + evaluate_all_with_rank: evaluate_all_with_rank_fn_wrapper, + get_range: get_range_fn_wrapper, + is_causal, + supports_bounded_execution, + include_rank, + uses_window_frame, + release: release_fn_wrapper, + private_data: Box::into_raw(Box::new(private_data)) as *mut c_void, + } + } +} + +impl Drop for FFI_PartitionEvaluator { + fn drop(&mut self) { + unsafe { (self.release)(self) } + } +} + +/// This struct is used to access an UDF provided by a foreign +/// library across a FFI boundary. +/// +/// The ForeignPartitionEvaluator is to be used by the caller of the UDF, so it has +/// no knowledge or access to the private data. All interaction with the UDF +/// must occur through the functions defined in FFI_PartitionEvaluator. +#[derive(Debug)] +pub struct ForeignPartitionEvaluator { + evaluator: FFI_PartitionEvaluator, +} + +unsafe impl Send for ForeignPartitionEvaluator {} +unsafe impl Sync for ForeignPartitionEvaluator {} + +impl From for ForeignPartitionEvaluator { + fn from(evaluator: FFI_PartitionEvaluator) -> Self { + Self { evaluator } + } +} + +impl PartitionEvaluator for ForeignPartitionEvaluator { + fn memoize(&mut self, _state: &mut WindowAggState) -> Result<()> { + // Exposing `memoize` increases the surface are of the FFI work + // so for now we dot support it. + Ok(()) + } + + fn get_range(&self, idx: usize, n_rows: usize) -> Result> { + let range = unsafe { (self.evaluator.get_range)(&self.evaluator, idx, n_rows) }; + df_result!(range).map(Range::from) + } + + /// Get whether evaluator needs future data for its result (if so returns `false`) or not + fn is_causal(&self) -> bool { + self.evaluator.is_causal + } + + fn evaluate_all(&mut self, values: &[ArrayRef], num_rows: usize) -> Result { + let result = unsafe { + let values = values + .iter() + .map(WrappedArray::try_from) + .collect::, ArrowError>>()?; + (self.evaluator.evaluate_all)(&mut self.evaluator, values, num_rows) + }; + + let array = df_result!(result)?; + + Ok(array.try_into()?) + } + + fn evaluate( + &mut self, + values: &[ArrayRef], + range: &Range, + ) -> Result { + unsafe { + let values = values + .iter() + .map(WrappedArray::try_from) + .collect::, ArrowError>>()?; + + let scalar_bytes = df_result!((self.evaluator.evaluate)( + &mut self.evaluator, + values, + range.to_owned().into() + ))?; + + let proto_scalar = + datafusion_proto::protobuf::ScalarValue::decode(scalar_bytes.as_ref()) + .map_err(|e| DataFusionError::External(Box::new(e)))?; + + ScalarValue::try_from(&proto_scalar).map_err(DataFusionError::from) + } + } + + fn evaluate_all_with_rank( + &self, + num_rows: usize, + ranks_in_partition: &[Range], + ) -> Result { + let result = unsafe { + let ranks_in_partition = ranks_in_partition + .iter() + .map(|rank| FFI_Range::from(rank.to_owned())) + .collect(); + (self.evaluator.evaluate_all_with_rank)( + &self.evaluator, + num_rows, + ranks_in_partition, + ) + }; + + let array = df_result!(result)?; + + Ok(array.try_into()?) + } + + fn supports_bounded_execution(&self) -> bool { + self.evaluator.supports_bounded_execution + } + + fn uses_window_frame(&self) -> bool { + self.evaluator.uses_window_frame + } + + fn include_rank(&self) -> bool { + self.evaluator.include_rank + } +} + +#[cfg(test)] +mod tests {} diff --git a/datafusion/ffi/src/udwf/partition_evaluator_args.rs b/datafusion/ffi/src/udwf/partition_evaluator_args.rs new file mode 100644 index 000000000000..dffeb23741b6 --- /dev/null +++ b/datafusion/ffi/src/udwf/partition_evaluator_args.rs @@ -0,0 +1,181 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{collections::HashMap, sync::Arc}; + +use crate::arrow_wrappers::WrappedSchema; +use abi_stable::{std_types::RVec, StableAbi}; +use arrow::{ + datatypes::{DataType, Field, Schema, SchemaRef}, + error::ArrowError, + ffi::FFI_ArrowSchema, +}; +use arrow_schema::FieldRef; +use datafusion::{ + error::{DataFusionError, Result}, + logical_expr::function::PartitionEvaluatorArgs, + physical_plan::{expressions::Column, PhysicalExpr}, + prelude::SessionContext, +}; +use datafusion_proto::{ + physical_plan::{ + from_proto::parse_physical_expr, to_proto::serialize_physical_exprs, + DefaultPhysicalExtensionCodec, + }, + protobuf::PhysicalExprNode, +}; +use prost::Message; + +/// A stable struct for sharing [`PartitionEvaluatorArgs`] across FFI boundaries. +/// For an explanation of each field, see the corresponding function +/// defined in [`PartitionEvaluatorArgs`]. +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub struct FFI_PartitionEvaluatorArgs { + input_exprs: RVec>, + input_fields: RVec, + is_reversed: bool, + ignore_nulls: bool, + schema: WrappedSchema, +} + +impl TryFrom> for FFI_PartitionEvaluatorArgs { + type Error = DataFusionError; + fn try_from(args: PartitionEvaluatorArgs) -> Result { + // This is a bit of a hack. Since PartitionEvaluatorArgs does not carry a schema + // around, and instead passes the data types directly we are unable to decode the + // protobuf PhysicalExpr correctly. In evaluating the code the only place these + // appear to be really used are the Column data types. So here we will find all + // of the required columns and create a schema that has empty fields except for + // the ones we require. Ideally we would enhance PartitionEvaluatorArgs to just + // pass along the schema, but that is a larger breaking change. + let required_columns: HashMap = args + .input_exprs() + .iter() + .zip(args.input_fields()) + .filter_map(|(expr, field)| { + expr.as_any() + .downcast_ref::() + .map(|column| (column.index(), (column.name(), field.data_type()))) + }) + .collect(); + + let max_column = required_columns.keys().max(); + let fields: Vec<_> = max_column + .map(|max_column| { + (0..(max_column + 1)) + .map(|idx| match required_columns.get(&idx) { + Some((name, data_type)) => { + Field::new(*name, (*data_type).clone(), true) + } + None => Field::new( + format!("ffi_partition_evaluator_col_{idx}"), + DataType::Null, + true, + ), + }) + .collect() + }) + .unwrap_or_default(); + + let schema = Arc::new(Schema::new(fields)); + + let codec = DefaultPhysicalExtensionCodec {}; + let input_exprs = serialize_physical_exprs(args.input_exprs(), &codec)? + .into_iter() + .map(|expr_node| expr_node.encode_to_vec().into()) + .collect(); + + let input_fields = args + .input_fields() + .iter() + .map(|input_type| FFI_ArrowSchema::try_from(input_type).map(WrappedSchema)) + .collect::, ArrowError>>()? + .into(); + + let schema: WrappedSchema = schema.into(); + + Ok(Self { + input_exprs, + input_fields, + schema, + is_reversed: args.is_reversed(), + ignore_nulls: args.ignore_nulls(), + }) + } +} + +/// This struct mirrors PartitionEvaluatorArgs except that it contains owned data. +/// It is necessary to create this struct so that we can parse the protobuf +/// data across the FFI boundary and turn it into owned data that +/// PartitionEvaluatorArgs can then reference. +pub struct ForeignPartitionEvaluatorArgs { + input_exprs: Vec>, + input_fields: Vec, + is_reversed: bool, + ignore_nulls: bool, +} + +impl TryFrom for ForeignPartitionEvaluatorArgs { + type Error = DataFusionError; + + fn try_from(value: FFI_PartitionEvaluatorArgs) -> Result { + let default_ctx = SessionContext::new(); + let codec = DefaultPhysicalExtensionCodec {}; + + let schema: SchemaRef = value.schema.into(); + + let input_exprs = value + .input_exprs + .into_iter() + .map(|input_expr_bytes| PhysicalExprNode::decode(input_expr_bytes.as_ref())) + .collect::, prost::DecodeError>>() + .map_err(|e| DataFusionError::Execution(e.to_string()))? + .iter() + .map(|expr_node| { + parse_physical_expr(expr_node, &default_ctx, &schema, &codec) + }) + .collect::>>()?; + + let input_fields = input_exprs + .iter() + .map(|expr| expr.return_field(&schema)) + .collect::>>()?; + + Ok(Self { + input_exprs, + input_fields, + is_reversed: value.is_reversed, + ignore_nulls: value.ignore_nulls, + }) + } +} + +impl<'a> From<&'a ForeignPartitionEvaluatorArgs> for PartitionEvaluatorArgs<'a> { + fn from(value: &'a ForeignPartitionEvaluatorArgs) -> Self { + PartitionEvaluatorArgs::new( + &value.input_exprs, + &value.input_fields, + value.is_reversed, + value.ignore_nulls, + ) + } +} + +#[cfg(test)] +mod tests {} diff --git a/datafusion/ffi/src/udf/return_info.rs b/datafusion/ffi/src/udwf/range.rs similarity index 50% rename from datafusion/ffi/src/udf/return_info.rs rename to datafusion/ffi/src/udwf/range.rs index cf76ddd1db76..1ddcc4199fe2 100644 --- a/datafusion/ffi/src/udf/return_info.rs +++ b/datafusion/ffi/src/udwf/range.rs @@ -15,39 +15,50 @@ // specific language governing permissions and limitations // under the License. -use abi_stable::StableAbi; -use arrow::{datatypes::DataType, ffi::FFI_ArrowSchema}; -use datafusion::{error::DataFusionError, logical_expr::ReturnInfo}; +use std::ops::Range; -use crate::arrow_wrappers::WrappedSchema; +use abi_stable::StableAbi; -/// A stable struct for sharing a [`ReturnInfo`] across FFI boundaries. +/// A stable struct for sharing [`Range`] across FFI boundaries. +/// For an explanation of each field, see the corresponding function +/// defined in [`Range`]. #[repr(C)] #[derive(Debug, StableAbi)] #[allow(non_camel_case_types)] -pub struct FFI_ReturnInfo { - return_type: WrappedSchema, - nullable: bool, +pub struct FFI_Range { + pub start: usize, + pub end: usize, } -impl TryFrom for FFI_ReturnInfo { - type Error = DataFusionError; +impl From> for FFI_Range { + fn from(value: Range) -> Self { + Self { + start: value.start, + end: value.end, + } + } +} - fn try_from(value: ReturnInfo) -> Result { - let return_type = WrappedSchema(FFI_ArrowSchema::try_from(value.return_type())?); - Ok(Self { - return_type, - nullable: value.nullable(), - }) +impl From for Range { + fn from(value: FFI_Range) -> Self { + Self { + start: value.start, + end: value.end, + } } } -impl TryFrom for ReturnInfo { - type Error = DataFusionError; +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_round_trip_ffi_range() { + let original = Range { start: 10, end: 30 }; - fn try_from(value: FFI_ReturnInfo) -> Result { - let return_type = DataType::try_from(&value.return_type.0)?; + let ffi_range: FFI_Range = original.clone().into(); + let round_trip: Range = ffi_range.into(); - Ok(ReturnInfo::new(return_type, value.nullable)) + assert_eq!(original, round_trip); } } diff --git a/datafusion/ffi/src/util.rs b/datafusion/ffi/src/util.rs index 9d5f2aefe324..abe369c57298 100644 --- a/datafusion/ffi/src/util.rs +++ b/datafusion/ffi/src/util.rs @@ -15,12 +15,14 @@ // specific language governing permissions and limitations // under the License. +use crate::arrow_wrappers::WrappedSchema; use abi_stable::std_types::RVec; +use arrow::datatypes::Field; use arrow::{datatypes::DataType, ffi::FFI_ArrowSchema}; +use arrow_schema::FieldRef; +use std::sync::Arc; -use crate::arrow_wrappers::WrappedSchema; - -/// This macro is a helpful conversion utility to conver from an abi_stable::RResult to a +/// This macro is a helpful conversion utility to convert from an abi_stable::RResult to a /// DataFusion result. #[macro_export] macro_rules! df_result { @@ -64,6 +66,31 @@ macro_rules! rresult_return { }; } +/// This is a utility function to convert a slice of [`Field`] to its equivalent +/// FFI friendly counterpart, [`WrappedSchema`] +pub fn vec_fieldref_to_rvec_wrapped( + fields: &[FieldRef], +) -> Result, arrow::error::ArrowError> { + Ok(fields + .iter() + .map(FFI_ArrowSchema::try_from) + .collect::, arrow::error::ArrowError>>()? + .into_iter() + .map(WrappedSchema) + .collect()) +} + +/// This is a utility function to convert an FFI friendly vector of [`WrappedSchema`] +/// to their equivalent [`Field`]. +pub fn rvec_wrapped_to_vec_fieldref( + fields: &RVec, +) -> Result, arrow::error::ArrowError> { + fields + .iter() + .map(|d| Field::try_from(&d.0).map(Arc::new)) + .collect() +} + /// This is a utility function to convert a slice of [`DataType`] to its equivalent /// FFI friendly counterpart, [`WrappedSchema`] pub fn vec_datatype_to_rvec_wrapped( @@ -116,7 +143,7 @@ mod tests { assert!(returned_err_result.is_err()); assert!( returned_err_result.unwrap_err().to_string() - == format!("Execution error: {}", ERROR_VALUE) + == format!("Execution error: {ERROR_VALUE}") ); let ok_result: Result = Ok(VALID_VALUE.to_string()); @@ -129,7 +156,7 @@ mod tests { let returned_err_r_result = wrap_result(err_result); assert!( returned_err_r_result - == RResult::RErr(format!("Execution error: {}", ERROR_VALUE).into()) + == RResult::RErr(format!("Execution error: {ERROR_VALUE}").into()) ); } } diff --git a/datafusion/ffi/src/volatility.rs b/datafusion/ffi/src/volatility.rs index 0aaf68a174cf..f1705da294a3 100644 --- a/datafusion/ffi/src/volatility.rs +++ b/datafusion/ffi/src/volatility.rs @@ -19,7 +19,7 @@ use abi_stable::StableAbi; use datafusion::logical_expr::Volatility; #[repr(C)] -#[derive(Debug, StableAbi)] +#[derive(Debug, StableAbi, Clone)] #[allow(non_camel_case_types)] pub enum FFI_Volatility { Immutable, diff --git a/datafusion/ffi/tests/ffi_integration.rs b/datafusion/ffi/tests/ffi_integration.rs index c6df324e9a17..1ef16fbaa4d8 100644 --- a/datafusion/ffi/tests/ffi_integration.rs +++ b/datafusion/ffi/tests/ffi_integration.rs @@ -19,7 +19,6 @@ /// when the feature integtation-tests is built #[cfg(feature = "integration-tests")] mod tests { - use datafusion::error::{DataFusionError, Result}; use datafusion::prelude::SessionContext; use datafusion_ffi::catalog_provider::ForeignCatalogProvider; diff --git a/datafusion/ffi/tests/ffi_udaf.rs b/datafusion/ffi/tests/ffi_udaf.rs new file mode 100644 index 000000000000..31b1f473913c --- /dev/null +++ b/datafusion/ffi/tests/ffi_udaf.rs @@ -0,0 +1,129 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// Add an additional module here for convenience to scope this to only +/// when the feature integtation-tests is built +#[cfg(feature = "integration-tests")] +mod tests { + use arrow::array::Float64Array; + use datafusion::common::record_batch; + use datafusion::error::{DataFusionError, Result}; + use datafusion::logical_expr::AggregateUDF; + use datafusion::prelude::{col, SessionContext}; + + use datafusion_ffi::tests::utils::get_module; + use datafusion_ffi::udaf::ForeignAggregateUDF; + + #[tokio::test] + async fn test_ffi_udaf() -> Result<()> { + let module = get_module()?; + + let ffi_sum_func = + module + .create_sum_udaf() + .ok_or(DataFusionError::NotImplemented( + "External table provider failed to implement create_udaf".to_string(), + ))?(); + let foreign_sum_func: ForeignAggregateUDF = (&ffi_sum_func).try_into()?; + + let udaf: AggregateUDF = foreign_sum_func.into(); + + let ctx = SessionContext::default(); + let record_batch = record_batch!( + ("a", Int32, vec![1, 2, 2, 4, 4, 4, 4]), + ("b", Float64, vec![1.0, 2.0, 2.0, 4.0, 4.0, 4.0, 4.0]) + ) + .unwrap(); + + let df = ctx.read_batch(record_batch)?; + + let df = df + .aggregate( + vec![col("a")], + vec![udaf.call(vec![col("b")]).alias("sum_b")], + )? + .sort_by(vec![col("a")])?; + + let result = df.collect().await?; + + let expected = record_batch!( + ("a", Int32, vec![1, 2, 4]), + ("sum_b", Float64, vec![1.0, 4.0, 16.0]) + )?; + + assert_eq!(result[0], expected); + + Ok(()) + } + + #[tokio::test] + async fn test_ffi_grouping_udaf() -> Result<()> { + let module = get_module()?; + + let ffi_stddev_func = + module + .create_stddev_udaf() + .ok_or(DataFusionError::NotImplemented( + "External table provider failed to implement create_udaf".to_string(), + ))?(); + let foreign_stddev_func: ForeignAggregateUDF = (&ffi_stddev_func).try_into()?; + + let udaf: AggregateUDF = foreign_stddev_func.into(); + + let ctx = SessionContext::default(); + let record_batch = record_batch!( + ("a", Int32, vec![1, 2, 2, 4, 4, 4, 4]), + ( + "b", + Float64, + vec![ + 1.0, + 2.0, + 2.0 + 2.0_f64.sqrt(), + 4.0, + 4.0, + 4.0 + 3.0_f64.sqrt(), + 4.0 + 3.0_f64.sqrt() + ] + ) + ) + .unwrap(); + + let df = ctx.read_batch(record_batch)?; + + let df = df + .aggregate( + vec![col("a")], + vec![udaf.call(vec![col("b")]).alias("stddev_b")], + )? + .sort_by(vec![col("a")])?; + + let result = df.collect().await?; + let result = result[0].column_by_name("stddev_b").unwrap(); + let result = result + .as_any() + .downcast_ref::() + .unwrap() + .values(); + + assert!(result.first().unwrap().is_nan()); + assert!(result.get(1).unwrap() - 1.0 < 0.00001); + assert!(result.get(2).unwrap() - 1.0 < 0.00001); + + Ok(()) + } +} diff --git a/datafusion/ffi/tests/ffi_udwf.rs b/datafusion/ffi/tests/ffi_udwf.rs new file mode 100644 index 000000000000..db9ebba0fdfb --- /dev/null +++ b/datafusion/ffi/tests/ffi_udwf.rs @@ -0,0 +1,68 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// Add an additional module here for convenience to scope this to only +/// when the feature integtation-tests is built +#[cfg(feature = "integration-tests")] +mod tests { + use arrow::array::{create_array, ArrayRef}; + use datafusion::error::{DataFusionError, Result}; + use datafusion::logical_expr::expr::Sort; + use datafusion::logical_expr::{col, ExprFunctionExt, WindowUDF}; + use datafusion::prelude::SessionContext; + use datafusion_ffi::tests::create_record_batch; + use datafusion_ffi::tests::utils::get_module; + use datafusion_ffi::udwf::ForeignWindowUDF; + + #[tokio::test] + async fn test_rank_udwf() -> Result<()> { + let module = get_module()?; + + let ffi_rank_func = + module + .create_rank_udwf() + .ok_or(DataFusionError::NotImplemented( + "External table provider failed to implement create_scalar_udf" + .to_string(), + ))?(); + let foreign_rank_func: ForeignWindowUDF = (&ffi_rank_func).try_into()?; + + let udwf: WindowUDF = foreign_rank_func.into(); + + let ctx = SessionContext::default(); + let df = ctx.read_batch(create_record_batch(-5, 5))?; + + let df = df.select(vec![ + col("a"), + udwf.call(vec![]) + .order_by(vec![Sort::new(col("a"), true, true)]) + .build() + .unwrap() + .alias("rank_a"), + ])?; + + df.clone().show().await?; + + let result = df.collect().await?; + let expected = create_array!(UInt64, [1, 2, 3, 4, 5]) as ArrayRef; + + assert_eq!(result.len(), 1); + assert_eq!(result[0].column(1), &expected); + + Ok(()) + } +} diff --git a/datafusion/functions-aggregate-common/src/accumulator.rs b/datafusion/functions-aggregate-common/src/accumulator.rs index a230bb028909..01b16f1b0a8c 100644 --- a/datafusion/functions-aggregate-common/src/accumulator.rs +++ b/datafusion/functions-aggregate-common/src/accumulator.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use arrow::datatypes::{DataType, Field, Schema}; +use arrow::datatypes::{DataType, FieldRef, Schema}; use datafusion_common::Result; use datafusion_expr_common::accumulator::Accumulator; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; @@ -27,8 +27,8 @@ use std::sync::Arc; /// ordering expressions. #[derive(Debug)] pub struct AccumulatorArgs<'a> { - /// The return type of the aggregate function. - pub return_type: &'a DataType, + /// The return field of the aggregate function. + pub return_field: FieldRef, /// The schema of the input arguments pub schema: &'a Schema, @@ -71,6 +71,13 @@ pub struct AccumulatorArgs<'a> { pub exprs: &'a [Arc], } +impl AccumulatorArgs<'_> { + /// Returns the return type of the aggregate function. + pub fn return_type(&self) -> &DataType { + self.return_field.data_type() + } +} + /// Factory that returns an accumulator for the given aggregate function. pub type AccumulatorFactoryFunction = Arc Result> + Send + Sync>; @@ -81,15 +88,22 @@ pub struct StateFieldsArgs<'a> { /// The name of the aggregate function. pub name: &'a str, - /// The input types of the aggregate function. - pub input_types: &'a [DataType], + /// The input fields of the aggregate function. + pub input_fields: &'a [FieldRef], - /// The return type of the aggregate function. - pub return_type: &'a DataType, + /// The return fields of the aggregate function. + pub return_field: FieldRef, /// The ordering fields of the aggregate function. - pub ordering_fields: &'a [Field], + pub ordering_fields: &'a [FieldRef], /// Whether the aggregate function is distinct. pub is_distinct: bool, } + +impl StateFieldsArgs<'_> { + /// The return type of the aggregate function. + pub fn return_type(&self) -> &DataType { + self.return_field.data_type() + } +} diff --git a/datafusion/functions-aggregate-common/src/aggregate/count_distinct.rs b/datafusion/functions-aggregate-common/src/aggregate/count_distinct.rs index 7d772f7c649d..25b40382299b 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/count_distinct.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/count_distinct.rs @@ -16,9 +16,11 @@ // under the License. mod bytes; +mod dict; mod native; pub use bytes::BytesDistinctCountAccumulator; pub use bytes::BytesViewDistinctCountAccumulator; +pub use dict::DictionaryCountAccumulator; pub use native::FloatDistinctCountAccumulator; pub use native::PrimitiveDistinctCountAccumulator; diff --git a/datafusion/functions-aggregate-common/src/aggregate/count_distinct/dict.rs b/datafusion/functions-aggregate-common/src/aggregate/count_distinct/dict.rs new file mode 100644 index 000000000000..089d8d5acded --- /dev/null +++ b/datafusion/functions-aggregate-common/src/aggregate/count_distinct/dict.rs @@ -0,0 +1,70 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, BooleanArray}; +use arrow::downcast_dictionary_array; +use datafusion_common::{arrow_datafusion_err, ScalarValue}; +use datafusion_common::{internal_err, DataFusionError}; +use datafusion_expr_common::accumulator::Accumulator; + +#[derive(Debug)] +pub struct DictionaryCountAccumulator { + inner: Box, +} + +impl DictionaryCountAccumulator { + pub fn new(inner: Box) -> Self { + Self { inner } + } +} + +impl Accumulator for DictionaryCountAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> datafusion_common::Result<()> { + let values: Vec<_> = values + .iter() + .map(|dict| { + downcast_dictionary_array! { + dict => { + let buff: BooleanArray = dict.occupancy().into(); + arrow::compute::filter( + dict.values(), + &buff + ).map_err(|e| arrow_datafusion_err!(e)) + }, + _ => internal_err!("DictionaryCountAccumulator only supports dictionary arrays") + } + }) + .collect::, _>>()?; + self.inner.update_batch(values.as_slice()) + } + + fn evaluate(&mut self) -> datafusion_common::Result { + self.inner.evaluate() + } + + fn size(&self) -> usize { + self.inner.size() + } + + fn state(&mut self) -> datafusion_common::Result> { + self.inner.state() + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> datafusion_common::Result<()> { + self.inner.merge_batch(states) + } +} diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs index e629e99e1657..987ba57f7719 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs @@ -636,7 +636,7 @@ mod test { #[test] fn accumulate_fuzz() { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); for _ in 0..100 { Fixture::new_random(&mut rng).run(); } @@ -661,23 +661,23 @@ mod test { impl Fixture { fn new_random(rng: &mut ThreadRng) -> Self { // Number of input values in a batch - let num_values: usize = rng.gen_range(1..200); + let num_values: usize = rng.random_range(1..200); // number of distinct groups - let num_groups: usize = rng.gen_range(2..1000); + let num_groups: usize = rng.random_range(2..1000); let max_group = num_groups - 1; let group_indices: Vec = (0..num_values) - .map(|_| rng.gen_range(0..max_group)) + .map(|_| rng.random_range(0..max_group)) .collect(); - let values: Vec = (0..num_values).map(|_| rng.gen()).collect(); + let values: Vec = (0..num_values).map(|_| rng.random()).collect(); // 10% chance of false // 10% change of null // 80% chance of true let filter: BooleanArray = (0..num_values) .map(|_| { - let filter_value = rng.gen_range(0.0..1.0); + let filter_value = rng.random_range(0.0..1.0); if filter_value < 0.1 { Some(false) } else if filter_value < 0.2 { @@ -690,14 +690,14 @@ mod test { // random values with random number and location of nulls // random null percentage - let null_pct: f32 = rng.gen_range(0.0..1.0); + let null_pct: f32 = rng.random_range(0.0..1.0); let values_with_nulls: Vec> = (0..num_values) .map(|_| { - let is_null = null_pct < rng.gen_range(0.0..1.0); + let is_null = null_pct < rng.random_range(0.0..1.0); if is_null { None } else { - Some(rng.gen()) + Some(rng.random()) } }) .collect(); diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs index 6a8946034cbc..c8c7736bba14 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs @@ -20,7 +20,7 @@ use arrow::array::{ Array, ArrayRef, ArrowNumericType, AsArray, BinaryArray, BinaryViewArray, BooleanArray, LargeBinaryArray, LargeStringArray, PrimitiveArray, StringArray, - StringViewArray, + StringViewArray, StructArray, }; use arrow::buffer::NullBuffer; use arrow::datatypes::DataType; @@ -193,6 +193,18 @@ pub fn set_nulls_dyn(input: &dyn Array, nulls: Option) -> Result { + let input = input.as_struct(); + // safety: values / offsets came from a valid struct array + // and we checked nulls has the same length as values + unsafe { + Arc::new(StructArray::new_unchecked( + input.fields().clone(), + input.columns().to_vec(), + nulls, + )) + } + } _ => { return not_impl_err!("Applying nulls {:?}", input.data_type()); } diff --git a/datafusion/functions-aggregate-common/src/utils.rs b/datafusion/functions-aggregate-common/src/utils.rs index 083dac615b5d..229d9a900105 100644 --- a/datafusion/functions-aggregate-common/src/utils.rs +++ b/datafusion/functions-aggregate-common/src/utils.rs @@ -18,7 +18,7 @@ use std::sync::Arc; use arrow::array::{ArrayRef, AsArray}; -use arrow::datatypes::ArrowNativeType; +use arrow::datatypes::{ArrowNativeType, FieldRef}; use arrow::{ array::ArrowNativeTypeOp, compute::SortOptions, @@ -92,7 +92,7 @@ pub fn ordering_fields( ordering_req: &LexOrdering, // Data type of each expression in the ordering requirement data_types: &[DataType], -) -> Vec { +) -> Vec { ordering_req .iter() .zip(data_types.iter()) @@ -104,6 +104,7 @@ pub fn ordering_fields( true, ) }) + .map(Arc::new) .collect() } diff --git a/datafusion/functions-aggregate/benches/array_agg.rs b/datafusion/functions-aggregate/benches/array_agg.rs index e22be611d8d7..6dadb12aba86 100644 --- a/datafusion/functions-aggregate/benches/array_agg.rs +++ b/datafusion/functions-aggregate/benches/array_agg.rs @@ -27,7 +27,7 @@ use datafusion_expr::Accumulator; use datafusion_functions_aggregate::array_agg::ArrayAggAccumulator; use arrow::buffer::OffsetBuffer; -use rand::distributions::{Distribution, Standard}; +use rand::distr::{Distribution, StandardUniform}; use rand::prelude::StdRng; use rand::Rng; use rand::SeedableRng; @@ -43,7 +43,7 @@ fn merge_batch_bench(c: &mut Criterion, name: &str, values: ArrayRef) { b.iter(|| { #[allow(clippy::unit_arg)] black_box( - ArrayAggAccumulator::try_new(&list_item_data_type) + ArrayAggAccumulator::try_new(&list_item_data_type, false) .unwrap() .merge_batch(&[values.clone()]) .unwrap(), @@ -55,23 +55,23 @@ fn merge_batch_bench(c: &mut Criterion, name: &str, values: ArrayRef) { pub fn create_primitive_array(size: usize, null_density: f32) -> PrimitiveArray where T: ArrowPrimitiveType, - Standard: Distribution, + StandardUniform: Distribution, { let mut rng = seedable_rng(); (0..size) .map(|_| { - if rng.gen::() < null_density { + if rng.random::() < null_density { None } else { - Some(rng.gen()) + Some(rng.random()) } }) .collect() } /// Create List array with the given item data type, null density, null locations and zero length lists density -/// Creates an random (but fixed-seeded) array of a given size and null density +/// Creates a random (but fixed-seeded) array of a given size and null density pub fn create_list_array( size: usize, null_density: f32, @@ -79,20 +79,20 @@ pub fn create_list_array( ) -> ListArray where T: ArrowPrimitiveType, - Standard: Distribution, + StandardUniform: Distribution, { let mut nulls_builder = NullBufferBuilder::new(size); - let mut rng = seedable_rng(); + let mut rng = StdRng::seed_from_u64(42); let offsets = OffsetBuffer::from_lengths((0..size).map(|_| { - let is_null = rng.gen::() < null_density; + let is_null = rng.random::() < null_density; - let mut length = rng.gen_range(1..10); + let mut length = rng.random_range(1..10); if is_null { nulls_builder.append_null(); - if rng.gen::() <= zero_length_lists_probability { + if rng.random::() <= zero_length_lists_probability { length = 0; } } else { diff --git a/datafusion/functions-aggregate/benches/count.rs b/datafusion/functions-aggregate/benches/count.rs index 8bde7d04c44d..cffa50bdda12 100644 --- a/datafusion/functions-aggregate/benches/count.rs +++ b/datafusion/functions-aggregate/benches/count.rs @@ -17,18 +17,23 @@ use arrow::array::{ArrayRef, BooleanArray}; use arrow::datatypes::{DataType, Field, Int32Type, Schema}; -use arrow::util::bench_util::{create_boolean_array, create_primitive_array}; +use arrow::util::bench_util::{ + create_boolean_array, create_dict_from_values, create_primitive_array, + create_string_array_with_len, +}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; -use datafusion_expr::{function::AccumulatorArgs, AggregateUDFImpl, GroupsAccumulator}; +use datafusion_expr::{ + function::AccumulatorArgs, Accumulator, AggregateUDFImpl, GroupsAccumulator, +}; use datafusion_functions_aggregate::count::Count; use datafusion_physical_expr::expressions::col; use datafusion_physical_expr_common::sort_expr::LexOrdering; use std::sync::Arc; -fn prepare_accumulator() -> Box { +fn prepare_group_accumulator() -> Box { let schema = Arc::new(Schema::new(vec![Field::new("f", DataType::Int32, true)])); let accumulator_args = AccumulatorArgs { - return_type: &DataType::Int64, + return_field: Field::new("f", DataType::Int64, true).into(), schema: &schema, ignore_nulls: false, ordering_req: &LexOrdering::default(), @@ -44,13 +49,34 @@ fn prepare_accumulator() -> Box { .unwrap() } +fn prepare_accumulator() -> Box { + let schema = Arc::new(Schema::new(vec![Field::new( + "f", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + true, + )])); + let accumulator_args = AccumulatorArgs { + return_field: Arc::new(Field::new_list_field(DataType::Int64, true)), + schema: &schema, + ignore_nulls: false, + ordering_req: &LexOrdering::default(), + is_reversed: false, + name: "COUNT(f)", + is_distinct: true, + exprs: &[col("f", &schema).unwrap()], + }; + let count_fn = Count::new(); + + count_fn.accumulator(accumulator_args).unwrap() +} + fn convert_to_state_bench( c: &mut Criterion, name: &str, values: ArrayRef, opt_filter: Option<&BooleanArray>, ) { - let accumulator = prepare_accumulator(); + let accumulator = prepare_group_accumulator(); c.bench_function(name, |b| { b.iter(|| { black_box( @@ -89,6 +115,18 @@ fn count_benchmark(c: &mut Criterion) { values, Some(&filter), ); + + let arr = create_string_array_with_len::(20, 0.0, 50); + let values = + Arc::new(create_dict_from_values::(200_000, 0.8, &arr)) as ArrayRef; + + let mut accumulator = prepare_accumulator(); + c.bench_function("count low cardinality dict 20% nulls, no filter", |b| { + b.iter(|| { + #[allow(clippy::unit_arg)] + black_box(accumulator.update_batch(&[values.clone()]).unwrap()) + }) + }); } criterion_group!(benches, count_benchmark); diff --git a/datafusion/functions-aggregate/benches/sum.rs b/datafusion/functions-aggregate/benches/sum.rs index fab53ae94b25..25df78b15f11 100644 --- a/datafusion/functions-aggregate/benches/sum.rs +++ b/datafusion/functions-aggregate/benches/sum.rs @@ -26,9 +26,10 @@ use datafusion_physical_expr_common::sort_expr::LexOrdering; use std::sync::Arc; fn prepare_accumulator(data_type: &DataType) -> Box { - let schema = Arc::new(Schema::new(vec![Field::new("f", data_type.clone(), true)])); + let field = Field::new("f", data_type.clone(), true).into(); + let schema = Arc::new(Schema::new(vec![Arc::clone(&field)])); let accumulator_args = AccumulatorArgs { - return_type: data_type, + return_field: field, schema: &schema, ignore_nulls: false, ordering_req: &LexOrdering::default(), diff --git a/datafusion/functions-aggregate/src/approx_distinct.rs b/datafusion/functions-aggregate/src/approx_distinct.rs index c97dba1925ca..0d5dcd5c2085 100644 --- a/datafusion/functions-aggregate/src/approx_distinct.rs +++ b/datafusion/functions-aggregate/src/approx_distinct.rs @@ -23,7 +23,7 @@ use arrow::array::{ GenericBinaryArray, GenericStringArray, OffsetSizeTrait, PrimitiveArray, }; use arrow::datatypes::{ - ArrowPrimitiveType, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, + ArrowPrimitiveType, FieldRef, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type, }; use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field}; @@ -322,12 +322,13 @@ impl AggregateUDFImpl for ApproxDistinct { Ok(DataType::UInt64) } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { Ok(vec![Field::new( format_state_name(args.name, "hll_registers"), DataType::Binary, false, - )]) + ) + .into()]) } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { diff --git a/datafusion/functions-aggregate/src/approx_median.rs b/datafusion/functions-aggregate/src/approx_median.rs index 787e08bae286..0f2e3039ca9f 100644 --- a/datafusion/functions-aggregate/src/approx_median.rs +++ b/datafusion/functions-aggregate/src/approx_median.rs @@ -17,11 +17,11 @@ //! Defines physical expressions for APPROX_MEDIAN that can be evaluated MEDIAN at runtime during query execution +use arrow::datatypes::DataType::{Float64, UInt64}; +use arrow::datatypes::{DataType, Field, FieldRef}; use std::any::Any; use std::fmt::Debug; - -use arrow::datatypes::DataType::{Float64, UInt64}; -use arrow::datatypes::{DataType, Field}; +use std::sync::Arc; use datafusion_common::{not_impl_err, plan_err, Result}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; @@ -45,7 +45,7 @@ make_udaf_expr_and_func!( /// APPROX_MEDIAN aggregate expression #[user_doc( doc_section(label = "Approximate Functions"), - description = "Returns the approximate median (50th percentile) of input values. It is an alias of `approx_percentile_cont(x, 0.5)`.", + description = "Returns the approximate median (50th percentile) of input values. It is an alias of `approx_percentile_cont(0.5) WITHIN GROUP (ORDER BY x)`.", syntax_example = "approx_median(expression)", sql_example = r#"```sql > SELECT approx_median(column_name) FROM table_name; @@ -91,7 +91,7 @@ impl AggregateUDFImpl for ApproxMedian { self } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { Ok(vec![ Field::new(format_state_name(args.name, "max_size"), UInt64, false), Field::new(format_state_name(args.name, "sum"), Float64, false), @@ -103,7 +103,10 @@ impl AggregateUDFImpl for ApproxMedian { Field::new_list_field(Float64, true), false, ), - ]) + ] + .into_iter() + .map(Arc::new) + .collect()) } fn name(&self) -> &str { diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont.rs b/datafusion/functions-aggregate/src/approx_percentile_cont.rs index 1fad5f73703c..024c0a823fa9 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont.rs @@ -22,6 +22,7 @@ use std::sync::Arc; use arrow::array::{Array, RecordBatch}; use arrow::compute::{filter, is_not_null}; +use arrow::datatypes::FieldRef; use arrow::{ array::{ ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, @@ -29,11 +30,11 @@ use arrow::{ }, datatypes::{DataType, Field, Schema}, }; - use datafusion_common::{ downcast_value, internal_err, not_impl_datafusion_err, not_impl_err, plan_err, Result, ScalarValue, }; +use datafusion_expr::expr::{AggregateFunction, Sort}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::type_coercion::aggregates::{INTEGERS, NUMERICS}; use datafusion_expr::utils::format_state_name; @@ -51,29 +52,39 @@ create_func!(ApproxPercentileCont, approx_percentile_cont_udaf); /// Computes the approximate percentile continuous of a set of numbers pub fn approx_percentile_cont( - expression: Expr, + order_by: Sort, percentile: Expr, centroids: Option, ) -> Expr { + let expr = order_by.expr.clone(); + let args = if let Some(centroids) = centroids { - vec![expression, percentile, centroids] + vec![expr, percentile, centroids] } else { - vec![expression, percentile] + vec![expr, percentile] }; - approx_percentile_cont_udaf().call(args) + + Expr::AggregateFunction(AggregateFunction::new_udf( + approx_percentile_cont_udaf(), + args, + false, + None, + Some(vec![order_by]), + None, + )) } #[user_doc( doc_section(label = "Approximate Functions"), description = "Returns the approximate percentile of input values using the t-digest algorithm.", - syntax_example = "approx_percentile_cont(expression, percentile, centroids)", + syntax_example = "approx_percentile_cont(percentile, centroids) WITHIN GROUP (ORDER BY expression)", sql_example = r#"```sql -> SELECT approx_percentile_cont(column_name, 0.75, 100) FROM table_name; -+-------------------------------------------------+ -| approx_percentile_cont(column_name, 0.75, 100) | -+-------------------------------------------------+ -| 65.0 | -+-------------------------------------------------+ +> SELECT approx_percentile_cont(0.75, 100) WITHIN GROUP (ORDER BY column_name) FROM table_name; ++-----------------------------------------------------------------------+ +| approx_percentile_cont(0.75, 100) WITHIN GROUP (ORDER BY column_name) | ++-----------------------------------------------------------------------+ +| 65.0 | ++-----------------------------------------------------------------------+ ```"#, standard_argument(name = "expression",), argument( @@ -130,6 +141,19 @@ impl ApproxPercentileCont { args: AccumulatorArgs, ) -> Result { let percentile = validate_input_percentile_expr(&args.exprs[1])?; + + let is_descending = args + .ordering_req + .first() + .map(|sort_expr| sort_expr.options.descending) + .unwrap_or(false); + + let percentile = if is_descending { + 1.0 - percentile + } else { + percentile + }; + let tdigest_max_size = if args.exprs.len() == 3 { Some(validate_input_max_size_expr(&args.exprs[2])?) } else { @@ -232,7 +256,7 @@ impl AggregateUDFImpl for ApproxPercentileCont { #[allow(rustdoc::private_intra_doc_links)] /// See [`TDigest::to_scalar_state()`] for a description of the serialized /// state. - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { Ok(vec![ Field::new( format_state_name(args.name, "max_size"), @@ -264,7 +288,10 @@ impl AggregateUDFImpl for ApproxPercentileCont { Field::new_list_field(DataType::Float64, true), false, ), - ]) + ] + .into_iter() + .map(Arc::new) + .collect()) } fn name(&self) -> &str { @@ -292,6 +319,14 @@ impl AggregateUDFImpl for ApproxPercentileCont { Ok(arg_types[0].clone()) } + fn supports_null_handling_clause(&self) -> bool { + false + } + + fn is_ordered_set_aggregate(&self) -> bool { + true + } + fn documentation(&self) -> Option<&Documentation> { self.doc() } diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs index 16dac2c1b8f0..5180d4588962 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs @@ -20,11 +20,8 @@ use std::fmt::{Debug, Formatter}; use std::mem::size_of_val; use std::sync::Arc; -use arrow::{ - array::ArrayRef, - datatypes::{DataType, Field}, -}; - +use arrow::datatypes::FieldRef; +use arrow::{array::ArrayRef, datatypes::DataType}; use datafusion_common::ScalarValue; use datafusion_common::{not_impl_err, plan_err, Result}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; @@ -52,14 +49,14 @@ make_udaf_expr_and_func!( #[user_doc( doc_section(label = "Approximate Functions"), description = "Returns the weighted approximate percentile of input values using the t-digest algorithm.", - syntax_example = "approx_percentile_cont_with_weight(expression, weight, percentile)", + syntax_example = "approx_percentile_cont_with_weight(weight, percentile) WITHIN GROUP (ORDER BY expression)", sql_example = r#"```sql -> SELECT approx_percentile_cont_with_weight(column_name, weight_column, 0.90) FROM table_name; -+----------------------------------------------------------------------+ -| approx_percentile_cont_with_weight(column_name, weight_column, 0.90) | -+----------------------------------------------------------------------+ -| 78.5 | -+----------------------------------------------------------------------+ +> SELECT approx_percentile_cont_with_weight(weight_column, 0.90) WITHIN GROUP (ORDER BY column_name) FROM table_name; ++---------------------------------------------------------------------------------------------+ +| approx_percentile_cont_with_weight(weight_column, 0.90) WITHIN GROUP (ORDER BY column_name) | ++---------------------------------------------------------------------------------------------+ +| 78.5 | ++---------------------------------------------------------------------------------------------+ ```"#, standard_argument(name = "expression", prefix = "The"), argument( @@ -174,10 +171,18 @@ impl AggregateUDFImpl for ApproxPercentileContWithWeight { #[allow(rustdoc::private_intra_doc_links)] /// See [`TDigest::to_scalar_state()`] for a description of the serialized /// state. - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { self.approx_percentile_cont.state_fields(args) } + fn supports_null_handling_clause(&self) -> bool { + false + } + + fn is_ordered_set_aggregate(&self) -> bool { + true + } + fn documentation(&self) -> Option<&Documentation> { self.doc() } diff --git a/datafusion/functions-aggregate/src/array_agg.rs b/datafusion/functions-aggregate/src/array_agg.rs index d658744c1ba5..71278767a83f 100644 --- a/datafusion/functions-aggregate/src/array_agg.rs +++ b/datafusion/functions-aggregate/src/array_agg.rs @@ -17,9 +17,11 @@ //! `ARRAY_AGG` aggregate implementation: [`ArrayAgg`] -use arrow::array::{new_empty_array, Array, ArrayRef, AsArray, ListArray, StructArray}; -use arrow::compute::SortOptions; -use arrow::datatypes::{DataType, Field, Fields}; +use arrow::array::{ + new_empty_array, Array, ArrayRef, AsArray, BooleanArray, ListArray, StructArray, +}; +use arrow::compute::{filter, SortOptions}; +use arrow::datatypes::{DataType, Field, FieldRef, Fields}; use datafusion_common::cast::as_list_array; use datafusion_common::utils::{get_row_at_idx, SingleRowListArrayBuilder}; @@ -107,39 +109,46 @@ impl AggregateUDFImpl for ArrayAgg { )))) } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { if args.is_distinct { return Ok(vec![Field::new_list( format_state_name(args.name, "distinct_array_agg"), // See COMMENTS.md to understand why nullable is set to true - Field::new_list_field(args.input_types[0].clone(), true), + Field::new_list_field(args.input_fields[0].data_type().clone(), true), true, - )]); + ) + .into()]); } let mut fields = vec![Field::new_list( format_state_name(args.name, "array_agg"), // See COMMENTS.md to understand why nullable is set to true - Field::new_list_field(args.input_types[0].clone(), true), + Field::new_list_field(args.input_fields[0].data_type().clone(), true), true, - )]; + ) + .into()]; if args.ordering_fields.is_empty() { return Ok(fields); } let orderings = args.ordering_fields.to_vec(); - fields.push(Field::new_list( - format_state_name(args.name, "array_agg_orderings"), - Field::new_list_field(DataType::Struct(Fields::from(orderings)), true), - false, - )); + fields.push( + Field::new_list( + format_state_name(args.name, "array_agg_orderings"), + Field::new_list_field(DataType::Struct(Fields::from(orderings)), true), + false, + ) + .into(), + ); Ok(fields) } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { let data_type = acc_args.exprs[0].data_type(acc_args.schema)?; + let ignore_nulls = + acc_args.ignore_nulls && acc_args.exprs[0].nullable(acc_args.schema)?; if acc_args.is_distinct { // Limitation similar to Postgres. The aggregation function can only mix @@ -166,14 +175,19 @@ impl AggregateUDFImpl for ArrayAgg { } sort_option = Some(order.options) } + return Ok(Box::new(DistinctArrayAggAccumulator::try_new( &data_type, sort_option, + ignore_nulls, )?)); } if acc_args.ordering_req.is_empty() { - return Ok(Box::new(ArrayAggAccumulator::try_new(&data_type)?)); + return Ok(Box::new(ArrayAggAccumulator::try_new( + &data_type, + ignore_nulls, + )?)); } let ordering_dtypes = acc_args @@ -187,6 +201,7 @@ impl AggregateUDFImpl for ArrayAgg { &ordering_dtypes, acc_args.ordering_req.clone(), acc_args.is_reversed, + ignore_nulls, ) .map(|acc| Box::new(acc) as _) } @@ -204,18 +219,20 @@ impl AggregateUDFImpl for ArrayAgg { pub struct ArrayAggAccumulator { values: Vec, datatype: DataType, + ignore_nulls: bool, } impl ArrayAggAccumulator { /// new array_agg accumulator based on given item data type - pub fn try_new(datatype: &DataType) -> Result { + pub fn try_new(datatype: &DataType, ignore_nulls: bool) -> Result { Ok(Self { values: vec![], datatype: datatype.clone(), + ignore_nulls, }) } - /// This function will return the underlying list array values if all valid values are consecutive without gaps (i.e. no null value point to a non empty list) + /// This function will return the underlying list array values if all valid values are consecutive without gaps (i.e. no null value point to a non-empty list) /// If there are gaps but only in the end of the list array, the function will return the values without the null values in the end fn get_optional_values_to_merge_as_is(list_array: &ListArray) -> Option { let offsets = list_array.value_offsets(); @@ -239,7 +256,7 @@ impl ArrayAggAccumulator { return Some(list_array.values().slice(0, 0)); } - // According to the Arrow spec, null values can point to non empty lists + // According to the Arrow spec, null values can point to non-empty lists // So this will check if all null values starting from the first valid value to the last one point to a 0 length list so we can just slice the underlying value // Unwrapping is safe as we just checked if there is a null value @@ -247,7 +264,7 @@ impl ArrayAggAccumulator { let mut valid_slices_iter = nulls.valid_slices(); - // This is safe as we validated that that are at least 1 valid value in the array + // This is safe as we validated that there is at least 1 valid value in the array let (start, end) = valid_slices_iter.next().unwrap(); let start_offset = offsets[start]; @@ -257,7 +274,7 @@ impl ArrayAggAccumulator { let mut end_offset_of_last_valid_value = offsets[end]; for (start, end) in valid_slices_iter { - // If there is a null value that point to a non empty list than the start offset of the valid value + // If there is a null value that point to a non-empty list than the start offset of the valid value // will be different that the end offset of the last valid value if offsets[start] != end_offset_of_last_valid_value { return None; @@ -288,10 +305,23 @@ impl Accumulator for ArrayAggAccumulator { return internal_err!("expects single batch"); } - let val = Arc::clone(&values[0]); + let val = &values[0]; + let nulls = if self.ignore_nulls { + val.logical_nulls() + } else { + None + }; + + let val = match nulls { + Some(nulls) if nulls.null_count() >= val.len() => return Ok(()), + Some(nulls) => filter(val, &BooleanArray::new(nulls.inner().clone(), None))?, + None => Arc::clone(val), + }; + if !val.is_empty() { self.values.push(val); } + Ok(()) } @@ -360,17 +390,20 @@ struct DistinctArrayAggAccumulator { values: HashSet, datatype: DataType, sort_options: Option, + ignore_nulls: bool, } impl DistinctArrayAggAccumulator { pub fn try_new( datatype: &DataType, sort_options: Option, + ignore_nulls: bool, ) -> Result { Ok(Self { values: HashSet::new(), datatype: datatype.clone(), sort_options, + ignore_nulls, }) } } @@ -385,11 +418,20 @@ impl Accumulator for DistinctArrayAggAccumulator { return Ok(()); } - let array = &values[0]; + let val = &values[0]; + let nulls = if self.ignore_nulls { + val.logical_nulls() + } else { + None + }; - for i in 0..array.len() { - let scalar = ScalarValue::try_from_array(&array, i)?; - self.values.insert(scalar); + let nulls = nulls.as_ref(); + if nulls.is_none_or(|nulls| nulls.null_count() < val.len()) { + for i in 0..val.len() { + if nulls.is_none_or(|nulls| nulls.is_valid(i)) { + self.values.insert(ScalarValue::try_from_array(val, i)?); + } + } } Ok(()) @@ -471,6 +513,8 @@ pub(crate) struct OrderSensitiveArrayAggAccumulator { ordering_req: LexOrdering, /// Whether the aggregation is running in reverse. reverse: bool, + /// Whether the aggregation should ignore null values. + ignore_nulls: bool, } impl OrderSensitiveArrayAggAccumulator { @@ -481,6 +525,7 @@ impl OrderSensitiveArrayAggAccumulator { ordering_dtypes: &[DataType], ordering_req: LexOrdering, reverse: bool, + ignore_nulls: bool, ) -> Result { let mut datatypes = vec![datatype.clone()]; datatypes.extend(ordering_dtypes.iter().cloned()); @@ -490,6 +535,7 @@ impl OrderSensitiveArrayAggAccumulator { datatypes, ordering_req, reverse, + ignore_nulls, }) } } @@ -500,11 +546,22 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator { return Ok(()); } - let n_row = values[0].len(); - for index in 0..n_row { - let row = get_row_at_idx(values, index)?; - self.values.push(row[0].clone()); - self.ordering_values.push(row[1..].to_vec()); + let val = &values[0]; + let ord = &values[1..]; + let nulls = if self.ignore_nulls { + val.logical_nulls() + } else { + None + }; + + let nulls = nulls.as_ref(); + if nulls.is_none_or(|nulls| nulls.null_count() < val.len()) { + for i in 0..val.len() { + if nulls.is_none_or(|nulls| nulls.is_valid(i)) { + self.values.push(ScalarValue::try_from_array(val, i)?); + self.ordering_values.push(get_row_at_idx(ord, i)?) + } + } } Ok(()) @@ -639,7 +696,6 @@ impl OrderSensitiveArrayAggAccumulator { fn evaluate_orderings(&self) -> Result { let fields = ordering_fields(self.ordering_req.as_ref(), &self.datatypes[1..]); let num_columns = fields.len(); - let struct_field = Fields::from(fields.clone()); let mut column_wise_ordering_values = vec![]; for i in 0..num_columns { @@ -656,6 +712,7 @@ impl OrderSensitiveArrayAggAccumulator { column_wise_ordering_values.push(array); } + let struct_field = Fields::from(fields); let ordering_array = StructArray::try_new(struct_field, column_wise_ordering_values, None)?; Ok(SingleRowListArrayBuilder::new(Arc::new(ordering_array)).build_list_scalar()) @@ -932,7 +989,7 @@ mod tests { } struct ArrayAggAccumulatorBuilder { - data_type: DataType, + return_field: FieldRef, distinct: bool, ordering: LexOrdering, schema: Schema, @@ -945,15 +1002,13 @@ mod tests { fn new(data_type: DataType) -> Self { Self { - data_type: data_type.clone(), - distinct: Default::default(), + return_field: Field::new("f", data_type.clone(), true).into(), + distinct: false, ordering: Default::default(), schema: Schema { fields: Fields::from(vec![Field::new( "col", - DataType::List(FieldRef::new(Field::new( - "item", data_type, true, - ))), + DataType::new_list(data_type, true), true, )]), metadata: Default::default(), @@ -979,7 +1034,7 @@ mod tests { fn build(&self) -> Result> { ArrayAgg::default().accumulator(AccumulatorArgs { - return_type: &self.data_type, + return_field: Arc::clone(&self.return_field), schema: &self.schema, ignore_nulls: false, ordering_req: &self.ordering, @@ -1007,7 +1062,7 @@ mod tests { fn print_nulls(sort: Vec>) -> Vec { sort.into_iter() - .map(|v| v.unwrap_or("NULL".to_string())) + .map(|v| v.unwrap_or_else(|| "NULL".to_string())) .collect() } diff --git a/datafusion/functions-aggregate/src/average.rs b/datafusion/functions-aggregate/src/average.rs index 141771b0412f..3c1d33e093b5 100644 --- a/datafusion/functions-aggregate/src/average.rs +++ b/datafusion/functions-aggregate/src/average.rs @@ -24,8 +24,9 @@ use arrow::array::{ use arrow::compute::sum; use arrow::datatypes::{ - i256, ArrowNativeType, DataType, Decimal128Type, Decimal256Type, DecimalType, Field, - Float64Type, UInt64Type, + i256, ArrowNativeType, DataType, Decimal128Type, Decimal256Type, DecimalType, + DurationMicrosecondType, DurationMillisecondType, DurationNanosecondType, + DurationSecondType, Field, FieldRef, Float64Type, TimeUnit, UInt64Type, }; use datafusion_common::{ exec_err, not_impl_err, utils::take_function_args, Result, ScalarValue, @@ -120,7 +121,7 @@ impl AggregateUDFImpl for Avg { let data_type = acc_args.exprs[0].data_type(acc_args.schema)?; // instantiate specialized accumulator based for the type - match (&data_type, acc_args.return_type) { + match (&data_type, acc_args.return_field.data_type()) { (Float64, Float64) => Ok(Box::::default()), ( Decimal128(sum_precision, sum_scale), @@ -145,15 +146,25 @@ impl AggregateUDFImpl for Avg { target_precision: *target_precision, target_scale: *target_scale, })), + + (Duration(time_unit), Duration(result_unit)) => { + Ok(Box::new(DurationAvgAccumulator { + sum: None, + count: 0, + time_unit: *time_unit, + result_unit: *result_unit, + })) + } + _ => exec_err!( "AvgAccumulator for ({} --> {})", &data_type, - acc_args.return_type + acc_args.return_field.data_type() ), } } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { Ok(vec![ Field::new( format_state_name(args.name, "count"), @@ -162,16 +173,19 @@ impl AggregateUDFImpl for Avg { ), Field::new( format_state_name(args.name, "sum"), - args.input_types[0].clone(), + args.input_fields[0].data_type().clone(), true, ), - ]) + ] + .into_iter() + .map(Arc::new) + .collect()) } fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { matches!( - args.return_type, - DataType::Float64 | DataType::Decimal128(_, _) + args.return_field.data_type(), + DataType::Float64 | DataType::Decimal128(_, _) | DataType::Duration(_) ) } @@ -183,11 +197,11 @@ impl AggregateUDFImpl for Avg { let data_type = args.exprs[0].data_type(args.schema)?; // instantiate specialized accumulator based for the type - match (&data_type, args.return_type) { + match (&data_type, args.return_field.data_type()) { (Float64, Float64) => { Ok(Box::new(AvgGroupsAccumulator::::new( &data_type, - args.return_type, + args.return_field.data_type(), |sum: f64, count: u64| Ok(sum / count as f64), ))) } @@ -206,7 +220,7 @@ impl AggregateUDFImpl for Avg { Ok(Box::new(AvgGroupsAccumulator::::new( &data_type, - args.return_type, + args.return_field.data_type(), avg_fn, ))) } @@ -227,15 +241,54 @@ impl AggregateUDFImpl for Avg { Ok(Box::new(AvgGroupsAccumulator::::new( &data_type, - args.return_type, + args.return_field.data_type(), avg_fn, ))) } + (Duration(time_unit), Duration(_result_unit)) => { + let avg_fn = move |sum: i64, count: u64| Ok(sum / count as i64); + + match time_unit { + TimeUnit::Second => Ok(Box::new(AvgGroupsAccumulator::< + DurationSecondType, + _, + >::new( + &data_type, + args.return_type(), + avg_fn, + ))), + TimeUnit::Millisecond => Ok(Box::new(AvgGroupsAccumulator::< + DurationMillisecondType, + _, + >::new( + &data_type, + args.return_type(), + avg_fn, + ))), + TimeUnit::Microsecond => Ok(Box::new(AvgGroupsAccumulator::< + DurationMicrosecondType, + _, + >::new( + &data_type, + args.return_type(), + avg_fn, + ))), + TimeUnit::Nanosecond => Ok(Box::new(AvgGroupsAccumulator::< + DurationNanosecondType, + _, + >::new( + &data_type, + args.return_type(), + avg_fn, + ))), + } + } + _ => not_impl_err!( "AvgGroupsAccumulator for ({} --> {})", &data_type, - args.return_type + args.return_field.data_type() ), } } @@ -399,6 +452,105 @@ impl Accumulator for DecimalAvgAccumu } } +/// An accumulator to compute the average for duration values +#[derive(Debug)] +struct DurationAvgAccumulator { + sum: Option, + count: u64, + time_unit: TimeUnit, + result_unit: TimeUnit, +} + +impl Accumulator for DurationAvgAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let array = &values[0]; + self.count += (array.len() - array.null_count()) as u64; + + let sum_value = match self.time_unit { + TimeUnit::Second => sum(array.as_primitive::()), + TimeUnit::Millisecond => sum(array.as_primitive::()), + TimeUnit::Microsecond => sum(array.as_primitive::()), + TimeUnit::Nanosecond => sum(array.as_primitive::()), + }; + + if let Some(x) = sum_value { + let v = self.sum.get_or_insert(0); + *v += x; + } + Ok(()) + } + + fn evaluate(&mut self) -> Result { + let avg = self.sum.map(|sum| sum / self.count as i64); + + match self.result_unit { + TimeUnit::Second => Ok(ScalarValue::DurationSecond(avg)), + TimeUnit::Millisecond => Ok(ScalarValue::DurationMillisecond(avg)), + TimeUnit::Microsecond => Ok(ScalarValue::DurationMicrosecond(avg)), + TimeUnit::Nanosecond => Ok(ScalarValue::DurationNanosecond(avg)), + } + } + + fn size(&self) -> usize { + size_of_val(self) + } + + fn state(&mut self) -> Result> { + let duration_value = match self.time_unit { + TimeUnit::Second => ScalarValue::DurationSecond(self.sum), + TimeUnit::Millisecond => ScalarValue::DurationMillisecond(self.sum), + TimeUnit::Microsecond => ScalarValue::DurationMicrosecond(self.sum), + TimeUnit::Nanosecond => ScalarValue::DurationNanosecond(self.sum), + }; + + Ok(vec![ScalarValue::from(self.count), duration_value]) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.count += sum(states[0].as_primitive::()).unwrap_or_default(); + + let sum_value = match self.time_unit { + TimeUnit::Second => sum(states[1].as_primitive::()), + TimeUnit::Millisecond => { + sum(states[1].as_primitive::()) + } + TimeUnit::Microsecond => { + sum(states[1].as_primitive::()) + } + TimeUnit::Nanosecond => { + sum(states[1].as_primitive::()) + } + }; + + if let Some(x) = sum_value { + let v = self.sum.get_or_insert(0); + *v += x; + } + Ok(()) + } + + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let array = &values[0]; + self.count -= (array.len() - array.null_count()) as u64; + + let sum_value = match self.time_unit { + TimeUnit::Second => sum(array.as_primitive::()), + TimeUnit::Millisecond => sum(array.as_primitive::()), + TimeUnit::Microsecond => sum(array.as_primitive::()), + TimeUnit::Nanosecond => sum(array.as_primitive::()), + }; + + if let Some(x) = sum_value { + self.sum = Some(self.sum.unwrap() - x); + } + Ok(()) + } + + fn supports_retract_batch(&self) -> bool { + true + } +} + /// An accumulator to compute the average of `[PrimitiveArray]`. /// Stores values as native types, and does overflow checking /// diff --git a/datafusion/functions-aggregate/src/bit_and_or_xor.rs b/datafusion/functions-aggregate/src/bit_and_or_xor.rs index 50ab50abc9e2..4512162ba5d3 100644 --- a/datafusion/functions-aggregate/src/bit_and_or_xor.rs +++ b/datafusion/functions-aggregate/src/bit_and_or_xor.rs @@ -25,8 +25,8 @@ use std::mem::{size_of, size_of_val}; use ahash::RandomState; use arrow::array::{downcast_integer, Array, ArrayRef, AsArray}; use arrow::datatypes::{ - ArrowNativeType, ArrowNumericType, DataType, Field, Int16Type, Int32Type, Int64Type, - Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type, + ArrowNativeType, ArrowNumericType, DataType, Field, FieldRef, Int16Type, Int32Type, + Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type, }; use datafusion_common::cast::as_list_array; @@ -87,7 +87,7 @@ macro_rules! accumulator_helper { /// `is_distinct` is boolean value indicating whether the operation is distinct or not. macro_rules! downcast_bitwise_accumulator { ($args:ident, $opr:expr, $is_distinct: expr) => { - match $args.return_type { + match $args.return_field.data_type() { DataType::Int8 => accumulator_helper!(Int8Type, $opr, $is_distinct), DataType::Int16 => accumulator_helper!(Int16Type, $opr, $is_distinct), DataType::Int32 => accumulator_helper!(Int32Type, $opr, $is_distinct), @@ -101,7 +101,7 @@ macro_rules! downcast_bitwise_accumulator { "{} not supported for {}: {}", stringify!($opr), $args.name, - $args.return_type + $args.return_field.data_type() ) } } @@ -205,7 +205,7 @@ enum BitwiseOperationType { impl Display for BitwiseOperationType { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{:?}", self) + write!(f, "{self:?}") } } @@ -263,7 +263,7 @@ impl AggregateUDFImpl for BitwiseOperation { downcast_bitwise_accumulator!(acc_args, self.operation, acc_args.is_distinct) } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { if self.operation == BitwiseOperationType::Xor && args.is_distinct { Ok(vec![Field::new_list( format_state_name( @@ -271,15 +271,17 @@ impl AggregateUDFImpl for BitwiseOperation { format!("{} distinct", self.name()).as_str(), ), // See COMMENTS.md to understand why nullable is set to true - Field::new_list_field(args.return_type.clone(), true), + Field::new_list_field(args.return_type().clone(), true), false, - )]) + ) + .into()]) } else { Ok(vec![Field::new( format_state_name(args.name, self.name()), - args.return_type.clone(), + args.return_field.data_type().clone(), true, - )]) + ) + .into()]) } } @@ -291,7 +293,7 @@ impl AggregateUDFImpl for BitwiseOperation { &self, args: AccumulatorArgs, ) -> Result> { - let data_type = args.return_type; + let data_type = args.return_field.data_type(); let operation = &self.operation; downcast_integer! { data_type => (group_accumulator_helper, data_type, operation), diff --git a/datafusion/functions-aggregate/src/bool_and_or.rs b/datafusion/functions-aggregate/src/bool_and_or.rs index 1b33a7900c00..e5de6d76217f 100644 --- a/datafusion/functions-aggregate/src/bool_and_or.rs +++ b/datafusion/functions-aggregate/src/bool_and_or.rs @@ -24,8 +24,8 @@ use arrow::array::ArrayRef; use arrow::array::BooleanArray; use arrow::compute::bool_and as compute_bool_and; use arrow::compute::bool_or as compute_bool_or; -use arrow::datatypes::DataType; use arrow::datatypes::Field; +use arrow::datatypes::{DataType, FieldRef}; use datafusion_common::internal_err; use datafusion_common::{downcast_value, not_impl_err}; @@ -150,12 +150,13 @@ impl AggregateUDFImpl for BoolAnd { Ok(Box::::default()) } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { Ok(vec![Field::new( format_state_name(args.name, self.name()), DataType::Boolean, true, - )]) + ) + .into()]) } fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { @@ -166,14 +167,14 @@ impl AggregateUDFImpl for BoolAnd { &self, args: AccumulatorArgs, ) -> Result> { - match args.return_type { + match args.return_field.data_type() { DataType::Boolean => { Ok(Box::new(BooleanGroupsAccumulator::new(|x, y| x && y, true))) } _ => not_impl_err!( "GroupsAccumulator not supported for {} with {}", args.name, - args.return_type + args.return_field.data_type() ), } } @@ -288,12 +289,13 @@ impl AggregateUDFImpl for BoolOr { Ok(Box::::default()) } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { Ok(vec![Field::new( format_state_name(args.name, self.name()), DataType::Boolean, true, - )]) + ) + .into()]) } fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { @@ -304,7 +306,7 @@ impl AggregateUDFImpl for BoolOr { &self, args: AccumulatorArgs, ) -> Result> { - match args.return_type { + match args.return_field.data_type() { DataType::Boolean => Ok(Box::new(BooleanGroupsAccumulator::new( |x, y| x || y, false, @@ -312,7 +314,7 @@ impl AggregateUDFImpl for BoolOr { _ => not_impl_err!( "GroupsAccumulator not supported for {} with {}", args.name, - args.return_type + args.return_field.data_type() ), } } diff --git a/datafusion/functions-aggregate/src/correlation.rs b/datafusion/functions-aggregate/src/correlation.rs index ac57256ce882..0a7345245ca8 100644 --- a/datafusion/functions-aggregate/src/correlation.rs +++ b/datafusion/functions-aggregate/src/correlation.rs @@ -27,7 +27,7 @@ use arrow::array::{ UInt64Array, }; use arrow::compute::{and, filter, is_not_null, kernels::cast}; -use arrow::datatypes::{Float64Type, UInt64Type}; +use arrow::datatypes::{FieldRef, Float64Type, UInt64Type}; use arrow::{ array::ArrayRef, datatypes::{DataType, Field}, @@ -117,7 +117,7 @@ impl AggregateUDFImpl for Correlation { Ok(Box::new(CorrelationAccumulator::try_new()?)) } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { let name = args.name; Ok(vec![ Field::new(format_state_name(name, "count"), DataType::UInt64, true), @@ -130,7 +130,10 @@ impl AggregateUDFImpl for Correlation { DataType::Float64, true, ), - ]) + ] + .into_iter() + .map(Arc::new) + .collect()) } fn documentation(&self) -> Option<&Documentation> { diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index 2d995b4a4179..6b7199c44b32 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -40,6 +40,7 @@ use arrow::{ }, }; +use arrow::datatypes::FieldRef; use arrow::{ array::{Array, BooleanArray, Int64Array, PrimitiveArray}, buffer::BooleanBuffer, @@ -56,8 +57,8 @@ use datafusion_expr::{ Expr, ReversedUDAF, StatisticsArgs, TypeSignature, WindowFunctionDefinition, }; use datafusion_functions_aggregate_common::aggregate::count_distinct::{ - BytesDistinctCountAccumulator, FloatDistinctCountAccumulator, - PrimitiveDistinctCountAccumulator, + BytesDistinctCountAccumulator, DictionaryCountAccumulator, + FloatDistinctCountAccumulator, PrimitiveDistinctCountAccumulator, }; use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate_indices; use datafusion_physical_expr_common::binary_map::OutputType; @@ -100,7 +101,7 @@ pub fn count_distinct(expr: Expr) -> Expr { /// let expr = col(expr.schema_name().to_string()); /// ``` pub fn count_all() -> Expr { - count(Expr::Literal(COUNT_STAR_EXPANSION)).alias("count(*)") + count(Expr::Literal(COUNT_STAR_EXPANSION, None)).alias("count(*)") } /// Creates window aggregation to count all rows. @@ -123,9 +124,9 @@ pub fn count_all() -> Expr { /// let expr = col(expr.schema_name().to_string()); /// ``` pub fn count_all_window() -> Expr { - Expr::WindowFunction(WindowFunction::new( + Expr::from(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(count_udaf()), - vec![Expr::Literal(COUNT_STAR_EXPANSION)], + vec![Expr::Literal(COUNT_STAR_EXPANSION, None)], )) } @@ -179,6 +180,107 @@ impl Count { } } } +fn get_count_accumulator(data_type: &DataType) -> Box { + match data_type { + // try and use a specialized accumulator if possible, otherwise fall back to generic accumulator + DataType::Int8 => Box::new(PrimitiveDistinctCountAccumulator::::new( + data_type, + )), + DataType::Int16 => Box::new(PrimitiveDistinctCountAccumulator::::new( + data_type, + )), + DataType::Int32 => Box::new(PrimitiveDistinctCountAccumulator::::new( + data_type, + )), + DataType::Int64 => Box::new(PrimitiveDistinctCountAccumulator::::new( + data_type, + )), + DataType::UInt8 => Box::new(PrimitiveDistinctCountAccumulator::::new( + data_type, + )), + DataType::UInt16 => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::UInt32 => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::UInt64 => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::Decimal128(_, _) => Box::new(PrimitiveDistinctCountAccumulator::< + Decimal128Type, + >::new(data_type)), + DataType::Decimal256(_, _) => Box::new(PrimitiveDistinctCountAccumulator::< + Decimal256Type, + >::new(data_type)), + + DataType::Date32 => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::Date64 => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::Time32(TimeUnit::Millisecond) => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::Time32(TimeUnit::Second) => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::Time64(TimeUnit::Microsecond) => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::Time64(TimeUnit::Nanosecond) => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::Timestamp(TimeUnit::Microsecond, _) => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::Timestamp(TimeUnit::Millisecond, _) => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::Timestamp(TimeUnit::Nanosecond, _) => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::Timestamp(TimeUnit::Second, _) => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + + DataType::Float16 => { + Box::new(FloatDistinctCountAccumulator::::new()) + } + DataType::Float32 => { + Box::new(FloatDistinctCountAccumulator::::new()) + } + DataType::Float64 => { + Box::new(FloatDistinctCountAccumulator::::new()) + } + + DataType::Utf8 => { + Box::new(BytesDistinctCountAccumulator::::new(OutputType::Utf8)) + } + DataType::Utf8View => { + Box::new(BytesViewDistinctCountAccumulator::new(OutputType::Utf8View)) + } + DataType::LargeUtf8 => { + Box::new(BytesDistinctCountAccumulator::::new(OutputType::Utf8)) + } + DataType::Binary => Box::new(BytesDistinctCountAccumulator::::new( + OutputType::Binary, + )), + DataType::BinaryView => Box::new(BytesViewDistinctCountAccumulator::new( + OutputType::BinaryView, + )), + DataType::LargeBinary => Box::new(BytesDistinctCountAccumulator::::new( + OutputType::Binary, + )), + + // Use the generic accumulator based on `ScalarValue` for all other types + _ => Box::new(DistinctCountAccumulator { + values: HashSet::default(), + state_data_type: data_type.clone(), + }), + } +} impl AggregateUDFImpl for Count { fn as_any(&self) -> &dyn std::any::Any { @@ -201,20 +303,27 @@ impl AggregateUDFImpl for Count { false } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { if args.is_distinct { + let dtype: DataType = match &args.input_fields[0].data_type() { + DataType::Dictionary(_, values_type) => (**values_type).clone(), + &dtype => dtype.clone(), + }; + Ok(vec![Field::new_list( format_state_name(args.name, "count distinct"), // See COMMENTS.md to understand why nullable is set to true - Field::new_list_field(args.input_types[0].clone(), true), + Field::new_list_field(dtype, true), false, - )]) + ) + .into()]) } else { Ok(vec![Field::new( format_state_name(args.name, "count"), DataType::Int64, false, - )]) + ) + .into()]) } } @@ -228,114 +337,13 @@ impl AggregateUDFImpl for Count { } let data_type = &acc_args.exprs[0].data_type(acc_args.schema)?; - Ok(match data_type { - // try and use a specialized accumulator if possible, otherwise fall back to generic accumulator - DataType::Int8 => Box::new( - PrimitiveDistinctCountAccumulator::::new(data_type), - ), - DataType::Int16 => Box::new( - PrimitiveDistinctCountAccumulator::::new(data_type), - ), - DataType::Int32 => Box::new( - PrimitiveDistinctCountAccumulator::::new(data_type), - ), - DataType::Int64 => Box::new( - PrimitiveDistinctCountAccumulator::::new(data_type), - ), - DataType::UInt8 => Box::new( - PrimitiveDistinctCountAccumulator::::new(data_type), - ), - DataType::UInt16 => Box::new( - PrimitiveDistinctCountAccumulator::::new(data_type), - ), - DataType::UInt32 => Box::new( - PrimitiveDistinctCountAccumulator::::new(data_type), - ), - DataType::UInt64 => Box::new( - PrimitiveDistinctCountAccumulator::::new(data_type), - ), - DataType::Decimal128(_, _) => Box::new(PrimitiveDistinctCountAccumulator::< - Decimal128Type, - >::new(data_type)), - DataType::Decimal256(_, _) => Box::new(PrimitiveDistinctCountAccumulator::< - Decimal256Type, - >::new(data_type)), - - DataType::Date32 => Box::new( - PrimitiveDistinctCountAccumulator::::new(data_type), - ), - DataType::Date64 => Box::new( - PrimitiveDistinctCountAccumulator::::new(data_type), - ), - DataType::Time32(TimeUnit::Millisecond) => Box::new( - PrimitiveDistinctCountAccumulator::::new( - data_type, - ), - ), - DataType::Time32(TimeUnit::Second) => Box::new( - PrimitiveDistinctCountAccumulator::::new(data_type), - ), - DataType::Time64(TimeUnit::Microsecond) => Box::new( - PrimitiveDistinctCountAccumulator::::new( - data_type, - ), - ), - DataType::Time64(TimeUnit::Nanosecond) => Box::new( - PrimitiveDistinctCountAccumulator::::new(data_type), - ), - DataType::Timestamp(TimeUnit::Microsecond, _) => Box::new( - PrimitiveDistinctCountAccumulator::::new( - data_type, - ), - ), - DataType::Timestamp(TimeUnit::Millisecond, _) => Box::new( - PrimitiveDistinctCountAccumulator::::new( - data_type, - ), - ), - DataType::Timestamp(TimeUnit::Nanosecond, _) => Box::new( - PrimitiveDistinctCountAccumulator::::new( - data_type, - ), - ), - DataType::Timestamp(TimeUnit::Second, _) => Box::new( - PrimitiveDistinctCountAccumulator::::new(data_type), - ), - DataType::Float16 => { - Box::new(FloatDistinctCountAccumulator::::new()) - } - DataType::Float32 => { - Box::new(FloatDistinctCountAccumulator::::new()) - } - DataType::Float64 => { - Box::new(FloatDistinctCountAccumulator::::new()) - } - - DataType::Utf8 => { - Box::new(BytesDistinctCountAccumulator::::new(OutputType::Utf8)) - } - DataType::Utf8View => { - Box::new(BytesViewDistinctCountAccumulator::new(OutputType::Utf8View)) - } - DataType::LargeUtf8 => { - Box::new(BytesDistinctCountAccumulator::::new(OutputType::Utf8)) + Ok(match data_type { + DataType::Dictionary(_, values_type) => { + let inner = get_count_accumulator(values_type); + Box::new(DictionaryCountAccumulator::new(inner)) } - DataType::Binary => Box::new(BytesDistinctCountAccumulator::::new( - OutputType::Binary, - )), - DataType::BinaryView => Box::new(BytesViewDistinctCountAccumulator::new( - OutputType::BinaryView, - )), - DataType::LargeBinary => Box::new(BytesDistinctCountAccumulator::::new( - OutputType::Binary, - )), - - // Use the generic accumulator based on `ScalarValue` for all other types - _ => Box::new(DistinctCountAccumulator { - values: HashSet::default(), - state_data_type: data_type.clone(), - }), + _ => get_count_accumulator(data_type), }) } @@ -755,7 +763,12 @@ impl Accumulator for DistinctCountAccumulator { #[cfg(test)] mod tests { use super::*; - use arrow::array::NullArray; + use arrow::array::{Int32Array, NullArray}; + use arrow::datatypes::{DataType, Field, Int32Type, Schema}; + use datafusion_expr::function::AccumulatorArgs; + use datafusion_physical_expr::expressions::Column; + use datafusion_physical_expr::LexOrdering; + use std::sync::Arc; #[test] fn count_accumulator_nulls() -> Result<()> { @@ -764,4 +777,49 @@ mod tests { assert_eq!(accumulator.evaluate()?, ScalarValue::Int64(Some(0))); Ok(()) } + + #[test] + fn test_nested_dictionary() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new( + "dict_col", + DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::Utf8), + )), + ), + true, + )])); + + // Using Count UDAF's accumulator + let count = Count::new(); + let expr = Arc::new(Column::new("dict_col", 0)); + let args = AccumulatorArgs { + schema: &schema, + exprs: &[expr], + is_distinct: true, + name: "count", + ignore_nulls: false, + is_reversed: false, + return_field: Arc::new(Field::new_list_field(DataType::Int64, true)), + ordering_req: &LexOrdering::default(), + }; + + let inner_dict = arrow::array::DictionaryArray::::from_iter([ + "a", "b", "c", "d", "a", "b", + ]); + + let keys = Int32Array::from(vec![0, 1, 2, 0, 3, 1]); + let dict_of_dict = arrow::array::DictionaryArray::::try_new( + keys, + Arc::new(inner_dict), + )?; + + let mut acc = count.accumulator(args)?; + acc.update_batch(&[Arc::new(dict_of_dict)])?; + assert_eq!(acc.evaluate()?, ScalarValue::Int64(Some(4))); + + Ok(()) + } } diff --git a/datafusion/functions-aggregate/src/covariance.rs b/datafusion/functions-aggregate/src/covariance.rs index d4ae27533c6d..9f37a73e5429 100644 --- a/datafusion/functions-aggregate/src/covariance.rs +++ b/datafusion/functions-aggregate/src/covariance.rs @@ -17,15 +17,12 @@ //! [`CovarianceSample`]: covariance sample aggregations. -use std::fmt::Debug; -use std::mem::size_of_val; - +use arrow::datatypes::FieldRef; use arrow::{ array::{ArrayRef, Float64Array, UInt64Array}, compute::kernels::cast, datatypes::{DataType, Field}, }; - use datafusion_common::{ downcast_value, plan_err, unwrap_or_internal_err, DataFusionError, Result, ScalarValue, @@ -38,6 +35,9 @@ use datafusion_expr::{ }; use datafusion_functions_aggregate_common::stats::StatsType; use datafusion_macros::user_doc; +use std::fmt::Debug; +use std::mem::size_of_val; +use std::sync::Arc; make_udaf_expr_and_func!( CovarianceSample, @@ -120,7 +120,7 @@ impl AggregateUDFImpl for CovarianceSample { Ok(DataType::Float64) } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { let name = args.name; Ok(vec![ Field::new(format_state_name(name, "count"), DataType::UInt64, true), @@ -131,7 +131,10 @@ impl AggregateUDFImpl for CovarianceSample { DataType::Float64, true, ), - ]) + ] + .into_iter() + .map(Arc::new) + .collect()) } fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { @@ -210,7 +213,7 @@ impl AggregateUDFImpl for CovariancePopulation { Ok(DataType::Float64) } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { let name = args.name; Ok(vec![ Field::new(format_state_name(name, "count"), DataType::UInt64, true), @@ -221,7 +224,10 @@ impl AggregateUDFImpl for CovariancePopulation { DataType::Float64, true, ), - ]) + ] + .into_iter() + .map(Arc::new) + .collect()) } fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { diff --git a/datafusion/functions-aggregate/src/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs index ec8c440b77e5..e8022245dba5 100644 --- a/datafusion/functions-aggregate/src/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -29,8 +29,8 @@ use arrow::array::{ use arrow::buffer::{BooleanBuffer, NullBuffer}; use arrow::compute::{self, LexicographicalComparator, SortColumn, SortOptions}; use arrow::datatypes::{ - DataType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, Field, Float16Type, - Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, + DataType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, Field, FieldRef, + Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType, TimeUnit, TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, @@ -161,7 +161,7 @@ impl AggregateUDFImpl for FirstValue { acc_args.ordering_req.is_empty() || self.requirement_satisfied; FirstValueAccumulator::try_new( - acc_args.return_type, + acc_args.return_field.data_type(), &ordering_dtypes, acc_args.ordering_req.clone(), acc_args.ignore_nulls, @@ -169,14 +169,15 @@ impl AggregateUDFImpl for FirstValue { .map(|acc| Box::new(acc.with_requirement_satisfied(requirement_satisfied)) as _) } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { let mut fields = vec![Field::new( format_state_name(args.name, "first_value"), - args.return_type.clone(), + args.return_type().clone(), true, - )]; + ) + .into()]; fields.extend(args.ordering_fields.to_vec()); - fields.push(Field::new("is_set", DataType::Boolean, true)); + fields.push(Field::new("is_set", DataType::Boolean, true).into()); Ok(fields) } @@ -184,7 +185,7 @@ impl AggregateUDFImpl for FirstValue { // TODO: extract to function use DataType::*; matches!( - args.return_type, + args.return_field.data_type(), Int8 | Int16 | Int32 | Int64 @@ -225,13 +226,13 @@ impl AggregateUDFImpl for FirstValue { Ok(Box::new(FirstPrimitiveGroupsAccumulator::::try_new( args.ordering_req.clone(), args.ignore_nulls, - args.return_type, + args.return_field.data_type(), &ordering_dtypes, true, )?)) } - match args.return_type { + match args.return_field.data_type() { DataType::Int8 => create_accumulator::(args), DataType::Int16 => create_accumulator::(args), DataType::Int32 => create_accumulator::(args), @@ -279,7 +280,7 @@ impl AggregateUDFImpl for FirstValue { _ => { internal_err!( "GroupsAccumulator not supported for first_value({})", - args.return_type + args.return_field.data_type() ) } } @@ -752,7 +753,7 @@ where fn size(&self) -> usize { self.vals.capacity() * size_of::() - + self.null_builder.capacity() / 8 // capacity is in bits, so convert to bytes + + self.null_builder.capacity() / 8 // capacity is in bits, so convert to bytes + self.is_sets.capacity() / 8 + self.size_of_orderings + self.min_of_each_group_buf.0.capacity() * size_of::() @@ -827,9 +828,14 @@ impl FirstValueAccumulator { } // Updates state with the values in the given row. - fn update_with_new_row(&mut self, row: &[ScalarValue]) { - self.first = row[0].clone(); - self.orderings = row[1..].to_vec(); + fn update_with_new_row(&mut self, mut row: Vec) { + // Ensure any Array based scalars hold have a single value to reduce memory pressure + row.iter_mut().for_each(|s| { + s.compact(); + }); + + self.first = row.remove(0); + self.orderings = row; self.is_set = true; } @@ -888,7 +894,7 @@ impl Accumulator for FirstValueAccumulator { if !self.is_set { if let Some(first_idx) = self.get_first_idx(values)? { let row = get_row_at_idx(values, first_idx)?; - self.update_with_new_row(&row); + self.update_with_new_row(row); } } else if !self.requirement_satisfied { if let Some(first_idx) = self.get_first_idx(values)? { @@ -901,7 +907,7 @@ impl Accumulator for FirstValueAccumulator { )? .is_gt() { - self.update_with_new_row(&row); + self.update_with_new_row(row); } } } @@ -925,7 +931,7 @@ impl Accumulator for FirstValueAccumulator { let min = (0..filtered_states[0].len()).min_by(|&a, &b| comparator.compare(a, b)); if let Some(first_idx) = min { - let first_row = get_row_at_idx(&filtered_states, first_idx)?; + let mut first_row = get_row_at_idx(&filtered_states, first_idx)?; // When collecting orderings, we exclude the is_set flag from the state. let first_ordering = &first_row[1..is_set_idx]; let sort_options = get_sort_options(self.ordering_req.as_ref()); @@ -936,7 +942,9 @@ impl Accumulator for FirstValueAccumulator { // Update with first value in the state. Note that we should exclude the // is_set flag from the state. Otherwise, we will end up with a state // containing two is_set flags. - self.update_with_new_row(&first_row[0..is_set_idx]); + assert!(is_set_idx <= first_row.len()); + first_row.resize(is_set_idx, ScalarValue::Null); + self.update_with_new_row(first_row); } } Ok(()) @@ -1031,7 +1039,7 @@ impl AggregateUDFImpl for LastValue { acc_args.ordering_req.is_empty() || self.requirement_satisfied; LastValueAccumulator::try_new( - acc_args.return_type, + acc_args.return_field.data_type(), &ordering_dtypes, acc_args.ordering_req.clone(), acc_args.ignore_nulls, @@ -1039,21 +1047,22 @@ impl AggregateUDFImpl for LastValue { .map(|acc| Box::new(acc.with_requirement_satisfied(requirement_satisfied)) as _) } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { let StateFieldsArgs { name, - input_types, - return_type: _, + input_fields, + return_field: _, ordering_fields, is_distinct: _, } = args; let mut fields = vec![Field::new( format_state_name(name, "last_value"), - input_types[0].clone(), + input_fields[0].data_type().clone(), true, - )]; + ) + .into()]; fields.extend(ordering_fields.to_vec()); - fields.push(Field::new("is_set", DataType::Boolean, true)); + fields.push(Field::new("is_set", DataType::Boolean, true).into()); Ok(fields) } @@ -1085,7 +1094,7 @@ impl AggregateUDFImpl for LastValue { fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { use DataType::*; matches!( - args.return_type, + args.return_field.data_type(), Int8 | Int16 | Int32 | Int64 @@ -1125,13 +1134,13 @@ impl AggregateUDFImpl for LastValue { Ok(Box::new(FirstPrimitiveGroupsAccumulator::::try_new( args.ordering_req.clone(), args.ignore_nulls, - args.return_type, + args.return_field.data_type(), &ordering_dtypes, false, )?)) } - match args.return_type { + match args.return_field.data_type() { DataType::Int8 => create_accumulator::(args), DataType::Int16 => create_accumulator::(args), DataType::Int32 => create_accumulator::(args), @@ -1179,7 +1188,7 @@ impl AggregateUDFImpl for LastValue { _ => { internal_err!( "GroupsAccumulator not supported for last_value({})", - args.return_type + args.return_field.data_type() ) } } @@ -1226,9 +1235,14 @@ impl LastValueAccumulator { } // Updates state with the values in the given row. - fn update_with_new_row(&mut self, row: &[ScalarValue]) { - self.last = row[0].clone(); - self.orderings = row[1..].to_vec(); + fn update_with_new_row(&mut self, mut row: Vec) { + // Ensure any Array based scalars hold have a single value to reduce memory pressure + row.iter_mut().for_each(|s| { + s.compact(); + }); + + self.last = row.remove(0); + self.orderings = row; self.is_set = true; } @@ -1289,7 +1303,7 @@ impl Accumulator for LastValueAccumulator { if !self.is_set || self.requirement_satisfied { if let Some(last_idx) = self.get_last_idx(values)? { let row = get_row_at_idx(values, last_idx)?; - self.update_with_new_row(&row); + self.update_with_new_row(row); } } else if let Some(last_idx) = self.get_last_idx(values)? { let row = get_row_at_idx(values, last_idx)?; @@ -1302,7 +1316,7 @@ impl Accumulator for LastValueAccumulator { )? .is_lt() { - self.update_with_new_row(&row); + self.update_with_new_row(row); } } @@ -1326,7 +1340,7 @@ impl Accumulator for LastValueAccumulator { let max = (0..filtered_states[0].len()).max_by(|&a, &b| comparator.compare(a, b)); if let Some(last_idx) = max { - let last_row = get_row_at_idx(&filtered_states, last_idx)?; + let mut last_row = get_row_at_idx(&filtered_states, last_idx)?; // When collecting orderings, we exclude the is_set flag from the state. let last_ordering = &last_row[1..is_set_idx]; let sort_options = get_sort_options(self.ordering_req.as_ref()); @@ -1339,7 +1353,9 @@ impl Accumulator for LastValueAccumulator { // Update with last value in the state. Note that we should exclude the // is_set flag from the state. Otherwise, we will end up with a state // containing two is_set flags. - self.update_with_new_row(&last_row[0..is_set_idx]); + assert!(is_set_idx <= last_row.len()); + last_row.resize(is_set_idx, ScalarValue::Null); + self.update_with_new_row(last_row); } } Ok(()) @@ -1382,7 +1398,13 @@ fn convert_to_sort_cols(arrs: &[ArrayRef], sort_exprs: &LexOrdering) -> Vec Result<()> { + fn size_after_batch(values: &[ArrayRef]) -> Result { + let mut first_accumulator = FirstValueAccumulator::try_new( + &DataType::List(Arc::new(Field::new_list_field(DataType::Int64, false))), + &[], + LexOrdering::default(), + false, + )?; + + first_accumulator.update_batch(values)?; + + Ok(first_accumulator.size()) + } + + let batch1 = ListArray::from_iter_primitive::( + repeat_with(|| Some(vec![Some(1)])).take(10000), + ); + let batch2 = + ListArray::from_iter_primitive::([Some(vec![Some(1)])]); + + let size1 = size_after_batch(&[Arc::new(batch1)])?; + let size2 = size_after_batch(&[Arc::new(batch2)])?; + assert_eq!(size1, size2); + + Ok(()) + } + + #[test] + fn test_last_list_acc_size() -> Result<()> { + fn size_after_batch(values: &[ArrayRef]) -> Result { + let mut last_accumulator = LastValueAccumulator::try_new( + &DataType::List(Arc::new(Field::new_list_field(DataType::Int64, false))), + &[], + LexOrdering::default(), + false, + )?; + + last_accumulator.update_batch(values)?; + + Ok(last_accumulator.size()) + } + + let batch1 = ListArray::from_iter_primitive::( + repeat_with(|| Some(vec![Some(1)])).take(10000), + ); + let batch2 = + ListArray::from_iter_primitive::([Some(vec![Some(1)])]); + + let size1 = size_after_batch(&[Arc::new(batch1)])?; + let size2 = size_after_batch(&[Arc::new(batch2)])?; + assert_eq!(size1, size2); + + Ok(()) + } } diff --git a/datafusion/functions-aggregate/src/grouping.rs b/datafusion/functions-aggregate/src/grouping.rs index 445774ff11e7..0727cf33036a 100644 --- a/datafusion/functions-aggregate/src/grouping.rs +++ b/datafusion/functions-aggregate/src/grouping.rs @@ -20,8 +20,8 @@ use std::any::Any; use std::fmt; -use arrow::datatypes::DataType; use arrow::datatypes::Field; +use arrow::datatypes::{DataType, FieldRef}; use datafusion_common::{not_impl_err, Result}; use datafusion_expr::function::AccumulatorArgs; use datafusion_expr::function::StateFieldsArgs; @@ -105,12 +105,13 @@ impl AggregateUDFImpl for Grouping { Ok(DataType::Int32) } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { Ok(vec![Field::new( format_state_name(args.name, "grouping"), DataType::Int32, true, - )]) + ) + .into()]) } fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index 7944280291eb..b5bb69f6da9d 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -220,8 +220,7 @@ mod tests { for alias in func.aliases() { assert!( names.insert(alias.to_string().to_lowercase()), - "duplicate function name: {}", - alias + "duplicate function name: {alias}" ); } } diff --git a/datafusion/functions-aggregate/src/macros.rs b/datafusion/functions-aggregate/src/macros.rs index b464dde6ccab..18f27c3c4ae3 100644 --- a/datafusion/functions-aggregate/src/macros.rs +++ b/datafusion/functions-aggregate/src/macros.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +#[macro_export] macro_rules! make_udaf_expr { ($EXPR_FN:ident, $($arg:ident)*, $DOC:expr, $AGGREGATE_UDF_FN:ident) => { // "fluent expr_fn" style function @@ -34,6 +35,7 @@ macro_rules! make_udaf_expr { }; } +#[macro_export] macro_rules! make_udaf_expr_and_func { ($UDAF:ty, $EXPR_FN:ident, $($arg:ident)*, $DOC:expr, $AGGREGATE_UDF_FN:ident) => { make_udaf_expr!($EXPR_FN, $($arg)*, $DOC, $AGGREGATE_UDF_FN); @@ -59,6 +61,7 @@ macro_rules! make_udaf_expr_and_func { }; } +#[macro_export] macro_rules! create_func { ($UDAF:ty, $AGGREGATE_UDF_FN:ident) => { create_func!($UDAF, $AGGREGATE_UDF_FN, <$UDAF>::default()); diff --git a/datafusion/functions-aggregate/src/median.rs b/datafusion/functions-aggregate/src/median.rs index ba6b63260e06..bfaea4b2398c 100644 --- a/datafusion/functions-aggregate/src/median.rs +++ b/datafusion/functions-aggregate/src/median.rs @@ -35,7 +35,7 @@ use arrow::{ use arrow::array::Array; use arrow::array::ArrowNativeTypeOp; -use arrow::datatypes::{ArrowNativeType, ArrowPrimitiveType}; +use arrow::datatypes::{ArrowNativeType, ArrowPrimitiveType, FieldRef}; use datafusion_common::{ internal_datafusion_err, internal_err, DataFusionError, HashSet, Result, ScalarValue, @@ -125,9 +125,9 @@ impl AggregateUDFImpl for Median { Ok(arg_types[0].clone()) } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { //Intermediate state is a list of the elements we have collected so far - let field = Field::new_list_field(args.input_types[0].clone(), true); + let field = Field::new_list_field(args.input_fields[0].data_type().clone(), true); let state_name = if args.is_distinct { "distinct_median" } else { @@ -138,7 +138,8 @@ impl AggregateUDFImpl for Median { format_state_name(args.name, state_name), DataType::List(Arc::new(field)), true, - )]) + ) + .into()]) } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { diff --git a/datafusion/functions-aggregate/src/min_max.rs b/datafusion/functions-aggregate/src/min_max.rs index ea4cad548803..bb46aa540461 100644 --- a/datafusion/functions-aggregate/src/min_max.rs +++ b/datafusion/functions-aggregate/src/min_max.rs @@ -19,17 +19,19 @@ //! [`Min`] and [`MinAccumulator`] accumulator for the `min` function mod min_max_bytes; +mod min_max_struct; use arrow::array::{ - ArrayRef, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, Date64Array, - Decimal128Array, Decimal256Array, DurationMicrosecondArray, DurationMillisecondArray, - DurationNanosecondArray, DurationSecondArray, Float16Array, Float32Array, - Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, IntervalDayTimeArray, - IntervalMonthDayNanoArray, IntervalYearMonthArray, LargeBinaryArray, - LargeStringArray, StringArray, StringViewArray, Time32MillisecondArray, - Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray, - TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, - TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, + ArrayRef, AsArray as _, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, + Date64Array, Decimal128Array, Decimal256Array, DurationMicrosecondArray, + DurationMillisecondArray, DurationNanosecondArray, DurationSecondArray, Float16Array, + Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, + IntervalDayTimeArray, IntervalMonthDayNanoArray, IntervalYearMonthArray, + LargeBinaryArray, LargeStringArray, StringArray, StringViewArray, + Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray, + Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, + TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array, + UInt64Array, UInt8Array, }; use arrow::compute; use arrow::datatypes::{ @@ -55,6 +57,7 @@ use arrow::datatypes::{ }; use crate::min_max::min_max_bytes::MinMaxBytesAccumulator; +use crate::min_max::min_max_struct::MinMaxStructAccumulator; use datafusion_common::ScalarValue; use datafusion_expr::{ function::AccumulatorArgs, Accumulator, AggregateUDFImpl, Documentation, @@ -231,7 +234,9 @@ impl AggregateUDFImpl for Max { } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { - Ok(Box::new(MaxAccumulator::try_new(acc_args.return_type)?)) + Ok(Box::new(MaxAccumulator::try_new( + acc_args.return_field.data_type(), + )?)) } fn aliases(&self) -> &[String] { @@ -241,7 +246,7 @@ impl AggregateUDFImpl for Max { fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { use DataType::*; matches!( - args.return_type, + args.return_field.data_type(), Int8 | Int16 | Int32 | Int64 @@ -266,6 +271,7 @@ impl AggregateUDFImpl for Max { | LargeBinary | BinaryView | Duration(_) + | Struct(_) ) } @@ -275,7 +281,7 @@ impl AggregateUDFImpl for Max { ) -> Result> { use DataType::*; use TimeUnit::*; - let data_type = args.return_type; + let data_type = args.return_field.data_type(); match data_type { Int8 => primitive_max_accumulator!(data_type, i8, Int8Type), Int16 => primitive_max_accumulator!(data_type, i16, Int16Type), @@ -341,7 +347,9 @@ impl AggregateUDFImpl for Max { Utf8 | LargeUtf8 | Utf8View | Binary | LargeBinary | BinaryView => { Ok(Box::new(MinMaxBytesAccumulator::new_max(data_type.clone()))) } - + Struct(_) => Ok(Box::new(MinMaxStructAccumulator::new_max( + data_type.clone(), + ))), // This is only reached if groups_accumulator_supported is out of sync _ => internal_err!("GroupsAccumulator not supported for max({})", data_type), } @@ -351,7 +359,9 @@ impl AggregateUDFImpl for Max { &self, args: AccumulatorArgs, ) -> Result> { - Ok(Box::new(SlidingMaxAccumulator::try_new(args.return_type)?)) + Ok(Box::new(SlidingMaxAccumulator::try_new( + args.return_field.data_type(), + )?)) } fn is_descending(&self) -> Option { @@ -610,10 +620,69 @@ fn min_batch(values: &ArrayRef) -> Result { min_binary_view ) } + DataType::Struct(_) => min_max_batch_generic(values, Ordering::Greater)?, + DataType::List(_) => min_max_batch_generic(values, Ordering::Greater)?, + DataType::LargeList(_) => min_max_batch_generic(values, Ordering::Greater)?, + DataType::FixedSizeList(_, _) => { + min_max_batch_generic(values, Ordering::Greater)? + } + DataType::Dictionary(_, _) => { + let values = values.as_any_dictionary().values(); + min_batch(values)? + } _ => min_max_batch!(values, min), }) } +fn min_max_batch_generic(array: &ArrayRef, ordering: Ordering) -> Result { + if array.len() == array.null_count() { + return ScalarValue::try_from(array.data_type()); + } + let mut extreme = ScalarValue::try_from_array(array, 0)?; + for i in 1..array.len() { + let current = ScalarValue::try_from_array(array, i)?; + if current.is_null() { + continue; + } + if extreme.is_null() { + extreme = current; + continue; + } + if let Some(cmp) = extreme.partial_cmp(¤t) { + if cmp == ordering { + extreme = current; + } + } + } + + Ok(extreme) +} + +macro_rules! min_max_generic { + ($VALUE:expr, $DELTA:expr, $OP:ident) => {{ + if $VALUE.is_null() { + let mut delta_copy = $DELTA.clone(); + // When the new value won we want to compact it to + // avoid storing the entire input + delta_copy.compact(); + delta_copy + } else if $DELTA.is_null() { + $VALUE.clone() + } else { + match $VALUE.partial_cmp(&$DELTA) { + Some(choose_min_max!($OP)) => { + // When the new value won we want to compact it to + // avoid storing the entire input + let mut delta_copy = $DELTA.clone(); + delta_copy.compact(); + delta_copy + } + _ => $VALUE.clone(), + } + } + }}; +} + /// dynamically-typed max(array) -> ScalarValue pub fn max_batch(values: &ArrayRef) -> Result { Ok(match values.data_type() { @@ -653,6 +722,14 @@ pub fn max_batch(values: &ArrayRef) -> Result { max_binary ) } + DataType::Struct(_) => min_max_batch_generic(values, Ordering::Less)?, + DataType::List(_) => min_max_batch_generic(values, Ordering::Less)?, + DataType::LargeList(_) => min_max_batch_generic(values, Ordering::Less)?, + DataType::FixedSizeList(_, _) => min_max_batch_generic(values, Ordering::Less)?, + DataType::Dictionary(_, _) => { + let values = values.as_any_dictionary().values(); + max_batch(values)? + } _ => min_max_batch!(values, max), }) } @@ -923,6 +1000,37 @@ macro_rules! min_max { ) => { typed_min_max!(lhs, rhs, DurationNanosecond, $OP) } + + ( + lhs @ ScalarValue::Struct(_), + rhs @ ScalarValue::Struct(_), + ) => { + min_max_generic!(lhs, rhs, $OP) + } + + ( + lhs @ ScalarValue::List(_), + rhs @ ScalarValue::List(_), + ) => { + min_max_generic!(lhs, rhs, $OP) + } + + + ( + lhs @ ScalarValue::LargeList(_), + rhs @ ScalarValue::LargeList(_), + ) => { + min_max_generic!(lhs, rhs, $OP) + } + + + ( + lhs @ ScalarValue::FixedSizeList(_), + rhs @ ScalarValue::FixedSizeList(_), + ) => { + min_max_generic!(lhs, rhs, $OP) + } + e => { return internal_err!( "MIN/MAX is not expected to receive scalars of incompatible types {:?}", @@ -1098,7 +1206,9 @@ impl AggregateUDFImpl for Min { } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { - Ok(Box::new(MinAccumulator::try_new(acc_args.return_type)?)) + Ok(Box::new(MinAccumulator::try_new( + acc_args.return_field.data_type(), + )?)) } fn aliases(&self) -> &[String] { @@ -1108,7 +1218,7 @@ impl AggregateUDFImpl for Min { fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { use DataType::*; matches!( - args.return_type, + args.return_field.data_type(), Int8 | Int16 | Int32 | Int64 @@ -1133,6 +1243,7 @@ impl AggregateUDFImpl for Min { | LargeBinary | BinaryView | Duration(_) + | Struct(_) ) } @@ -1142,7 +1253,7 @@ impl AggregateUDFImpl for Min { ) -> Result> { use DataType::*; use TimeUnit::*; - let data_type = args.return_type; + let data_type = args.return_field.data_type(); match data_type { Int8 => primitive_min_accumulator!(data_type, i8, Int8Type), Int16 => primitive_min_accumulator!(data_type, i16, Int16Type), @@ -1208,7 +1319,9 @@ impl AggregateUDFImpl for Min { Utf8 | LargeUtf8 | Utf8View | Binary | LargeBinary | BinaryView => { Ok(Box::new(MinMaxBytesAccumulator::new_min(data_type.clone()))) } - + Struct(_) => Ok(Box::new(MinMaxStructAccumulator::new_min( + data_type.clone(), + ))), // This is only reached if groups_accumulator_supported is out of sync _ => internal_err!("GroupsAccumulator not supported for min({})", data_type), } @@ -1218,7 +1331,9 @@ impl AggregateUDFImpl for Min { &self, args: AccumulatorArgs, ) -> Result> { - Ok(Box::new(SlidingMinAccumulator::try_new(args.return_type)?)) + Ok(Box::new(SlidingMinAccumulator::try_new( + args.return_field.data_type(), + )?)) } fn is_descending(&self) -> Option { @@ -1627,8 +1742,11 @@ make_udaf_expr_and_func!( #[cfg(test)] mod tests { use super::*; - use arrow::datatypes::{ - IntervalDayTimeType, IntervalMonthDayNanoType, IntervalYearMonthType, + use arrow::{ + array::DictionaryArray, + datatypes::{ + IntervalDayTimeType, IntervalMonthDayNanoType, IntervalYearMonthType, + }, }; use std::sync::Arc; @@ -1768,10 +1886,10 @@ mod tests { use rand::Rng; fn get_random_vec_i32(len: usize) -> Vec { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let mut input = Vec::with_capacity(len); for _i in 0..len { - input.push(rng.gen_range(0..100)); + input.push(rng.random_range(0..100)); } input } @@ -1854,9 +1972,31 @@ mod tests { #[test] fn test_get_min_max_return_type_coerce_dictionary() -> Result<()> { let data_type = - DataType::Dictionary(Box::new(DataType::Utf8), Box::new(DataType::Int32)); + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)); let result = get_min_max_result_type(&[data_type])?; - assert_eq!(result, vec![DataType::Int32]); + assert_eq!(result, vec![DataType::Utf8]); + Ok(()) + } + + #[test] + fn test_min_max_dictionary() -> Result<()> { + let values = StringArray::from(vec!["b", "c", "a", "🦀", "d"]); + let keys = Int32Array::from(vec![Some(0), Some(1), Some(2), None, Some(4)]); + let dict_array = + DictionaryArray::try_new(keys, Arc::new(values) as ArrayRef).unwrap(); + let dict_array_ref = Arc::new(dict_array) as ArrayRef; + let rt_type = + get_min_max_result_type(&[dict_array_ref.data_type().clone()])?[0].clone(); + + let mut min_acc = MinAccumulator::try_new(&rt_type)?; + min_acc.update_batch(&[Arc::clone(&dict_array_ref)])?; + let min_result = min_acc.evaluate()?; + assert_eq!(min_result, ScalarValue::Utf8(Some("a".to_string()))); + + let mut max_acc = MaxAccumulator::try_new(&rt_type)?; + max_acc.update_batch(&[Arc::clone(&dict_array_ref)])?; + let max_result = max_acc.evaluate()?; + assert_eq!(max_result, ScalarValue::Utf8(Some("🦀".to_string()))); Ok(()) } } diff --git a/datafusion/functions-aggregate/src/min_max/min_max_struct.rs b/datafusion/functions-aggregate/src/min_max/min_max_struct.rs new file mode 100644 index 000000000000..8038f2f01d90 --- /dev/null +++ b/datafusion/functions-aggregate/src/min_max/min_max_struct.rs @@ -0,0 +1,544 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{cmp::Ordering, sync::Arc}; + +use arrow::{ + array::{ + Array, ArrayData, ArrayRef, AsArray, BooleanArray, MutableArrayData, StructArray, + }, + datatypes::DataType, +}; +use datafusion_common::{ + internal_err, + scalar::{copy_array_data, partial_cmp_struct}, + Result, +}; +use datafusion_expr::{EmitTo, GroupsAccumulator}; +use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::apply_filter_as_nulls; + +/// Accumulator for MIN/MAX operations on Struct data types. +/// +/// This accumulator tracks the minimum or maximum struct value encountered +/// during aggregation, depending on the `is_min` flag. +/// +/// The comparison is done based on the struct fields in order. +pub(crate) struct MinMaxStructAccumulator { + /// Inner data storage. + inner: MinMaxStructState, + /// if true, is `MIN` otherwise is `MAX` + is_min: bool, +} + +impl MinMaxStructAccumulator { + pub fn new_min(data_type: DataType) -> Self { + Self { + inner: MinMaxStructState::new(data_type), + is_min: true, + } + } + + pub fn new_max(data_type: DataType) -> Self { + Self { + inner: MinMaxStructState::new(data_type), + is_min: false, + } + } +} + +impl GroupsAccumulator for MinMaxStructAccumulator { + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + let array = &values[0]; + assert_eq!(array.len(), group_indices.len()); + assert_eq!(array.data_type(), &self.inner.data_type); + // apply filter if needed + let array = apply_filter_as_nulls(array, opt_filter)?; + + fn struct_min(a: &StructArray, b: &StructArray) -> bool { + matches!(partial_cmp_struct(a, b), Some(Ordering::Less)) + } + + fn struct_max(a: &StructArray, b: &StructArray) -> bool { + matches!(partial_cmp_struct(a, b), Some(Ordering::Greater)) + } + + if self.is_min { + self.inner.update_batch( + array.as_struct(), + group_indices, + total_num_groups, + struct_min, + ) + } else { + self.inner.update_batch( + array.as_struct(), + group_indices, + total_num_groups, + struct_max, + ) + } + } + + fn evaluate(&mut self, emit_to: EmitTo) -> Result { + let (_, min_maxes) = self.inner.emit_to(emit_to); + let fields = match &self.inner.data_type { + DataType::Struct(fields) => fields, + _ => return internal_err!("Data type is not a struct"), + }; + let null_array = StructArray::new_null(fields.clone(), 1); + let min_maxes_data: Vec = min_maxes + .iter() + .map(|v| match v { + Some(v) => v.to_data(), + None => null_array.to_data(), + }) + .collect(); + let min_maxes_refs: Vec<&ArrayData> = min_maxes_data.iter().collect(); + let mut copy = MutableArrayData::new(min_maxes_refs, true, min_maxes_data.len()); + + for (i, item) in min_maxes_data.iter().enumerate() { + copy.extend(i, 0, item.len()); + } + let result = copy.freeze(); + assert_eq!(&self.inner.data_type, result.data_type()); + Ok(Arc::new(StructArray::from(result))) + } + + fn state(&mut self, emit_to: EmitTo) -> Result> { + // min/max are their own states (no transition needed) + self.evaluate(emit_to).map(|arr| vec![arr]) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + // min/max are their own states (no transition needed) + self.update_batch(values, group_indices, opt_filter, total_num_groups) + } + + fn convert_to_state( + &self, + values: &[ArrayRef], + opt_filter: Option<&BooleanArray>, + ) -> Result> { + // Min/max do not change the values as they are their own states + // apply the filter by combining with the null mask, if any + let output = apply_filter_as_nulls(&values[0], opt_filter)?; + Ok(vec![output]) + } + + fn supports_convert_to_state(&self) -> bool { + true + } + + fn size(&self) -> usize { + self.inner.size() + } +} + +#[derive(Debug)] +struct MinMaxStructState { + /// The minimum/maximum value for each group + min_max: Vec>, + /// The data type of the array + data_type: DataType, + /// The total bytes of the string data (for pre-allocating the final array, + /// and tracking memory usage) + total_data_bytes: usize, +} + +#[derive(Debug, Clone)] +enum MinMaxLocation { + /// the min/max value is stored in the existing `min_max` array + ExistingMinMax, + /// the min/max value is stored in the input array at the given index + Input(StructArray), +} + +/// Implement the MinMaxStructState with a comparison function +/// for comparing structs +impl MinMaxStructState { + /// Create a new MinMaxStructState + /// + /// # Arguments: + /// * `data_type`: The data type of the arrays that will be passed to this accumulator + fn new(data_type: DataType) -> Self { + Self { + min_max: vec![], + data_type, + total_data_bytes: 0, + } + } + + /// Set the specified group to the given value, updating memory usage appropriately + fn set_value(&mut self, group_index: usize, new_val: &StructArray) { + let new_val = StructArray::from(copy_array_data(&new_val.to_data())); + match self.min_max[group_index].as_mut() { + None => { + self.total_data_bytes += new_val.get_array_memory_size(); + self.min_max[group_index] = Some(new_val); + } + Some(existing_val) => { + // Copy data over to avoid re-allocating + self.total_data_bytes -= existing_val.get_array_memory_size(); + self.total_data_bytes += new_val.get_array_memory_size(); + *existing_val = new_val; + } + } + } + + /// Updates the min/max values for the given string values + /// + /// `cmp` is the comparison function to use, called like `cmp(new_val, existing_val)` + /// returns true if the `new_val` should replace `existing_val` + fn update_batch( + &mut self, + array: &StructArray, + group_indices: &[usize], + total_num_groups: usize, + mut cmp: F, + ) -> Result<()> + where + F: FnMut(&StructArray, &StructArray) -> bool + Send + Sync, + { + self.min_max.resize(total_num_groups, None); + // Minimize value copies by calculating the new min/maxes for each group + // in this batch (either the existing min/max or the new input value) + // and updating the owned values in `self.min_maxes` at most once + let mut locations = vec![MinMaxLocation::ExistingMinMax; total_num_groups]; + + // Figure out the new min value for each group + for (index, group_index) in (0..array.len()).zip(group_indices.iter()) { + let group_index = *group_index; + if array.is_null(index) { + continue; + } + let new_val = array.slice(index, 1); + + let existing_val = match &locations[group_index] { + // previous input value was the min/max, so compare it + MinMaxLocation::Input(existing_val) => existing_val, + MinMaxLocation::ExistingMinMax => { + let Some(existing_val) = self.min_max[group_index].as_ref() else { + // no existing min/max, so this is the new min/max + locations[group_index] = MinMaxLocation::Input(new_val); + continue; + }; + existing_val + } + }; + + // Compare the new value to the existing value, replacing if necessary + if cmp(&new_val, existing_val) { + locations[group_index] = MinMaxLocation::Input(new_val); + } + } + + // Update self.min_max with any new min/max values we found in the input + for (group_index, location) in locations.iter().enumerate() { + match location { + MinMaxLocation::ExistingMinMax => {} + MinMaxLocation::Input(new_val) => self.set_value(group_index, new_val), + } + } + Ok(()) + } + + /// Emits the specified min_max values + /// + /// Returns (data_capacity, min_maxes), updating the current value of total_data_bytes + /// + /// - `data_capacity`: the total length of all strings and their contents, + /// - `min_maxes`: the actual min/max values for each group + fn emit_to(&mut self, emit_to: EmitTo) -> (usize, Vec>) { + match emit_to { + EmitTo::All => { + ( + std::mem::take(&mut self.total_data_bytes), // reset total bytes and min_max + std::mem::take(&mut self.min_max), + ) + } + EmitTo::First(n) => { + let first_min_maxes: Vec<_> = self.min_max.drain(..n).collect(); + let first_data_capacity: usize = first_min_maxes + .iter() + .map(|opt| opt.as_ref().map(|s| s.len()).unwrap_or(0)) + .sum(); + self.total_data_bytes -= first_data_capacity; + (first_data_capacity, first_min_maxes) + } + } + } + + fn size(&self) -> usize { + self.total_data_bytes + self.min_max.len() * size_of::>() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{Int32Array, StringArray, StructArray}; + use arrow::datatypes::{DataType, Field, Fields, Int32Type}; + use std::sync::Arc; + + fn create_test_struct_array( + int_values: Vec>, + str_values: Vec>, + ) -> StructArray { + let int_array = Int32Array::from(int_values); + let str_array = StringArray::from(str_values); + + let fields = vec![ + Field::new("int_field", DataType::Int32, true), + Field::new("str_field", DataType::Utf8, true), + ]; + + StructArray::new( + Fields::from(fields), + vec![ + Arc::new(int_array) as ArrayRef, + Arc::new(str_array) as ArrayRef, + ], + None, + ) + } + + fn create_nested_struct_array( + int_values: Vec>, + str_values: Vec>, + ) -> StructArray { + let inner_struct = create_test_struct_array(int_values, str_values); + + let fields = vec![Field::new("inner", inner_struct.data_type().clone(), true)]; + + StructArray::new( + Fields::from(fields), + vec![Arc::new(inner_struct) as ArrayRef], + None, + ) + } + + #[test] + fn test_min_max_simple_struct() { + let array = create_test_struct_array( + vec![Some(1), Some(2), Some(3)], + vec![Some("a"), Some("b"), Some("c")], + ); + + let mut min_accumulator = + MinMaxStructAccumulator::new_min(array.data_type().clone()); + let mut max_accumulator = + MinMaxStructAccumulator::new_max(array.data_type().clone()); + let values = vec![Arc::new(array) as ArrayRef]; + let group_indices = vec![0, 0, 0]; + + min_accumulator + .update_batch(&values, &group_indices, None, 1) + .unwrap(); + max_accumulator + .update_batch(&values, &group_indices, None, 1) + .unwrap(); + let min_result = min_accumulator.evaluate(EmitTo::All).unwrap(); + let max_result = max_accumulator.evaluate(EmitTo::All).unwrap(); + let min_result = min_result.as_struct(); + let max_result = max_result.as_struct(); + + assert_eq!(min_result.len(), 1); + let int_array = min_result.column(0).as_primitive::(); + let str_array = min_result.column(1).as_string::(); + assert_eq!(int_array.value(0), 1); + assert_eq!(str_array.value(0), "a"); + + assert_eq!(max_result.len(), 1); + let int_array = max_result.column(0).as_primitive::(); + let str_array = max_result.column(1).as_string::(); + assert_eq!(int_array.value(0), 3); + assert_eq!(str_array.value(0), "c"); + } + + #[test] + fn test_min_max_nested_struct() { + let array = create_nested_struct_array( + vec![Some(1), Some(2), Some(3)], + vec![Some("a"), Some("b"), Some("c")], + ); + + let mut min_accumulator = + MinMaxStructAccumulator::new_min(array.data_type().clone()); + let mut max_accumulator = + MinMaxStructAccumulator::new_max(array.data_type().clone()); + let values = vec![Arc::new(array) as ArrayRef]; + let group_indices = vec![0, 0, 0]; + + min_accumulator + .update_batch(&values, &group_indices, None, 1) + .unwrap(); + max_accumulator + .update_batch(&values, &group_indices, None, 1) + .unwrap(); + let min_result = min_accumulator.evaluate(EmitTo::All).unwrap(); + let max_result = max_accumulator.evaluate(EmitTo::All).unwrap(); + let min_result = min_result.as_struct(); + let max_result = max_result.as_struct(); + + assert_eq!(min_result.len(), 1); + let inner = min_result.column(0).as_struct(); + let int_array = inner.column(0).as_primitive::(); + let str_array = inner.column(1).as_string::(); + assert_eq!(int_array.value(0), 1); + assert_eq!(str_array.value(0), "a"); + + assert_eq!(max_result.len(), 1); + let inner = max_result.column(0).as_struct(); + let int_array = inner.column(0).as_primitive::(); + let str_array = inner.column(1).as_string::(); + assert_eq!(int_array.value(0), 3); + assert_eq!(str_array.value(0), "c"); + } + + #[test] + fn test_min_max_with_nulls() { + let array = create_test_struct_array( + vec![Some(1), None, Some(3)], + vec![Some("a"), None, Some("c")], + ); + + let mut min_accumulator = + MinMaxStructAccumulator::new_min(array.data_type().clone()); + let mut max_accumulator = + MinMaxStructAccumulator::new_max(array.data_type().clone()); + let values = vec![Arc::new(array) as ArrayRef]; + let group_indices = vec![0, 0, 0]; + + min_accumulator + .update_batch(&values, &group_indices, None, 1) + .unwrap(); + max_accumulator + .update_batch(&values, &group_indices, None, 1) + .unwrap(); + let min_result = min_accumulator.evaluate(EmitTo::All).unwrap(); + let max_result = max_accumulator.evaluate(EmitTo::All).unwrap(); + let min_result = min_result.as_struct(); + let max_result = max_result.as_struct(); + + assert_eq!(min_result.len(), 1); + let int_array = min_result.column(0).as_primitive::(); + let str_array = min_result.column(1).as_string::(); + assert_eq!(int_array.value(0), 1); + assert_eq!(str_array.value(0), "a"); + + assert_eq!(max_result.len(), 1); + let int_array = max_result.column(0).as_primitive::(); + let str_array = max_result.column(1).as_string::(); + assert_eq!(int_array.value(0), 3); + assert_eq!(str_array.value(0), "c"); + } + + #[test] + fn test_min_max_multiple_groups() { + let array = create_test_struct_array( + vec![Some(1), Some(2), Some(3), Some(4)], + vec![Some("a"), Some("b"), Some("c"), Some("d")], + ); + + let mut min_accumulator = + MinMaxStructAccumulator::new_min(array.data_type().clone()); + let mut max_accumulator = + MinMaxStructAccumulator::new_max(array.data_type().clone()); + let values = vec![Arc::new(array) as ArrayRef]; + let group_indices = vec![0, 1, 0, 1]; + + min_accumulator + .update_batch(&values, &group_indices, None, 2) + .unwrap(); + max_accumulator + .update_batch(&values, &group_indices, None, 2) + .unwrap(); + let min_result = min_accumulator.evaluate(EmitTo::All).unwrap(); + let max_result = max_accumulator.evaluate(EmitTo::All).unwrap(); + let min_result = min_result.as_struct(); + let max_result = max_result.as_struct(); + + assert_eq!(min_result.len(), 2); + let int_array = min_result.column(0).as_primitive::(); + let str_array = min_result.column(1).as_string::(); + assert_eq!(int_array.value(0), 1); + assert_eq!(str_array.value(0), "a"); + assert_eq!(int_array.value(1), 2); + assert_eq!(str_array.value(1), "b"); + + assert_eq!(max_result.len(), 2); + let int_array = max_result.column(0).as_primitive::(); + let str_array = max_result.column(1).as_string::(); + assert_eq!(int_array.value(0), 3); + assert_eq!(str_array.value(0), "c"); + assert_eq!(int_array.value(1), 4); + assert_eq!(str_array.value(1), "d"); + } + + #[test] + fn test_min_max_with_filter() { + let array = create_test_struct_array( + vec![Some(1), Some(2), Some(3), Some(4)], + vec![Some("a"), Some("b"), Some("c"), Some("d")], + ); + + // Create a filter that only keeps even numbers + let filter = BooleanArray::from(vec![false, true, false, true]); + + let mut min_accumulator = + MinMaxStructAccumulator::new_min(array.data_type().clone()); + let mut max_accumulator = + MinMaxStructAccumulator::new_max(array.data_type().clone()); + let values = vec![Arc::new(array) as ArrayRef]; + let group_indices = vec![0, 0, 0, 0]; + + min_accumulator + .update_batch(&values, &group_indices, Some(&filter), 1) + .unwrap(); + max_accumulator + .update_batch(&values, &group_indices, Some(&filter), 1) + .unwrap(); + let min_result = min_accumulator.evaluate(EmitTo::All).unwrap(); + let max_result = max_accumulator.evaluate(EmitTo::All).unwrap(); + let min_result = min_result.as_struct(); + let max_result = max_result.as_struct(); + + assert_eq!(min_result.len(), 1); + let int_array = min_result.column(0).as_primitive::(); + let str_array = min_result.column(1).as_string::(); + assert_eq!(int_array.value(0), 2); + assert_eq!(str_array.value(0), "b"); + + assert_eq!(max_result.len(), 1); + let int_array = max_result.column(0).as_primitive::(); + let str_array = max_result.column(1).as_string::(); + assert_eq!(int_array.value(0), 4); + assert_eq!(str_array.value(0), "d"); + } +} diff --git a/datafusion/functions-aggregate/src/nth_value.rs b/datafusion/functions-aggregate/src/nth_value.rs index d84bd02a6baf..1525b2f991a1 100644 --- a/datafusion/functions-aggregate/src/nth_value.rs +++ b/datafusion/functions-aggregate/src/nth_value.rs @@ -24,7 +24,7 @@ use std::mem::{size_of, size_of_val}; use std::sync::Arc; use arrow::array::{new_empty_array, ArrayRef, AsArray, StructArray}; -use arrow::datatypes::{DataType, Field, Fields}; +use arrow::datatypes::{DataType, Field, FieldRef, Fields}; use datafusion_common::utils::{get_row_at_idx, SingleRowListArrayBuilder}; use datafusion_common::{exec_err, internal_err, not_impl_err, Result, ScalarValue}; @@ -164,11 +164,11 @@ impl AggregateUDFImpl for NthValueAgg { .map(|acc| Box::new(acc) as _) } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { let mut fields = vec![Field::new_list( format_state_name(self.name(), "nth_value"), // See COMMENTS.md to understand why nullable is set to true - Field::new_list_field(args.input_types[0].clone(), true), + Field::new_list_field(args.input_fields[0].data_type().clone(), true), false, )]; let orderings = args.ordering_fields.to_vec(); @@ -179,7 +179,7 @@ impl AggregateUDFImpl for NthValueAgg { false, )); } - Ok(fields) + Ok(fields.into_iter().map(Arc::new).collect()) } fn aliases(&self) -> &[String] { @@ -400,7 +400,6 @@ impl Accumulator for NthValueAccumulator { impl NthValueAccumulator { fn evaluate_orderings(&self) -> Result { let fields = ordering_fields(self.ordering_req.as_ref(), &self.datatypes[1..]); - let struct_field = Fields::from(fields.clone()); let mut column_wise_ordering_values = vec![]; let num_columns = fields.len(); @@ -418,6 +417,7 @@ impl NthValueAccumulator { column_wise_ordering_values.push(array); } + let struct_field = Fields::from(fields); let ordering_array = StructArray::try_new(struct_field, column_wise_ordering_values, None)?; diff --git a/datafusion/functions-aggregate/src/planner.rs b/datafusion/functions-aggregate/src/planner.rs index c8cb84118995..f0e37f6b1dbe 100644 --- a/datafusion/functions-aggregate/src/planner.rs +++ b/datafusion/functions-aggregate/src/planner.rs @@ -100,7 +100,7 @@ impl ExprPlanner for AggregateFunctionPlanner { let new_expr = Expr::AggregateFunction(AggregateFunction::new_udf( func, - vec![Expr::Literal(COUNT_STAR_EXPANSION)], + vec![Expr::Literal(COUNT_STAR_EXPANSION, None)], distinct, filter, order_by, diff --git a/datafusion/functions-aggregate/src/regr.rs b/datafusion/functions-aggregate/src/regr.rs index 82575d15e50b..0f84aa1323f5 100644 --- a/datafusion/functions-aggregate/src/regr.rs +++ b/datafusion/functions-aggregate/src/regr.rs @@ -18,6 +18,7 @@ //! Defines physical expressions that can evaluated at runtime during query execution use arrow::array::Float64Array; +use arrow::datatypes::FieldRef; use arrow::{ array::{ArrayRef, UInt64Array}, compute::cast, @@ -38,7 +39,7 @@ use datafusion_expr::{ use std::any::Any; use std::fmt::Debug; use std::mem::size_of_val; -use std::sync::LazyLock; +use std::sync::{Arc, LazyLock}; macro_rules! make_regr_udaf_expr_and_func { ($EXPR_FN:ident, $AGGREGATE_UDF_FN:ident, $REGR_TYPE:expr) => { @@ -278,7 +279,7 @@ impl AggregateUDFImpl for Regr { Ok(Box::new(RegrAccumulator::try_new(&self.regr_type)?)) } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { Ok(vec![ Field::new( format_state_name(args.name, "count"), @@ -310,7 +311,10 @@ impl AggregateUDFImpl for Regr { DataType::Float64, true, ), - ]) + ] + .into_iter() + .map(Arc::new) + .collect()) } fn documentation(&self) -> Option<&Documentation> { diff --git a/datafusion/functions-aggregate/src/stddev.rs b/datafusion/functions-aggregate/src/stddev.rs index adf86a128cfb..f948df840e73 100644 --- a/datafusion/functions-aggregate/src/stddev.rs +++ b/datafusion/functions-aggregate/src/stddev.rs @@ -23,8 +23,8 @@ use std::mem::align_of_val; use std::sync::Arc; use arrow::array::Float64Array; +use arrow::datatypes::FieldRef; use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field}; - use datafusion_common::{internal_err, not_impl_err, Result}; use datafusion_common::{plan_err, ScalarValue}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; @@ -109,7 +109,7 @@ impl AggregateUDFImpl for Stddev { Ok(DataType::Float64) } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { Ok(vec![ Field::new( format_state_name(args.name, "count"), @@ -122,7 +122,10 @@ impl AggregateUDFImpl for Stddev { true, ), Field::new(format_state_name(args.name, "m2"), DataType::Float64, true), - ]) + ] + .into_iter() + .map(Arc::new) + .collect()) } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { @@ -217,7 +220,7 @@ impl AggregateUDFImpl for StddevPop { &self.signature } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { Ok(vec![ Field::new( format_state_name(args.name, "count"), @@ -230,7 +233,10 @@ impl AggregateUDFImpl for StddevPop { true, ), Field::new(format_state_name(args.name, "m2"), DataType::Float64, true), - ]) + ] + .into_iter() + .map(Arc::new) + .collect()) } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { @@ -436,7 +442,7 @@ mod tests { schema: &Schema, ) -> Result { let args1 = AccumulatorArgs { - return_type: &DataType::Float64, + return_field: Field::new("f", DataType::Float64, true).into(), schema, ignore_nulls: false, ordering_req: &LexOrdering::default(), @@ -447,7 +453,7 @@ mod tests { }; let args2 = AccumulatorArgs { - return_type: &DataType::Float64, + return_field: Field::new("f", DataType::Float64, true).into(), schema, ignore_nulls: false, ordering_req: &LexOrdering::default(), diff --git a/datafusion/functions-aggregate/src/string_agg.rs b/datafusion/functions-aggregate/src/string_agg.rs index a7594b9ccb01..4682e574bfa2 100644 --- a/datafusion/functions-aggregate/src/string_agg.rs +++ b/datafusion/functions-aggregate/src/string_agg.rs @@ -19,8 +19,8 @@ use crate::array_agg::ArrayAgg; use arrow::array::ArrayRef; -use arrow::datatypes::{DataType, Field}; -use datafusion_common::cast::as_generic_string_array; +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion_common::cast::{as_generic_string_array, as_string_view_array}; use datafusion_common::Result; use datafusion_common::{internal_err, not_impl_err, ScalarValue}; use datafusion_expr::function::AccumulatorArgs; @@ -95,9 +95,15 @@ impl StringAgg { TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Utf8]), TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::LargeUtf8]), TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Null]), + TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Utf8View]), TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8]), TypeSignature::Exact(vec![DataType::Utf8, DataType::LargeUtf8]), TypeSignature::Exact(vec![DataType::Utf8, DataType::Null]), + TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8View]), + TypeSignature::Exact(vec![DataType::Utf8View, DataType::Utf8View]), + TypeSignature::Exact(vec![DataType::Utf8View, DataType::LargeUtf8]), + TypeSignature::Exact(vec![DataType::Utf8View, DataType::Null]), + TypeSignature::Exact(vec![DataType::Utf8View, DataType::Utf8]), ], Volatility::Immutable, ), @@ -129,7 +135,7 @@ impl AggregateUDFImpl for StringAgg { Ok(DataType::LargeUtf8) } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { self.array_agg.state_fields(args) } @@ -154,7 +160,12 @@ impl AggregateUDFImpl for StringAgg { }; let array_agg_acc = self.array_agg.accumulator(AccumulatorArgs { - return_type: &DataType::new_list(acc_args.return_type.clone(), true), + return_field: Field::new( + "f", + DataType::new_list(acc_args.return_field.data_type().clone(), true), + true, + ) + .into(), exprs: &filter_index(acc_args.exprs, 1), ..acc_args })?; @@ -206,6 +217,10 @@ impl Accumulator for StringAggAccumulator { .iter() .flatten() .collect(), + DataType::Utf8View => as_string_view_array(list.values())? + .iter() + .flatten() + .collect(), _ => { return internal_err!( "Expected elements to of type Utf8 or LargeUtf8, but got {}", @@ -436,7 +451,7 @@ mod tests { fn build(&self) -> Result> { StringAgg::new().accumulator(AccumulatorArgs { - return_type: &DataType::LargeUtf8, + return_field: Field::new("f", DataType::LargeUtf8, true).into(), schema: &self.schema, ignore_nulls: false, ordering_req: &self.ordering, diff --git a/datafusion/functions-aggregate/src/sum.rs b/datafusion/functions-aggregate/src/sum.rs index 76a1315c2d88..37d208ffb03a 100644 --- a/datafusion/functions-aggregate/src/sum.rs +++ b/datafusion/functions-aggregate/src/sum.rs @@ -26,8 +26,8 @@ use std::mem::{size_of, size_of_val}; use arrow::array::Array; use arrow::array::ArrowNativeTypeOp; use arrow::array::{ArrowNumericType, AsArray}; -use arrow::datatypes::ArrowNativeType; use arrow::datatypes::ArrowPrimitiveType; +use arrow::datatypes::{ArrowNativeType, FieldRef}; use arrow::datatypes::{ DataType, Decimal128Type, Decimal256Type, Float64Type, Int64Type, UInt64Type, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, @@ -63,17 +63,27 @@ make_udaf_expr_and_func!( /// `helper` is a macro accepting (ArrowPrimitiveType, DataType) macro_rules! downcast_sum { ($args:ident, $helper:ident) => { - match $args.return_type { - DataType::UInt64 => $helper!(UInt64Type, $args.return_type), - DataType::Int64 => $helper!(Int64Type, $args.return_type), - DataType::Float64 => $helper!(Float64Type, $args.return_type), - DataType::Decimal128(_, _) => $helper!(Decimal128Type, $args.return_type), - DataType::Decimal256(_, _) => $helper!(Decimal256Type, $args.return_type), + match $args.return_field.data_type().clone() { + DataType::UInt64 => { + $helper!(UInt64Type, $args.return_field.data_type().clone()) + } + DataType::Int64 => { + $helper!(Int64Type, $args.return_field.data_type().clone()) + } + DataType::Float64 => { + $helper!(Float64Type, $args.return_field.data_type().clone()) + } + DataType::Decimal128(_, _) => { + $helper!(Decimal128Type, $args.return_field.data_type().clone()) + } + DataType::Decimal256(_, _) => { + $helper!(Decimal256Type, $args.return_field.data_type().clone()) + } _ => { not_impl_err!( "Sum not supported for {}: {}", $args.name, - $args.return_type + $args.return_field.data_type() ) } } @@ -191,20 +201,22 @@ impl AggregateUDFImpl for Sum { } } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { if args.is_distinct { Ok(vec![Field::new_list( format_state_name(args.name, "sum distinct"), // See COMMENTS.md to understand why nullable is set to true - Field::new_list_field(args.return_type.clone(), true), + Field::new_list_field(args.return_type().clone(), true), false, - )]) + ) + .into()]) } else { Ok(vec![Field::new( format_state_name(args.name, "sum"), - args.return_type.clone(), + args.return_type().clone(), true, - )]) + ) + .into()]) } } diff --git a/datafusion/functions-aggregate/src/variance.rs b/datafusion/functions-aggregate/src/variance.rs index 53e3e0cc56cd..586b2dab0ae6 100644 --- a/datafusion/functions-aggregate/src/variance.rs +++ b/datafusion/functions-aggregate/src/variance.rs @@ -18,15 +18,13 @@ //! [`VarianceSample`]: variance sample aggregations. //! [`VariancePopulation`]: variance population aggregations. +use arrow::datatypes::FieldRef; use arrow::{ array::{Array, ArrayRef, BooleanArray, Float64Array, UInt64Array}, buffer::NullBuffer, compute::kernels::cast, datatypes::{DataType, Field}, }; -use std::mem::{size_of, size_of_val}; -use std::{fmt::Debug, sync::Arc}; - use datafusion_common::{downcast_value, not_impl_err, plan_err, Result, ScalarValue}; use datafusion_expr::{ function::{AccumulatorArgs, StateFieldsArgs}, @@ -38,6 +36,8 @@ use datafusion_functions_aggregate_common::{ aggregate::groups_accumulator::accumulate::accumulate, stats::StatsType, }; use datafusion_macros::user_doc; +use std::mem::{size_of, size_of_val}; +use std::{fmt::Debug, sync::Arc}; make_udaf_expr_and_func!( VarianceSample, @@ -107,13 +107,16 @@ impl AggregateUDFImpl for VarianceSample { Ok(DataType::Float64) } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { let name = args.name; Ok(vec![ Field::new(format_state_name(name, "count"), DataType::UInt64, true), Field::new(format_state_name(name, "mean"), DataType::Float64, true), Field::new(format_state_name(name, "m2"), DataType::Float64, true), - ]) + ] + .into_iter() + .map(Arc::new) + .collect()) } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { @@ -200,13 +203,16 @@ impl AggregateUDFImpl for VariancePopulation { Ok(DataType::Float64) } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { let name = args.name; Ok(vec![ Field::new(format_state_name(name, "count"), DataType::UInt64, true), Field::new(format_state_name(name, "mean"), DataType::Float64, true), Field::new(format_state_name(name, "m2"), DataType::Float64, true), - ]) + ] + .into_iter() + .map(Arc::new) + .collect()) } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { diff --git a/datafusion/functions-nested/benches/map.rs b/datafusion/functions-nested/benches/map.rs index 2774b24b902a..55dd7ad14460 100644 --- a/datafusion/functions-nested/benches/map.rs +++ b/datafusion/functions-nested/benches/map.rs @@ -36,7 +36,7 @@ fn keys(rng: &mut ThreadRng) -> Vec { let mut keys = HashSet::with_capacity(1000); while keys.len() < 1000 { - keys.insert(rng.gen_range(0..10000).to_string()); + keys.insert(rng.random_range(0..10000).to_string()); } keys.into_iter().collect() @@ -46,20 +46,23 @@ fn values(rng: &mut ThreadRng) -> Vec { let mut values = HashSet::with_capacity(1000); while values.len() < 1000 { - values.insert(rng.gen_range(0..10000)); + values.insert(rng.random_range(0..10000)); } values.into_iter().collect() } fn criterion_benchmark(c: &mut Criterion) { c.bench_function("make_map_1000", |b| { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let keys = keys(&mut rng); let values = values(&mut rng); let mut buffer = Vec::new(); for i in 0..1000 { - buffer.push(Expr::Literal(ScalarValue::Utf8(Some(keys[i].clone())))); - buffer.push(Expr::Literal(ScalarValue::Int32(Some(values[i])))); + buffer.push(Expr::Literal( + ScalarValue::Utf8(Some(keys[i].clone())), + None, + )); + buffer.push(Expr::Literal(ScalarValue::Int32(Some(values[i])), None)); } let planner = NestedFunctionPlanner {}; @@ -74,7 +77,7 @@ fn criterion_benchmark(c: &mut Criterion) { }); c.bench_function("map_1000", |b| { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let field = Arc::new(Field::new_list_field(DataType::Utf8, true)); let offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0, 1000])); let key_list = ListArray::new( @@ -94,17 +97,23 @@ fn criterion_benchmark(c: &mut Criterion) { let keys = ColumnarValue::Scalar(ScalarValue::List(Arc::new(key_list))); let values = ColumnarValue::Scalar(ScalarValue::List(Arc::new(value_list))); - let return_type = &map_udf() + let return_type = map_udf() .return_type(&[DataType::Utf8, DataType::Int32]) .expect("should get return type"); + let arg_fields = vec![ + Field::new("a", keys.data_type(), true).into(), + Field::new("a", values.data_type(), true).into(), + ]; + let return_field = Field::new("f", return_type, true).into(); b.iter(|| { black_box( map_udf() .invoke_with_args(ScalarFunctionArgs { args: vec![keys.clone(), values.clone()], + arg_fields: arg_fields.clone(), number_rows: 1, - return_type, + return_field: Arc::clone(&return_field), }) .expect("map should work on valid values"), ); diff --git a/datafusion/functions-nested/src/array_has.rs b/datafusion/functions-nested/src/array_has.rs index 5ef1491313b1..3b9b705e72c5 100644 --- a/datafusion/functions-nested/src/array_has.rs +++ b/datafusion/functions-nested/src/array_has.rs @@ -133,7 +133,7 @@ impl ScalarUDFImpl for ArrayHas { // if the haystack is a constant list, we can use an inlist expression which is more // efficient because the haystack is not varying per-row - if let Expr::Literal(ScalarValue::List(array)) = haystack { + if let Expr::Literal(ScalarValue::List(array), _) = haystack { // TODO: support LargeList // (not supported by `convert_array_to_scalar_vec`) // (FixedSizeList not supported either, but seems to have worked fine when attempting to @@ -147,7 +147,7 @@ impl ScalarUDFImpl for ArrayHas { let list = scalar_values .into_iter() .flatten() - .map(Expr::Literal) + .map(|v| Expr::Literal(v, None)) .collect(); return Ok(ExprSimplifyResult::Simplified(Expr::InList(InList { diff --git a/datafusion/functions-nested/src/cardinality.rs b/datafusion/functions-nested/src/cardinality.rs index f2f23841586c..98bda81ef25f 100644 --- a/datafusion/functions-nested/src/cardinality.rs +++ b/datafusion/functions-nested/src/cardinality.rs @@ -23,12 +23,12 @@ use arrow::array::{ }; use arrow::datatypes::{ DataType, - DataType::{FixedSizeList, LargeList, List, Map, UInt64}, + DataType::{LargeList, List, Map, Null, UInt64}, }; use datafusion_common::cast::{as_large_list_array, as_list_array, as_map_array}; -use datafusion_common::utils::take_function_args; +use datafusion_common::exec_err; +use datafusion_common::utils::{take_function_args, ListCoercion}; use datafusion_common::Result; -use datafusion_common::{exec_err, plan_err}; use datafusion_expr::{ ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature, Volatility, @@ -52,7 +52,7 @@ impl Cardinality { vec![ TypeSignature::ArraySignature(ArrayFunctionSignature::Array { arguments: vec![ArrayFunctionArgument::Array], - array_coercion: None, + array_coercion: Some(ListCoercion::FixedSizedListToList), }), TypeSignature::ArraySignature(ArrayFunctionSignature::MapArray), ], @@ -103,13 +103,8 @@ impl ScalarUDFImpl for Cardinality { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(match arg_types[0] { - List(_) | LargeList(_) | FixedSizeList(_, _) | Map(_, _) => UInt64, - _ => { - return plan_err!("The cardinality function can only accept List/LargeList/FixedSizeList/Map."); - } - }) + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(UInt64) } fn invoke_with_args( @@ -131,21 +126,22 @@ impl ScalarUDFImpl for Cardinality { /// Cardinality SQL function pub fn cardinality_inner(args: &[ArrayRef]) -> Result { let [array] = take_function_args("cardinality", args)?; - match &array.data_type() { + match array.data_type() { + Null => Ok(Arc::new(UInt64Array::from_value(0, array.len()))), List(_) => { - let list_array = as_list_array(&array)?; + let list_array = as_list_array(array)?; generic_list_cardinality::(list_array) } LargeList(_) => { - let list_array = as_large_list_array(&array)?; + let list_array = as_large_list_array(array)?; generic_list_cardinality::(list_array) } Map(_, _) => { - let map_array = as_map_array(&array)?; + let map_array = as_map_array(array)?; generic_map_cardinality(map_array) } - other => { - exec_err!("cardinality does not support type '{:?}'", other) + arg_type => { + exec_err!("cardinality does not support type {arg_type}") } } } diff --git a/datafusion/functions-nested/src/concat.rs b/datafusion/functions-nested/src/concat.rs index f4b9208e5c83..dd8784d36c48 100644 --- a/datafusion/functions-nested/src/concat.rs +++ b/datafusion/functions-nested/src/concat.rs @@ -17,30 +17,32 @@ //! [`ScalarUDFImpl`] definitions for `array_append`, `array_prepend` and `array_concat` functions. +use std::any::Any; use std::sync::Arc; -use std::{any::Any, cmp::Ordering}; +use crate::make_array::make_array_inner; +use crate::utils::{align_array_dimensions, check_datatypes, make_scalar_function}; use arrow::array::{ - Array, ArrayRef, Capacities, GenericListArray, MutableArrayData, NullBufferBuilder, - OffsetSizeTrait, + Array, ArrayRef, Capacities, GenericListArray, MutableArrayData, NullArray, + NullBufferBuilder, OffsetSizeTrait, }; use arrow::buffer::OffsetBuffer; use arrow::datatypes::{DataType, Field}; -use datafusion_common::utils::ListCoercion; +use datafusion_common::utils::{ + base_type, coerced_type_with_base_type_only, ListCoercion, +}; use datafusion_common::Result; use datafusion_common::{ cast::as_generic_list_array, - exec_err, not_impl_err, plan_err, + exec_err, plan_err, utils::{list_ndims, take_function_args}, }; +use datafusion_expr::binary::type_union_resolution; use datafusion_expr::{ - ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation, - ScalarUDFImpl, Signature, TypeSignature, Volatility, + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; use datafusion_macros::user_doc; -use crate::utils::{align_array_dimensions, check_datatypes, make_scalar_function}; - make_udf_expr_and_func!( ArrayAppend, array_append, @@ -106,7 +108,12 @@ impl ScalarUDFImpl for ArrayAppend { } fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(arg_types[0].clone()) + let [array_type, element_type] = take_function_args(self.name(), arg_types)?; + if array_type.is_null() { + Ok(DataType::new_list(element_type.clone(), true)) + } else { + Ok(array_type.clone()) + } } fn invoke_with_args( @@ -166,18 +173,7 @@ impl Default for ArrayPrepend { impl ArrayPrepend { pub fn new() -> Self { Self { - signature: Signature { - type_signature: TypeSignature::ArraySignature( - ArrayFunctionSignature::Array { - arguments: vec![ - ArrayFunctionArgument::Element, - ArrayFunctionArgument::Array, - ], - array_coercion: Some(ListCoercion::FixedSizedListToList), - }, - ), - volatility: Volatility::Immutable, - }, + signature: Signature::element_and_array(Volatility::Immutable), aliases: vec![ String::from("list_prepend"), String::from("array_push_front"), @@ -201,7 +197,12 @@ impl ScalarUDFImpl for ArrayPrepend { } fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(arg_types[1].clone()) + let [element_type, array_type] = take_function_args(self.name(), arg_types)?; + if array_type.is_null() { + Ok(DataType::new_list(element_type.clone(), true)) + } else { + Ok(array_type.clone()) + } } fn invoke_with_args( @@ -263,7 +264,7 @@ impl Default for ArrayConcat { impl ArrayConcat { pub fn new() -> Self { Self { - signature: Signature::variadic_any(Volatility::Immutable), + signature: Signature::user_defined(Volatility::Immutable), aliases: vec![ String::from("array_cat"), String::from("list_concat"), @@ -287,39 +288,40 @@ impl ScalarUDFImpl for ArrayConcat { } fn return_type(&self, arg_types: &[DataType]) -> Result { - let mut expr_type = DataType::Null; let mut max_dims = 0; + let mut large_list = false; + let mut element_types = Vec::with_capacity(arg_types.len()); for arg_type in arg_types { - let DataType::List(field) = arg_type else { - return plan_err!( - "The array_concat function can only accept list as the args." - ); - }; - if !field.data_type().equals_datatype(&DataType::Null) { - let dims = list_ndims(arg_type); - expr_type = match max_dims.cmp(&dims) { - Ordering::Greater => expr_type, - Ordering::Equal => { - if expr_type == DataType::Null { - arg_type.clone() - } else if !expr_type.equals_datatype(arg_type) { - return plan_err!( - "It is not possible to concatenate arrays of different types. Expected: {}, got: {}", expr_type, arg_type - ); - } else { - expr_type - } - } - - Ordering::Less => { - max_dims = dims; - arg_type.clone() - } - }; + match arg_type { + DataType::Null | DataType::List(_) | DataType::FixedSizeList(..) => (), + DataType::LargeList(_) => large_list = true, + arg_type => { + return plan_err!("{} does not support type {arg_type}", self.name()) + } } + + max_dims = max_dims.max(list_ndims(arg_type)); + element_types.push(base_type(arg_type)) } - Ok(expr_type) + if max_dims == 0 { + Ok(DataType::Null) + } else if let Some(mut return_type) = type_union_resolution(&element_types) { + for _ in 1..max_dims { + return_type = DataType::new_list(return_type, true) + } + + if large_list { + Ok(DataType::new_large_list(return_type, true)) + } else { + Ok(DataType::new_list(return_type, true)) + } + } else { + plan_err!( + "Failed to unify argument types of {}: {arg_types:?}", + self.name() + ) + } } fn invoke_with_args( @@ -333,6 +335,16 @@ impl ScalarUDFImpl for ArrayConcat { &self.aliases } + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + let base_type = base_type(&self.return_type(arg_types)?); + let coercion = Some(&ListCoercion::FixedSizedListToList); + let arg_types = arg_types.iter().map(|arg_type| { + coerced_type_with_base_type_only(arg_type, &base_type, coercion) + }); + + Ok(arg_types.collect()) + } + fn documentation(&self) -> Option<&Documentation> { self.doc() } @@ -341,24 +353,27 @@ impl ScalarUDFImpl for ArrayConcat { /// Array_concat/Array_cat SQL function pub(crate) fn array_concat_inner(args: &[ArrayRef]) -> Result { if args.is_empty() { - return exec_err!("array_concat expects at least one arguments"); + return exec_err!("array_concat expects at least one argument"); } - let mut new_args = vec![]; + let mut all_null = true; + let mut large_list = false; for arg in args { - let ndim = list_ndims(arg.data_type()); - let base_type = datafusion_common::utils::base_type(arg.data_type()); - if ndim == 0 { - return not_impl_err!("Array is not type '{base_type:?}'."); - } - if !base_type.eq(&DataType::Null) { - new_args.push(Arc::clone(arg)); + match arg.data_type() { + DataType::Null => continue, + DataType::LargeList(_) => large_list = true, + _ => (), } + + all_null = false } - match &args[0].data_type() { - DataType::LargeList(_) => concat_internal::(new_args.as_slice()), - _ => concat_internal::(new_args.as_slice()), + if all_null { + Ok(Arc::new(NullArray::new(args[0].len()))) + } else if large_list { + concat_internal::(args) + } else { + concat_internal::(args) } } @@ -427,21 +442,23 @@ fn concat_internal(args: &[ArrayRef]) -> Result { /// Array_append SQL function pub(crate) fn array_append_inner(args: &[ArrayRef]) -> Result { - let [array, _] = take_function_args("array_append", args)?; - + let [array, values] = take_function_args("array_append", args)?; match array.data_type() { + DataType::Null => make_array_inner(&[Arc::clone(values)]), + DataType::List(_) => general_append_and_prepend::(args, true), DataType::LargeList(_) => general_append_and_prepend::(args, true), - _ => general_append_and_prepend::(args, true), + arg_type => exec_err!("array_append does not support type {arg_type}"), } } /// Array_prepend SQL function pub(crate) fn array_prepend_inner(args: &[ArrayRef]) -> Result { - let [_, array] = take_function_args("array_prepend", args)?; - + let [values, array] = take_function_args("array_prepend", args)?; match array.data_type() { + DataType::Null => make_array_inner(&[Arc::clone(values)]), + DataType::List(_) => general_append_and_prepend::(args, false), DataType::LargeList(_) => general_append_and_prepend::(args, false), - _ => general_append_and_prepend::(args, false), + arg_type => exec_err!("array_prepend does not support type {arg_type}"), } } diff --git a/datafusion/functions-nested/src/dimension.rs b/datafusion/functions-nested/src/dimension.rs index a7d033641413..d1e6b1be4cfa 100644 --- a/datafusion/functions-nested/src/dimension.rs +++ b/datafusion/functions-nested/src/dimension.rs @@ -17,24 +17,26 @@ //! [`ScalarUDFImpl`] definitions for array_dims and array_ndims functions. -use arrow::array::{ - Array, ArrayRef, GenericListArray, ListArray, OffsetSizeTrait, UInt64Array, -}; +use arrow::array::{Array, ArrayRef, ListArray, UInt64Array}; use arrow::datatypes::{ DataType, - DataType::{FixedSizeList, LargeList, List, UInt64}, - Field, UInt64Type, + DataType::{FixedSizeList, LargeList, List, Null, UInt64}, + UInt64Type, }; use std::any::Any; -use datafusion_common::cast::{as_large_list_array, as_list_array}; -use datafusion_common::{exec_err, plan_err, utils::take_function_args, Result}; +use datafusion_common::cast::{ + as_fixed_size_list_array, as_large_list_array, as_list_array, +}; +use datafusion_common::{exec_err, utils::take_function_args, Result}; use crate::utils::{compute_array_dims, make_scalar_function}; +use datafusion_common::utils::list_ndims; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; use datafusion_macros::user_doc; +use itertools::Itertools; use std::sync::Arc; make_udf_expr_and_func!( @@ -77,7 +79,7 @@ impl Default for ArrayDims { impl ArrayDims { pub fn new() -> Self { Self { - signature: Signature::array(Volatility::Immutable), + signature: Signature::arrays(1, None, Volatility::Immutable), aliases: vec!["list_dims".to_string()], } } @@ -95,15 +97,8 @@ impl ScalarUDFImpl for ArrayDims { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(match arg_types[0] { - List(_) | LargeList(_) | FixedSizeList(_, _) => { - List(Arc::new(Field::new_list_field(UInt64, true))) - } - _ => { - return plan_err!("The array_dims function can only accept List/LargeList/FixedSizeList."); - } - }) + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::new_list(UInt64, true)) } fn invoke_with_args( @@ -156,7 +151,7 @@ pub(super) struct ArrayNdims { impl ArrayNdims { pub fn new() -> Self { Self { - signature: Signature::array(Volatility::Immutable), + signature: Signature::arrays(1, None, Volatility::Immutable), aliases: vec![String::from("list_ndims")], } } @@ -174,13 +169,8 @@ impl ScalarUDFImpl for ArrayNdims { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(match arg_types[0] { - List(_) | LargeList(_) | FixedSizeList(_, _) => UInt64, - _ => { - return plan_err!("The array_ndims function can only accept List/LargeList/FixedSizeList."); - } - }) + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(UInt64) } fn invoke_with_args( @@ -202,61 +192,42 @@ impl ScalarUDFImpl for ArrayNdims { /// Array_dims SQL function pub fn array_dims_inner(args: &[ArrayRef]) -> Result { let [array] = take_function_args("array_dims", args)?; - - let data = match array.data_type() { - List(_) => { - let array = as_list_array(&array)?; - array - .iter() - .map(compute_array_dims) - .collect::>>()? - } - LargeList(_) => { - let array = as_large_list_array(&array)?; - array - .iter() - .map(compute_array_dims) - .collect::>>()? - } - array_type => { - return exec_err!("array_dims does not support type '{array_type:?}'"); + let data: Vec<_> = match array.data_type() { + List(_) => as_list_array(&array)? + .iter() + .map(compute_array_dims) + .try_collect()?, + LargeList(_) => as_large_list_array(&array)? + .iter() + .map(compute_array_dims) + .try_collect()?, + FixedSizeList(..) => as_fixed_size_list_array(&array)? + .iter() + .map(compute_array_dims) + .try_collect()?, + arg_type => { + return exec_err!("array_dims does not support type {arg_type}"); } }; let result = ListArray::from_iter_primitive::(data); - - Ok(Arc::new(result) as ArrayRef) + Ok(Arc::new(result)) } /// Array_ndims SQL function pub fn array_ndims_inner(args: &[ArrayRef]) -> Result { - let [array_dim] = take_function_args("array_ndims", args)?; + let [array] = take_function_args("array_ndims", args)?; - fn general_list_ndims( - array: &GenericListArray, - ) -> Result { - let mut data = Vec::new(); - let ndims = datafusion_common::utils::list_ndims(array.data_type()); - - for arr in array.iter() { - if arr.is_some() { - data.push(Some(ndims)) - } else { - data.push(None) - } - } - - Ok(Arc::new(UInt64Array::from(data)) as ArrayRef) + fn general_list_ndims(array: &ArrayRef) -> Result { + let ndims = list_ndims(array.data_type()); + let data = vec![ndims; array.len()]; + let result = UInt64Array::new(data.into(), array.nulls().cloned()); + Ok(Arc::new(result)) } - match array_dim.data_type() { - List(_) => { - let array = as_list_array(&array_dim)?; - general_list_ndims::(array) - } - LargeList(_) => { - let array = as_large_list_array(&array_dim)?; - general_list_ndims::(array) - } - array_type => exec_err!("array_ndims does not support type {array_type:?}"), + + match array.data_type() { + Null => Ok(Arc::new(UInt64Array::new_null(array.len()))), + List(_) | LargeList(_) | FixedSizeList(..) => general_list_ndims(array), + arg_type => exec_err!("array_ndims does not support type {arg_type}"), } } diff --git a/datafusion/functions-nested/src/distance.rs b/datafusion/functions-nested/src/distance.rs index cfc7fccdd70c..3392e194b176 100644 --- a/datafusion/functions-nested/src/distance.rs +++ b/datafusion/functions-nested/src/distance.rs @@ -23,21 +23,22 @@ use arrow::array::{ }; use arrow::datatypes::{ DataType, - DataType::{FixedSizeList, Float64, LargeList, List}, + DataType::{FixedSizeList, LargeList, List, Null}, }; use datafusion_common::cast::{ as_float32_array, as_float64_array, as_generic_list_array, as_int32_array, as_int64_array, }; -use datafusion_common::utils::coerced_fixed_size_list_to_list; +use datafusion_common::utils::{coerced_type_with_base_type_only, ListCoercion}; use datafusion_common::{ - exec_err, internal_datafusion_err, utils::take_function_args, Result, + exec_err, internal_datafusion_err, plan_err, utils::take_function_args, Result, }; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; use datafusion_functions::{downcast_arg, downcast_named_arg}; use datafusion_macros::user_doc; +use itertools::Itertools; use std::any::Any; use std::sync::Arc; @@ -104,24 +105,26 @@ impl ScalarUDFImpl for ArrayDistance { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - match arg_types[0] { - List(_) | LargeList(_) | FixedSizeList(_, _) => Ok(Float64), - _ => exec_err!("The array_distance function can only accept List/LargeList/FixedSizeList."), - } + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) } fn coerce_types(&self, arg_types: &[DataType]) -> Result> { let [_, _] = take_function_args(self.name(), arg_types)?; - let mut result = Vec::new(); - for arg_type in arg_types { - match arg_type { - List(_) | LargeList(_) | FixedSizeList(_, _) => result.push(coerced_fixed_size_list_to_list(arg_type)), - _ => return exec_err!("The array_distance function can only accept List/LargeList/FixedSizeList."), + let coercion = Some(&ListCoercion::FixedSizedListToList); + let arg_types = arg_types.iter().map(|arg_type| { + if matches!(arg_type, Null | List(_) | LargeList(_) | FixedSizeList(..)) { + Ok(coerced_type_with_base_type_only( + arg_type, + &DataType::Float64, + coercion, + )) + } else { + plan_err!("{} does not support type {arg_type}", self.name()) } - } + }); - Ok(result) + arg_types.try_collect() } fn invoke_with_args( @@ -142,12 +145,11 @@ impl ScalarUDFImpl for ArrayDistance { pub fn array_distance_inner(args: &[ArrayRef]) -> Result { let [array1, array2] = take_function_args("array_distance", args)?; - - match (&array1.data_type(), &array2.data_type()) { + match (array1.data_type(), array2.data_type()) { (List(_), List(_)) => general_array_distance::(args), (LargeList(_), LargeList(_)) => general_array_distance::(args), - (array_type1, array_type2) => { - exec_err!("array_distance does not support types '{array_type1:?}' and '{array_type2:?}'") + (arg_type1, arg_type2) => { + exec_err!("array_distance does not support types {arg_type1} and {arg_type2}") } } } @@ -243,7 +245,7 @@ fn compute_array_distance( /// Converts an array of any numeric type to a Float64Array. fn convert_to_f64_array(array: &ArrayRef) -> Result { match array.data_type() { - Float64 => Ok(as_float64_array(array)?.clone()), + DataType::Float64 => Ok(as_float64_array(array)?.clone()), DataType::Float32 => { let array = as_float32_array(array)?; let converted: Float64Array = diff --git a/datafusion/functions-nested/src/empty.rs b/datafusion/functions-nested/src/empty.rs index dcefd583e937..67c795886bde 100644 --- a/datafusion/functions-nested/src/empty.rs +++ b/datafusion/functions-nested/src/empty.rs @@ -18,13 +18,14 @@ //! [`ScalarUDFImpl`] definitions for array_empty function. use crate::utils::make_scalar_function; -use arrow::array::{ArrayRef, BooleanArray, OffsetSizeTrait}; +use arrow::array::{Array, ArrayRef, BooleanArray, OffsetSizeTrait}; +use arrow::buffer::BooleanBuffer; use arrow::datatypes::{ DataType, DataType::{Boolean, FixedSizeList, LargeList, List}, }; use datafusion_common::cast::as_generic_list_array; -use datafusion_common::{exec_err, plan_err, utils::take_function_args, Result}; +use datafusion_common::{exec_err, utils::take_function_args, Result}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; @@ -71,7 +72,7 @@ impl Default for ArrayEmpty { impl ArrayEmpty { pub fn new() -> Self { Self { - signature: Signature::array(Volatility::Immutable), + signature: Signature::arrays(1, None, Volatility::Immutable), aliases: vec!["array_empty".to_string(), "list_empty".to_string()], } } @@ -89,13 +90,8 @@ impl ScalarUDFImpl for ArrayEmpty { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(match arg_types[0] { - List(_) | LargeList(_) | FixedSizeList(_, _) => Boolean, - _ => { - return plan_err!("The array_empty function can only accept List/LargeList/FixedSizeList."); - } - }) + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(Boolean) } fn invoke_with_args( @@ -117,21 +113,25 @@ impl ScalarUDFImpl for ArrayEmpty { /// Array_empty SQL function pub fn array_empty_inner(args: &[ArrayRef]) -> Result { let [array] = take_function_args("array_empty", args)?; - - let array_type = array.data_type(); - match array_type { + match array.data_type() { List(_) => general_array_empty::(array), LargeList(_) => general_array_empty::(array), - _ => exec_err!("array_empty does not support type '{array_type:?}'."), + FixedSizeList(_, size) => { + let values = if *size == 0 { + BooleanBuffer::new_set(array.len()) + } else { + BooleanBuffer::new_unset(array.len()) + }; + Ok(Arc::new(BooleanArray::new(values, array.nulls().cloned()))) + } + arg_type => exec_err!("array_empty does not support type {arg_type}"), } } fn general_array_empty(array: &ArrayRef) -> Result { - let array = as_generic_list_array::(array)?; - - let builder = array + let result = as_generic_list_array::(array)? .iter() .map(|arr| arr.map(|arr| arr.is_empty())) .collect::(); - Ok(Arc::new(builder)) + Ok(Arc::new(result)) } diff --git a/datafusion/functions-nested/src/extract.rs b/datafusion/functions-nested/src/extract.rs index 321dda55ce09..95bf5a7341d9 100644 --- a/datafusion/functions-nested/src/extract.rs +++ b/datafusion/functions-nested/src/extract.rs @@ -19,12 +19,12 @@ use arrow::array::{ Array, ArrayRef, ArrowNativeTypeOp, Capacities, GenericListArray, Int64Array, - MutableArrayData, NullBufferBuilder, OffsetSizeTrait, + MutableArrayData, NullArray, NullBufferBuilder, OffsetSizeTrait, }; use arrow::buffer::OffsetBuffer; use arrow::datatypes::DataType; use arrow::datatypes::{ - DataType::{FixedSizeList, LargeList, List}, + DataType::{FixedSizeList, LargeList, List, Null}, Field, }; use datafusion_common::cast::as_int64_array; @@ -163,13 +163,9 @@ impl ScalarUDFImpl for ArrayElement { fn return_type(&self, arg_types: &[DataType]) -> Result { match &arg_types[0] { - List(field) - | LargeList(field) - | FixedSizeList(field, _) => Ok(field.data_type().clone()), - DataType::Null => Ok(List(Arc::new(Field::new_list_field(DataType::Int64, true)))), - _ => plan_err!( - "ArrayElement can only accept List, LargeList or FixedSizeList as the first argument" - ), + Null => Ok(Null), + List(field) | LargeList(field) => Ok(field.data_type().clone()), + arg_type => plan_err!("{} does not support type {arg_type}", self.name()), } } @@ -200,6 +196,7 @@ fn array_element_inner(args: &[ArrayRef]) -> Result { let [array, indexes] = take_function_args("array_element", args)?; match &array.data_type() { + Null => Ok(Arc::new(NullArray::new(array.len()))), List(_) => { let array = as_list_array(&array)?; let indexes = as_int64_array(&indexes)?; @@ -210,10 +207,9 @@ fn array_element_inner(args: &[ArrayRef]) -> Result { let indexes = as_int64_array(&indexes)?; general_array_element::(array, indexes) } - _ => exec_err!( - "array_element does not support type: {:?}", - array.data_type() - ), + arg_type => { + exec_err!("array_element does not support type {arg_type}") + } } } @@ -225,6 +221,10 @@ where i64: TryInto, { let values = array.values(); + if values.data_type().is_null() { + return Ok(Arc::new(NullArray::new(array.len()))); + } + let original_data = values.to_data(); let capacity = Capacities::Array(original_data.len()); @@ -238,8 +238,7 @@ where { let index: O = index.try_into().map_err(|_| { DataFusionError::Execution(format!( - "array_element got invalid index: {}", - index + "array_element got invalid index: {index}" )) })?; // 0 ~ len - 1 diff --git a/datafusion/functions-nested/src/flatten.rs b/datafusion/functions-nested/src/flatten.rs index f288035948dc..c6fa2831f4f0 100644 --- a/datafusion/functions-nested/src/flatten.rs +++ b/datafusion/functions-nested/src/flatten.rs @@ -18,19 +18,18 @@ //! [`ScalarUDFImpl`] definitions for flatten function. use crate::utils::make_scalar_function; -use arrow::array::{ArrayRef, GenericListArray, OffsetSizeTrait}; +use arrow::array::{Array, ArrayRef, GenericListArray, OffsetSizeTrait}; use arrow::buffer::OffsetBuffer; use arrow::datatypes::{ DataType, DataType::{FixedSizeList, LargeList, List, Null}, }; -use datafusion_common::cast::{ - as_generic_list_array, as_large_list_array, as_list_array, -}; +use datafusion_common::cast::{as_large_list_array, as_list_array}; +use datafusion_common::utils::ListCoercion; use datafusion_common::{exec_err, utils::take_function_args, Result}; use datafusion_expr::{ - ArrayFunctionSignature, ColumnarValue, Documentation, ScalarUDFImpl, Signature, - TypeSignature, Volatility, + ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation, + ScalarUDFImpl, Signature, TypeSignature, Volatility, }; use datafusion_macros::user_doc; use std::any::Any; @@ -77,9 +76,11 @@ impl Flatten { pub fn new() -> Self { Self { signature: Signature { - // TODO (https://github.com/apache/datafusion/issues/13757) flatten should be single-step, not recursive type_signature: TypeSignature::ArraySignature( - ArrayFunctionSignature::RecursiveArray, + ArrayFunctionSignature::Array { + arguments: vec![ArrayFunctionArgument::Array], + array_coercion: Some(ListCoercion::FixedSizedListToList), + }, ), volatility: Volatility::Immutable, }, @@ -102,25 +103,23 @@ impl ScalarUDFImpl for Flatten { } fn return_type(&self, arg_types: &[DataType]) -> Result { - fn get_base_type(data_type: &DataType) -> Result { - match data_type { - List(field) | FixedSizeList(field, _) - if matches!(field.data_type(), List(_) | FixedSizeList(_, _)) => - { - get_base_type(field.data_type()) - } - LargeList(field) if matches!(field.data_type(), LargeList(_)) => { - get_base_type(field.data_type()) + let data_type = match &arg_types[0] { + List(field) | FixedSizeList(field, _) => match field.data_type() { + List(field) | FixedSizeList(field, _) => List(Arc::clone(field)), + _ => arg_types[0].clone(), + }, + LargeList(field) => match field.data_type() { + List(field) | LargeList(field) | FixedSizeList(field, _) => { + LargeList(Arc::clone(field)) } - Null | List(_) | LargeList(_) => Ok(data_type.to_owned()), - FixedSizeList(field, _) => Ok(List(Arc::clone(field))), - _ => exec_err!( - "Not reachable, data_type should be List, LargeList or FixedSizeList" - ), - } - } + _ => arg_types[0].clone(), + }, + Null => Null, + _ => exec_err!( + "Not reachable, data_type should be List, LargeList or FixedSizeList" + )?, + }; - let data_type = get_base_type(&arg_types[0])?; Ok(data_type) } @@ -146,14 +145,64 @@ pub fn flatten_inner(args: &[ArrayRef]) -> Result { match array.data_type() { List(_) => { - let list_arr = as_list_array(&array)?; - let flattened_array = flatten_internal::(list_arr.clone(), None)?; - Ok(Arc::new(flattened_array) as ArrayRef) + let (_field, offsets, values, nulls) = + as_list_array(&array)?.clone().into_parts(); + let values = cast_fsl_to_list(values)?; + + match values.data_type() { + List(_) => { + let (inner_field, inner_offsets, inner_values, _) = + as_list_array(&values)?.clone().into_parts(); + let offsets = get_offsets_for_flatten::(inner_offsets, offsets); + let flattened_array = GenericListArray::::new( + inner_field, + offsets, + inner_values, + nulls, + ); + + Ok(Arc::new(flattened_array) as ArrayRef) + } + LargeList(_) => { + exec_err!("flatten does not support type '{:?}'", array.data_type())? + } + _ => Ok(Arc::clone(array) as ArrayRef), + } } LargeList(_) => { - let list_arr = as_large_list_array(&array)?; - let flattened_array = flatten_internal::(list_arr.clone(), None)?; - Ok(Arc::new(flattened_array) as ArrayRef) + let (_field, offsets, values, nulls) = + as_large_list_array(&array)?.clone().into_parts(); + let values = cast_fsl_to_list(values)?; + + match values.data_type() { + List(_) => { + let (inner_field, inner_offsets, inner_values, _) = + as_list_array(&values)?.clone().into_parts(); + let offsets = get_large_offsets_for_flatten(inner_offsets, offsets); + let flattened_array = GenericListArray::::new( + inner_field, + offsets, + inner_values, + nulls, + ); + + Ok(Arc::new(flattened_array) as ArrayRef) + } + LargeList(_) => { + let (inner_field, inner_offsets, inner_values, nulls) = + as_large_list_array(&values)?.clone().into_parts(); + let offsets = get_offsets_for_flatten::(inner_offsets, offsets); + let flattened_array = GenericListArray::::new( + inner_field, + offsets, + inner_values, + nulls, + ); + + Ok(Arc::new(flattened_array) as ArrayRef) + } + _ => Ok(Arc::clone(array) as ArrayRef), + } } Null => Ok(Arc::clone(array)), _ => { @@ -162,37 +211,6 @@ pub fn flatten_inner(args: &[ArrayRef]) -> Result { } } -fn flatten_internal( - list_arr: GenericListArray, - indexes: Option>, -) -> Result> { - let (field, offsets, values, _) = list_arr.clone().into_parts(); - let data_type = field.data_type(); - - match data_type { - // Recursively get the base offsets for flattened array - List(_) | LargeList(_) => { - let sub_list = as_generic_list_array::(&values)?; - if let Some(indexes) = indexes { - let offsets = get_offsets_for_flatten(offsets, indexes); - flatten_internal::(sub_list.clone(), Some(offsets)) - } else { - flatten_internal::(sub_list.clone(), Some(offsets)) - } - } - // Reach the base level, create a new list array - _ => { - if let Some(indexes) = indexes { - let offsets = get_offsets_for_flatten(offsets, indexes); - let list_arr = GenericListArray::::new(field, offsets, values, None); - Ok(list_arr) - } else { - Ok(list_arr) - } - } - } -} - // Create new offsets that are equivalent to `flatten` the array. fn get_offsets_for_flatten( offsets: OffsetBuffer, @@ -205,3 +223,25 @@ fn get_offsets_for_flatten( .collect(); OffsetBuffer::new(offsets.into()) } + +// Create new large offsets that are equivalent to `flatten` the array. +fn get_large_offsets_for_flatten( + offsets: OffsetBuffer, + indexes: OffsetBuffer

, +) -> OffsetBuffer { + let buffer = offsets.into_inner(); + let offsets: Vec = indexes + .iter() + .map(|i| buffer[i.to_usize().unwrap()].to_i64().unwrap()) + .collect(); + OffsetBuffer::new(offsets.into()) +} + +fn cast_fsl_to_list(array: ArrayRef) -> Result { + match array.data_type() { + FixedSizeList(field, _) => { + Ok(arrow::compute::cast(&array, &List(Arc::clone(field)))?) + } + _ => Ok(array), + } +} diff --git a/datafusion/functions-nested/src/length.rs b/datafusion/functions-nested/src/length.rs index 3c3a42da0d69..0da12684158e 100644 --- a/datafusion/functions-nested/src/length.rs +++ b/datafusion/functions-nested/src/length.rs @@ -19,13 +19,16 @@ use crate::utils::make_scalar_function; use arrow::array::{ - Array, ArrayRef, Int64Array, LargeListArray, ListArray, OffsetSizeTrait, UInt64Array, + Array, ArrayRef, FixedSizeListArray, Int64Array, LargeListArray, ListArray, + OffsetSizeTrait, UInt64Array, }; use arrow::datatypes::{ DataType, DataType::{FixedSizeList, LargeList, List, UInt64}, }; -use datafusion_common::cast::{as_generic_list_array, as_int64_array}; +use datafusion_common::cast::{ + as_fixed_size_list_array, as_generic_list_array, as_int64_array, +}; use datafusion_common::{exec_err, internal_datafusion_err, plan_err, Result}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, @@ -119,6 +122,23 @@ impl ScalarUDFImpl for ArrayLength { } } +macro_rules! array_length_impl { + ($array:expr, $dimension:expr) => {{ + let array = $array; + let dimension = match $dimension { + Some(d) => as_int64_array(d)?.clone(), + None => Int64Array::from_value(1, array.len()), + }; + let result = array + .iter() + .zip(dimension.iter()) + .map(|(arr, dim)| compute_array_length(arr, dim)) + .collect::>()?; + + Ok(Arc::new(result) as ArrayRef) + }}; +} + /// Array_length SQL function pub fn array_length_inner(args: &[ArrayRef]) -> Result { if args.len() != 1 && args.len() != 2 { @@ -128,26 +148,18 @@ pub fn array_length_inner(args: &[ArrayRef]) -> Result { match &args[0].data_type() { List(_) => general_array_length::(args), LargeList(_) => general_array_length::(args), + FixedSizeList(_, _) => fixed_size_array_length(args), array_type => exec_err!("array_length does not support type '{array_type:?}'"), } } +fn fixed_size_array_length(array: &[ArrayRef]) -> Result { + array_length_impl!(as_fixed_size_list_array(&array[0])?, array.get(1)) +} + /// Dispatch array length computation based on the offset type. fn general_array_length(array: &[ArrayRef]) -> Result { - let list_array = as_generic_list_array::(&array[0])?; - let dimension = if array.len() == 2 { - as_int64_array(&array[1])?.clone() - } else { - Int64Array::from_value(1, list_array.len()) - }; - - let result = list_array - .iter() - .zip(dimension.iter()) - .map(|(arr, dim)| compute_array_length(arr, dim)) - .collect::>()?; - - Ok(Arc::new(result) as ArrayRef) + array_length_impl!(as_generic_list_array::(&array[0])?, array.get(1)) } /// Returns the length of a concrete array dimension @@ -185,6 +197,10 @@ fn compute_array_length( value = downcast_arg!(value, LargeListArray).value(0); current_dimension += 1; } + FixedSizeList(_, _) => { + value = downcast_arg!(value, FixedSizeListArray).value(0); + current_dimension += 1; + } _ => return Ok(None), } } diff --git a/datafusion/functions-nested/src/lib.rs b/datafusion/functions-nested/src/lib.rs index c9a61d98cd44..b05b53d2d8ee 100644 --- a/datafusion/functions-nested/src/lib.rs +++ b/datafusion/functions-nested/src/lib.rs @@ -201,8 +201,7 @@ mod tests { for alias in func.aliases() { assert!( names.insert(alias.to_string().to_lowercase()), - "duplicate function name: {}", - alias + "duplicate function name: {alias}" ); } } diff --git a/datafusion/functions-nested/src/make_array.rs b/datafusion/functions-nested/src/make_array.rs index 4daaafc5a888..babb03919157 100644 --- a/datafusion/functions-nested/src/make_array.rs +++ b/datafusion/functions-nested/src/make_array.rs @@ -28,10 +28,7 @@ use arrow::array::{ }; use arrow::buffer::OffsetBuffer; use arrow::datatypes::DataType; -use arrow::datatypes::{ - DataType::{List, Null}, - Field, -}; +use arrow::datatypes::{DataType::Null, Field}; use datafusion_common::utils::SingleRowListArrayBuilder; use datafusion_common::{plan_err, Result}; use datafusion_expr::binary::{ @@ -105,16 +102,14 @@ impl ScalarUDFImpl for MakeArray { } fn return_type(&self, arg_types: &[DataType]) -> Result { - match arg_types.len() { - 0 => Ok(empty_array_type()), - _ => { - // At this point, all the type in array should be coerced to the same one - Ok(List(Arc::new(Field::new_list_field( - arg_types[0].to_owned(), - true, - )))) - } - } + let element_type = if arg_types.is_empty() { + Null + } else { + // At this point, all the type in array should be coerced to the same one. + arg_types[0].to_owned() + }; + + Ok(DataType::new_list(element_type, true)) } fn invoke_with_args( @@ -129,26 +124,16 @@ impl ScalarUDFImpl for MakeArray { } fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - let mut errors = vec![]; - match try_type_union_resolution_with_struct(arg_types) { - Ok(r) => return Ok(r), - Err(e) => { - errors.push(e); - } + if let Ok(unified) = try_type_union_resolution_with_struct(arg_types) { + return Ok(unified); } - if let Some(new_type) = type_union_resolution(arg_types) { - if new_type.is_null() { - Ok(vec![DataType::Int64; arg_types.len()]) - } else { - Ok(vec![new_type; arg_types.len()]) - } + if let Some(unified) = type_union_resolution(arg_types) { + Ok(vec![unified; arg_types.len()]) } else { plan_err!( - "Fail to find the valid type between {:?} for {}, errors are {:?}", - arg_types, - self.name(), - errors + "Failed to unify argument types of {}: {arg_types:?}", + self.name() ) } } @@ -158,35 +143,25 @@ impl ScalarUDFImpl for MakeArray { } } -// Empty array is a special case that is useful for many other array functions -pub(super) fn empty_array_type() -> DataType { - List(Arc::new(Field::new_list_field(DataType::Int64, true))) -} - /// `make_array_inner` is the implementation of the `make_array` function. /// Constructs an array using the input `data` as `ArrayRef`. /// Returns a reference-counted `Array` instance result. pub(crate) fn make_array_inner(arrays: &[ArrayRef]) -> Result { - let mut data_type = Null; - for arg in arrays { - let arg_data_type = arg.data_type(); - if !arg_data_type.equals_datatype(&Null) { - data_type = arg_data_type.clone(); - break; - } - } + let data_type = arrays.iter().find_map(|arg| { + let arg_type = arg.data_type(); + (!arg_type.is_null()).then_some(arg_type) + }); - match data_type { + let data_type = data_type.unwrap_or(&Null); + if data_type.is_null() { // Either an empty array or all nulls: - Null => { - let length = arrays.iter().map(|a| a.len()).sum(); - // By default Int64 - let array = new_null_array(&DataType::Int64, length); - Ok(Arc::new( - SingleRowListArrayBuilder::new(array).build_list_array(), - )) - } - _ => array_array::(arrays, data_type), + let length = arrays.iter().map(|a| a.len()).sum(); + let array = new_null_array(&Null, length); + Ok(Arc::new( + SingleRowListArrayBuilder::new(array).build_list_array(), + )) + } else { + array_array::(arrays, data_type.clone()) } } diff --git a/datafusion/functions-nested/src/map_values.rs b/datafusion/functions-nested/src/map_values.rs index f82e4bfa1a89..8247fdd4a74c 100644 --- a/datafusion/functions-nested/src/map_values.rs +++ b/datafusion/functions-nested/src/map_values.rs @@ -19,15 +19,16 @@ use crate::utils::{get_map_entry_field, make_scalar_function}; use arrow::array::{Array, ArrayRef, ListArray}; -use arrow::datatypes::{DataType, Field}; +use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion_common::utils::take_function_args; -use datafusion_common::{cast::as_map_array, exec_err, Result}; +use datafusion_common::{cast::as_map_array, exec_err, internal_err, Result}; use datafusion_expr::{ ArrayFunctionSignature, ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature, Volatility, }; use datafusion_macros::user_doc; use std::any::Any; +use std::ops::Deref; use std::sync::Arc; make_udf_expr_and_func!( @@ -91,13 +92,23 @@ impl ScalarUDFImpl for MapValuesFunc { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - let [map_type] = take_function_args(self.name(), arg_types)?; - let map_fields = get_map_entry_field(map_type)?; - Ok(DataType::List(Arc::new(Field::new_list_field( - map_fields.last().unwrap().data_type().clone(), - true, - )))) + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args( + &self, + args: datafusion_expr::ReturnFieldArgs, + ) -> Result { + let [map_type] = take_function_args(self.name(), args.arg_fields)?; + + Ok(Field::new( + self.name(), + DataType::List(get_map_values_field_as_list_field(map_type.data_type())?), + // Nullable if the map is nullable + args.arg_fields.iter().any(|x| x.is_nullable()), + ) + .into()) } fn invoke_with_args( @@ -121,9 +132,139 @@ fn map_values_inner(args: &[ArrayRef]) -> Result { }; Ok(Arc::new(ListArray::new( - Arc::new(Field::new_list_field(map_array.value_type().clone(), true)), + get_map_values_field_as_list_field(map_arg.data_type())?, map_array.offsets().clone(), Arc::clone(map_array.values()), map_array.nulls().cloned(), ))) } + +fn get_map_values_field_as_list_field(map_type: &DataType) -> Result { + let map_fields = get_map_entry_field(map_type)?; + + let values_field = map_fields + .last() + .unwrap() + .deref() + .clone() + .with_name(Field::LIST_FIELD_DEFAULT_NAME); + + Ok(Arc::new(values_field)) +} + +#[cfg(test)] +mod tests { + use crate::map_values::MapValuesFunc; + use arrow::datatypes::{DataType, Field, FieldRef}; + use datafusion_common::ScalarValue; + use datafusion_expr::ScalarUDFImpl; + use std::sync::Arc; + + #[test] + fn return_type_field() { + fn get_map_field( + is_map_nullable: bool, + is_keys_nullable: bool, + is_values_nullable: bool, + ) -> FieldRef { + Field::new_map( + "something", + "entries", + Arc::new(Field::new("keys", DataType::Utf8, is_keys_nullable)), + Arc::new(Field::new( + "values", + DataType::LargeUtf8, + is_values_nullable, + )), + false, + is_map_nullable, + ) + .into() + } + + fn get_list_field( + name: &str, + is_list_nullable: bool, + list_item_type: DataType, + is_list_items_nullable: bool, + ) -> FieldRef { + Field::new_list( + name, + Arc::new(Field::new_list_field( + list_item_type, + is_list_items_nullable, + )), + is_list_nullable, + ) + .into() + } + + fn get_return_field(field: FieldRef) -> FieldRef { + let func = MapValuesFunc::new(); + let args = datafusion_expr::ReturnFieldArgs { + arg_fields: &[field], + scalar_arguments: &[None::<&ScalarValue>], + }; + + func.return_field_from_args(args).unwrap() + } + + // Test cases: + // + // | Input Map || Expected Output | + // | ------------------------------------------------------ || ----------------------------------------------------- | + // | map nullable | map keys nullable | map values nullable || expected list nullable | expected list items nullable | + // | ------------ | ----------------- | ------------------- || ---------------------- | ---------------------------- | + // | false | false | false || false | false | + // | false | false | true || false | true | + // | false | true | false || false | false | + // | false | true | true || false | true | + // | true | false | false || true | false | + // | true | false | true || true | true | + // | true | true | false || true | false | + // | true | true | true || true | true | + // + // --------------- + // We added the key nullability to show that it does not affect the nullability of the list or the list items. + + assert_eq!( + get_return_field(get_map_field(false, false, false)), + get_list_field("map_values", false, DataType::LargeUtf8, false) + ); + + assert_eq!( + get_return_field(get_map_field(false, false, true)), + get_list_field("map_values", false, DataType::LargeUtf8, true) + ); + + assert_eq!( + get_return_field(get_map_field(false, true, false)), + get_list_field("map_values", false, DataType::LargeUtf8, false) + ); + + assert_eq!( + get_return_field(get_map_field(false, true, true)), + get_list_field("map_values", false, DataType::LargeUtf8, true) + ); + + assert_eq!( + get_return_field(get_map_field(true, false, false)), + get_list_field("map_values", true, DataType::LargeUtf8, false) + ); + + assert_eq!( + get_return_field(get_map_field(true, false, true)), + get_list_field("map_values", true, DataType::LargeUtf8, true) + ); + + assert_eq!( + get_return_field(get_map_field(true, true, false)), + get_list_field("map_values", true, DataType::LargeUtf8, false) + ); + + assert_eq!( + get_return_field(get_map_field(true, true, true)), + get_list_field("map_values", true, DataType::LargeUtf8, true) + ); + } +} diff --git a/datafusion/functions-nested/src/max.rs b/datafusion/functions-nested/src/max.rs index 32957edc62b5..b667a7b42650 100644 --- a/datafusion/functions-nested/src/max.rs +++ b/datafusion/functions-nested/src/max.rs @@ -17,12 +17,13 @@ //! [`ScalarUDFImpl`] definitions for array_max function. use crate::utils::make_scalar_function; -use arrow::array::ArrayRef; +use arrow::array::{ArrayRef, GenericListArray, OffsetSizeTrait}; use arrow::datatypes::DataType; -use arrow::datatypes::DataType::List; -use datafusion_common::cast::as_list_array; +use arrow::datatypes::DataType::{LargeList, List}; +use datafusion_common::cast::{as_large_list_array, as_list_array}; use datafusion_common::utils::take_function_args; -use datafusion_common::{exec_err, ScalarValue}; +use datafusion_common::Result; +use datafusion_common::{exec_err, plan_err, ScalarValue}; use datafusion_doc::Documentation; use datafusion_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, @@ -91,17 +92,15 @@ impl ScalarUDFImpl for ArrayMax { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result { - match &arg_types[0] { - List(field) => Ok(field.data_type().clone()), - _ => exec_err!("Not reachable, data_type should be List"), + fn return_type(&self, arg_types: &[DataType]) -> Result { + let [array] = take_function_args(self.name(), arg_types)?; + match array { + List(field) | LargeList(field) => Ok(field.data_type().clone()), + arg_type => plan_err!("{} does not support type {arg_type}", self.name()), } } - fn invoke_with_args( - &self, - args: ScalarFunctionArgs, - ) -> datafusion_common::Result { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { make_scalar_function(array_max_inner)(&args.args) } @@ -121,18 +120,25 @@ impl ScalarUDFImpl for ArrayMax { /// /// For example: /// > array_max(\[1, 3, 2]) -> 3 -pub fn array_max_inner(args: &[ArrayRef]) -> datafusion_common::Result { - let [arg1] = take_function_args("array_max", args)?; - - match arg1.data_type() { - List(_) => { - let input_list_array = as_list_array(&arg1)?; - let result_vec = input_list_array - .iter() - .flat_map(|arr| min_max::max_batch(&arr.unwrap())) - .collect_vec(); - ScalarValue::iter_to_array(result_vec) - } - _ => exec_err!("array_max does not support type: {:?}", arg1.data_type()), +pub fn array_max_inner(args: &[ArrayRef]) -> Result { + let [array] = take_function_args("array_max", args)?; + match array.data_type() { + List(_) => general_array_max(as_list_array(array)?), + LargeList(_) => general_array_max(as_large_list_array(array)?), + arg_type => exec_err!("array_max does not support type: {arg_type}"), } } + +fn general_array_max( + array: &GenericListArray, +) -> Result { + let null_value = ScalarValue::try_from(array.value_type())?; + let result_vec: Vec = array + .iter() + .map(|arr| { + arr.as_ref() + .map_or_else(|| Ok(null_value.clone()), min_max::max_batch) + }) + .try_collect()?; + ScalarValue::iter_to_array(result_vec) +} diff --git a/datafusion/functions-nested/src/set_ops.rs b/datafusion/functions-nested/src/set_ops.rs index a67945b1f1e1..4f9457aa59c6 100644 --- a/datafusion/functions-nested/src/set_ops.rs +++ b/datafusion/functions-nested/src/set_ops.rs @@ -17,16 +17,21 @@ //! [`ScalarUDFImpl`] definitions for array_union, array_intersect and array_distinct functions. -use crate::make_array::{empty_array_type, make_array_inner}; use crate::utils::make_scalar_function; -use arrow::array::{new_empty_array, Array, ArrayRef, GenericListArray, OffsetSizeTrait}; +use arrow::array::{ + new_null_array, Array, ArrayRef, GenericListArray, LargeListArray, ListArray, + OffsetSizeTrait, +}; use arrow::buffer::OffsetBuffer; use arrow::compute; -use arrow::datatypes::DataType::{FixedSizeList, LargeList, List, Null}; +use arrow::datatypes::DataType::{LargeList, List, Null}; use arrow::datatypes::{DataType, Field, FieldRef}; use arrow::row::{RowConverter, SortField}; use datafusion_common::cast::{as_large_list_array, as_list_array}; -use datafusion_common::{exec_err, internal_err, utils::take_function_args, Result}; +use datafusion_common::utils::ListCoercion; +use datafusion_common::{ + exec_err, internal_err, plan_err, utils::take_function_args, Result, +}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; @@ -104,7 +109,11 @@ impl Default for ArrayUnion { impl ArrayUnion { pub fn new() -> Self { Self { - signature: Signature::any(2, Volatility::Immutable), + signature: Signature::arrays( + 2, + Some(ListCoercion::FixedSizedListToList), + Volatility::Immutable, + ), aliases: vec![String::from("list_union")], } } @@ -124,8 +133,10 @@ impl ScalarUDFImpl for ArrayUnion { } fn return_type(&self, arg_types: &[DataType]) -> Result { - match (&arg_types[0], &arg_types[1]) { - (&Null, dt) => Ok(dt.clone()), + let [array1, array2] = take_function_args(self.name(), arg_types)?; + match (array1, array2) { + (Null, Null) => Ok(DataType::new_list(Null, true)), + (Null, dt) => Ok(dt.clone()), (dt, Null) => Ok(dt.clone()), (dt, _) => Ok(dt.clone()), } @@ -183,7 +194,11 @@ pub(super) struct ArrayIntersect { impl ArrayIntersect { pub fn new() -> Self { Self { - signature: Signature::any(2, Volatility::Immutable), + signature: Signature::arrays( + 2, + Some(ListCoercion::FixedSizedListToList), + Volatility::Immutable, + ), aliases: vec![String::from("list_intersect")], } } @@ -203,10 +218,12 @@ impl ScalarUDFImpl for ArrayIntersect { } fn return_type(&self, arg_types: &[DataType]) -> Result { - match (arg_types[0].clone(), arg_types[1].clone()) { - (Null, Null) | (Null, _) => Ok(Null), - (_, Null) => Ok(empty_array_type()), - (dt, _) => Ok(dt), + let [array1, array2] = take_function_args(self.name(), arg_types)?; + match (array1, array2) { + (Null, Null) => Ok(DataType::new_list(Null, true)), + (Null, dt) => Ok(dt.clone()), + (dt, Null) => Ok(dt.clone()), + (dt, _) => Ok(dt.clone()), } } @@ -273,16 +290,11 @@ impl ScalarUDFImpl for ArrayDistinct { fn return_type(&self, arg_types: &[DataType]) -> Result { match &arg_types[0] { - List(field) | FixedSizeList(field, _) => Ok(List(Arc::new( - Field::new_list_field(field.data_type().clone(), true), - ))), - LargeList(field) => Ok(LargeList(Arc::new(Field::new_list_field( - field.data_type().clone(), - true, - )))), - _ => exec_err!( - "Not reachable, data_type should be List, LargeList or FixedSizeList" - ), + List(field) => Ok(DataType::new_list(field.data_type().clone(), true)), + LargeList(field) => { + Ok(DataType::new_large_list(field.data_type().clone(), true)) + } + arg_type => plan_err!("{} does not support type {arg_type}", self.name()), } } @@ -305,24 +317,18 @@ impl ScalarUDFImpl for ArrayDistinct { /// array_distinct SQL function /// example: from list [1, 3, 2, 3, 1, 2, 4] to [1, 2, 3, 4] fn array_distinct_inner(args: &[ArrayRef]) -> Result { - let [input_array] = take_function_args("array_distinct", args)?; - - // handle null - if input_array.data_type() == &Null { - return Ok(Arc::clone(input_array)); - } - - // handle for list & largelist - match input_array.data_type() { + let [array] = take_function_args("array_distinct", args)?; + match array.data_type() { + Null => Ok(Arc::clone(array)), List(field) => { - let array = as_list_array(&input_array)?; + let array = as_list_array(&array)?; general_array_distinct(array, field) } LargeList(field) => { - let array = as_large_list_array(&input_array)?; + let array = as_large_list_array(&array)?; general_array_distinct(array, field) } - array_type => exec_err!("array_distinct does not support type '{array_type:?}'"), + arg_type => exec_err!("array_distinct does not support type {arg_type}"), } } @@ -347,80 +353,76 @@ fn generic_set_lists( field: Arc, set_op: SetOp, ) -> Result { - if matches!(l.value_type(), Null) { + if l.is_empty() || l.value_type().is_null() { let field = Arc::new(Field::new_list_field(r.value_type(), true)); return general_array_distinct::(r, &field); - } else if matches!(r.value_type(), Null) { + } else if r.is_empty() || r.value_type().is_null() { let field = Arc::new(Field::new_list_field(l.value_type(), true)); return general_array_distinct::(l, &field); } - // Handle empty array at rhs case - // array_union(arr, []) -> arr; - // array_intersect(arr, []) -> []; - if r.value_length(0).is_zero() { - if set_op == SetOp::Union { - return Ok(Arc::new(l.clone()) as ArrayRef); - } else { - return Ok(Arc::new(r.clone()) as ArrayRef); - } - } - if l.value_type() != r.value_type() { return internal_err!("{set_op:?} is not implemented for '{l:?}' and '{r:?}'"); } - let dt = l.value_type(); - let mut offsets = vec![OffsetSize::usize_as(0)]; let mut new_arrays = vec![]; - - let converter = RowConverter::new(vec![SortField::new(dt)])?; + let converter = RowConverter::new(vec![SortField::new(l.value_type())])?; for (first_arr, second_arr) in l.iter().zip(r.iter()) { - if let (Some(first_arr), Some(second_arr)) = (first_arr, second_arr) { - let l_values = converter.convert_columns(&[first_arr])?; - let r_values = converter.convert_columns(&[second_arr])?; - - let l_iter = l_values.iter().sorted().dedup(); - let values_set: HashSet<_> = l_iter.clone().collect(); - let mut rows = if set_op == SetOp::Union { - l_iter.collect::>() - } else { - vec![] - }; - for r_val in r_values.iter().sorted().dedup() { - match set_op { - SetOp::Union => { - if !values_set.contains(&r_val) { - rows.push(r_val); - } + let l_values = if let Some(first_arr) = first_arr { + converter.convert_columns(&[first_arr])? + } else { + converter.convert_columns(&[])? + }; + + let r_values = if let Some(second_arr) = second_arr { + converter.convert_columns(&[second_arr])? + } else { + converter.convert_columns(&[])? + }; + + let l_iter = l_values.iter().sorted().dedup(); + let values_set: HashSet<_> = l_iter.clone().collect(); + let mut rows = if set_op == SetOp::Union { + l_iter.collect() + } else { + vec![] + }; + + for r_val in r_values.iter().sorted().dedup() { + match set_op { + SetOp::Union => { + if !values_set.contains(&r_val) { + rows.push(r_val); } - SetOp::Intersect => { - if values_set.contains(&r_val) { - rows.push(r_val); - } + } + SetOp::Intersect => { + if values_set.contains(&r_val) { + rows.push(r_val); } } } - - let last_offset = match offsets.last().copied() { - Some(offset) => offset, - None => return internal_err!("offsets should not be empty"), - }; - offsets.push(last_offset + OffsetSize::usize_as(rows.len())); - let arrays = converter.convert_rows(rows)?; - let array = match arrays.first() { - Some(array) => Arc::clone(array), - None => { - return internal_err!("{set_op}: failed to get array from rows"); - } - }; - new_arrays.push(array); } + + let last_offset = match offsets.last() { + Some(offset) => *offset, + None => return internal_err!("offsets should not be empty"), + }; + + offsets.push(last_offset + OffsetSize::usize_as(rows.len())); + let arrays = converter.convert_rows(rows)?; + let array = match arrays.first() { + Some(array) => Arc::clone(array), + None => { + return internal_err!("{set_op}: failed to get array from rows"); + } + }; + + new_arrays.push(array); } let offsets = OffsetBuffer::new(offsets.into()); - let new_arrays_ref = new_arrays.iter().map(|v| v.as_ref()).collect::>(); + let new_arrays_ref: Vec<_> = new_arrays.iter().map(|v| v.as_ref()).collect(); let values = compute::concat(&new_arrays_ref)?; let arr = GenericListArray::::try_new(field, offsets, values, None)?; Ok(Arc::new(arr)) @@ -431,38 +433,59 @@ fn general_set_op( array2: &ArrayRef, set_op: SetOp, ) -> Result { + fn empty_array(data_type: &DataType, len: usize, large: bool) -> Result { + let field = Arc::new(Field::new_list_field(data_type.clone(), true)); + let values = new_null_array(data_type, len); + if large { + Ok(Arc::new(LargeListArray::try_new( + field, + OffsetBuffer::new_zeroed(len), + values, + None, + )?)) + } else { + Ok(Arc::new(ListArray::try_new( + field, + OffsetBuffer::new_zeroed(len), + values, + None, + )?)) + } + } + match (array1.data_type(), array2.data_type()) { + (Null, Null) => Ok(Arc::new(ListArray::new_null( + Arc::new(Field::new_list_field(Null, true)), + array1.len(), + ))), (Null, List(field)) => { if set_op == SetOp::Intersect { - return Ok(new_empty_array(&Null)); + return empty_array(field.data_type(), array1.len(), false); } let array = as_list_array(&array2)?; general_array_distinct::(array, field) } - (List(field), Null) => { if set_op == SetOp::Intersect { - return make_array_inner(&[]); + return empty_array(field.data_type(), array1.len(), false); } let array = as_list_array(&array1)?; general_array_distinct::(array, field) } (Null, LargeList(field)) => { if set_op == SetOp::Intersect { - return Ok(new_empty_array(&Null)); + return empty_array(field.data_type(), array1.len(), true); } let array = as_large_list_array(&array2)?; general_array_distinct::(array, field) } (LargeList(field), Null) => { if set_op == SetOp::Intersect { - return make_array_inner(&[]); + return empty_array(field.data_type(), array1.len(), true); } let array = as_large_list_array(&array1)?; general_array_distinct::(array, field) } - (Null, Null) => Ok(new_empty_array(&Null)), - (List(field), List(_)) => { let array1 = as_list_array(&array1)?; let array2 = as_list_array(&array2)?; diff --git a/datafusion/functions-nested/src/sort.rs b/datafusion/functions-nested/src/sort.rs index 85737ef135bc..7b2f41c0541c 100644 --- a/datafusion/functions-nested/src/sort.rs +++ b/datafusion/functions-nested/src/sort.rs @@ -21,11 +21,11 @@ use crate::utils::make_scalar_function; use arrow::array::{new_null_array, Array, ArrayRef, ListArray, NullBufferBuilder}; use arrow::buffer::OffsetBuffer; use arrow::compute::SortColumn; -use arrow::datatypes::DataType::{FixedSizeList, LargeList, List}; use arrow::datatypes::{DataType, Field}; use arrow::{compute, compute::SortOptions}; use datafusion_common::cast::{as_list_array, as_string_array}; -use datafusion_common::{exec_err, Result}; +use datafusion_common::utils::ListCoercion; +use datafusion_common::{exec_err, plan_err, Result}; use datafusion_expr::{ ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature, Volatility, @@ -93,14 +93,14 @@ impl ArraySort { vec![ TypeSignature::ArraySignature(ArrayFunctionSignature::Array { arguments: vec![ArrayFunctionArgument::Array], - array_coercion: None, + array_coercion: Some(ListCoercion::FixedSizedListToList), }), TypeSignature::ArraySignature(ArrayFunctionSignature::Array { arguments: vec![ ArrayFunctionArgument::Array, ArrayFunctionArgument::String, ], - array_coercion: None, + array_coercion: Some(ListCoercion::FixedSizedListToList), }), TypeSignature::ArraySignature(ArrayFunctionSignature::Array { arguments: vec![ @@ -108,7 +108,7 @@ impl ArraySort { ArrayFunctionArgument::String, ArrayFunctionArgument::String, ], - array_coercion: None, + array_coercion: Some(ListCoercion::FixedSizedListToList), }), ], Volatility::Immutable, @@ -133,17 +133,13 @@ impl ScalarUDFImpl for ArraySort { fn return_type(&self, arg_types: &[DataType]) -> Result { match &arg_types[0] { - List(field) | FixedSizeList(field, _) => Ok(List(Arc::new( - Field::new_list_field(field.data_type().clone(), true), - ))), - LargeList(field) => Ok(LargeList(Arc::new(Field::new_list_field( - field.data_type().clone(), - true, - )))), DataType::Null => Ok(DataType::Null), - _ => exec_err!( - "Not reachable, data_type should be List, LargeList or FixedSizeList" - ), + DataType::List(field) => { + Ok(DataType::new_list(field.data_type().clone(), true)) + } + arg_type => { + plan_err!("{} does not support type {arg_type}", self.name()) + } } } @@ -169,6 +165,16 @@ pub fn array_sort_inner(args: &[ArrayRef]) -> Result { return exec_err!("array_sort expects one to three arguments"); } + if args[0].data_type().is_null() { + return Ok(Arc::clone(&args[0])); + } + + let list_array = as_list_array(&args[0])?; + let row_count = list_array.len(); + if row_count == 0 || list_array.value_type().is_null() { + return Ok(Arc::clone(&args[0])); + } + if args[1..].iter().any(|array| array.is_null(0)) { return Ok(new_null_array(args[0].data_type(), args[0].len())); } @@ -193,12 +199,6 @@ pub fn array_sort_inner(args: &[ArrayRef]) -> Result { _ => return exec_err!("array_sort expects 1 to 3 arguments"), }; - let list_array = as_list_array(&args[0])?; - let row_count = list_array.len(); - if row_count == 0 { - return Ok(Arc::clone(&args[0])); - } - let mut array_lengths = vec![]; let mut arrays = vec![]; let mut valid = NullBufferBuilder::new(row_count); diff --git a/datafusion/functions-nested/src/utils.rs b/datafusion/functions-nested/src/utils.rs index 74b21a3ceb47..ed08a8235874 100644 --- a/datafusion/functions-nested/src/utils.rs +++ b/datafusion/functions-nested/src/utils.rs @@ -22,17 +22,15 @@ use std::sync::Arc; use arrow::datatypes::{DataType, Field, Fields}; use arrow::array::{ - Array, ArrayRef, BooleanArray, GenericListArray, ListArray, OffsetSizeTrait, Scalar, - UInt32Array, + Array, ArrayRef, BooleanArray, GenericListArray, OffsetSizeTrait, Scalar, UInt32Array, }; use arrow::buffer::OffsetBuffer; -use datafusion_common::cast::{as_large_list_array, as_list_array}; -use datafusion_common::{ - exec_err, internal_datafusion_err, internal_err, plan_err, Result, ScalarValue, +use datafusion_common::cast::{ + as_fixed_size_list_array, as_large_list_array, as_list_array, }; +use datafusion_common::{exec_err, internal_err, plan_err, Result, ScalarValue}; use datafusion_expr::ColumnarValue; -use datafusion_functions::{downcast_arg, downcast_named_arg}; pub(crate) fn check_datatypes(name: &str, args: &[&ArrayRef]) -> Result<()> { let data_type = args[0].data_type(); @@ -234,8 +232,16 @@ pub(crate) fn compute_array_dims( loop { match value.data_type() { - DataType::List(..) => { - value = downcast_arg!(value, ListArray).value(0); + DataType::List(_) => { + value = as_list_array(&value)?.value(0); + res.push(Some(value.len() as u64)); + } + DataType::LargeList(_) => { + value = as_large_list_array(&value)?.value(0); + res.push(Some(value.len() as u64)); + } + DataType::FixedSizeList(..) => { + value = as_fixed_size_list_array(&value)?.value(0); res.push(Some(value.len() as u64)); } _ => return Ok(Some(res)), @@ -261,6 +267,7 @@ pub(crate) fn get_map_entry_field(data_type: &DataType) -> Result<&Fields> { #[cfg(test)] mod tests { use super::*; + use arrow::array::ListArray; use arrow::datatypes::Int64Type; use datafusion_common::utils::SingleRowListArrayBuilder; diff --git a/datafusion/functions-table/src/generate_series.rs b/datafusion/functions-table/src/generate_series.rs index ee95567ab73d..ffb93cf59b16 100644 --- a/datafusion/functions-table/src/generate_series.rs +++ b/datafusion/functions-table/src/generate_series.rs @@ -199,8 +199,8 @@ impl TableFunctionImpl for GenerateSeriesFuncImpl { let mut normalize_args = Vec::new(); for expr in exprs { match expr { - Expr::Literal(ScalarValue::Null) => {} - Expr::Literal(ScalarValue::Int64(Some(n))) => normalize_args.push(*n), + Expr::Literal(ScalarValue::Null, _) => {} + Expr::Literal(ScalarValue::Int64(Some(n)), _) => normalize_args.push(*n), _ => return plan_err!("First argument must be an integer literal"), }; } diff --git a/datafusion/functions-window-common/src/expr.rs b/datafusion/functions-window-common/src/expr.rs index 76e27b045b0a..774cd5182b30 100644 --- a/datafusion/functions-window-common/src/expr.rs +++ b/datafusion/functions-window-common/src/expr.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use datafusion_common::arrow::datatypes::DataType; +use datafusion_common::arrow::datatypes::FieldRef; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use std::sync::Arc; @@ -25,9 +25,9 @@ pub struct ExpressionArgs<'a> { /// The expressions passed as arguments to the user-defined window /// function. input_exprs: &'a [Arc], - /// The corresponding data types of expressions passed as arguments + /// The corresponding fields of expressions passed as arguments /// to the user-defined window function. - input_types: &'a [DataType], + input_fields: &'a [FieldRef], } impl<'a> ExpressionArgs<'a> { @@ -42,11 +42,11 @@ impl<'a> ExpressionArgs<'a> { /// pub fn new( input_exprs: &'a [Arc], - input_types: &'a [DataType], + input_fields: &'a [FieldRef], ) -> Self { Self { input_exprs, - input_types, + input_fields, } } @@ -56,9 +56,9 @@ impl<'a> ExpressionArgs<'a> { self.input_exprs } - /// Returns the [`DataType`]s corresponding to the input expressions + /// Returns the [`FieldRef`]s corresponding to the input expressions /// to the user-defined window function. - pub fn input_types(&self) -> &'a [DataType] { - self.input_types + pub fn input_fields(&self) -> &'a [FieldRef] { + self.input_fields } } diff --git a/datafusion/functions-window-common/src/field.rs b/datafusion/functions-window-common/src/field.rs index 03f88b0b95cc..8d22efa3bcf4 100644 --- a/datafusion/functions-window-common/src/field.rs +++ b/datafusion/functions-window-common/src/field.rs @@ -15,14 +15,14 @@ // specific language governing permissions and limitations // under the License. -use datafusion_common::arrow::datatypes::DataType; +use datafusion_common::arrow::datatypes::FieldRef; /// Metadata for defining the result field from evaluating a /// user-defined window function. pub struct WindowUDFFieldArgs<'a> { - /// The data types corresponding to the arguments to the + /// The fields corresponding to the arguments to the /// user-defined window function. - input_types: &'a [DataType], + input_fields: &'a [FieldRef], /// The display name of the user-defined window function. display_name: &'a str, } @@ -32,22 +32,22 @@ impl<'a> WindowUDFFieldArgs<'a> { /// /// # Arguments /// - /// * `input_types` - The data types corresponding to the + /// * `input_fields` - The fields corresponding to the /// arguments to the user-defined window function. /// * `function_name` - The qualified schema name of the /// user-defined window function expression. /// - pub fn new(input_types: &'a [DataType], display_name: &'a str) -> Self { + pub fn new(input_fields: &'a [FieldRef], display_name: &'a str) -> Self { WindowUDFFieldArgs { - input_types, + input_fields, display_name, } } - /// Returns the data type of input expressions passed as arguments + /// Returns the field of input expressions passed as arguments /// to the user-defined window function. - pub fn input_types(&self) -> &[DataType] { - self.input_types + pub fn input_fields(&self) -> &[FieldRef] { + self.input_fields } /// Returns the name for the field of the final result of evaluating @@ -56,9 +56,9 @@ impl<'a> WindowUDFFieldArgs<'a> { self.display_name } - /// Returns `Some(DataType)` of input expression at index, otherwise + /// Returns `Some(Field)` of input expression at index, otherwise /// returns `None` if the index is out of bounds. - pub fn get_input_type(&self, index: usize) -> Option { - self.input_types.get(index).cloned() + pub fn get_input_field(&self, index: usize) -> Option { + self.input_fields.get(index).cloned() } } diff --git a/datafusion/functions-window-common/src/partition.rs b/datafusion/functions-window-common/src/partition.rs index e853aa8fb05d..61125e596130 100644 --- a/datafusion/functions-window-common/src/partition.rs +++ b/datafusion/functions-window-common/src/partition.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use datafusion_common::arrow::datatypes::DataType; +use datafusion_common::arrow::datatypes::FieldRef; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use std::sync::Arc; @@ -26,9 +26,9 @@ pub struct PartitionEvaluatorArgs<'a> { /// The expressions passed as arguments to the user-defined window /// function. input_exprs: &'a [Arc], - /// The corresponding data types of expressions passed as arguments + /// The corresponding fields of expressions passed as arguments /// to the user-defined window function. - input_types: &'a [DataType], + input_fields: &'a [FieldRef], /// Set to `true` if the user-defined window function is reversed. is_reversed: bool, /// Set to `true` if `IGNORE NULLS` is specified. @@ -51,13 +51,13 @@ impl<'a> PartitionEvaluatorArgs<'a> { /// pub fn new( input_exprs: &'a [Arc], - input_types: &'a [DataType], + input_fields: &'a [FieldRef], is_reversed: bool, ignore_nulls: bool, ) -> Self { Self { input_exprs, - input_types, + input_fields, is_reversed, ignore_nulls, } @@ -69,10 +69,10 @@ impl<'a> PartitionEvaluatorArgs<'a> { self.input_exprs } - /// Returns the [`DataType`]s corresponding to the input expressions + /// Returns the [`FieldRef`]s corresponding to the input expressions /// to the user-defined window function. - pub fn input_types(&self) -> &'a [DataType] { - self.input_types + pub fn input_fields(&self) -> &'a [FieldRef] { + self.input_fields } /// Returns `true` when the user-defined window function is diff --git a/datafusion/functions-window/Cargo.toml b/datafusion/functions-window/Cargo.toml index e0c17c579b19..23ee608a8267 100644 --- a/datafusion/functions-window/Cargo.toml +++ b/datafusion/functions-window/Cargo.toml @@ -38,6 +38,7 @@ workspace = true name = "datafusion_functions_window" [dependencies] +arrow = { workspace = true } datafusion-common = { workspace = true } datafusion-doc = { workspace = true } datafusion-expr = { workspace = true } @@ -47,6 +48,3 @@ datafusion-physical-expr = { workspace = true } datafusion-physical-expr-common = { workspace = true } log = { workspace = true } paste = "1.0.15" - -[dev-dependencies] -arrow = { workspace = true } diff --git a/datafusion/functions-window/src/cume_dist.rs b/datafusion/functions-window/src/cume_dist.rs index d156416a82a4..ed8669948188 100644 --- a/datafusion/functions-window/src/cume_dist.rs +++ b/datafusion/functions-window/src/cume_dist.rs @@ -17,6 +17,7 @@ //! `cume_dist` window function implementation +use arrow::datatypes::FieldRef; use datafusion_common::arrow::array::{ArrayRef, Float64Array}; use datafusion_common::arrow::datatypes::DataType; use datafusion_common::arrow::datatypes::Field; @@ -101,8 +102,8 @@ impl WindowUDFImpl for CumeDist { Ok(Box::::default()) } - fn field(&self, field_args: WindowUDFFieldArgs) -> Result { - Ok(Field::new(field_args.name(), DataType::Float64, false)) + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + Ok(Field::new(field_args.name(), DataType::Float64, false).into()) } fn documentation(&self) -> Option<&Documentation> { diff --git a/datafusion/functions-window/src/lead_lag.rs b/datafusion/functions-window/src/lead_lag.rs index 5df20cf5b980..e2a755371ebc 100644 --- a/datafusion/functions-window/src/lead_lag.rs +++ b/datafusion/functions-window/src/lead_lag.rs @@ -18,6 +18,7 @@ //! `lead` and `lag` window function implementations use crate::utils::{get_scalar_value_from_args, get_signed_integer}; +use arrow::datatypes::FieldRef; use datafusion_common::arrow::array::ArrayRef; use datafusion_common::arrow::datatypes::DataType; use datafusion_common::arrow::datatypes::Field; @@ -157,6 +158,24 @@ static LAG_DOCUMENTATION: LazyLock = LazyLock::new(|| { the value of expression should be retrieved. Defaults to 1.") .with_argument("default", "The default value if the offset is \ not within the partition. Must be of the same type as expression.") + .with_sql_example(r#"```sql + --Example usage of the lag window function: + SELECT employee_id, + salary, + lag(salary, 1, 0) OVER (ORDER BY employee_id) AS prev_salary + FROM employees; +``` + +```sql ++-------------+--------+-------------+ +| employee_id | salary | prev_salary | ++-------------+--------+-------------+ +| 1 | 30000 | 0 | +| 2 | 50000 | 30000 | +| 3 | 70000 | 50000 | +| 4 | 60000 | 70000 | ++-------------+--------+-------------+ +```"#) .build() }); @@ -175,6 +194,27 @@ static LEAD_DOCUMENTATION: LazyLock = LazyLock::new(|| { forward the value of expression should be retrieved. Defaults to 1.") .with_argument("default", "The default value if the offset is \ not within the partition. Must be of the same type as expression.") + .with_sql_example(r#"```sql +-- Example usage of lead() : +SELECT + employee_id, + department, + salary, + lead(salary, 1, 0) OVER (PARTITION BY department ORDER BY salary) AS next_salary +FROM employees; +``` + +```sql ++-------------+-------------+--------+--------------+ +| employee_id | department | salary | next_salary | ++-------------+-------------+--------+--------------+ +| 1 | Sales | 30000 | 50000 | +| 2 | Sales | 50000 | 70000 | +| 3 | Sales | 70000 | 0 | +| 4 | Engineering | 40000 | 60000 | +| 5 | Engineering | 60000 | 0 | ++-------------+-------------+--------+--------------+ +```"#) .build() }); @@ -201,7 +241,7 @@ impl WindowUDFImpl for WindowShift { /// /// For more details see: fn expressions(&self, expr_args: ExpressionArgs) -> Vec> { - parse_expr(expr_args.input_exprs(), expr_args.input_types()) + parse_expr(expr_args.input_exprs(), expr_args.input_fields()) .into_iter() .collect::>() } @@ -224,7 +264,7 @@ impl WindowUDFImpl for WindowShift { })?; let default_value = parse_default_value( partition_evaluator_args.input_exprs(), - partition_evaluator_args.input_types(), + partition_evaluator_args.input_fields(), )?; Ok(Box::new(WindowShiftEvaluator { @@ -235,10 +275,14 @@ impl WindowUDFImpl for WindowShift { })) } - fn field(&self, field_args: WindowUDFFieldArgs) -> Result { - let return_type = parse_expr_type(field_args.input_types())?; + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + let return_field = parse_expr_field(field_args.input_fields())?; - Ok(Field::new(field_args.name(), return_type, true)) + Ok(return_field + .as_ref() + .clone() + .with_name(field_args.name()) + .into()) } fn reverse_expr(&self) -> ReversedUDWF { @@ -270,16 +314,16 @@ impl WindowUDFImpl for WindowShift { /// For more details see: fn parse_expr( input_exprs: &[Arc], - input_types: &[DataType], + input_fields: &[FieldRef], ) -> Result> { assert!(!input_exprs.is_empty()); - assert!(!input_types.is_empty()); + assert!(!input_fields.is_empty()); let expr = Arc::clone(input_exprs.first().unwrap()); - let expr_type = input_types.first().unwrap(); + let expr_field = input_fields.first().unwrap(); // Handles the most common case where NULL is unexpected - if !expr_type.is_null() { + if !expr_field.data_type().is_null() { return Ok(expr); } @@ -292,36 +336,43 @@ fn parse_expr( }) } -/// Returns the data type of the default value(if provided) when the +static NULL_FIELD: LazyLock = + LazyLock::new(|| Field::new("value", DataType::Null, true).into()); + +/// Returns the field of the default value(if provided) when the /// expression is `NULL`. /// -/// Otherwise, returns the expression type unchanged. -fn parse_expr_type(input_types: &[DataType]) -> Result { - assert!(!input_types.is_empty()); - let expr_type = input_types.first().unwrap_or(&DataType::Null); +/// Otherwise, returns the expression field unchanged. +fn parse_expr_field(input_fields: &[FieldRef]) -> Result { + assert!(!input_fields.is_empty()); + let expr_field = input_fields.first().unwrap_or(&NULL_FIELD); // Handles the most common case where NULL is unexpected - if !expr_type.is_null() { - return Ok(expr_type.clone()); + if !expr_field.data_type().is_null() { + return Ok(expr_field.as_ref().clone().with_nullable(true).into()); } - let default_value_type = input_types.get(2).unwrap_or(&DataType::Null); - Ok(default_value_type.clone()) + let default_value_field = input_fields.get(2).unwrap_or(&NULL_FIELD); + Ok(default_value_field + .as_ref() + .clone() + .with_nullable(true) + .into()) } /// Handles type coercion and null value refinement for default value /// argument depending on the data type of the input expression. fn parse_default_value( input_exprs: &[Arc], - input_types: &[DataType], + input_types: &[FieldRef], ) -> Result { - let expr_type = parse_expr_type(input_types)?; + let expr_field = parse_expr_field(input_types)?; let unparsed = get_scalar_value_from_args(input_exprs, 2)?; unparsed .filter(|v| !v.data_type().is_null()) - .map(|v| v.cast_to(&expr_type)) - .unwrap_or(ScalarValue::try_from(expr_type)) + .map(|v| v.cast_to(expr_field.data_type())) + .unwrap_or_else(|| ScalarValue::try_from(expr_field.data_type())) } #[derive(Debug)] @@ -666,7 +717,12 @@ mod tests { test_i32_result( WindowShift::lead(), - PartitionEvaluatorArgs::new(&[expr], &[DataType::Int32], false, false), + PartitionEvaluatorArgs::new( + &[expr], + &[Field::new("f", DataType::Int32, true).into()], + false, + false, + ), [ Some(-2), Some(3), @@ -688,7 +744,12 @@ mod tests { test_i32_result( WindowShift::lag(), - PartitionEvaluatorArgs::new(&[expr], &[DataType::Int32], false, false), + PartitionEvaluatorArgs::new( + &[expr], + &[Field::new("f", DataType::Int32, true).into()], + false, + false, + ), [ None, Some(1), @@ -713,12 +774,15 @@ mod tests { as Arc; let input_exprs = &[expr, shift_offset, default_value]; - let input_types: &[DataType] = - &[DataType::Int32, DataType::Int32, DataType::Int32]; + let input_fields = [DataType::Int32, DataType::Int32, DataType::Int32] + .into_iter() + .map(|d| Field::new("f", d, true)) + .map(Arc::new) + .collect::>(); test_i32_result( WindowShift::lag(), - PartitionEvaluatorArgs::new(input_exprs, input_types, false, false), + PartitionEvaluatorArgs::new(input_exprs, &input_fields, false, false), [ Some(100), Some(1), diff --git a/datafusion/functions-window/src/macros.rs b/datafusion/functions-window/src/macros.rs index 2ef1eacba953..23414a7a7172 100644 --- a/datafusion/functions-window/src/macros.rs +++ b/datafusion/functions-window/src/macros.rs @@ -40,6 +40,7 @@ /// /// ``` /// # use std::any::Any; +/// use arrow::datatypes::FieldRef; /// # use datafusion_common::arrow::datatypes::{DataType, Field}; /// # use datafusion_expr::{PartitionEvaluator, Signature, Volatility, WindowUDFImpl}; /// # @@ -85,8 +86,8 @@ /// # ) -> datafusion_common::Result> { /// # unimplemented!() /// # } -/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { -/// # Ok(Field::new(field_args.name(), DataType::Int64, false)) +/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { +/// # Ok(Field::new(field_args.name(), DataType::Int64, false).into()) /// # } /// # } /// # @@ -138,6 +139,7 @@ macro_rules! get_or_init_udwf { /// 1. With Zero Parameters /// ``` /// # use std::any::Any; +/// use arrow::datatypes::FieldRef; /// # use datafusion_common::arrow::datatypes::{DataType, Field}; /// # use datafusion_expr::{PartitionEvaluator, Signature, Volatility, WindowUDFImpl}; /// # use datafusion_functions_window::{create_udwf_expr, get_or_init_udwf}; @@ -196,8 +198,8 @@ macro_rules! get_or_init_udwf { /// # ) -> datafusion_common::Result> { /// # unimplemented!() /// # } -/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { -/// # Ok(Field::new(field_args.name(), DataType::UInt64, false)) +/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { +/// # Ok(Field::new(field_args.name(), DataType::UInt64, false).into()) /// # } /// # } /// ``` @@ -205,6 +207,7 @@ macro_rules! get_or_init_udwf { /// 2. With Multiple Parameters /// ``` /// # use std::any::Any; +/// use arrow::datatypes::FieldRef; /// # /// # use datafusion_expr::{ /// # PartitionEvaluator, Signature, TypeSignature, Volatility, WindowUDFImpl, @@ -283,12 +286,12 @@ macro_rules! get_or_init_udwf { /// # ) -> datafusion_common::Result> { /// # unimplemented!() /// # } -/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { +/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { /// # Ok(Field::new( /// # field_args.name(), -/// # field_args.get_input_type(0).unwrap(), +/// # field_args.get_input_field(0).unwrap().data_type().clone(), /// # false, -/// # )) +/// # ).into()) /// # } /// # } /// ``` @@ -352,6 +355,7 @@ macro_rules! create_udwf_expr { /// /// ``` /// # use std::any::Any; +/// use arrow::datatypes::FieldRef; /// # use datafusion_common::arrow::datatypes::{DataType, Field}; /// # use datafusion_expr::{PartitionEvaluator, Signature, Volatility, WindowUDFImpl}; /// # @@ -404,8 +408,8 @@ macro_rules! create_udwf_expr { /// # ) -> datafusion_common::Result> { /// # unimplemented!() /// # } -/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { -/// # Ok(Field::new(field_args.name(), DataType::Int64, false)) +/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { +/// # Ok(Field::new(field_args.name(), DataType::Int64, false).into()) /// # } /// # } /// # @@ -415,6 +419,7 @@ macro_rules! create_udwf_expr { /// /// ``` /// # use std::any::Any; +/// use arrow::datatypes::FieldRef; /// # use datafusion_common::arrow::datatypes::{DataType, Field}; /// # use datafusion_expr::{PartitionEvaluator, Signature, Volatility, WindowUDFImpl}; /// # use datafusion_functions_window::{create_udwf_expr, define_udwf_and_expr, get_or_init_udwf}; @@ -468,8 +473,8 @@ macro_rules! create_udwf_expr { /// # ) -> datafusion_common::Result> { /// # unimplemented!() /// # } -/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { -/// # Ok(Field::new(field_args.name(), DataType::UInt64, false)) +/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { +/// # Ok(Field::new(field_args.name(), DataType::UInt64, false).into()) /// # } /// # } /// ``` @@ -479,6 +484,7 @@ macro_rules! create_udwf_expr { /// /// ``` /// # use std::any::Any; +/// use arrow::datatypes::FieldRef; /// # /// # use datafusion_expr::{ /// # PartitionEvaluator, Signature, TypeSignature, Volatility, WindowUDFImpl, @@ -554,12 +560,12 @@ macro_rules! create_udwf_expr { /// # ) -> datafusion_common::Result> { /// # unimplemented!() /// # } -/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { +/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { /// # Ok(Field::new( /// # field_args.name(), -/// # field_args.get_input_type(0).unwrap(), +/// # field_args.get_input_field(0).unwrap().data_type().clone(), /// # false, -/// # )) +/// # ).into()) /// # } /// # } /// ``` @@ -567,6 +573,7 @@ macro_rules! create_udwf_expr { /// /// ``` /// # use std::any::Any; +/// use arrow::datatypes::FieldRef; /// # /// # use datafusion_expr::{ /// # PartitionEvaluator, Signature, TypeSignature, Volatility, WindowUDFImpl, @@ -643,12 +650,12 @@ macro_rules! create_udwf_expr { /// # ) -> datafusion_common::Result> { /// # unimplemented!() /// # } -/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { +/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { /// # Ok(Field::new( /// # field_args.name(), -/// # field_args.get_input_type(0).unwrap(), +/// # field_args.get_input_field(0).unwrap().data_type().clone(), /// # false, -/// # )) +/// # ).into()) /// # } /// # } /// ``` diff --git a/datafusion/functions-window/src/nth_value.rs b/datafusion/functions-window/src/nth_value.rs index 36e6b83d61ce..0b83e1ff9f08 100644 --- a/datafusion/functions-window/src/nth_value.rs +++ b/datafusion/functions-window/src/nth_value.rs @@ -19,12 +19,7 @@ use crate::utils::{get_scalar_value_from_args, get_signed_integer}; -use std::any::Any; -use std::cmp::Ordering; -use std::fmt::Debug; -use std::ops::Range; -use std::sync::LazyLock; - +use arrow::datatypes::FieldRef; use datafusion_common::arrow::array::ArrayRef; use datafusion_common::arrow::datatypes::{DataType, Field}; use datafusion_common::{exec_datafusion_err, exec_err, Result, ScalarValue}; @@ -37,6 +32,11 @@ use datafusion_expr::{ use datafusion_functions_window_common::field; use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; use field::WindowUDFFieldArgs; +use std::any::Any; +use std::cmp::Ordering; +use std::fmt::Debug; +use std::ops::Range; +use std::sync::LazyLock; get_or_init_udwf!( First, @@ -135,6 +135,26 @@ static FIRST_VALUE_DOCUMENTATION: LazyLock = LazyLock::new(|| { "first_value(expression)", ) .with_argument("expression", "Expression to operate on") + .with_sql_example(r#"```sql + --Example usage of the first_value window function: + SELECT department, + employee_id, + salary, + first_value(salary) OVER (PARTITION BY department ORDER BY salary DESC) AS top_salary + FROM employees; +``` + +```sql ++-------------+-------------+--------+------------+ +| department | employee_id | salary | top_salary | ++-------------+-------------+--------+------------+ +| Sales | 1 | 70000 | 70000 | +| Sales | 2 | 50000 | 70000 | +| Sales | 3 | 30000 | 70000 | +| Engineering | 4 | 90000 | 90000 | +| Engineering | 5 | 80000 | 90000 | ++-------------+-------------+--------+------------+ +```"#) .build() }); @@ -150,6 +170,26 @@ static LAST_VALUE_DOCUMENTATION: LazyLock = LazyLock::new(|| { "last_value(expression)", ) .with_argument("expression", "Expression to operate on") + .with_sql_example(r#"```sql +-- SQL example of last_value: +SELECT department, + employee_id, + salary, + last_value(salary) OVER (PARTITION BY department ORDER BY salary) AS running_last_salary +FROM employees; +``` + +```sql ++-------------+-------------+--------+---------------------+ +| department | employee_id | salary | running_last_salary | ++-------------+-------------+--------+---------------------+ +| Sales | 1 | 30000 | 30000 | +| Sales | 2 | 50000 | 50000 | +| Sales | 3 | 70000 | 70000 | +| Engineering | 4 | 40000 | 40000 | +| Engineering | 5 | 60000 | 60000 | ++-------------+-------------+--------+---------------------+ +```"#) .build() }); @@ -269,11 +309,15 @@ impl WindowUDFImpl for NthValue { })) } - fn field(&self, field_args: WindowUDFFieldArgs) -> Result { - let nullable = true; - let return_type = field_args.input_types().first().unwrap_or(&DataType::Null); + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + let return_type = field_args + .input_fields() + .first() + .map(|f| f.data_type()) + .cloned() + .unwrap_or(DataType::Null); - Ok(Field::new(field_args.name(), return_type.clone(), nullable)) + Ok(Field::new(field_args.name(), return_type, true).into()) } fn reverse_expr(&self) -> ReversedUDWF { @@ -511,7 +555,12 @@ mod tests { let expr = Arc::new(Column::new("c3", 0)) as Arc; test_i32_result( NthValue::first(), - PartitionEvaluatorArgs::new(&[expr], &[DataType::Int32], false, false), + PartitionEvaluatorArgs::new( + &[expr], + &[Field::new("f", DataType::Int32, true).into()], + false, + false, + ), Int32Array::from(vec![1; 8]).iter().collect::(), ) } @@ -521,7 +570,12 @@ mod tests { let expr = Arc::new(Column::new("c3", 0)) as Arc; test_i32_result( NthValue::last(), - PartitionEvaluatorArgs::new(&[expr], &[DataType::Int32], false, false), + PartitionEvaluatorArgs::new( + &[expr], + &[Field::new("f", DataType::Int32, true).into()], + false, + false, + ), Int32Array::from(vec![ Some(1), Some(-2), @@ -545,7 +599,7 @@ mod tests { NthValue::nth(), PartitionEvaluatorArgs::new( &[expr, n_value], - &[DataType::Int32], + &[Field::new("f", DataType::Int32, true).into()], false, false, ), @@ -564,7 +618,7 @@ mod tests { NthValue::nth(), PartitionEvaluatorArgs::new( &[expr, n_value], - &[DataType::Int32], + &[Field::new("f", DataType::Int32, true).into()], false, false, ), diff --git a/datafusion/functions-window/src/ntile.rs b/datafusion/functions-window/src/ntile.rs index 180f7ab02c03..6b4c0960e695 100644 --- a/datafusion/functions-window/src/ntile.rs +++ b/datafusion/functions-window/src/ntile.rs @@ -17,13 +17,10 @@ //! `ntile` window function implementation -use std::any::Any; -use std::fmt::Debug; -use std::sync::Arc; - use crate::utils::{ get_scalar_value_from_args, get_signed_integer, get_unsigned_integer, }; +use arrow::datatypes::FieldRef; use datafusion_common::arrow::array::{ArrayRef, UInt64Array}; use datafusion_common::arrow::datatypes::{DataType, Field}; use datafusion_common::{exec_err, DataFusionError, Result}; @@ -34,6 +31,9 @@ use datafusion_functions_window_common::field; use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; use datafusion_macros::user_doc; use field::WindowUDFFieldArgs; +use std::any::Any; +use std::fmt::Debug; +use std::sync::Arc; get_or_init_udwf!( Ntile, @@ -52,7 +52,29 @@ pub fn ntile(arg: Expr) -> Expr { argument( name = "expression", description = "An integer describing the number groups the partition should be split into" - ) + ), + sql_example = r#"```sql + --Example usage of the ntile window function: + SELECT employee_id, + salary, + ntile(4) OVER (ORDER BY salary DESC) AS quartile + FROM employees; +``` + +```sql ++-------------+--------+----------+ +| employee_id | salary | quartile | ++-------------+--------+----------+ +| 1 | 90000 | 1 | +| 2 | 85000 | 1 | +| 3 | 80000 | 2 | +| 4 | 70000 | 2 | +| 5 | 60000 | 3 | +| 6 | 50000 | 3 | +| 7 | 40000 | 4 | +| 8 | 30000 | 4 | ++-------------+--------+----------+ +```"# )] #[derive(Debug)] pub struct Ntile { @@ -127,10 +149,10 @@ impl WindowUDFImpl for Ntile { Ok(Box::new(NtileEvaluator { n: n as u64 })) } } - fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { let nullable = false; - Ok(Field::new(field_args.name(), DataType::UInt64, nullable)) + Ok(Field::new(field_args.name(), DataType::UInt64, nullable).into()) } fn documentation(&self) -> Option<&Documentation> { diff --git a/datafusion/functions-window/src/planner.rs b/datafusion/functions-window/src/planner.rs index 1ddd8b27c420..091737bb9c15 100644 --- a/datafusion/functions-window/src/planner.rs +++ b/datafusion/functions-window/src/planner.rs @@ -43,7 +43,7 @@ impl ExprPlanner for WindowFunctionPlanner { null_treatment, } = raw_expr; - let origin_expr = Expr::WindowFunction(WindowFunction { + let origin_expr = Expr::from(WindowFunction { fun: func_def, params: WindowFunctionParams { args, @@ -56,7 +56,10 @@ impl ExprPlanner for WindowFunctionPlanner { let saved_name = NamePreserver::new_for_projection().save(&origin_expr); - let Expr::WindowFunction(WindowFunction { + let Expr::WindowFunction(window_fun) = origin_expr else { + unreachable!("") + }; + let WindowFunction { fun, params: WindowFunctionParams { @@ -66,10 +69,7 @@ impl ExprPlanner for WindowFunctionPlanner { window_frame, null_treatment, }, - }) = origin_expr - else { - unreachable!("") - }; + } = *window_fun; let raw_expr = RawWindowExpr { func_def: fun, args, @@ -95,9 +95,9 @@ impl ExprPlanner for WindowFunctionPlanner { null_treatment, } = raw_expr; - let new_expr = Expr::WindowFunction(WindowFunction::new( + let new_expr = Expr::from(WindowFunction::new( func_def, - vec![Expr::Literal(COUNT_STAR_EXPANSION)], + vec![Expr::Literal(COUNT_STAR_EXPANSION, None)], )) .partition_by(partition_by) .order_by(order_by) diff --git a/datafusion/functions-window/src/rank.rs b/datafusion/functions-window/src/rank.rs index 2ff2c31d8c2a..969a957cddd9 100644 --- a/datafusion/functions-window/src/rank.rs +++ b/datafusion/functions-window/src/rank.rs @@ -18,13 +18,8 @@ //! Implementation of `rank`, `dense_rank`, and `percent_rank` window functions, //! which can be evaluated at runtime during query execution. -use std::any::Any; -use std::fmt::Debug; -use std::iter; -use std::ops::Range; -use std::sync::{Arc, LazyLock}; - use crate::define_udwf_and_expr; +use arrow::datatypes::FieldRef; use datafusion_common::arrow::array::ArrayRef; use datafusion_common::arrow::array::{Float64Array, UInt64Array}; use datafusion_common::arrow::compute::SortOptions; @@ -39,6 +34,11 @@ use datafusion_expr::{ use datafusion_functions_window_common::field; use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; use field::WindowUDFFieldArgs; +use std::any::Any; +use std::fmt::Debug; +use std::iter; +use std::ops::Range; +use std::sync::{Arc, LazyLock}; define_udwf_and_expr!( Rank, @@ -110,6 +110,26 @@ static RANK_DOCUMENTATION: LazyLock = LazyLock::new(|| { skips ranks for identical values.", "rank()") + .with_sql_example(r#"```sql + --Example usage of the rank window function: + SELECT department, + salary, + rank() OVER (PARTITION BY department ORDER BY salary DESC) AS rank + FROM employees; +``` + +```sql ++-------------+--------+------+ +| department | salary | rank | ++-------------+--------+------+ +| Sales | 70000 | 1 | +| Sales | 50000 | 2 | +| Sales | 50000 | 2 | +| Sales | 30000 | 4 | +| Engineering | 90000 | 1 | +| Engineering | 80000 | 2 | ++-------------+--------+------+ +```"#) .build() }); @@ -121,6 +141,26 @@ static DENSE_RANK_DOCUMENTATION: LazyLock = LazyLock::new(|| { Documentation::builder(DOC_SECTION_RANKING, "Returns the rank of the current row without gaps. This function ranks \ rows in a dense manner, meaning consecutive ranks are assigned even for identical \ values.", "dense_rank()") + .with_sql_example(r#"```sql + --Example usage of the dense_rank window function: + SELECT department, + salary, + dense_rank() OVER (PARTITION BY department ORDER BY salary DESC) AS dense_rank + FROM employees; +``` + +```sql ++-------------+--------+------------+ +| department | salary | dense_rank | ++-------------+--------+------------+ +| Sales | 70000 | 1 | +| Sales | 50000 | 2 | +| Sales | 50000 | 2 | +| Sales | 30000 | 3 | +| Engineering | 90000 | 1 | +| Engineering | 80000 | 2 | ++-------------+--------+------------+ +```"#) .build() }); @@ -131,6 +171,23 @@ fn get_dense_rank_doc() -> &'static Documentation { static PERCENT_RANK_DOCUMENTATION: LazyLock = LazyLock::new(|| { Documentation::builder(DOC_SECTION_RANKING, "Returns the percentage rank of the current row within its partition. \ The value ranges from 0 to 1 and is computed as `(rank - 1) / (total_rows - 1)`.", "percent_rank()") + .with_sql_example(r#"```sql + --Example usage of the percent_rank window function: + SELECT employee_id, + salary, + percent_rank() OVER (ORDER BY salary) AS percent_rank + FROM employees; +``` + +```sql ++-------------+--------+---------------+ +| employee_id | salary | percent_rank | ++-------------+--------+---------------+ +| 1 | 30000 | 0.00 | +| 2 | 50000 | 0.50 | +| 3 | 70000 | 1.00 | ++-------------+--------+---------------+ +```"#) .build() }); @@ -161,14 +218,14 @@ impl WindowUDFImpl for Rank { })) } - fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { let return_type = match self.rank_type { RankType::Basic | RankType::Dense => DataType::UInt64, RankType::Percent => DataType::Float64, }; let nullable = false; - Ok(Field::new(field_args.name(), return_type, nullable)) + Ok(Field::new(field_args.name(), return_type, nullable).into()) } fn sort_options(&self) -> Option { diff --git a/datafusion/functions-window/src/row_number.rs b/datafusion/functions-window/src/row_number.rs index 8f462528dbed..ba8627dd86d7 100644 --- a/datafusion/functions-window/src/row_number.rs +++ b/datafusion/functions-window/src/row_number.rs @@ -17,6 +17,7 @@ //! `row_number` window function implementation +use arrow::datatypes::FieldRef; use datafusion_common::arrow::array::ArrayRef; use datafusion_common::arrow::array::UInt64Array; use datafusion_common::arrow::compute::SortOptions; @@ -44,7 +45,27 @@ define_udwf_and_expr!( #[user_doc( doc_section(label = "Ranking Functions"), description = "Number of the current row within its partition, counting from 1.", - syntax_example = "row_number()" + syntax_example = "row_number()", + sql_example = r"```sql + --Example usage of the row_number window function: + SELECT department, + salary, + row_number() OVER (PARTITION BY department ORDER BY salary DESC) AS row_num + FROM employees; +``` + +```sql ++-------------+--------+---------+ +| department | salary | row_num | ++-------------+--------+---------+ +| Sales | 70000 | 1 | +| Sales | 50000 | 2 | +| Sales | 50000 | 3 | +| Sales | 30000 | 4 | +| Engineering | 90000 | 1 | +| Engineering | 80000 | 2 | ++-------------+--------+---------+ +```#" )] #[derive(Debug)] pub struct RowNumber { @@ -86,8 +107,8 @@ impl WindowUDFImpl for RowNumber { Ok(Box::::default()) } - fn field(&self, field_args: WindowUDFFieldArgs) -> Result { - Ok(Field::new(field_args.name(), DataType::UInt64, false)) + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + Ok(Field::new(field_args.name(), DataType::UInt64, false).into()) } fn sort_options(&self) -> Option { diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index 729770b8a65c..0c4280babc70 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -80,9 +80,9 @@ log = { workspace = true } md-5 = { version = "^0.10.0", optional = true } rand = { workspace = true } regex = { workspace = true, optional = true } -sha2 = { version = "^0.10.1", optional = true } +sha2 = { version = "^0.10.9", optional = true } unicode-segmentation = { version = "^1.7.1", optional = true } -uuid = { version = "1.16", features = ["v4"], optional = true } +uuid = { version = "1.17", features = ["v4"], optional = true } [dev-dependencies] arrow = { workspace = true, features = ["test_utils"] } @@ -90,6 +90,11 @@ criterion = { workspace = true } rand = { workspace = true } tokio = { workspace = true, features = ["macros", "rt", "sync"] } +[[bench]] +harness = false +name = "ascii" +required-features = ["string_expressions"] + [[bench]] harness = false name = "concat" diff --git a/datafusion/functions/benches/ascii.rs b/datafusion/functions/benches/ascii.rs new file mode 100644 index 000000000000..1c7023f4497e --- /dev/null +++ b/datafusion/functions/benches/ascii.rs @@ -0,0 +1,133 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate criterion; +mod helper; + +use arrow::datatypes::{DataType, Field}; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_expr::ScalarFunctionArgs; +use helper::gen_string_array; +use std::sync::Arc; + +fn criterion_benchmark(c: &mut Criterion) { + let ascii = datafusion_functions::string::ascii(); + + // All benches are single batch run with 8192 rows + const N_ROWS: usize = 8192; + const STR_LEN: usize = 16; + const UTF8_DENSITY_OF_ALL_ASCII: f32 = 0.0; + const NORMAL_UTF8_DENSITY: f32 = 0.8; + + for null_density in [0.0, 0.5] { + // StringArray ASCII only + let args_string_ascii = gen_string_array( + N_ROWS, + STR_LEN, + null_density, + UTF8_DENSITY_OF_ALL_ASCII, + false, + ); + + let arg_fields = + vec![Field::new("a", args_string_ascii[0].data_type(), true).into()]; + let return_field = Field::new("f", DataType::Utf8, true).into(); + + c.bench_function( + format!("ascii/string_ascii_only (null_density={null_density})").as_str(), + |b| { + b.iter(|| { + black_box(ascii.invoke_with_args(ScalarFunctionArgs { + args: args_string_ascii.clone(), + arg_fields: arg_fields.clone(), + number_rows: N_ROWS, + return_field: Arc::clone(&return_field), + })) + }) + }, + ); + + // StringArray UTF8 + let args_string_utf8 = + gen_string_array(N_ROWS, STR_LEN, null_density, NORMAL_UTF8_DENSITY, false); + let arg_fields = + vec![Field::new("a", args_string_utf8[0].data_type(), true).into()]; + let return_field = Field::new("f", DataType::Utf8, true).into(); + c.bench_function( + format!("ascii/string_utf8 (null_density={null_density})").as_str(), + |b| { + b.iter(|| { + black_box(ascii.invoke_with_args(ScalarFunctionArgs { + args: args_string_utf8.clone(), + arg_fields: arg_fields.clone(), + number_rows: N_ROWS, + return_field: Arc::clone(&return_field), + })) + }) + }, + ); + + // StringViewArray ASCII only + let args_string_view_ascii = gen_string_array( + N_ROWS, + STR_LEN, + null_density, + UTF8_DENSITY_OF_ALL_ASCII, + true, + ); + let arg_fields = + vec![Field::new("a", args_string_view_ascii[0].data_type(), true).into()]; + let return_field = Field::new("f", DataType::Utf8, true).into(); + c.bench_function( + format!("ascii/string_view_ascii_only (null_density={null_density})") + .as_str(), + |b| { + b.iter(|| { + black_box(ascii.invoke_with_args(ScalarFunctionArgs { + args: args_string_view_ascii.clone(), + arg_fields: arg_fields.clone(), + number_rows: N_ROWS, + return_field: Arc::clone(&return_field), + })) + }) + }, + ); + + // StringViewArray UTF8 + let args_string_view_utf8 = + gen_string_array(N_ROWS, STR_LEN, null_density, NORMAL_UTF8_DENSITY, true); + let arg_fields = + vec![Field::new("a", args_string_view_utf8[0].data_type(), true).into()]; + let return_field = Field::new("f", DataType::Utf8, true).into(); + c.bench_function( + format!("ascii/string_view_utf8 (null_density={null_density})").as_str(), + |b| { + b.iter(|| { + black_box(ascii.invoke_with_args(ScalarFunctionArgs { + args: args_string_view_utf8.clone(), + arg_fields: arg_fields.clone(), + number_rows: N_ROWS, + return_field: Arc::clone(&return_field), + })) + }) + }, + ); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/benches/character_length.rs b/datafusion/functions/benches/character_length.rs index bbcfed021064..b4a9e917f416 100644 --- a/datafusion/functions/benches/character_length.rs +++ b/datafusion/functions/benches/character_length.rs @@ -17,10 +17,11 @@ extern crate criterion; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_expr::ScalarFunctionArgs; use helper::gen_string_array; +use std::sync::Arc; mod helper; @@ -28,20 +29,28 @@ fn criterion_benchmark(c: &mut Criterion) { // All benches are single batch run with 8192 rows let character_length = datafusion_functions::unicode::character_length(); - let return_type = DataType::Utf8; + let return_field = Arc::new(Field::new("f", DataType::Utf8, true)); let n_rows = 8192; for str_len in [8, 32, 128, 4096] { // StringArray ASCII only let args_string_ascii = gen_string_array(n_rows, str_len, 0.1, 0.0, false); + let arg_fields = args_string_ascii + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); c.bench_function( - &format!("character_length_StringArray_ascii_str_len_{}", str_len), + &format!("character_length_StringArray_ascii_str_len_{str_len}"), |b| { b.iter(|| { black_box(character_length.invoke_with_args(ScalarFunctionArgs { args: args_string_ascii.clone(), + arg_fields: arg_fields.clone(), number_rows: n_rows, - return_type: &return_type, + return_field: Arc::clone(&return_field), })) }) }, @@ -49,14 +58,22 @@ fn criterion_benchmark(c: &mut Criterion) { // StringArray UTF8 let args_string_utf8 = gen_string_array(n_rows, str_len, 0.1, 0.5, false); + let arg_fields = args_string_utf8 + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); c.bench_function( - &format!("character_length_StringArray_utf8_str_len_{}", str_len), + &format!("character_length_StringArray_utf8_str_len_{str_len}"), |b| { b.iter(|| { black_box(character_length.invoke_with_args(ScalarFunctionArgs { args: args_string_utf8.clone(), + arg_fields: arg_fields.clone(), number_rows: n_rows, - return_type: &return_type, + return_field: Arc::clone(&return_field), })) }) }, @@ -64,14 +81,22 @@ fn criterion_benchmark(c: &mut Criterion) { // StringViewArray ASCII only let args_string_view_ascii = gen_string_array(n_rows, str_len, 0.1, 0.0, true); + let arg_fields = args_string_view_ascii + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); c.bench_function( - &format!("character_length_StringViewArray_ascii_str_len_{}", str_len), + &format!("character_length_StringViewArray_ascii_str_len_{str_len}"), |b| { b.iter(|| { black_box(character_length.invoke_with_args(ScalarFunctionArgs { args: args_string_view_ascii.clone(), + arg_fields: arg_fields.clone(), number_rows: n_rows, - return_type: &return_type, + return_field: Arc::clone(&return_field), })) }) }, @@ -79,14 +104,22 @@ fn criterion_benchmark(c: &mut Criterion) { // StringViewArray UTF8 let args_string_view_utf8 = gen_string_array(n_rows, str_len, 0.1, 0.5, true); + let arg_fields = args_string_view_utf8 + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); c.bench_function( - &format!("character_length_StringViewArray_utf8_str_len_{}", str_len), + &format!("character_length_StringViewArray_utf8_str_len_{str_len}"), |b| { b.iter(|| { black_box(character_length.invoke_with_args(ScalarFunctionArgs { args: args_string_view_utf8.clone(), + arg_fields: arg_fields.clone(), number_rows: n_rows, - return_type: &return_type, + return_field: Arc::clone(&return_field), })) }) }, diff --git a/datafusion/functions/benches/chr.rs b/datafusion/functions/benches/chr.rs index 8575809c21c8..6a956bb78812 100644 --- a/datafusion/functions/benches/chr.rs +++ b/datafusion/functions/benches/chr.rs @@ -23,7 +23,7 @@ use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::string::chr; use rand::{Rng, SeedableRng}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use rand::rngs::StdRng; use std::sync::Arc; @@ -37,27 +37,34 @@ fn criterion_benchmark(c: &mut Criterion) { let size = 1024; let input: PrimitiveArray = { let null_density = 0.2; - let mut rng = seedable_rng(); + let mut rng = StdRng::seed_from_u64(42); (0..size) .map(|_| { - if rng.gen::() < null_density { + if rng.random::() < null_density { None } else { - Some(rng.gen_range::(1i64..10_000)) + Some(rng.random_range::(1i64..10_000)) } }) .collect() }; let input = Arc::new(input); let args = vec![ColumnarValue::Array(input)]; + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true).into()) + .collect::>(); + c.bench_function("chr", |b| { b.iter(|| { black_box( cot_fn .invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), }) .unwrap(), ) diff --git a/datafusion/functions/benches/concat.rs b/datafusion/functions/benches/concat.rs index 45ca076e754f..d350c03c497b 100644 --- a/datafusion/functions/benches/concat.rs +++ b/datafusion/functions/benches/concat.rs @@ -16,7 +16,7 @@ // under the License. use arrow::array::ArrayRef; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::create_string_array_with_len; use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; use datafusion_common::ScalarValue; @@ -37,6 +37,14 @@ fn create_args(size: usize, str_len: usize) -> Vec { fn criterion_benchmark(c: &mut Criterion) { for size in [1024, 4096, 8192] { let args = create_args(size, 32); + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + let mut group = c.benchmark_group("concat function"); group.bench_function(BenchmarkId::new("concat", size), |b| { b.iter(|| { @@ -45,8 +53,9 @@ fn criterion_benchmark(c: &mut Criterion) { concat() .invoke_with_args(ScalarFunctionArgs { args: args_cloned, + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), }) .unwrap(), ) diff --git a/datafusion/functions/benches/cot.rs b/datafusion/functions/benches/cot.rs index b2a9ca0b9f47..a32e0d834672 100644 --- a/datafusion/functions/benches/cot.rs +++ b/datafusion/functions/benches/cot.rs @@ -25,7 +25,7 @@ use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::math::cot; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use std::sync::Arc; fn criterion_benchmark(c: &mut Criterion) { @@ -33,14 +33,23 @@ fn criterion_benchmark(c: &mut Criterion) { for size in [1024, 4096, 8192] { let f32_array = Arc::new(create_primitive_array::(size, 0.2)); let f32_args = vec![ColumnarValue::Array(f32_array)]; - c.bench_function(&format!("cot f32 array: {}", size), |b| { + let arg_fields = f32_args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + + c.bench_function(&format!("cot f32 array: {size}"), |b| { b.iter(|| { black_box( cot_fn .invoke_with_args(ScalarFunctionArgs { args: f32_args.clone(), + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Float32, + return_field: Field::new("f", DataType::Float32, true).into(), }) .unwrap(), ) @@ -48,14 +57,24 @@ fn criterion_benchmark(c: &mut Criterion) { }); let f64_array = Arc::new(create_primitive_array::(size, 0.2)); let f64_args = vec![ColumnarValue::Array(f64_array)]; - c.bench_function(&format!("cot f64 array: {}", size), |b| { + let arg_fields = f64_args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + let return_field = Arc::new(Field::new("f", DataType::Float64, true)); + + c.bench_function(&format!("cot f64 array: {size}"), |b| { b.iter(|| { black_box( cot_fn .invoke_with_args(ScalarFunctionArgs { args: f64_args.clone(), + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Float64, + return_field: Arc::clone(&return_field), }) .unwrap(), ) diff --git a/datafusion/functions/benches/date_bin.rs b/datafusion/functions/benches/date_bin.rs index 7ea5fdcb2be2..ac766a002576 100644 --- a/datafusion/functions/benches/date_bin.rs +++ b/datafusion/functions/benches/date_bin.rs @@ -20,6 +20,7 @@ extern crate criterion; use std::sync::Arc; use arrow::array::{Array, ArrayRef, TimestampSecondArray}; +use arrow::datatypes::Field; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_common::ScalarValue; use rand::rngs::ThreadRng; @@ -31,7 +32,7 @@ use datafusion_functions::datetime::date_bin; fn timestamps(rng: &mut ThreadRng) -> TimestampSecondArray { let mut seconds = vec![]; for _ in 0..1000 { - seconds.push(rng.gen_range(0..1_000_000)); + seconds.push(rng.random_range(0..1_000_000)); } TimestampSecondArray::from(seconds) @@ -39,7 +40,7 @@ fn timestamps(rng: &mut ThreadRng) -> TimestampSecondArray { fn criterion_benchmark(c: &mut Criterion) { c.bench_function("date_bin_1000", |b| { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let timestamps_array = Arc::new(timestamps(&mut rng)) as ArrayRef; let batch_len = timestamps_array.len(); let interval = ColumnarValue::Scalar(ScalarValue::new_interval_dt(0, 1_000_000)); @@ -48,13 +49,19 @@ fn criterion_benchmark(c: &mut Criterion) { let return_type = udf .return_type(&[interval.data_type(), timestamps.data_type()]) .unwrap(); + let return_field = Arc::new(Field::new("f", return_type, true)); + let arg_fields = vec![ + Field::new("a", interval.data_type(), true).into(), + Field::new("b", timestamps.data_type(), true).into(), + ]; b.iter(|| { black_box( udf.invoke_with_args(ScalarFunctionArgs { args: vec![interval.clone(), timestamps.clone()], + arg_fields: arg_fields.clone(), number_rows: batch_len, - return_type: &return_type, + return_field: Arc::clone(&return_field), }) .expect("date_bin should work on valid values"), ) diff --git a/datafusion/functions/benches/date_trunc.rs b/datafusion/functions/benches/date_trunc.rs index e7e96fb7a9fa..ad4d0d0fbb79 100644 --- a/datafusion/functions/benches/date_trunc.rs +++ b/datafusion/functions/benches/date_trunc.rs @@ -20,6 +20,7 @@ extern crate criterion; use std::sync::Arc; use arrow::array::{Array, ArrayRef, TimestampSecondArray}; +use arrow::datatypes::Field; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_common::ScalarValue; use rand::rngs::ThreadRng; @@ -31,7 +32,7 @@ use datafusion_functions::datetime::date_trunc; fn timestamps(rng: &mut ThreadRng) -> TimestampSecondArray { let mut seconds = vec![]; for _ in 0..1000 { - seconds.push(rng.gen_range(0..1_000_000)); + seconds.push(rng.random_range(0..1_000_000)); } TimestampSecondArray::from(seconds) @@ -39,7 +40,7 @@ fn timestamps(rng: &mut ThreadRng) -> TimestampSecondArray { fn criterion_benchmark(c: &mut Criterion) { c.bench_function("date_trunc_minute_1000", |b| { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let timestamps_array = Arc::new(timestamps(&mut rng)) as ArrayRef; let batch_len = timestamps_array.len(); let precision = @@ -47,15 +48,25 @@ fn criterion_benchmark(c: &mut Criterion) { let timestamps = ColumnarValue::Array(timestamps_array); let udf = date_trunc(); let args = vec![precision, timestamps]; - let return_type = &udf + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + + let return_type = udf .return_type(&args.iter().map(|arg| arg.data_type()).collect::>()) .unwrap(); + let return_field = Arc::new(Field::new("f", return_type, true)); b.iter(|| { black_box( udf.invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_fields: arg_fields.clone(), number_rows: batch_len, - return_type, + return_field: Arc::clone(&return_field), }) .expect("date_trunc should work on valid values"), ) diff --git a/datafusion/functions/benches/encoding.rs b/datafusion/functions/benches/encoding.rs index cf8f8d2fd62c..830e0324766f 100644 --- a/datafusion/functions/benches/encoding.rs +++ b/datafusion/functions/benches/encoding.rs @@ -17,7 +17,8 @@ extern crate criterion; -use arrow::datatypes::DataType; +use arrow::array::Array; +use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::create_string_array_with_len; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; @@ -33,19 +34,29 @@ fn criterion_benchmark(c: &mut Criterion) { let encoded = encoding::encode() .invoke_with_args(ScalarFunctionArgs { args: vec![ColumnarValue::Array(str_array.clone()), method.clone()], + arg_fields: vec![ + Field::new("a", str_array.data_type().to_owned(), true).into(), + Field::new("b", method.data_type().to_owned(), true).into(), + ], number_rows: size, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), }) .unwrap(); + let arg_fields = vec![ + Field::new("a", encoded.data_type().to_owned(), true).into(), + Field::new("b", method.data_type().to_owned(), true).into(), + ]; let args = vec![encoded, method]; + b.iter(|| { black_box( decode .invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), }) .unwrap(), ) @@ -54,22 +65,34 @@ fn criterion_benchmark(c: &mut Criterion) { c.bench_function(&format!("hex_decode/{size}"), |b| { let method = ColumnarValue::Scalar("hex".into()); + let arg_fields = vec![ + Field::new("a", str_array.data_type().to_owned(), true).into(), + Field::new("b", method.data_type().to_owned(), true).into(), + ]; let encoded = encoding::encode() .invoke_with_args(ScalarFunctionArgs { args: vec![ColumnarValue::Array(str_array.clone()), method.clone()], + arg_fields, number_rows: size, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), }) .unwrap(); + let arg_fields = vec![ + Field::new("a", encoded.data_type().to_owned(), true).into(), + Field::new("b", method.data_type().to_owned(), true).into(), + ]; + let return_field = Field::new("f", DataType::Utf8, true).into(); let args = vec![encoded, method]; + b.iter(|| { black_box( decode .invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Utf8, + return_field: Arc::clone(&return_field), }) .unwrap(), ) diff --git a/datafusion/functions/benches/find_in_set.rs b/datafusion/functions/benches/find_in_set.rs index 9307525482c2..bad540f049e2 100644 --- a/datafusion/functions/benches/find_in_set.rs +++ b/datafusion/functions/benches/find_in_set.rs @@ -18,14 +18,14 @@ extern crate criterion; use arrow::array::{StringArray, StringViewArray}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::{ create_string_array_with_len, create_string_view_array_with_len, }; use criterion::{black_box, criterion_group, criterion_main, Criterion, SamplingMode}; use datafusion_common::ScalarValue; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; -use rand::distributions::Alphanumeric; +use rand::distr::Alphanumeric; use rand::prelude::StdRng; use rand::{Rng, SeedableRng}; use std::sync::Arc; @@ -51,7 +51,7 @@ fn gen_args_array( let mut output_set_vec: Vec> = Vec::with_capacity(n_rows); let mut output_element_vec: Vec> = Vec::with_capacity(n_rows); for _ in 0..n_rows { - let rand_num = rng_ref.gen::(); // [0.0, 1.0) + let rand_num = rng_ref.random::(); // [0.0, 1.0) if rand_num < null_density { output_element_vec.push(None); output_set_vec.push(None); @@ -60,7 +60,7 @@ fn gen_args_array( let mut generated_string = String::with_capacity(str_len_chars); for i in 0..num_elements { for _ in 0..str_len_chars { - let idx = rng_ref.gen_range(0..corpus_char_count); + let idx = rng_ref.random_range(0..corpus_char_count); let char = utf8.chars().nth(idx).unwrap(); generated_string.push(char); } @@ -112,7 +112,7 @@ fn random_element_in_set(string: &str) -> String { } let mut rng = StdRng::seed_from_u64(44); - let random_index = rng.gen_range(0..elements.len()); + let random_index = rng.random_range(0..elements.len()); elements[random_index].to_string() } @@ -153,23 +153,35 @@ fn criterion_benchmark(c: &mut Criterion) { group.measurement_time(Duration::from_secs(10)); let args = gen_args_array(n_rows, str_len, 0.1, 0.5, false); - group.bench_function(format!("string_len_{}", str_len), |b| { + let arg_fields = args + .iter() + .map(|arg| Field::new("a", arg.data_type().clone(), true).into()) + .collect::>(); + let return_field = Field::new("f", DataType::Int32, true).into(); + group.bench_function(format!("string_len_{str_len}"), |b| { b.iter(|| { black_box(find_in_set.invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_fields: arg_fields.clone(), number_rows: n_rows, - return_type: &DataType::Int32, + return_field: Arc::clone(&return_field), })) }) }); let args = gen_args_array(n_rows, str_len, 0.1, 0.5, true); - group.bench_function(format!("string_view_len_{}", str_len), |b| { + let arg_fields = args + .iter() + .map(|arg| Field::new("a", arg.data_type().clone(), true).into()) + .collect::>(); + let return_field = Arc::new(Field::new("f", DataType::Int32, true)); + group.bench_function(format!("string_view_len_{str_len}"), |b| { b.iter(|| { black_box(find_in_set.invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_fields: arg_fields.clone(), number_rows: n_rows, - return_type: &DataType::Int32, + return_field: Arc::clone(&return_field), })) }) }); @@ -179,23 +191,35 @@ fn criterion_benchmark(c: &mut Criterion) { let mut group = c.benchmark_group("find_in_set_scalar"); let args = gen_args_scalar(n_rows, str_len, 0.1, false); - group.bench_function(format!("string_len_{}", str_len), |b| { + let arg_fields = args + .iter() + .map(|arg| Field::new("a", arg.data_type().clone(), true).into()) + .collect::>(); + let return_field = Arc::new(Field::new("f", DataType::Int32, true)); + group.bench_function(format!("string_len_{str_len}"), |b| { b.iter(|| { black_box(find_in_set.invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_fields: arg_fields.clone(), number_rows: n_rows, - return_type: &DataType::Int32, + return_field: Arc::clone(&return_field), })) }) }); let args = gen_args_scalar(n_rows, str_len, 0.1, true); - group.bench_function(format!("string_view_len_{}", str_len), |b| { + let arg_fields = args + .iter() + .map(|arg| Field::new("a", arg.data_type().clone(), true).into()) + .collect::>(); + let return_field = Arc::new(Field::new("f", DataType::Int32, true)); + group.bench_function(format!("string_view_len_{str_len}"), |b| { b.iter(|| { black_box(find_in_set.invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_fields: arg_fields.clone(), number_rows: n_rows, - return_type: &DataType::Int32, + return_field: Arc::clone(&return_field), })) }) }); diff --git a/datafusion/functions/benches/gcd.rs b/datafusion/functions/benches/gcd.rs index f8c855c82ad4..f700d31123a9 100644 --- a/datafusion/functions/benches/gcd.rs +++ b/datafusion/functions/benches/gcd.rs @@ -17,6 +17,7 @@ extern crate criterion; +use arrow::datatypes::Field; use arrow::{ array::{ArrayRef, Int64Array}, datatypes::DataType, @@ -29,9 +30,9 @@ use rand::Rng; use std::sync::Arc; fn generate_i64_array(n_rows: usize) -> ArrayRef { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let values = (0..n_rows) - .map(|_| rng.gen_range(0..1000)) + .map(|_| rng.random_range(0..1000)) .collect::>(); Arc::new(Int64Array::from(values)) as ArrayRef } @@ -47,8 +48,12 @@ fn criterion_benchmark(c: &mut Criterion) { black_box( udf.invoke_with_args(ScalarFunctionArgs { args: vec![array_a.clone(), array_b.clone()], + arg_fields: vec![ + Field::new("a", array_a.data_type(), true).into(), + Field::new("b", array_b.data_type(), true).into(), + ], number_rows: 0, - return_type: &DataType::Int64, + return_field: Field::new("f", DataType::Int64, true).into(), }) .expect("date_bin should work on valid values"), ) @@ -63,8 +68,12 @@ fn criterion_benchmark(c: &mut Criterion) { black_box( udf.invoke_with_args(ScalarFunctionArgs { args: vec![array_a.clone(), scalar_b.clone()], + arg_fields: vec![ + Field::new("a", array_a.data_type(), true).into(), + Field::new("b", scalar_b.data_type(), true).into(), + ], number_rows: 0, - return_type: &DataType::Int64, + return_field: Field::new("f", DataType::Int64, true).into(), }) .expect("date_bin should work on valid values"), ) @@ -79,8 +88,12 @@ fn criterion_benchmark(c: &mut Criterion) { black_box( udf.invoke_with_args(ScalarFunctionArgs { args: vec![scalar_a.clone(), scalar_b.clone()], + arg_fields: vec![ + Field::new("a", scalar_a.data_type(), true).into(), + Field::new("b", scalar_b.data_type(), true).into(), + ], number_rows: 0, - return_type: &DataType::Int64, + return_field: Field::new("f", DataType::Int64, true).into(), }) .expect("date_bin should work on valid values"), ) diff --git a/datafusion/functions/benches/helper.rs b/datafusion/functions/benches/helper.rs index 0dbb4b0027d4..a2b110ae4d63 100644 --- a/datafusion/functions/benches/helper.rs +++ b/datafusion/functions/benches/helper.rs @@ -17,7 +17,7 @@ use arrow::array::{StringArray, StringViewArray}; use datafusion_expr::ColumnarValue; -use rand::distributions::Alphanumeric; +use rand::distr::Alphanumeric; use rand::{rngs::StdRng, Rng, SeedableRng}; use std::sync::Arc; @@ -39,14 +39,14 @@ pub fn gen_string_array( let mut output_string_vec: Vec> = Vec::with_capacity(n_rows); for _ in 0..n_rows { - let rand_num = rng_ref.gen::(); // [0.0, 1.0) + let rand_num = rng_ref.random::(); // [0.0, 1.0) if rand_num < null_density { output_string_vec.push(None); } else if rand_num < null_density + utf8_density { // Generate random UTF8 string let mut generated_string = String::with_capacity(str_len_chars); for _ in 0..str_len_chars { - let char = corpus[rng_ref.gen_range(0..corpus.len())]; + let char = corpus[rng_ref.random_range(0..corpus.len())]; generated_string.push(char); } output_string_vec.push(Some(generated_string)); diff --git a/datafusion/functions/benches/initcap.rs b/datafusion/functions/benches/initcap.rs index 97c76831b33c..f89b11dff8fb 100644 --- a/datafusion/functions/benches/initcap.rs +++ b/datafusion/functions/benches/initcap.rs @@ -18,7 +18,7 @@ extern crate criterion; use arrow::array::OffsetSizeTrait; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::{ create_string_array_with_len, create_string_view_array_with_len, }; @@ -49,14 +49,23 @@ fn criterion_benchmark(c: &mut Criterion) { let initcap = unicode::initcap(); for size in [1024, 4096] { let args = create_args::(size, 8, true); + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + c.bench_function( - format!("initcap string view shorter than 12 [size={}]", size).as_str(), + format!("initcap string view shorter than 12 [size={size}]").as_str(), |b| { b.iter(|| { black_box(initcap.invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Utf8View, + return_field: Field::new("f", DataType::Utf8View, true).into(), })) }) }, @@ -64,25 +73,27 @@ fn criterion_benchmark(c: &mut Criterion) { let args = create_args::(size, 16, true); c.bench_function( - format!("initcap string view longer than 12 [size={}]", size).as_str(), + format!("initcap string view longer than 12 [size={size}]").as_str(), |b| { b.iter(|| { black_box(initcap.invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Utf8View, + return_field: Field::new("f", DataType::Utf8View, true).into(), })) }) }, ); let args = create_args::(size, 16, false); - c.bench_function(format!("initcap string [size={}]", size).as_str(), |b| { + c.bench_function(format!("initcap string [size={size}]").as_str(), |b| { b.iter(|| { black_box(initcap.invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), })) }) }); diff --git a/datafusion/functions/benches/isnan.rs b/datafusion/functions/benches/isnan.rs index 42004cc24f69..49d0a9e326dd 100644 --- a/datafusion/functions/benches/isnan.rs +++ b/datafusion/functions/benches/isnan.rs @@ -17,7 +17,7 @@ extern crate criterion; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use arrow::{ datatypes::{Float32Type, Float64Type}, util::bench_util::create_primitive_array, @@ -32,14 +32,23 @@ fn criterion_benchmark(c: &mut Criterion) { for size in [1024, 4096, 8192] { let f32_array = Arc::new(create_primitive_array::(size, 0.2)); let f32_args = vec![ColumnarValue::Array(f32_array)]; - c.bench_function(&format!("isnan f32 array: {}", size), |b| { + let arg_fields = f32_args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + + c.bench_function(&format!("isnan f32 array: {size}"), |b| { b.iter(|| { black_box( isnan .invoke_with_args(ScalarFunctionArgs { args: f32_args.clone(), + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Boolean, + return_field: Field::new("f", DataType::Boolean, true).into(), }) .unwrap(), ) @@ -47,14 +56,22 @@ fn criterion_benchmark(c: &mut Criterion) { }); let f64_array = Arc::new(create_primitive_array::(size, 0.2)); let f64_args = vec![ColumnarValue::Array(f64_array)]; - c.bench_function(&format!("isnan f64 array: {}", size), |b| { + let arg_fields = f64_args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + c.bench_function(&format!("isnan f64 array: {size}"), |b| { b.iter(|| { black_box( isnan .invoke_with_args(ScalarFunctionArgs { args: f64_args.clone(), + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Boolean, + return_field: Field::new("f", DataType::Boolean, true).into(), }) .unwrap(), ) diff --git a/datafusion/functions/benches/iszero.rs b/datafusion/functions/benches/iszero.rs index 9e5f6a84804b..6d1d34c7a832 100644 --- a/datafusion/functions/benches/iszero.rs +++ b/datafusion/functions/benches/iszero.rs @@ -17,7 +17,7 @@ extern crate criterion; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use arrow::{ datatypes::{Float32Type, Float64Type}, util::bench_util::create_primitive_array, @@ -33,14 +33,24 @@ fn criterion_benchmark(c: &mut Criterion) { let f32_array = Arc::new(create_primitive_array::(size, 0.2)); let batch_len = f32_array.len(); let f32_args = vec![ColumnarValue::Array(f32_array)]; - c.bench_function(&format!("iszero f32 array: {}", size), |b| { + let arg_fields = f32_args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + let return_field = Arc::new(Field::new("f", DataType::Boolean, true)); + + c.bench_function(&format!("iszero f32 array: {size}"), |b| { b.iter(|| { black_box( iszero .invoke_with_args(ScalarFunctionArgs { args: f32_args.clone(), + arg_fields: arg_fields.clone(), number_rows: batch_len, - return_type: &DataType::Boolean, + return_field: Arc::clone(&return_field), }) .unwrap(), ) @@ -49,14 +59,24 @@ fn criterion_benchmark(c: &mut Criterion) { let f64_array = Arc::new(create_primitive_array::(size, 0.2)); let batch_len = f64_array.len(); let f64_args = vec![ColumnarValue::Array(f64_array)]; - c.bench_function(&format!("iszero f64 array: {}", size), |b| { + let arg_fields = f64_args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + let return_field = Arc::new(Field::new("f", DataType::Boolean, true)); + + c.bench_function(&format!("iszero f64 array: {size}"), |b| { b.iter(|| { black_box( iszero .invoke_with_args(ScalarFunctionArgs { args: f64_args.clone(), + arg_fields: arg_fields.clone(), number_rows: batch_len, - return_type: &DataType::Boolean, + return_field: Arc::clone(&return_field), }) .unwrap(), ) diff --git a/datafusion/functions/benches/lower.rs b/datafusion/functions/benches/lower.rs index 534e5739225d..cdf1529c108c 100644 --- a/datafusion/functions/benches/lower.rs +++ b/datafusion/functions/benches/lower.rs @@ -18,7 +18,7 @@ extern crate criterion; use arrow::array::{ArrayRef, StringArray, StringViewBuilder}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::{ create_string_array_with_len, create_string_view_array_with_len, }; @@ -44,7 +44,7 @@ fn create_args2(size: usize) -> Vec { let mut items = Vec::with_capacity(size); items.push("农历新年".to_string()); for i in 1..size { - items.push(format!("DATAFUSION {}", i)); + items.push(format!("DATAFUSION {i}")); } let array = Arc::new(StringArray::from(items)) as ArrayRef; vec![ColumnarValue::Array(array)] @@ -58,11 +58,11 @@ fn create_args3(size: usize) -> Vec { let mut items = Vec::with_capacity(size); let half = size / 2; for i in 0..half { - items.push(format!("DATAFUSION {}", i)); + items.push(format!("DATAFUSION {i}")); } items.push("Ⱦ".to_string()); for i in half + 1..size { - items.push(format!("DATAFUSION {}", i)); + items.push(format!("DATAFUSION {i}")); } let array = Arc::new(StringArray::from(items)) as ArrayRef; vec![ColumnarValue::Array(array)] @@ -124,42 +124,66 @@ fn criterion_benchmark(c: &mut Criterion) { let lower = string::lower(); for size in [1024, 4096, 8192] { let args = create_args1(size, 32); - c.bench_function(&format!("lower_all_values_are_ascii: {}", size), |b| { + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + + c.bench_function(&format!("lower_all_values_are_ascii: {size}"), |b| { b.iter(|| { let args_cloned = args.clone(); black_box(lower.invoke_with_args(ScalarFunctionArgs { args: args_cloned, + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), })) }) }); let args = create_args2(size); - c.bench_function( - &format!("lower_the_first_value_is_nonascii: {}", size), - |b| { - b.iter(|| { - let args_cloned = args.clone(); - black_box(lower.invoke_with_args(ScalarFunctionArgs { - args: args_cloned, - number_rows: size, - return_type: &DataType::Utf8, - })) - }) - }, - ); + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + + c.bench_function(&format!("lower_the_first_value_is_nonascii: {size}"), |b| { + b.iter(|| { + let args_cloned = args.clone(); + black_box(lower.invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: arg_fields.clone(), + number_rows: size, + return_field: Field::new("f", DataType::Utf8, true).into(), + })) + }) + }); let args = create_args3(size); + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + c.bench_function( - &format!("lower_the_middle_value_is_nonascii: {}", size), + &format!("lower_the_middle_value_is_nonascii: {size}"), |b| { b.iter(|| { let args_cloned = args.clone(); black_box(lower.invoke_with_args(ScalarFunctionArgs { args: args_cloned, + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), })) }) }, @@ -176,29 +200,37 @@ fn criterion_benchmark(c: &mut Criterion) { for &str_len in &str_lens { for &size in &sizes { let args = create_args4(size, str_len, *null_density, mixed); + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + c.bench_function( - &format!("lower_all_values_are_ascii_string_views: size: {}, str_len: {}, null_density: {}, mixed: {}", - size, str_len, null_density, mixed), + &format!("lower_all_values_are_ascii_string_views: size: {size}, str_len: {str_len}, null_density: {null_density}, mixed: {mixed}"), |b| b.iter(|| { let args_cloned = args.clone(); black_box(lower.invoke_with_args(ScalarFunctionArgs{ args: args_cloned, + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), })) }), ); let args = create_args4(size, str_len, *null_density, mixed); c.bench_function( - &format!("lower_all_values_are_ascii_string_views: size: {}, str_len: {}, null_density: {}, mixed: {}", - size, str_len, null_density, mixed), + &format!("lower_all_values_are_ascii_string_views: size: {size}, str_len: {str_len}, null_density: {null_density}, mixed: {mixed}"), |b| b.iter(|| { let args_cloned = args.clone(); black_box(lower.invoke_with_args(ScalarFunctionArgs{ args: args_cloned, + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), })) }), ); @@ -211,8 +243,9 @@ fn criterion_benchmark(c: &mut Criterion) { let args_cloned = args.clone(); black_box(lower.invoke_with_args(ScalarFunctionArgs{ args: args_cloned, + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), })) }), ); diff --git a/datafusion/functions/benches/ltrim.rs b/datafusion/functions/benches/ltrim.rs index 457fb499f5a1..7a44f40a689a 100644 --- a/datafusion/functions/benches/ltrim.rs +++ b/datafusion/functions/benches/ltrim.rs @@ -18,7 +18,7 @@ extern crate criterion; use arrow::array::{ArrayRef, LargeStringArray, StringArray, StringViewArray}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use criterion::{ black_box, criterion_group, criterion_main, measurement::Measurement, BenchmarkGroup, Criterion, SamplingMode, @@ -26,13 +26,9 @@ use criterion::{ use datafusion_common::ScalarValue; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDF}; use datafusion_functions::string; -use rand::{distributions::Alphanumeric, rngs::StdRng, Rng, SeedableRng}; +use rand::{distr::Alphanumeric, rngs::StdRng, Rng, SeedableRng}; use std::{fmt, sync::Arc}; -pub fn seedable_rng() -> StdRng { - StdRng::seed_from_u64(42) -} - #[derive(Clone, Copy)] pub enum StringArrayType { Utf8View, @@ -58,14 +54,14 @@ pub fn create_string_array_and_characters( remaining_len: usize, string_array_type: StringArrayType, ) -> (ArrayRef, ScalarValue) { - let rng = &mut seedable_rng(); + let rng = &mut StdRng::seed_from_u64(42); // Create `size` rows: // - 10% rows will be `None` // - Other 90% will be strings with same `remaining_len` lengths // We will build the string array on it later. let string_iter = (0..size).map(|_| { - if rng.gen::() < 0.1 { + if rng.random::() < 0.1 { None } else { let mut value = trimmed.as_bytes().to_vec(); @@ -136,6 +132,11 @@ fn run_with_string_type( string_type: StringArrayType, ) { let args = create_args(size, characters, trimmed, remaining_len, string_type); + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true).into()) + .collect::>(); group.bench_function( format!( "{string_type} [size={size}, len_before={len}, len_after={remaining_len}]", @@ -145,8 +146,9 @@ fn run_with_string_type( let args_cloned = args.clone(); black_box(ltrim.invoke_with_args(ScalarFunctionArgs { args: args_cloned, + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), })) }) }, diff --git a/datafusion/functions/benches/make_date.rs b/datafusion/functions/benches/make_date.rs index 8dd7a7a59773..e1f609fbb35c 100644 --- a/datafusion/functions/benches/make_date.rs +++ b/datafusion/functions/benches/make_date.rs @@ -20,7 +20,7 @@ extern crate criterion; use std::sync::Arc; use arrow::array::{Array, ArrayRef, Int32Array}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use rand::rngs::ThreadRng; use rand::Rng; @@ -32,7 +32,7 @@ use datafusion_functions::datetime::make_date; fn years(rng: &mut ThreadRng) -> Int32Array { let mut years = vec![]; for _ in 0..1000 { - years.push(rng.gen_range(1900..2050)); + years.push(rng.random_range(1900..2050)); } Int32Array::from(years) @@ -41,7 +41,7 @@ fn years(rng: &mut ThreadRng) -> Int32Array { fn months(rng: &mut ThreadRng) -> Int32Array { let mut months = vec![]; for _ in 0..1000 { - months.push(rng.gen_range(1..13)); + months.push(rng.random_range(1..13)); } Int32Array::from(months) @@ -50,27 +50,34 @@ fn months(rng: &mut ThreadRng) -> Int32Array { fn days(rng: &mut ThreadRng) -> Int32Array { let mut days = vec![]; for _ in 0..1000 { - days.push(rng.gen_range(1..29)); + days.push(rng.random_range(1..29)); } Int32Array::from(days) } fn criterion_benchmark(c: &mut Criterion) { c.bench_function("make_date_col_col_col_1000", |b| { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let years_array = Arc::new(years(&mut rng)) as ArrayRef; let batch_len = years_array.len(); let years = ColumnarValue::Array(years_array); let months = ColumnarValue::Array(Arc::new(months(&mut rng)) as ArrayRef); let days = ColumnarValue::Array(Arc::new(days(&mut rng)) as ArrayRef); + let arg_fields = vec![ + Field::new("a", years.data_type(), true).into(), + Field::new("a", months.data_type(), true).into(), + Field::new("a", days.data_type(), true).into(), + ]; + let return_field = Field::new("f", DataType::Date32, true).into(); b.iter(|| { black_box( make_date() .invoke_with_args(ScalarFunctionArgs { args: vec![years.clone(), months.clone(), days.clone()], + arg_fields: arg_fields.clone(), number_rows: batch_len, - return_type: &DataType::Date32, + return_field: Arc::clone(&return_field), }) .expect("make_date should work on valid values"), ) @@ -78,20 +85,26 @@ fn criterion_benchmark(c: &mut Criterion) { }); c.bench_function("make_date_scalar_col_col_1000", |b| { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let year = ColumnarValue::Scalar(ScalarValue::Int32(Some(2025))); let months_arr = Arc::new(months(&mut rng)) as ArrayRef; let batch_len = months_arr.len(); let months = ColumnarValue::Array(months_arr); let days = ColumnarValue::Array(Arc::new(days(&mut rng)) as ArrayRef); - + let arg_fields = vec![ + Field::new("a", year.data_type(), true).into(), + Field::new("a", months.data_type(), true).into(), + Field::new("a", days.data_type(), true).into(), + ]; + let return_field = Field::new("f", DataType::Date32, true).into(); b.iter(|| { black_box( make_date() .invoke_with_args(ScalarFunctionArgs { args: vec![year.clone(), months.clone(), days.clone()], + arg_fields: arg_fields.clone(), number_rows: batch_len, - return_type: &DataType::Date32, + return_field: Arc::clone(&return_field), }) .expect("make_date should work on valid values"), ) @@ -99,20 +112,26 @@ fn criterion_benchmark(c: &mut Criterion) { }); c.bench_function("make_date_scalar_scalar_col_1000", |b| { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let year = ColumnarValue::Scalar(ScalarValue::Int32(Some(2025))); let month = ColumnarValue::Scalar(ScalarValue::Int32(Some(11))); let day_arr = Arc::new(days(&mut rng)); let batch_len = day_arr.len(); let days = ColumnarValue::Array(day_arr); - + let arg_fields = vec![ + Field::new("a", year.data_type(), true).into(), + Field::new("a", month.data_type(), true).into(), + Field::new("a", days.data_type(), true).into(), + ]; + let return_field = Field::new("f", DataType::Date32, true).into(); b.iter(|| { black_box( make_date() .invoke_with_args(ScalarFunctionArgs { args: vec![year.clone(), month.clone(), days.clone()], + arg_fields: arg_fields.clone(), number_rows: batch_len, - return_type: &DataType::Date32, + return_field: Arc::clone(&return_field), }) .expect("make_date should work on valid values"), ) @@ -123,14 +142,21 @@ fn criterion_benchmark(c: &mut Criterion) { let year = ColumnarValue::Scalar(ScalarValue::Int32(Some(2025))); let month = ColumnarValue::Scalar(ScalarValue::Int32(Some(11))); let day = ColumnarValue::Scalar(ScalarValue::Int32(Some(26))); + let arg_fields = vec![ + Field::new("a", year.data_type(), true).into(), + Field::new("a", month.data_type(), true).into(), + Field::new("a", day.data_type(), true).into(), + ]; + let return_field = Field::new("f", DataType::Date32, true).into(); b.iter(|| { black_box( make_date() .invoke_with_args(ScalarFunctionArgs { args: vec![year.clone(), month.clone(), day.clone()], + arg_fields: arg_fields.clone(), number_rows: 1, - return_type: &DataType::Date32, + return_field: Arc::clone(&return_field), }) .expect("make_date should work on valid values"), ) diff --git a/datafusion/functions/benches/nullif.rs b/datafusion/functions/benches/nullif.rs index 9096c976bf31..4ac977af9d42 100644 --- a/datafusion/functions/benches/nullif.rs +++ b/datafusion/functions/benches/nullif.rs @@ -17,7 +17,7 @@ extern crate criterion; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::create_string_array_with_len; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_common::ScalarValue; @@ -33,14 +33,23 @@ fn criterion_benchmark(c: &mut Criterion) { ColumnarValue::Scalar(ScalarValue::Utf8(Some("abcd".to_string()))), ColumnarValue::Array(array), ]; - c.bench_function(&format!("nullif scalar array: {}", size), |b| { + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + + c.bench_function(&format!("nullif scalar array: {size}"), |b| { b.iter(|| { black_box( nullif .invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), }) .unwrap(), ) diff --git a/datafusion/functions/benches/pad.rs b/datafusion/functions/benches/pad.rs index f78a53fbee19..d954ff452ed5 100644 --- a/datafusion/functions/benches/pad.rs +++ b/datafusion/functions/benches/pad.rs @@ -16,14 +16,15 @@ // under the License. use arrow::array::{ArrayRef, ArrowPrimitiveType, OffsetSizeTrait, PrimitiveArray}; -use arrow::datatypes::{DataType, Int64Type}; +use arrow::datatypes::{DataType, Field, Int64Type}; use arrow::util::bench_util::{ create_string_array_with_len, create_string_view_array_with_len, }; use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; +use datafusion_common::DataFusionError; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::unicode::{lpad, rpad}; -use rand::distributions::{Distribution, Uniform}; +use rand::distr::{Distribution, Uniform}; use rand::Rng; use std::sync::Arc; @@ -52,13 +53,13 @@ where dist: Uniform::new_inclusive::(0, len as i64), }; - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); (0..size) .map(|_| { - if rng.gen::() < null_density { + if rng.random::() < null_density { None } else { - Some(rng.sample(&dist)) + Some(rng.sample(dist.dist.unwrap())) } }) .collect() @@ -95,21 +96,41 @@ fn create_args( } } +fn invoke_pad_with_args( + args: Vec, + number_rows: usize, + left_pad: bool, +) -> Result { + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true).into()) + .collect::>(); + + let scalar_args = ScalarFunctionArgs { + args: args.clone(), + arg_fields, + number_rows, + return_field: Field::new("f", DataType::Utf8, true).into(), + }; + + if left_pad { + lpad().invoke_with_args(scalar_args) + } else { + rpad().invoke_with_args(scalar_args) + } +} + fn criterion_benchmark(c: &mut Criterion) { for size in [1024, 2048] { let mut group = c.benchmark_group("lpad function"); let args = create_args::(size, 32, false); + group.bench_function(BenchmarkId::new("utf8 type", size), |b| { b.iter(|| { criterion::black_box( - lpad() - .invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - number_rows: size, - return_type: &DataType::Utf8, - }) - .unwrap(), + invoke_pad_with_args(args.clone(), size, true).unwrap(), ) }) }); @@ -118,13 +139,7 @@ fn criterion_benchmark(c: &mut Criterion) { group.bench_function(BenchmarkId::new("largeutf8 type", size), |b| { b.iter(|| { criterion::black_box( - lpad() - .invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - number_rows: size, - return_type: &DataType::LargeUtf8, - }) - .unwrap(), + invoke_pad_with_args(args.clone(), size, true).unwrap(), ) }) }); @@ -133,13 +148,7 @@ fn criterion_benchmark(c: &mut Criterion) { group.bench_function(BenchmarkId::new("stringview type", size), |b| { b.iter(|| { criterion::black_box( - lpad() - .invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - number_rows: size, - return_type: &DataType::Utf8, - }) - .unwrap(), + invoke_pad_with_args(args.clone(), size, true).unwrap(), ) }) }); @@ -152,13 +161,7 @@ fn criterion_benchmark(c: &mut Criterion) { group.bench_function(BenchmarkId::new("utf8 type", size), |b| { b.iter(|| { criterion::black_box( - rpad() - .invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - number_rows: size, - return_type: &DataType::Utf8, - }) - .unwrap(), + invoke_pad_with_args(args.clone(), size, false).unwrap(), ) }) }); @@ -167,13 +170,7 @@ fn criterion_benchmark(c: &mut Criterion) { group.bench_function(BenchmarkId::new("largeutf8 type", size), |b| { b.iter(|| { criterion::black_box( - rpad() - .invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - number_rows: size, - return_type: &DataType::LargeUtf8, - }) - .unwrap(), + invoke_pad_with_args(args.clone(), size, false).unwrap(), ) }) }); @@ -183,13 +180,7 @@ fn criterion_benchmark(c: &mut Criterion) { group.bench_function(BenchmarkId::new("stringview type", size), |b| { b.iter(|| { criterion::black_box( - rpad() - .invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - number_rows: size, - return_type: &DataType::Utf8, - }) - .unwrap(), + invoke_pad_with_args(args.clone(), size, false).unwrap(), ) }) }); diff --git a/datafusion/functions/benches/random.rs b/datafusion/functions/benches/random.rs index 78ebf23e02e0..dc1e280b93b1 100644 --- a/datafusion/functions/benches/random.rs +++ b/datafusion/functions/benches/random.rs @@ -17,14 +17,16 @@ extern crate criterion; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl}; use datafusion_functions::math::random::RandomFunc; +use std::sync::Arc; fn criterion_benchmark(c: &mut Criterion) { let random_func = RandomFunc::new(); + let return_field = Field::new("f", DataType::Float64, true).into(); // Benchmark to evaluate 1M rows in batch size 8192 let iterations = 1_000_000 / 8192; // Calculate how many iterations are needed to reach approximately 1M rows c.bench_function("random_1M_rows_batch_8192", |b| { @@ -34,8 +36,9 @@ fn criterion_benchmark(c: &mut Criterion) { random_func .invoke_with_args(ScalarFunctionArgs { args: vec![], + arg_fields: vec![], number_rows: 8192, - return_type: &DataType::Float64, + return_field: Arc::clone(&return_field), }) .unwrap(), ); @@ -43,6 +46,7 @@ fn criterion_benchmark(c: &mut Criterion) { }) }); + let return_field = Field::new("f", DataType::Float64, true).into(); // Benchmark to evaluate 1M rows in batch size 128 let iterations_128 = 1_000_000 / 128; // Calculate how many iterations are needed to reach approximately 1M rows with batch size 128 c.bench_function("random_1M_rows_batch_128", |b| { @@ -52,8 +56,9 @@ fn criterion_benchmark(c: &mut Criterion) { random_func .invoke_with_args(ScalarFunctionArgs { args: vec![], + arg_fields: vec![], number_rows: 128, - return_type: &DataType::Float64, + return_field: Arc::clone(&return_field), }) .unwrap(), ); diff --git a/datafusion/functions/benches/regx.rs b/datafusion/functions/benches/regx.rs index 3a1a6a71173e..c0b50ad62f64 100644 --- a/datafusion/functions/benches/regx.rs +++ b/datafusion/functions/benches/regx.rs @@ -26,9 +26,9 @@ use datafusion_functions::regex::regexpcount::regexp_count_func; use datafusion_functions::regex::regexplike::regexp_like; use datafusion_functions::regex::regexpmatch::regexp_match; use datafusion_functions::regex::regexpreplace::regexp_replace; -use rand::distributions::Alphanumeric; +use rand::distr::Alphanumeric; +use rand::prelude::IndexedRandom; use rand::rngs::ThreadRng; -use rand::seq::SliceRandom; use rand::Rng; use std::iter; use std::sync::Arc; @@ -65,7 +65,7 @@ fn regex(rng: &mut ThreadRng) -> StringArray { fn start(rng: &mut ThreadRng) -> Int64Array { let mut data: Vec = vec![]; for _ in 0..1000 { - data.push(rng.gen_range(1..5)); + data.push(rng.random_range(1..5)); } Int64Array::from(data) @@ -88,7 +88,7 @@ fn flags(rng: &mut ThreadRng) -> StringArray { fn criterion_benchmark(c: &mut Criterion) { c.bench_function("regexp_count_1000 string", |b| { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let data = Arc::new(data(&mut rng)) as ArrayRef; let regex = Arc::new(regex(&mut rng)) as ArrayRef; let start = Arc::new(start(&mut rng)) as ArrayRef; @@ -108,7 +108,7 @@ fn criterion_benchmark(c: &mut Criterion) { }); c.bench_function("regexp_count_1000 utf8view", |b| { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let data = cast(&data(&mut rng), &DataType::Utf8View).unwrap(); let regex = cast(®ex(&mut rng), &DataType::Utf8View).unwrap(); let start = Arc::new(start(&mut rng)) as ArrayRef; @@ -128,7 +128,7 @@ fn criterion_benchmark(c: &mut Criterion) { }); c.bench_function("regexp_like_1000", |b| { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let data = Arc::new(data(&mut rng)) as ArrayRef; let regex = Arc::new(regex(&mut rng)) as ArrayRef; let flags = Arc::new(flags(&mut rng)) as ArrayRef; @@ -142,7 +142,7 @@ fn criterion_benchmark(c: &mut Criterion) { }); c.bench_function("regexp_like_1000 utf8view", |b| { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let data = cast(&data(&mut rng), &DataType::Utf8View).unwrap(); let regex = cast(®ex(&mut rng), &DataType::Utf8View).unwrap(); let flags = cast(&flags(&mut rng), &DataType::Utf8View).unwrap(); @@ -156,7 +156,7 @@ fn criterion_benchmark(c: &mut Criterion) { }); c.bench_function("regexp_match_1000", |b| { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let data = Arc::new(data(&mut rng)) as ArrayRef; let regex = Arc::new(regex(&mut rng)) as ArrayRef; let flags = Arc::new(flags(&mut rng)) as ArrayRef; @@ -174,7 +174,7 @@ fn criterion_benchmark(c: &mut Criterion) { }); c.bench_function("regexp_match_1000 utf8view", |b| { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let data = cast(&data(&mut rng), &DataType::Utf8View).unwrap(); let regex = cast(®ex(&mut rng), &DataType::Utf8View).unwrap(); let flags = cast(&flags(&mut rng), &DataType::Utf8View).unwrap(); @@ -192,7 +192,7 @@ fn criterion_benchmark(c: &mut Criterion) { }); c.bench_function("regexp_replace_1000", |b| { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let data = Arc::new(data(&mut rng)) as ArrayRef; let regex = Arc::new(regex(&mut rng)) as ArrayRef; let flags = Arc::new(flags(&mut rng)) as ArrayRef; @@ -214,7 +214,7 @@ fn criterion_benchmark(c: &mut Criterion) { }); c.bench_function("regexp_replace_1000 utf8view", |b| { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let data = cast(&data(&mut rng), &DataType::Utf8View).unwrap(); let regex = cast(®ex(&mut rng), &DataType::Utf8View).unwrap(); // flags are not allowed to be utf8view according to the function diff --git a/datafusion/functions/benches/repeat.rs b/datafusion/functions/benches/repeat.rs index 5cc6a177d9d9..175933f5f745 100644 --- a/datafusion/functions/benches/repeat.rs +++ b/datafusion/functions/benches/repeat.rs @@ -18,11 +18,12 @@ extern crate criterion; use arrow::array::{ArrayRef, Int64Array, OffsetSizeTrait}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::{ create_string_array_with_len, create_string_view_array_with_len, }; use criterion::{black_box, criterion_group, criterion_main, Criterion, SamplingMode}; +use datafusion_common::DataFusionError; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::string; use std::sync::Arc; @@ -56,66 +57,62 @@ fn create_args( } } +fn invoke_repeat_with_args( + args: Vec, + repeat_times: i64, +) -> Result { + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true).into()) + .collect::>(); + + string::repeat().invoke_with_args(ScalarFunctionArgs { + args, + arg_fields, + number_rows: repeat_times as usize, + return_field: Field::new("f", DataType::Utf8, true).into(), + }) +} + fn criterion_benchmark(c: &mut Criterion) { - let repeat = string::repeat(); for size in [1024, 4096] { // REPEAT 3 TIMES let repeat_times = 3; - let mut group = c.benchmark_group(format!("repeat {} times", repeat_times)); + let mut group = c.benchmark_group(format!("repeat {repeat_times} times")); group.sampling_mode(SamplingMode::Flat); group.sample_size(10); group.measurement_time(Duration::from_secs(10)); let args = create_args::(size, 32, repeat_times, true); group.bench_function( - format!( - "repeat_string_view [size={}, repeat_times={}]", - size, repeat_times - ), + format!("repeat_string_view [size={size}, repeat_times={repeat_times}]"), |b| { b.iter(|| { let args_cloned = args.clone(); - black_box(repeat.invoke_with_args(ScalarFunctionArgs { - args: args_cloned, - number_rows: repeat_times as usize, - return_type: &DataType::Utf8, - })) + black_box(invoke_repeat_with_args(args_cloned, repeat_times)) }) }, ); let args = create_args::(size, 32, repeat_times, false); group.bench_function( - format!( - "repeat_string [size={}, repeat_times={}]", - size, repeat_times - ), + format!("repeat_string [size={size}, repeat_times={repeat_times}]"), |b| { b.iter(|| { let args_cloned = args.clone(); - black_box(repeat.invoke_with_args(ScalarFunctionArgs { - args: args_cloned, - number_rows: repeat_times as usize, - return_type: &DataType::Utf8, - })) + black_box(invoke_repeat_with_args(args_cloned, repeat_times)) }) }, ); let args = create_args::(size, 32, repeat_times, false); group.bench_function( - format!( - "repeat_large_string [size={}, repeat_times={}]", - size, repeat_times - ), + format!("repeat_large_string [size={size}, repeat_times={repeat_times}]"), |b| { b.iter(|| { let args_cloned = args.clone(); - black_box(repeat.invoke_with_args(ScalarFunctionArgs { - args: args_cloned, - number_rows: repeat_times as usize, - return_type: &DataType::Utf8, - })) + black_box(invoke_repeat_with_args(args_cloned, repeat_times)) }) }, ); @@ -124,61 +121,40 @@ fn criterion_benchmark(c: &mut Criterion) { // REPEAT 30 TIMES let repeat_times = 30; - let mut group = c.benchmark_group(format!("repeat {} times", repeat_times)); + let mut group = c.benchmark_group(format!("repeat {repeat_times} times")); group.sampling_mode(SamplingMode::Flat); group.sample_size(10); group.measurement_time(Duration::from_secs(10)); let args = create_args::(size, 32, repeat_times, true); group.bench_function( - format!( - "repeat_string_view [size={}, repeat_times={}]", - size, repeat_times - ), + format!("repeat_string_view [size={size}, repeat_times={repeat_times}]"), |b| { b.iter(|| { let args_cloned = args.clone(); - black_box(repeat.invoke_with_args(ScalarFunctionArgs { - args: args_cloned, - number_rows: repeat_times as usize, - return_type: &DataType::Utf8, - })) + black_box(invoke_repeat_with_args(args_cloned, repeat_times)) }) }, ); let args = create_args::(size, 32, repeat_times, false); group.bench_function( - format!( - "repeat_string [size={}, repeat_times={}]", - size, repeat_times - ), + format!("repeat_string [size={size}, repeat_times={repeat_times}]"), |b| { b.iter(|| { let args_cloned = args.clone(); - black_box(repeat.invoke_with_args(ScalarFunctionArgs { - args: args_cloned, - number_rows: repeat_times as usize, - return_type: &DataType::Utf8, - })) + black_box(invoke_repeat_with_args(args_cloned, repeat_times)) }) }, ); let args = create_args::(size, 32, repeat_times, false); group.bench_function( - format!( - "repeat_large_string [size={}, repeat_times={}]", - size, repeat_times - ), + format!("repeat_large_string [size={size}, repeat_times={repeat_times}]"), |b| { b.iter(|| { let args_cloned = args.clone(); - black_box(repeat.invoke_with_args(ScalarFunctionArgs { - args: args_cloned, - number_rows: repeat_times as usize, - return_type: &DataType::Utf8, - })) + black_box(invoke_repeat_with_args(args_cloned, repeat_times)) }) }, ); @@ -187,25 +163,18 @@ fn criterion_benchmark(c: &mut Criterion) { // REPEAT overflow let repeat_times = 1073741824; - let mut group = c.benchmark_group(format!("repeat {} times", repeat_times)); + let mut group = c.benchmark_group(format!("repeat {repeat_times} times")); group.sampling_mode(SamplingMode::Flat); group.sample_size(10); group.measurement_time(Duration::from_secs(10)); let args = create_args::(size, 2, repeat_times, false); group.bench_function( - format!( - "repeat_string overflow [size={}, repeat_times={}]", - size, repeat_times - ), + format!("repeat_string overflow [size={size}, repeat_times={repeat_times}]"), |b| { b.iter(|| { let args_cloned = args.clone(); - black_box(repeat.invoke_with_args(ScalarFunctionArgs { - args: args_cloned, - number_rows: repeat_times as usize, - return_type: &DataType::Utf8, - })) + black_box(invoke_repeat_with_args(args_cloned, repeat_times)) }) }, ); diff --git a/datafusion/functions/benches/reverse.rs b/datafusion/functions/benches/reverse.rs index d61f8fb80517..640366011305 100644 --- a/datafusion/functions/benches/reverse.rs +++ b/datafusion/functions/benches/reverse.rs @@ -18,7 +18,7 @@ extern crate criterion; mod helper; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_expr::ScalarFunctionArgs; use helper::gen_string_array; @@ -41,13 +41,18 @@ fn criterion_benchmark(c: &mut Criterion) { false, ); c.bench_function( - &format!("reverse_StringArray_ascii_str_len_{}", str_len), + &format!("reverse_StringArray_ascii_str_len_{str_len}"), |b| { b.iter(|| { black_box(reverse.invoke_with_args(ScalarFunctionArgs { args: args_string_ascii.clone(), + arg_fields: vec![Field::new( + "a", + args_string_ascii[0].data_type(), + true, + ).into()], number_rows: N_ROWS, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), })) }) }, @@ -58,15 +63,17 @@ fn criterion_benchmark(c: &mut Criterion) { gen_string_array(N_ROWS, str_len, NULL_DENSITY, NORMAL_UTF8_DENSITY, false); c.bench_function( &format!( - "reverse_StringArray_utf8_density_{}_str_len_{}", - NORMAL_UTF8_DENSITY, str_len + "reverse_StringArray_utf8_density_{NORMAL_UTF8_DENSITY}_str_len_{str_len}" ), |b| { b.iter(|| { black_box(reverse.invoke_with_args(ScalarFunctionArgs { args: args_string_utf8.clone(), + arg_fields: vec![ + Field::new("a", args_string_utf8[0].data_type(), true).into(), + ], number_rows: N_ROWS, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), })) }) }, @@ -81,13 +88,18 @@ fn criterion_benchmark(c: &mut Criterion) { true, ); c.bench_function( - &format!("reverse_StringViewArray_ascii_str_len_{}", str_len), + &format!("reverse_StringViewArray_ascii_str_len_{str_len}"), |b| { b.iter(|| { black_box(reverse.invoke_with_args(ScalarFunctionArgs { args: args_string_view_ascii.clone(), + arg_fields: vec![Field::new( + "a", + args_string_view_ascii[0].data_type(), + true, + ).into()], number_rows: N_ROWS, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), })) }) }, @@ -98,15 +110,19 @@ fn criterion_benchmark(c: &mut Criterion) { gen_string_array(N_ROWS, str_len, NULL_DENSITY, NORMAL_UTF8_DENSITY, true); c.bench_function( &format!( - "reverse_StringViewArray_utf8_density_{}_str_len_{}", - NORMAL_UTF8_DENSITY, str_len + "reverse_StringViewArray_utf8_density_{NORMAL_UTF8_DENSITY}_str_len_{str_len}" ), |b| { b.iter(|| { black_box(reverse.invoke_with_args(ScalarFunctionArgs { args: args_string_view_utf8.clone(), + arg_fields: vec![Field::new( + "a", + args_string_view_utf8[0].data_type(), + true, + ).into()], number_rows: N_ROWS, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), })) }) }, diff --git a/datafusion/functions/benches/signum.rs b/datafusion/functions/benches/signum.rs index 01939fad5f34..10079bcc81c7 100644 --- a/datafusion/functions/benches/signum.rs +++ b/datafusion/functions/benches/signum.rs @@ -19,7 +19,7 @@ extern crate criterion; use arrow::datatypes::DataType; use arrow::{ - datatypes::{Float32Type, Float64Type}, + datatypes::{Field, Float32Type, Float64Type}, util::bench_util::create_primitive_array, }; use criterion::{black_box, criterion_group, criterion_main, Criterion}; @@ -33,14 +33,24 @@ fn criterion_benchmark(c: &mut Criterion) { let f32_array = Arc::new(create_primitive_array::(size, 0.2)); let batch_len = f32_array.len(); let f32_args = vec![ColumnarValue::Array(f32_array)]; - c.bench_function(&format!("signum f32 array: {}", size), |b| { + let arg_fields = f32_args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + let return_field = Field::new("f", DataType::Float32, true).into(); + + c.bench_function(&format!("signum f32 array: {size}"), |b| { b.iter(|| { black_box( signum .invoke_with_args(ScalarFunctionArgs { args: f32_args.clone(), + arg_fields: arg_fields.clone(), number_rows: batch_len, - return_type: &DataType::Float32, + return_field: Arc::clone(&return_field), }) .unwrap(), ) @@ -50,14 +60,24 @@ fn criterion_benchmark(c: &mut Criterion) { let batch_len = f64_array.len(); let f64_args = vec![ColumnarValue::Array(f64_array)]; - c.bench_function(&format!("signum f64 array: {}", size), |b| { + let arg_fields = f64_args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + let return_field = Field::new("f", DataType::Float64, true).into(); + + c.bench_function(&format!("signum f64 array: {size}"), |b| { b.iter(|| { black_box( signum .invoke_with_args(ScalarFunctionArgs { args: f64_args.clone(), + arg_fields: arg_fields.clone(), number_rows: batch_len, - return_type: &DataType::Float64, + return_field: Arc::clone(&return_field), }) .unwrap(), ) diff --git a/datafusion/functions/benches/strpos.rs b/datafusion/functions/benches/strpos.rs index df57c229e0ad..df32db1182f1 100644 --- a/datafusion/functions/benches/strpos.rs +++ b/datafusion/functions/benches/strpos.rs @@ -18,10 +18,10 @@ extern crate criterion; use arrow::array::{StringArray, StringViewArray}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; -use rand::distributions::Alphanumeric; +use rand::distr::Alphanumeric; use rand::prelude::StdRng; use rand::{Rng, SeedableRng}; use std::str::Chars; @@ -46,7 +46,7 @@ fn gen_string_array( let mut output_string_vec: Vec> = Vec::with_capacity(n_rows); let mut output_sub_string_vec: Vec> = Vec::with_capacity(n_rows); for _ in 0..n_rows { - let rand_num = rng_ref.gen::(); // [0.0, 1.0) + let rand_num = rng_ref.random::(); // [0.0, 1.0) if rand_num < null_density { output_sub_string_vec.push(None); output_string_vec.push(None); @@ -54,7 +54,7 @@ fn gen_string_array( // Generate random UTF8 string let mut generated_string = String::with_capacity(str_len_chars); for _ in 0..str_len_chars { - let idx = rng_ref.gen_range(0..corpus_char_count); + let idx = rng_ref.random_range(0..corpus_char_count); let char = utf8.chars().nth(idx).unwrap(); generated_string.push(char); } @@ -94,8 +94,8 @@ fn random_substring(chars: Chars) -> String { // get the substring of a random length from the input string by byte unit let mut rng = StdRng::seed_from_u64(44); let count = chars.clone().count(); - let start = rng.gen_range(0..count - 1); - let end = rng.gen_range(start + 1..count); + let start = rng.random_range(0..count - 1); + let end = rng.random_range(start + 1..count); chars .enumerate() .filter(|(i, _)| *i >= start && *i < end) @@ -111,14 +111,18 @@ fn criterion_benchmark(c: &mut Criterion) { for str_len in [8, 32, 128, 4096] { // StringArray ASCII only let args_string_ascii = gen_string_array(n_rows, str_len, 0.1, 0.0, false); + let arg_fields = + vec![Field::new("a", args_string_ascii[0].data_type(), true).into()]; + let return_field = Field::new("f", DataType::Int32, true).into(); c.bench_function( - &format!("strpos_StringArray_ascii_str_len_{}", str_len), + &format!("strpos_StringArray_ascii_str_len_{str_len}"), |b| { b.iter(|| { black_box(strpos.invoke_with_args(ScalarFunctionArgs { args: args_string_ascii.clone(), + arg_fields: arg_fields.clone(), number_rows: n_rows, - return_type: &DataType::Int32, + return_field: Arc::clone(&return_field), })) }) }, @@ -126,29 +130,34 @@ fn criterion_benchmark(c: &mut Criterion) { // StringArray UTF8 let args_string_utf8 = gen_string_array(n_rows, str_len, 0.1, 0.5, false); - c.bench_function( - &format!("strpos_StringArray_utf8_str_len_{}", str_len), - |b| { - b.iter(|| { - black_box(strpos.invoke_with_args(ScalarFunctionArgs { - args: args_string_utf8.clone(), - number_rows: n_rows, - return_type: &DataType::Int32, - })) - }) - }, - ); + let arg_fields = + vec![Field::new("a", args_string_utf8[0].data_type(), true).into()]; + let return_field = Field::new("f", DataType::Int32, true).into(); + c.bench_function(&format!("strpos_StringArray_utf8_str_len_{str_len}"), |b| { + b.iter(|| { + black_box(strpos.invoke_with_args(ScalarFunctionArgs { + args: args_string_utf8.clone(), + arg_fields: arg_fields.clone(), + number_rows: n_rows, + return_field: Arc::clone(&return_field), + })) + }) + }); // StringViewArray ASCII only let args_string_view_ascii = gen_string_array(n_rows, str_len, 0.1, 0.0, true); + let arg_fields = + vec![Field::new("a", args_string_view_ascii[0].data_type(), true).into()]; + let return_field = Field::new("f", DataType::Int32, true).into(); c.bench_function( - &format!("strpos_StringViewArray_ascii_str_len_{}", str_len), + &format!("strpos_StringViewArray_ascii_str_len_{str_len}"), |b| { b.iter(|| { black_box(strpos.invoke_with_args(ScalarFunctionArgs { args: args_string_view_ascii.clone(), + arg_fields: arg_fields.clone(), number_rows: n_rows, - return_type: &DataType::Int32, + return_field: Arc::clone(&return_field), })) }) }, @@ -156,14 +165,18 @@ fn criterion_benchmark(c: &mut Criterion) { // StringViewArray UTF8 let args_string_view_utf8 = gen_string_array(n_rows, str_len, 0.1, 0.5, true); + let arg_fields = + vec![Field::new("a", args_string_view_utf8[0].data_type(), true).into()]; + let return_field = Field::new("f", DataType::Int32, true).into(); c.bench_function( - &format!("strpos_StringViewArray_utf8_str_len_{}", str_len), + &format!("strpos_StringViewArray_utf8_str_len_{str_len}"), |b| { b.iter(|| { black_box(strpos.invoke_with_args(ScalarFunctionArgs { args: args_string_view_utf8.clone(), + arg_fields: arg_fields.clone(), number_rows: n_rows, - return_type: &DataType::Int32, + return_field: Arc::clone(&return_field), })) }) }, diff --git a/datafusion/functions/benches/substr.rs b/datafusion/functions/benches/substr.rs index 80ab70ef71b0..342e18b0d9a2 100644 --- a/datafusion/functions/benches/substr.rs +++ b/datafusion/functions/benches/substr.rs @@ -18,11 +18,12 @@ extern crate criterion; use arrow::array::{ArrayRef, Int64Array, OffsetSizeTrait}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::{ create_string_array_with_len, create_string_view_array_with_len, }; use criterion::{black_box, criterion_group, criterion_main, Criterion, SamplingMode}; +use datafusion_common::DataFusionError; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::unicode; use std::sync::Arc; @@ -96,8 +97,25 @@ fn create_args_with_count( } } +fn invoke_substr_with_args( + args: Vec, + number_rows: usize, +) -> Result { + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true).into()) + .collect::>(); + + unicode::substr().invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields, + number_rows, + return_field: Field::new("f", DataType::Utf8View, true).into(), + }) +} + fn criterion_benchmark(c: &mut Criterion) { - let substr = unicode::substr(); for size in [1024, 4096] { // string_len = 12, substring_len=6 (see `create_args_without_count`) let len = 12; @@ -107,44 +125,19 @@ fn criterion_benchmark(c: &mut Criterion) { let args = create_args_without_count::(size, len, true, true); group.bench_function( - format!("substr_string_view [size={}, strlen={}]", size, len), - |b| { - b.iter(|| { - black_box(substr.invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - number_rows: size, - return_type: &DataType::Utf8View, - })) - }) - }, + format!("substr_string_view [size={size}, strlen={len}]"), + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), ); let args = create_args_without_count::(size, len, false, false); - group.bench_function( - format!("substr_string [size={}, strlen={}]", size, len), - |b| { - b.iter(|| { - black_box(substr.invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - number_rows: size, - return_type: &DataType::Utf8View, - })) - }) - }, - ); + group.bench_function(format!("substr_string [size={size}, strlen={len}]"), |b| { + b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))) + }); let args = create_args_without_count::(size, len, true, false); group.bench_function( - format!("substr_large_string [size={}, strlen={}]", size, len), - |b| { - b.iter(|| { - black_box(substr.invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - number_rows: size, - return_type: &DataType::Utf8View, - })) - }) - }, + format!("substr_large_string [size={size}, strlen={len}]"), + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), ); group.finish(); @@ -158,53 +151,20 @@ fn criterion_benchmark(c: &mut Criterion) { let args = create_args_with_count::(size, len, count, true); group.bench_function( - format!( - "substr_string_view [size={}, count={}, strlen={}]", - size, count, len, - ), - |b| { - b.iter(|| { - black_box(substr.invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - number_rows: size, - return_type: &DataType::Utf8View, - })) - }) - }, + format!("substr_string_view [size={size}, count={count}, strlen={len}]",), + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), ); let args = create_args_with_count::(size, len, count, false); group.bench_function( - format!( - "substr_string [size={}, count={}, strlen={}]", - size, count, len, - ), - |b| { - b.iter(|| { - black_box(substr.invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - number_rows: size, - return_type: &DataType::Utf8View, - })) - }) - }, + format!("substr_string [size={size}, count={count}, strlen={len}]",), + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), ); let args = create_args_with_count::(size, len, count, false); group.bench_function( - format!( - "substr_large_string [size={}, count={}, strlen={}]", - size, count, len, - ), - |b| { - b.iter(|| { - black_box(substr.invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - number_rows: size, - return_type: &DataType::Utf8View, - })) - }) - }, + format!("substr_large_string [size={size}, count={count}, strlen={len}]",), + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), ); group.finish(); @@ -218,53 +178,20 @@ fn criterion_benchmark(c: &mut Criterion) { let args = create_args_with_count::(size, len, count, true); group.bench_function( - format!( - "substr_string_view [size={}, count={}, strlen={}]", - size, count, len, - ), - |b| { - b.iter(|| { - black_box(substr.invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - number_rows: size, - return_type: &DataType::Utf8View, - })) - }) - }, + format!("substr_string_view [size={size}, count={count}, strlen={len}]",), + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), ); let args = create_args_with_count::(size, len, count, false); group.bench_function( - format!( - "substr_string [size={}, count={}, strlen={}]", - size, count, len, - ), - |b| { - b.iter(|| { - black_box(substr.invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - number_rows: size, - return_type: &DataType::Utf8View, - })) - }) - }, + format!("substr_string [size={size}, count={count}, strlen={len}]",), + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), ); let args = create_args_with_count::(size, len, count, false); group.bench_function( - format!( - "substr_large_string [size={}, count={}, strlen={}]", - size, count, len, - ), - |b| { - b.iter(|| { - black_box(substr.invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - number_rows: size, - return_type: &DataType::Utf8View, - })) - }) - }, + format!("substr_large_string [size={size}, count={count}, strlen={len}]",), + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), ); group.finish(); diff --git a/datafusion/functions/benches/substr_index.rs b/datafusion/functions/benches/substr_index.rs index b1c1c3c34a95..e772fb38fc40 100644 --- a/datafusion/functions/benches/substr_index.rs +++ b/datafusion/functions/benches/substr_index.rs @@ -20,9 +20,9 @@ extern crate criterion; use std::sync::Arc; use arrow::array::{ArrayRef, Int64Array, StringArray}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; -use rand::distributions::{Alphanumeric, Uniform}; +use rand::distr::{Alphanumeric, Uniform}; use rand::prelude::Distribution; use rand::Rng; @@ -54,21 +54,21 @@ fn data() -> (StringArray, StringArray, Int64Array) { dist: Uniform::new(-4, 5), test: |x: &i64| x != &0, }; - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let mut strings: Vec = vec![]; let mut delimiters: Vec = vec![]; let mut counts: Vec = vec![]; for _ in 0..1000 { - let length = rng.gen_range(20..50); + let length = rng.random_range(20..50); let text: String = (&mut rng) .sample_iter(&Alphanumeric) .take(length) .map(char::from) .collect(); - let char = rng.gen_range(0..text.len()); + let char = rng.random_range(0..text.len()); let delimiter = &text.chars().nth(char).unwrap(); - let count = rng.sample(&dist); + let count = rng.sample(dist.dist.unwrap()); strings.push(text); delimiters.push(delimiter.to_string()); @@ -91,13 +91,22 @@ fn criterion_benchmark(c: &mut Criterion) { let counts = ColumnarValue::Array(Arc::new(counts) as ArrayRef); let args = vec![strings, delimiters, counts]; + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + b.iter(|| { black_box( substr_index() .invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_fields: arg_fields.clone(), number_rows: batch_len, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), }) .expect("substr_index should work on valid values"), ) diff --git a/datafusion/functions/benches/to_char.rs b/datafusion/functions/benches/to_char.rs index 6f20a20dc219..d19714ce6166 100644 --- a/datafusion/functions/benches/to_char.rs +++ b/datafusion/functions/benches/to_char.rs @@ -20,12 +20,12 @@ extern crate criterion; use std::sync::Arc; use arrow::array::{ArrayRef, Date32Array, StringArray}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use chrono::prelude::*; use chrono::TimeDelta; use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use rand::prelude::IndexedRandom; use rand::rngs::ThreadRng; -use rand::seq::SliceRandom; use rand::Rng; use datafusion_common::ScalarValue; @@ -39,7 +39,7 @@ fn random_date_in_range( end_date: NaiveDate, ) -> NaiveDate { let days_in_range = (end_date - start_date).num_days(); - let random_days: i64 = rng.gen_range(0..days_in_range); + let random_days: i64 = rng.random_range(0..days_in_range); start_date + TimeDelta::try_days(random_days).unwrap() } @@ -82,7 +82,7 @@ fn patterns(rng: &mut ThreadRng) -> StringArray { fn criterion_benchmark(c: &mut Criterion) { c.bench_function("to_char_array_array_1000", |b| { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let data_arr = data(&mut rng); let batch_len = data_arr.len(); let data = ColumnarValue::Array(Arc::new(data_arr) as ArrayRef); @@ -93,8 +93,12 @@ fn criterion_benchmark(c: &mut Criterion) { to_char() .invoke_with_args(ScalarFunctionArgs { args: vec![data.clone(), patterns.clone()], + arg_fields: vec![ + Field::new("a", data.data_type(), true).into(), + Field::new("b", patterns.data_type(), true).into(), + ], number_rows: batch_len, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), }) .expect("to_char should work on valid values"), ) @@ -102,7 +106,7 @@ fn criterion_benchmark(c: &mut Criterion) { }); c.bench_function("to_char_array_scalar_1000", |b| { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let data_arr = data(&mut rng); let batch_len = data_arr.len(); let data = ColumnarValue::Array(Arc::new(data_arr) as ArrayRef); @@ -114,8 +118,12 @@ fn criterion_benchmark(c: &mut Criterion) { to_char() .invoke_with_args(ScalarFunctionArgs { args: vec![data.clone(), patterns.clone()], + arg_fields: vec![ + Field::new("a", data.data_type(), true).into(), + Field::new("b", patterns.data_type(), true).into(), + ], number_rows: batch_len, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), }) .expect("to_char should work on valid values"), ) @@ -141,8 +149,12 @@ fn criterion_benchmark(c: &mut Criterion) { to_char() .invoke_with_args(ScalarFunctionArgs { args: vec![data.clone(), pattern.clone()], + arg_fields: vec![ + Field::new("a", data.data_type(), true).into(), + Field::new("b", pattern.data_type(), true).into(), + ], number_rows: 1, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), }) .expect("to_char should work on valid values"), ) diff --git a/datafusion/functions/benches/to_hex.rs b/datafusion/functions/benches/to_hex.rs index a45d936c0a52..4a02b74ca42d 100644 --- a/datafusion/functions/benches/to_hex.rs +++ b/datafusion/functions/benches/to_hex.rs @@ -17,7 +17,7 @@ extern crate criterion; -use arrow::datatypes::{DataType, Int32Type, Int64Type}; +use arrow::datatypes::{DataType, Field, Int32Type, Int64Type}; use arrow::util::bench_util::create_primitive_array; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; @@ -30,14 +30,15 @@ fn criterion_benchmark(c: &mut Criterion) { let i32_array = Arc::new(create_primitive_array::(size, 0.2)); let batch_len = i32_array.len(); let i32_args = vec![ColumnarValue::Array(i32_array)]; - c.bench_function(&format!("to_hex i32 array: {}", size), |b| { + c.bench_function(&format!("to_hex i32 array: {size}"), |b| { b.iter(|| { let args_cloned = i32_args.clone(); black_box( hex.invoke_with_args(ScalarFunctionArgs { args: args_cloned, + arg_fields: vec![Field::new("a", DataType::Int32, false).into()], number_rows: batch_len, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), }) .unwrap(), ) @@ -46,14 +47,15 @@ fn criterion_benchmark(c: &mut Criterion) { let i64_array = Arc::new(create_primitive_array::(size, 0.2)); let batch_len = i64_array.len(); let i64_args = vec![ColumnarValue::Array(i64_array)]; - c.bench_function(&format!("to_hex i64 array: {}", size), |b| { + c.bench_function(&format!("to_hex i64 array: {size}"), |b| { b.iter(|| { let args_cloned = i64_args.clone(); black_box( hex.invoke_with_args(ScalarFunctionArgs { args: args_cloned, + arg_fields: vec![Field::new("a", DataType::Int64, false).into()], number_rows: batch_len, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), }) .unwrap(), ) diff --git a/datafusion/functions/benches/to_timestamp.rs b/datafusion/functions/benches/to_timestamp.rs index aec56697691f..d89811348489 100644 --- a/datafusion/functions/benches/to_timestamp.rs +++ b/datafusion/functions/benches/to_timestamp.rs @@ -22,7 +22,7 @@ use std::sync::Arc; use arrow::array::builder::StringBuilder; use arrow::array::{Array, ArrayRef, StringArray}; use arrow::compute::cast; -use arrow::datatypes::{DataType, TimeUnit}; +use arrow::datatypes::{DataType, Field, TimeUnit}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; @@ -109,7 +109,10 @@ fn data_with_formats() -> (StringArray, StringArray, StringArray, StringArray) { ) } fn criterion_benchmark(c: &mut Criterion) { - let return_type = &DataType::Timestamp(TimeUnit::Nanosecond, None); + let return_field = + Field::new("f", DataType::Timestamp(TimeUnit::Nanosecond, None), true).into(); + let arg_field = Field::new("a", DataType::Utf8, false).into(); + let arg_fields = vec![arg_field]; c.bench_function("to_timestamp_no_formats_utf8", |b| { let arr_data = data(); let batch_len = arr_data.len(); @@ -120,8 +123,9 @@ fn criterion_benchmark(c: &mut Criterion) { to_timestamp() .invoke_with_args(ScalarFunctionArgs { args: vec![string_array.clone()], + arg_fields: arg_fields.clone(), number_rows: batch_len, - return_type, + return_field: Arc::clone(&return_field), }) .expect("to_timestamp should work on valid values"), ) @@ -138,8 +142,9 @@ fn criterion_benchmark(c: &mut Criterion) { to_timestamp() .invoke_with_args(ScalarFunctionArgs { args: vec![string_array.clone()], + arg_fields: arg_fields.clone(), number_rows: batch_len, - return_type, + return_field: Arc::clone(&return_field), }) .expect("to_timestamp should work on valid values"), ) @@ -156,8 +161,9 @@ fn criterion_benchmark(c: &mut Criterion) { to_timestamp() .invoke_with_args(ScalarFunctionArgs { args: vec![string_array.clone()], + arg_fields: arg_fields.clone(), number_rows: batch_len, - return_type, + return_field: Arc::clone(&return_field), }) .expect("to_timestamp should work on valid values"), ) @@ -174,13 +180,22 @@ fn criterion_benchmark(c: &mut Criterion) { ColumnarValue::Array(Arc::new(format2) as ArrayRef), ColumnarValue::Array(Arc::new(format3) as ArrayRef), ]; + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + b.iter(|| { black_box( to_timestamp() .invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_fields: arg_fields.clone(), number_rows: batch_len, - return_type, + return_field: Arc::clone(&return_field), }) .expect("to_timestamp should work on valid values"), ) @@ -205,13 +220,22 @@ fn criterion_benchmark(c: &mut Criterion) { Arc::new(cast(&format3, &DataType::LargeUtf8).unwrap()) as ArrayRef ), ]; + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + b.iter(|| { black_box( to_timestamp() .invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_fields: arg_fields.clone(), number_rows: batch_len, - return_type, + return_field: Arc::clone(&return_field), }) .expect("to_timestamp should work on valid values"), ) @@ -237,13 +261,22 @@ fn criterion_benchmark(c: &mut Criterion) { Arc::new(cast(&format3, &DataType::Utf8View).unwrap()) as ArrayRef ), ]; + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + b.iter(|| { black_box( to_timestamp() .invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_fields: arg_fields.clone(), number_rows: batch_len, - return_type, + return_field: Arc::clone(&return_field), }) .expect("to_timestamp should work on valid values"), ) diff --git a/datafusion/functions/benches/trunc.rs b/datafusion/functions/benches/trunc.rs index 7fc93921d2e7..897e21c1e1d9 100644 --- a/datafusion/functions/benches/trunc.rs +++ b/datafusion/functions/benches/trunc.rs @@ -18,7 +18,7 @@ extern crate criterion; use arrow::{ - datatypes::{Float32Type, Float64Type}, + datatypes::{Field, Float32Type, Float64Type}, util::bench_util::create_primitive_array, }; use criterion::{black_box, criterion_group, criterion_main, Criterion}; @@ -33,14 +33,17 @@ fn criterion_benchmark(c: &mut Criterion) { for size in [1024, 4096, 8192] { let f32_array = Arc::new(create_primitive_array::(size, 0.2)); let f32_args = vec![ColumnarValue::Array(f32_array)]; - c.bench_function(&format!("trunc f32 array: {}", size), |b| { + let arg_fields = vec![Field::new("a", DataType::Float32, false).into()]; + let return_field = Field::new("f", DataType::Float32, true).into(); + c.bench_function(&format!("trunc f32 array: {size}"), |b| { b.iter(|| { black_box( trunc .invoke_with_args(ScalarFunctionArgs { args: f32_args.clone(), + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Float32, + return_field: Arc::clone(&return_field), }) .unwrap(), ) @@ -48,14 +51,17 @@ fn criterion_benchmark(c: &mut Criterion) { }); let f64_array = Arc::new(create_primitive_array::(size, 0.2)); let f64_args = vec![ColumnarValue::Array(f64_array)]; - c.bench_function(&format!("trunc f64 array: {}", size), |b| { + let arg_fields = vec![Field::new("a", DataType::Float64, true).into()]; + let return_field = Field::new("f", DataType::Float64, true).into(); + c.bench_function(&format!("trunc f64 array: {size}"), |b| { b.iter(|| { black_box( trunc .invoke_with_args(ScalarFunctionArgs { args: f64_args.clone(), + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Float64, + return_field: Arc::clone(&return_field), }) .unwrap(), ) diff --git a/datafusion/functions/benches/upper.rs b/datafusion/functions/benches/upper.rs index f0bee89c7d37..bf2c4161001e 100644 --- a/datafusion/functions/benches/upper.rs +++ b/datafusion/functions/benches/upper.rs @@ -17,7 +17,7 @@ extern crate criterion; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::create_string_array_with_len; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; @@ -42,8 +42,9 @@ fn criterion_benchmark(c: &mut Criterion) { let args_cloned = args.clone(); black_box(upper.invoke_with_args(ScalarFunctionArgs { args: args_cloned, + arg_fields: vec![Field::new("a", DataType::Utf8, true).into()], number_rows: size, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), })) }) }); diff --git a/datafusion/functions/benches/uuid.rs b/datafusion/functions/benches/uuid.rs index 7b8d156fec21..942af122562a 100644 --- a/datafusion/functions/benches/uuid.rs +++ b/datafusion/functions/benches/uuid.rs @@ -17,7 +17,7 @@ extern crate criterion; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_expr::ScalarFunctionArgs; use datafusion_functions::string; @@ -28,8 +28,9 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(uuid.invoke_with_args(ScalarFunctionArgs { args: vec![], + arg_fields: vec![], number_rows: 1024, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), })) }) }); diff --git a/datafusion/functions/src/core/arrow_cast.rs b/datafusion/functions/src/core/arrow_cast.rs index 2686dbf8be3c..e9dee09e74bf 100644 --- a/datafusion/functions/src/core/arrow_cast.rs +++ b/datafusion/functions/src/core/arrow_cast.rs @@ -17,7 +17,7 @@ //! [`ArrowCastFunc`]: Implementation of the `arrow_cast` -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field, FieldRef}; use arrow::error::ArrowError; use datafusion_common::{ arrow_datafusion_err, exec_err, internal_err, Result, ScalarValue, @@ -29,7 +29,7 @@ use std::any::Any; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::{ - ColumnarValue, Documentation, Expr, ReturnInfo, ReturnTypeArgs, ScalarFunctionArgs, + ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, }; use datafusion_macros::user_doc; @@ -113,11 +113,11 @@ impl ScalarUDFImpl for ArrowCastFunc { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - internal_err!("return_type_from_args should be called instead") + internal_err!("return_field_from_args should be called instead") } - fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { - let nullable = args.nullables.iter().any(|&nullable| nullable); + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); let [_, type_arg] = take_function_args(self.name(), args.scalar_arguments)?; @@ -131,7 +131,7 @@ impl ScalarUDFImpl for ArrowCastFunc { ) }, |casted_type| match casted_type.parse::() { - Ok(data_type) => Ok(ReturnInfo::new(data_type, nullable)), + Ok(data_type) => Ok(Field::new(self.name(), data_type, nullable).into()), Err(ArrowError::ParseError(e)) => Err(exec_datafusion_err!("{e}")), Err(e) => Err(arrow_datafusion_err!(e)), }, @@ -177,7 +177,7 @@ impl ScalarUDFImpl for ArrowCastFunc { fn data_type_from_args(args: &[Expr]) -> Result { let [_, type_arg] = take_function_args("arrow_cast", args)?; - let Expr::Literal(ScalarValue::Utf8(Some(val))) = type_arg else { + let Expr::Literal(ScalarValue::Utf8(Some(val)), _) = type_arg else { return exec_err!( "arrow_cast requires its second argument to be a constant string, got {:?}", type_arg diff --git a/datafusion/functions/src/core/coalesce.rs b/datafusion/functions/src/core/coalesce.rs index ba20c23828eb..12a4bef24739 100644 --- a/datafusion/functions/src/core/coalesce.rs +++ b/datafusion/functions/src/core/coalesce.rs @@ -18,11 +18,11 @@ use arrow::array::{new_null_array, BooleanArray}; use arrow::compute::kernels::zip::zip; use arrow::compute::{and, is_not_null, is_null}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion_common::{exec_err, internal_err, Result}; use datafusion_expr::binary::try_type_union_resolution; use datafusion_expr::{ - ColumnarValue, Documentation, ReturnInfo, ReturnTypeArgs, ScalarFunctionArgs, + ColumnarValue, Documentation, ReturnFieldArgs, ScalarFunctionArgs, }; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; @@ -79,19 +79,20 @@ impl ScalarUDFImpl for CoalesceFunc { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - internal_err!("return_type_from_args should be called instead") + internal_err!("return_field_from_args should be called instead") } - fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { // If any the arguments in coalesce is non-null, the result is non-null - let nullable = args.nullables.iter().all(|&nullable| nullable); + let nullable = args.arg_fields.iter().all(|f| f.is_nullable()); let return_type = args - .arg_types + .arg_fields .iter() + .map(|f| f.data_type()) .find_or_first(|d| !d.is_null()) .unwrap() .clone(); - Ok(ReturnInfo::new(return_type, nullable)) + Ok(Field::new(self.name(), return_type, nullable).into()) } /// coalesce evaluates to the first value which is not NULL diff --git a/datafusion/functions/src/core/getfield.rs b/datafusion/functions/src/core/getfield.rs index 3ac26b98359b..2f39132871bb 100644 --- a/datafusion/functions/src/core/getfield.rs +++ b/datafusion/functions/src/core/getfield.rs @@ -20,7 +20,7 @@ use arrow::array::{ Scalar, }; use arrow::compute::SortOptions; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field, FieldRef}; use arrow_buffer::NullBuffer; use datafusion_common::cast::{as_map_array, as_struct_array}; use datafusion_common::{ @@ -28,7 +28,7 @@ use datafusion_common::{ ScalarValue, }; use datafusion_expr::{ - ColumnarValue, Documentation, Expr, ReturnInfo, ReturnTypeArgs, ScalarFunctionArgs, + ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarFunctionArgs, }; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; @@ -108,7 +108,7 @@ impl ScalarUDFImpl for GetFieldFunc { let [base, field_name] = take_function_args(self.name(), args)?; let name = match field_name { - Expr::Literal(name) => name, + Expr::Literal(name, _) => name, other => &ScalarValue::Utf8(Some(other.schema_name().to_string())), }; @@ -118,7 +118,7 @@ impl ScalarUDFImpl for GetFieldFunc { fn schema_name(&self, args: &[Expr]) -> Result { let [base, field_name] = take_function_args(self.name(), args)?; let name = match field_name { - Expr::Literal(name) => name, + Expr::Literal(name, _) => name, other => &ScalarValue::Utf8(Some(other.schema_name().to_string())), }; @@ -130,14 +130,14 @@ impl ScalarUDFImpl for GetFieldFunc { } fn return_type(&self, _: &[DataType]) -> Result { - internal_err!("return_type_from_args should be called instead") + internal_err!("return_field_from_args should be called instead") } - fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { // Length check handled in the signature debug_assert_eq!(args.scalar_arguments.len(), 2); - match (&args.arg_types[0], args.scalar_arguments[1].as_ref()) { + match (&args.arg_fields[0].data_type(), args.scalar_arguments[1].as_ref()) { (DataType::Map(fields, _), _) => { match fields.data_type() { DataType::Struct(fields) if fields.len() == 2 => { @@ -146,7 +146,8 @@ impl ScalarUDFImpl for GetFieldFunc { // instead, we assume that the second column is the "value" column both here and in // execution. let value_field = fields.get(1).expect("fields should have exactly two members"); - Ok(ReturnInfo::new_nullable(value_field.data_type().clone())) + + Ok(value_field.as_ref().clone().with_nullable(true).into()) }, _ => exec_err!("Map fields must contain a Struct with exactly 2 fields"), } @@ -158,10 +159,20 @@ impl ScalarUDFImpl for GetFieldFunc { |field_name| { fields.iter().find(|f| f.name() == field_name) .ok_or(plan_datafusion_err!("Field {field_name} not found in struct")) - .map(|f| ReturnInfo::new_nullable(f.data_type().to_owned())) + .map(|f| { + let mut child_field = f.as_ref().clone(); + + // If the parent is nullable, then getting the child must be nullable, + // so potentially override the return value + + if args.arg_fields[0].is_nullable() { + child_field = child_field.with_nullable(true); + } + Arc::new(child_field) + }) }) }, - (DataType::Null, _) => Ok(ReturnInfo::new_nullable(DataType::Null)), + (DataType::Null, _) => Ok(Field::new(self.name(), DataType::Null, true).into()), (other, _) => exec_err!("The expression to get an indexed field is only valid for `Struct`, `Map` or `Null` types, got {other}"), } } diff --git a/datafusion/functions/src/core/mod.rs b/datafusion/functions/src/core/mod.rs index c6329b1ee0af..db080cd62847 100644 --- a/datafusion/functions/src/core/mod.rs +++ b/datafusion/functions/src/core/mod.rs @@ -36,6 +36,7 @@ pub mod overlay; pub mod planner; pub mod r#struct; pub mod union_extract; +pub mod union_tag; pub mod version; // create UDFs @@ -52,6 +53,7 @@ make_udf_function!(coalesce::CoalesceFunc, coalesce); make_udf_function!(greatest::GreatestFunc, greatest); make_udf_function!(least::LeastFunc, least); make_udf_function!(union_extract::UnionExtractFun, union_extract); +make_udf_function!(union_tag::UnionTagFunc, union_tag); make_udf_function!(version::VersionFunc, version); pub mod expr_fn { @@ -101,6 +103,10 @@ pub mod expr_fn { least, "Returns `least(args...)`, which evaluates to the smallest value in the list of expressions or NULL if all the expressions are NULL", args, + ),( + union_tag, + "Returns the name of the currently selected field in the union", + arg1 )); #[doc = "Returns the value of the field with the given name from the struct"] @@ -136,6 +142,7 @@ pub fn functions() -> Vec> { greatest(), least(), union_extract(), + union_tag(), version(), r#struct(), ] diff --git a/datafusion/functions/src/core/named_struct.rs b/datafusion/functions/src/core/named_struct.rs index bba884d96483..115f4a8aba22 100644 --- a/datafusion/functions/src/core/named_struct.rs +++ b/datafusion/functions/src/core/named_struct.rs @@ -16,10 +16,10 @@ // under the License. use arrow::array::StructArray; -use arrow::datatypes::{DataType, Field, Fields}; +use arrow::datatypes::{DataType, Field, FieldRef, Fields}; use datafusion_common::{exec_err, internal_err, Result}; use datafusion_expr::{ - ColumnarValue, Documentation, ReturnInfo, ReturnTypeArgs, ScalarFunctionArgs, + ColumnarValue, Documentation, ReturnFieldArgs, ScalarFunctionArgs, }; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; @@ -91,10 +91,12 @@ impl ScalarUDFImpl for NamedStructFunc { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - internal_err!("named_struct: return_type called instead of return_type_from_args") + internal_err!( + "named_struct: return_type called instead of return_field_from_args" + ) } - fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { // do not accept 0 arguments. if args.scalar_arguments.is_empty() { return exec_err!( @@ -126,7 +128,13 @@ impl ScalarUDFImpl for NamedStructFunc { ) ) .collect::>>()?; - let types = args.arg_types.iter().skip(1).step_by(2).collect::>(); + let types = args + .arg_fields + .iter() + .skip(1) + .step_by(2) + .map(|f| f.data_type()) + .collect::>(); let return_fields = names .into_iter() @@ -134,13 +142,16 @@ impl ScalarUDFImpl for NamedStructFunc { .map(|(name, data_type)| Ok(Field::new(name, data_type.to_owned(), true))) .collect::>>()?; - Ok(ReturnInfo::new_nullable(DataType::Struct(Fields::from( - return_fields, - )))) + Ok(Field::new( + self.name(), + DataType::Struct(Fields::from(return_fields)), + true, + ) + .into()) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - let DataType::Struct(fields) = args.return_type else { + let DataType::Struct(fields) = args.return_type() else { return internal_err!("incorrect named_struct return type"); }; diff --git a/datafusion/functions/src/core/struct.rs b/datafusion/functions/src/core/struct.rs index 8792bf1bd1b9..f068fc18a8b0 100644 --- a/datafusion/functions/src/core/struct.rs +++ b/datafusion/functions/src/core/struct.rs @@ -117,7 +117,7 @@ impl ScalarUDFImpl for StructFunc { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - let DataType::Struct(fields) = args.return_type else { + let DataType::Struct(fields) = args.return_type() else { return internal_err!("incorrect struct return type"); }; diff --git a/datafusion/functions/src/core/union_extract.rs b/datafusion/functions/src/core/union_extract.rs index 420eeed42cc3..be49f8226712 100644 --- a/datafusion/functions/src/core/union_extract.rs +++ b/datafusion/functions/src/core/union_extract.rs @@ -16,14 +16,14 @@ // under the License. use arrow::array::Array; -use arrow::datatypes::{DataType, FieldRef, UnionFields}; +use arrow::datatypes::{DataType, Field, FieldRef, UnionFields}; use datafusion_common::cast::as_union_array; use datafusion_common::utils::take_function_args; use datafusion_common::{ exec_datafusion_err, exec_err, internal_err, Result, ScalarValue, }; use datafusion_doc::Documentation; -use datafusion_expr::{ColumnarValue, ReturnInfo, ReturnTypeArgs, ScalarFunctionArgs}; +use datafusion_expr::{ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; @@ -82,35 +82,35 @@ impl ScalarUDFImpl for UnionExtractFun { } fn return_type(&self, _: &[DataType]) -> Result { - // should be using return_type_from_args and not calling the default implementation + // should be using return_field_from_args and not calling the default implementation internal_err!("union_extract should return type from args") } - fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { - if args.arg_types.len() != 2 { + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + if args.arg_fields.len() != 2 { return exec_err!( "union_extract expects 2 arguments, got {} instead", - args.arg_types.len() + args.arg_fields.len() ); } - let DataType::Union(fields, _) = &args.arg_types[0] else { + let DataType::Union(fields, _) = &args.arg_fields[0].data_type() else { return exec_err!( "union_extract first argument must be a union, got {} instead", - args.arg_types[0] + args.arg_fields[0].data_type() ); }; let Some(ScalarValue::Utf8(Some(field_name))) = &args.scalar_arguments[1] else { return exec_err!( "union_extract second argument must be a non-null string literal, got {} instead", - args.arg_types[1] + args.arg_fields[1].data_type() ); }; let field = find_field(fields, field_name)?.1; - Ok(ReturnInfo::new_nullable(field.data_type().clone())) + Ok(Field::new(self.name(), field.data_type().clone(), true).into()) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { @@ -189,47 +189,67 @@ mod tests { ], ); + let args = vec![ + ColumnarValue::Scalar(ScalarValue::Union( + None, + fields.clone(), + UnionMode::Dense, + )), + ColumnarValue::Scalar(ScalarValue::new_utf8("str")), + ]; + let arg_fields = args + .iter() + .map(|arg| Field::new("a", arg.data_type().clone(), true).into()) + .collect::>(); + let result = fun.invoke_with_args(ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::Union( - None, - fields.clone(), - UnionMode::Dense, - )), - ColumnarValue::Scalar(ScalarValue::new_utf8("str")), - ], + args, + arg_fields, number_rows: 1, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), })?; assert_scalar(result, ScalarValue::Utf8(None)); + let args = vec![ + ColumnarValue::Scalar(ScalarValue::Union( + Some((3, Box::new(ScalarValue::Int32(Some(42))))), + fields.clone(), + UnionMode::Dense, + )), + ColumnarValue::Scalar(ScalarValue::new_utf8("str")), + ]; + let arg_fields = args + .iter() + .map(|arg| Field::new("a", arg.data_type().clone(), true).into()) + .collect::>(); + let result = fun.invoke_with_args(ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::Union( - Some((3, Box::new(ScalarValue::Int32(Some(42))))), - fields.clone(), - UnionMode::Dense, - )), - ColumnarValue::Scalar(ScalarValue::new_utf8("str")), - ], + args, + arg_fields, number_rows: 1, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), })?; assert_scalar(result, ScalarValue::Utf8(None)); + let args = vec![ + ColumnarValue::Scalar(ScalarValue::Union( + Some((1, Box::new(ScalarValue::new_utf8("42")))), + fields.clone(), + UnionMode::Dense, + )), + ColumnarValue::Scalar(ScalarValue::new_utf8("str")), + ]; + let arg_fields = args + .iter() + .map(|arg| Field::new("a", arg.data_type().clone(), true).into()) + .collect::>(); let result = fun.invoke_with_args(ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::Union( - Some((1, Box::new(ScalarValue::new_utf8("42")))), - fields.clone(), - UnionMode::Dense, - )), - ColumnarValue::Scalar(ScalarValue::new_utf8("str")), - ], + args, + arg_fields, number_rows: 1, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), })?; assert_scalar(result, ScalarValue::new_utf8("42")); diff --git a/datafusion/functions/src/core/union_tag.rs b/datafusion/functions/src/core/union_tag.rs new file mode 100644 index 000000000000..3a4d96de2bc0 --- /dev/null +++ b/datafusion/functions/src/core/union_tag.rs @@ -0,0 +1,225 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Array, AsArray, DictionaryArray, Int8Array, StringArray}; +use arrow::datatypes::DataType; +use datafusion_common::utils::take_function_args; +use datafusion_common::{exec_datafusion_err, exec_err, Result, ScalarValue}; +use datafusion_doc::Documentation; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; +use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use datafusion_macros::user_doc; +use std::sync::Arc; + +#[user_doc( + doc_section(label = "Union Functions"), + description = "Returns the name of the currently selected field in the union", + syntax_example = "union_tag(union_expression)", + sql_example = r#"```sql +❯ select union_column, union_tag(union_column) from table_with_union; ++--------------+-------------------------+ +| union_column | union_tag(union_column) | ++--------------+-------------------------+ +| {a=1} | a | +| {b=3.0} | b | +| {a=4} | a | +| {b=} | b | +| {a=} | a | ++--------------+-------------------------+ +```"#, + standard_argument(name = "union", prefix = "Union") +)] +#[derive(Debug)] +pub struct UnionTagFunc { + signature: Signature, +} + +impl Default for UnionTagFunc { + fn default() -> Self { + Self::new() + } +} + +impl UnionTagFunc { + pub fn new() -> Self { + Self { + signature: Signature::any(1, Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for UnionTagFunc { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "union_tag" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _: &[DataType]) -> Result { + Ok(DataType::Dictionary( + Box::new(DataType::Int8), + Box::new(DataType::Utf8), + )) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let [union_] = take_function_args("union_tag", args.args)?; + + match union_ { + ColumnarValue::Array(array) + if matches!(array.data_type(), DataType::Union(_, _)) => + { + let union_array = array.as_union(); + + let keys = Int8Array::try_new(union_array.type_ids().clone(), None)?; + + let fields = match union_array.data_type() { + DataType::Union(fields, _) => fields, + _ => unreachable!(), + }; + + // Union fields type IDs only constraints are being unique and in the 0..128 range: + // They may not start at 0, be sequential, or even contiguous. + // Therefore, we allocate a values vector with a length equal to the highest type ID plus one, + // ensuring that each field's name can be placed at the index corresponding to its type ID. + let values_len = fields + .iter() + .map(|(type_id, _)| type_id + 1) + .max() + .unwrap_or_default() as usize; + + let mut values = vec![""; values_len]; + + for (type_id, field) in fields.iter() { + values[type_id as usize] = field.name().as_str() + } + + let values = Arc::new(StringArray::from(values)); + + // SAFETY: union type_ids are validated to not be smaller than zero. + // values len is the union biggest type id plus one. + // keys is built from the union type_ids, which contains only valid type ids + // therefore, `keys[i] >= values.len() || keys[i] < 0` never occurs + let dict = unsafe { DictionaryArray::new_unchecked(keys, values) }; + + Ok(ColumnarValue::Array(Arc::new(dict))) + } + ColumnarValue::Scalar(ScalarValue::Union(value, fields, _)) => match value { + Some((value_type_id, _)) => fields + .iter() + .find(|(type_id, _)| value_type_id == *type_id) + .map(|(_, field)| { + ColumnarValue::Scalar(ScalarValue::Dictionary( + Box::new(DataType::Int8), + Box::new(field.name().as_str().into()), + )) + }) + .ok_or_else(|| { + exec_datafusion_err!( + "union_tag: union scalar with unknow type_id {value_type_id}" + ) + }), + None => Ok(ColumnarValue::Scalar(ScalarValue::try_new_null( + args.return_field.data_type(), + )?)), + }, + v => exec_err!("union_tag only support unions, got {:?}", v.data_type()), + } + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} + +#[cfg(test)] +mod tests { + use super::UnionTagFunc; + use arrow::datatypes::{DataType, Field, UnionFields, UnionMode}; + use datafusion_common::ScalarValue; + use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; + use std::sync::Arc; + + // when it becomes possible to construct union scalars in SQL, this should go to sqllogictests + #[test] + fn union_scalar() { + let fields = [(0, Arc::new(Field::new("a", DataType::UInt32, false)))] + .into_iter() + .collect(); + + let scalar = ScalarValue::Union( + Some((0, Box::new(ScalarValue::UInt32(Some(0))))), + fields, + UnionMode::Dense, + ); + + let return_type = + DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)); + + let result = UnionTagFunc::new() + .invoke_with_args(ScalarFunctionArgs { + args: vec![ColumnarValue::Scalar(scalar)], + number_rows: 1, + return_field: Field::new("res", return_type, true).into(), + arg_fields: vec![], + }) + .unwrap(); + + assert_scalar( + result, + ScalarValue::Dictionary(Box::new(DataType::Int8), Box::new("a".into())), + ); + } + + #[test] + fn union_scalar_empty() { + let scalar = ScalarValue::Union(None, UnionFields::empty(), UnionMode::Dense); + + let return_type = + DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)); + + let result = UnionTagFunc::new() + .invoke_with_args(ScalarFunctionArgs { + args: vec![ColumnarValue::Scalar(scalar)], + number_rows: 1, + return_field: Field::new("res", return_type, true).into(), + arg_fields: vec![], + }) + .unwrap(); + + assert_scalar( + result, + ScalarValue::Dictionary( + Box::new(DataType::Int8), + Box::new(ScalarValue::Utf8(None)), + ), + ); + } + + fn assert_scalar(value: ColumnarValue, expected: ScalarValue) { + match value { + ColumnarValue::Array(array) => panic!("expected scalar got {array:?}"), + ColumnarValue::Scalar(scalar) => assert_eq!(scalar, expected), + } + } +} diff --git a/datafusion/functions/src/core/version.rs b/datafusion/functions/src/core/version.rs index 34038022f2dc..b3abe246b4b3 100644 --- a/datafusion/functions/src/core/version.rs +++ b/datafusion/functions/src/core/version.rs @@ -97,6 +97,7 @@ impl ScalarUDFImpl for VersionFunc { #[cfg(test)] mod test { use super::*; + use arrow::datatypes::Field; use datafusion_expr::ScalarUDF; #[tokio::test] @@ -105,8 +106,9 @@ mod test { let version = version_udf .invoke_with_args(ScalarFunctionArgs { args: vec![], + arg_fields: vec![], number_rows: 0, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), }) .unwrap(); diff --git a/datafusion/functions/src/datetime/current_date.rs b/datafusion/functions/src/datetime/current_date.rs index 9998e7d3758e..2bda1f262abe 100644 --- a/datafusion/functions/src/datetime/current_date.rs +++ b/datafusion/functions/src/datetime/current_date.rs @@ -108,6 +108,7 @@ impl ScalarUDFImpl for CurrentDateFunc { ); Ok(ExprSimplifyResult::Simplified(Expr::Literal( ScalarValue::Date32(days), + None, ))) } diff --git a/datafusion/functions/src/datetime/current_time.rs b/datafusion/functions/src/datetime/current_time.rs index c416d0240b13..9b9d3997e9d7 100644 --- a/datafusion/functions/src/datetime/current_time.rs +++ b/datafusion/functions/src/datetime/current_time.rs @@ -96,6 +96,7 @@ impl ScalarUDFImpl for CurrentTimeFunc { let nano = now_ts.timestamp_nanos_opt().map(|ts| ts % 86400000000000); Ok(ExprSimplifyResult::Simplified(Expr::Literal( ScalarValue::Time64Nanosecond(nano), + None, ))) } diff --git a/datafusion/functions/src/datetime/date_bin.rs b/datafusion/functions/src/datetime/date_bin.rs index 5ffae46dde48..1c801dfead72 100644 --- a/datafusion/functions/src/datetime/date_bin.rs +++ b/datafusion/functions/src/datetime/date_bin.rs @@ -505,85 +505,88 @@ mod tests { use arrow::array::types::TimestampNanosecondType; use arrow::array::{Array, IntervalDayTimeArray, TimestampNanosecondArray}; use arrow::compute::kernels::cast_utils::string_to_timestamp_nanos; - use arrow::datatypes::{DataType, TimeUnit}; + use arrow::datatypes::{DataType, Field, FieldRef, TimeUnit}; use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano}; - use datafusion_common::ScalarValue; + use datafusion_common::{DataFusionError, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; use chrono::TimeDelta; + fn invoke_date_bin_with_args( + args: Vec, + number_rows: usize, + return_field: &FieldRef, + ) -> Result { + let arg_fields = args + .iter() + .map(|arg| Field::new("a", arg.data_type(), true).into()) + .collect::>(); + + let args = datafusion_expr::ScalarFunctionArgs { + args, + arg_fields, + number_rows, + return_field: Arc::clone(return_field), + }; + DateBinFunc::new().invoke_with_args(args) + } + #[test] fn test_date_bin() { - let mut args = datafusion_expr::ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some( - IntervalDayTime { - days: 0, - milliseconds: 1, - }, - ))), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ], - number_rows: 1, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), - }; - let res = DateBinFunc::new().invoke_with_args(args); + let return_field = &Arc::new(Field::new( + "f", + DataType::Timestamp(TimeUnit::Nanosecond, None), + true, + )); + + let mut args = vec![ + ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(IntervalDayTime { + days: 0, + milliseconds: 1, + }))), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ]; + let res = invoke_date_bin_with_args(args, 1, return_field); assert!(res.is_ok()); let timestamps = Arc::new((1..6).map(Some).collect::()); let batch_len = timestamps.len(); - args = datafusion_expr::ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some( - IntervalDayTime { - days: 0, - milliseconds: 1, - }, - ))), - ColumnarValue::Array(timestamps), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ], - number_rows: batch_len, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), - }; - let res = DateBinFunc::new().invoke_with_args(args); + args = vec![ + ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(IntervalDayTime { + days: 0, + milliseconds: 1, + }))), + ColumnarValue::Array(timestamps), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ]; + let res = invoke_date_bin_with_args(args, batch_len, return_field); assert!(res.is_ok()); - args = datafusion_expr::ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some( - IntervalDayTime { - days: 0, - milliseconds: 1, - }, - ))), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ], - number_rows: 1, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), - }; - let res = DateBinFunc::new().invoke_with_args(args); + args = vec![ + ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(IntervalDayTime { + days: 0, + milliseconds: 1, + }))), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ]; + let res = invoke_date_bin_with_args(args, 1, return_field); assert!(res.is_ok()); // stride supports month-day-nano - args = datafusion_expr::ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::IntervalMonthDayNano(Some( - IntervalMonthDayNano { - months: 0, - days: 0, - nanoseconds: 1, - }, - ))), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ], - number_rows: 1, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), - }; - let res = DateBinFunc::new().invoke_with_args(args); + args = vec![ + ColumnarValue::Scalar(ScalarValue::IntervalMonthDayNano(Some( + IntervalMonthDayNano { + months: 0, + days: 0, + nanoseconds: 1, + }, + ))), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ]; + let res = invoke_date_bin_with_args(args, 1, return_field); assert!(res.is_ok()); // @@ -591,33 +594,25 @@ mod tests { // // invalid number of arguments - args = datafusion_expr::ScalarFunctionArgs { - args: vec![ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some( - IntervalDayTime { - days: 0, - milliseconds: 1, - }, - )))], - number_rows: 1, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), - }; - let res = DateBinFunc::new().invoke_with_args(args); + args = vec![ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some( + IntervalDayTime { + days: 0, + milliseconds: 1, + }, + )))]; + let res = invoke_date_bin_with_args(args, 1, return_field); assert_eq!( res.err().unwrap().strip_backtrace(), "Execution error: DATE_BIN expected two or three arguments" ); // stride: invalid type - args = datafusion_expr::ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::IntervalYearMonth(Some(1))), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ], - number_rows: 1, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), - }; - let res = DateBinFunc::new().invoke_with_args(args); + args = vec![ + ColumnarValue::Scalar(ScalarValue::IntervalYearMonth(Some(1))), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ]; + let res = invoke_date_bin_with_args(args, 1, return_field); assert_eq!( res.err().unwrap().strip_backtrace(), "Execution error: DATE_BIN expects stride argument to be an INTERVAL but got Interval(YearMonth)" @@ -625,113 +620,83 @@ mod tests { // stride: invalid value - args = datafusion_expr::ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some( - IntervalDayTime { - days: 0, - milliseconds: 0, - }, - ))), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ], - number_rows: 1, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), - }; + args = vec![ + ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(IntervalDayTime { + days: 0, + milliseconds: 0, + }))), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ]; - let res = DateBinFunc::new().invoke_with_args(args); + let res = invoke_date_bin_with_args(args, 1, return_field); assert_eq!( res.err().unwrap().strip_backtrace(), "Execution error: DATE_BIN stride must be non-zero" ); // stride: overflow of day-time interval - args = datafusion_expr::ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some( - IntervalDayTime::MAX, - ))), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ], - number_rows: 1, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), - }; - let res = DateBinFunc::new().invoke_with_args(args); + args = vec![ + ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some( + IntervalDayTime::MAX, + ))), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ]; + let res = invoke_date_bin_with_args(args, 1, return_field); assert_eq!( res.err().unwrap().strip_backtrace(), "Execution error: DATE_BIN stride argument is too large" ); // stride: overflow of month-day-nano interval - args = datafusion_expr::ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::new_interval_mdn(0, i32::MAX, 1)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ], - number_rows: 1, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), - }; - let res = DateBinFunc::new().invoke_with_args(args); + args = vec![ + ColumnarValue::Scalar(ScalarValue::new_interval_mdn(0, i32::MAX, 1)), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ]; + let res = invoke_date_bin_with_args(args, 1, return_field); assert_eq!( res.err().unwrap().strip_backtrace(), "Execution error: DATE_BIN stride argument is too large" ); // stride: month intervals - args = datafusion_expr::ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::new_interval_mdn(1, 1, 1)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ], - number_rows: 1, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), - }; - let res = DateBinFunc::new().invoke_with_args(args); + args = vec![ + ColumnarValue::Scalar(ScalarValue::new_interval_mdn(1, 1, 1)), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ]; + let res = invoke_date_bin_with_args(args, 1, return_field); assert_eq!( res.err().unwrap().strip_backtrace(), "This feature is not implemented: DATE_BIN stride does not support combination of month, day and nanosecond intervals" ); // origin: invalid type - args = datafusion_expr::ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some( - IntervalDayTime { - days: 0, - milliseconds: 1, - }, - ))), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(Some(1), None)), - ], - number_rows: 1, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), - }; - let res = DateBinFunc::new().invoke_with_args(args); + args = vec![ + ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(IntervalDayTime { + days: 0, + milliseconds: 1, + }))), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(Some(1), None)), + ]; + let res = invoke_date_bin_with_args(args, 1, return_field); assert_eq!( res.err().unwrap().strip_backtrace(), "Execution error: DATE_BIN expects origin argument to be a TIMESTAMP with nanosecond precision but got Timestamp(Microsecond, None)" ); - args = datafusion_expr::ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some( - IntervalDayTime { - days: 0, - milliseconds: 1, - }, - ))), - ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ], - number_rows: 1, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), - }; - let res = DateBinFunc::new().invoke_with_args(args); + args = vec![ + ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(IntervalDayTime { + days: 0, + milliseconds: 1, + }))), + ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(Some(1), None)), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ]; + let res = invoke_date_bin_with_args(args, 1, return_field); assert!(res.is_ok()); // unsupported array type for stride @@ -745,16 +710,12 @@ mod tests { }) .collect::(), ); - args = datafusion_expr::ScalarFunctionArgs { - args: vec![ - ColumnarValue::Array(intervals), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ], - number_rows: 1, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), - }; - let res = DateBinFunc::new().invoke_with_args(args); + args = vec![ + ColumnarValue::Array(intervals), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ]; + let res = invoke_date_bin_with_args(args, 1, return_field); assert_eq!( res.err().unwrap().strip_backtrace(), "This feature is not implemented: DATE_BIN only supports literal values for the stride argument, not arrays" @@ -763,21 +724,15 @@ mod tests { // unsupported array type for origin let timestamps = Arc::new((1..6).map(Some).collect::()); let batch_len = timestamps.len(); - args = datafusion_expr::ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some( - IntervalDayTime { - days: 0, - milliseconds: 1, - }, - ))), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Array(timestamps), - ], - number_rows: batch_len, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), - }; - let res = DateBinFunc::new().invoke_with_args(args); + args = vec![ + ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(IntervalDayTime { + days: 0, + milliseconds: 1, + }))), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::Array(timestamps), + ]; + let res = invoke_date_bin_with_args(args, batch_len, return_field); assert_eq!( res.err().unwrap().strip_backtrace(), "This feature is not implemented: DATE_BIN only supports literal values for the origin argument, not arrays" @@ -893,22 +848,22 @@ mod tests { .collect::() .with_timezone_opt(tz_opt.clone()); let batch_len = input.len(); - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::new_interval_dt(1, 0)), - ColumnarValue::Array(Arc::new(input)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond( - Some(string_to_timestamp_nanos(origin).unwrap()), - tz_opt.clone(), - )), - ], - number_rows: batch_len, - return_type: &DataType::Timestamp( - TimeUnit::Nanosecond, + let args = vec![ + ColumnarValue::Scalar(ScalarValue::new_interval_dt(1, 0)), + ColumnarValue::Array(Arc::new(input)), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond( + Some(string_to_timestamp_nanos(origin).unwrap()), tz_opt.clone(), - ), - }; - let result = DateBinFunc::new().invoke_with_args(args).unwrap(); + )), + ]; + let return_field = &Arc::new(Field::new( + "f", + DataType::Timestamp(TimeUnit::Nanosecond, tz_opt.clone()), + true, + )); + let result = + invoke_date_bin_with_args(args, batch_len, return_field).unwrap(); + if let ColumnarValue::Array(result) = result { assert_eq!( result.data_type(), diff --git a/datafusion/functions/src/datetime/date_part.rs b/datafusion/functions/src/datetime/date_part.rs index bfd06b39d206..021000dc100b 100644 --- a/datafusion/functions/src/datetime/date_part.rs +++ b/datafusion/functions/src/datetime/date_part.rs @@ -26,7 +26,7 @@ use arrow::datatypes::DataType::{ Date32, Date64, Duration, Interval, Time32, Time64, Timestamp, }; use arrow::datatypes::TimeUnit::{Microsecond, Millisecond, Nanosecond, Second}; -use arrow::datatypes::{DataType, TimeUnit}; +use arrow::datatypes::{DataType, Field, FieldRef, TimeUnit}; use datafusion_common::types::{logical_date, NativeType}; use datafusion_common::{ @@ -42,7 +42,7 @@ use datafusion_common::{ Result, ScalarValue, }; use datafusion_expr::{ - ColumnarValue, Documentation, ReturnInfo, ReturnTypeArgs, ScalarUDFImpl, Signature, + ColumnarValue, Documentation, ReturnFieldArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, }; use datafusion_expr_common::signature::{Coercion, TypeSignatureClass}; @@ -142,10 +142,10 @@ impl ScalarUDFImpl for DatePartFunc { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - internal_err!("return_type_from_args should be called instead") + internal_err!("return_field_from_args should be called instead") } - fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { let [field, _] = take_function_args(self.name(), args.scalar_arguments)?; field @@ -155,12 +155,13 @@ impl ScalarUDFImpl for DatePartFunc { .filter(|s| !s.is_empty()) .map(|part| { if is_epoch(part) { - ReturnInfo::new_nullable(DataType::Float64) + Field::new(self.name(), DataType::Float64, true) } else { - ReturnInfo::new_nullable(DataType::Int32) + Field::new(self.name(), DataType::Int32, true) } }) }) + .map(Arc::new) .map_or_else( || exec_err!("{} requires non-empty constant string", self.name()), Ok, diff --git a/datafusion/functions/src/datetime/date_trunc.rs b/datafusion/functions/src/datetime/date_trunc.rs index ed3eb228bf03..8963ef77a53b 100644 --- a/datafusion/functions/src/datetime/date_trunc.rs +++ b/datafusion/functions/src/datetime/date_trunc.rs @@ -471,7 +471,7 @@ fn parse_tz(tz: &Option>) -> Result> { tz.as_ref() .map(|tz| { Tz::from_str(tz).map_err(|op| { - DataFusionError::Execution(format!("failed on timezone {tz}: {:?}", op)) + DataFusionError::Execution(format!("failed on timezone {tz}: {op:?}")) }) }) .transpose() @@ -487,7 +487,7 @@ mod tests { use arrow::array::types::TimestampNanosecondType; use arrow::array::{Array, TimestampNanosecondArray}; use arrow::compute::kernels::cast_utils::string_to_timestamp_nanos; - use arrow::datatypes::{DataType, TimeUnit}; + use arrow::datatypes::{DataType, Field, TimeUnit}; use datafusion_common::ScalarValue; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; @@ -726,13 +726,23 @@ mod tests { .collect::() .with_timezone_opt(tz_opt.clone()); let batch_len = input.len(); + let arg_fields = vec![ + Field::new("a", DataType::Utf8, false).into(), + Field::new("b", input.data_type().clone(), false).into(), + ]; let args = datafusion_expr::ScalarFunctionArgs { args: vec![ ColumnarValue::Scalar(ScalarValue::from("day")), ColumnarValue::Array(Arc::new(input)), ], + arg_fields, number_rows: batch_len, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, tz_opt.clone()), + return_field: Field::new( + "f", + DataType::Timestamp(TimeUnit::Nanosecond, tz_opt.clone()), + true, + ) + .into(), }; let result = DateTruncFunc::new().invoke_with_args(args).unwrap(); if let ColumnarValue::Array(result) = result { @@ -888,13 +898,23 @@ mod tests { .collect::() .with_timezone_opt(tz_opt.clone()); let batch_len = input.len(); + let arg_fields = vec![ + Field::new("a", DataType::Utf8, false).into(), + Field::new("b", input.data_type().clone(), false).into(), + ]; let args = datafusion_expr::ScalarFunctionArgs { args: vec![ ColumnarValue::Scalar(ScalarValue::from("hour")), ColumnarValue::Array(Arc::new(input)), ], + arg_fields, number_rows: batch_len, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, tz_opt.clone()), + return_field: Field::new( + "f", + DataType::Timestamp(TimeUnit::Nanosecond, tz_opt.clone()), + true, + ) + .into(), }; let result = DateTruncFunc::new().invoke_with_args(args).unwrap(); if let ColumnarValue::Array(result) = result { diff --git a/datafusion/functions/src/datetime/from_unixtime.rs b/datafusion/functions/src/datetime/from_unixtime.rs index ed8181452dbd..c1497040261c 100644 --- a/datafusion/functions/src/datetime/from_unixtime.rs +++ b/datafusion/functions/src/datetime/from_unixtime.rs @@ -18,20 +18,19 @@ use std::any::Any; use std::sync::Arc; -use arrow::datatypes::DataType; use arrow::datatypes::DataType::{Int64, Timestamp, Utf8}; use arrow::datatypes::TimeUnit::Second; +use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion_common::{exec_err, internal_err, Result, ScalarValue}; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ - ColumnarValue, Documentation, ReturnInfo, ReturnTypeArgs, ScalarUDFImpl, Signature, - Volatility, + ColumnarValue, Documentation, ReturnFieldArgs, ScalarUDFImpl, Signature, Volatility, }; use datafusion_macros::user_doc; #[user_doc( doc_section(label = "Time and Date Functions"), - description = "Converts an integer to RFC3339 timestamp format (`YYYY-MM-DDT00:00:00.000000000Z`). Integers and unsigned integers are interpreted as nanoseconds since the unix epoch (`1970-01-01T00:00:00Z`) return the corresponding timestamp.", + description = "Converts an integer to RFC3339 timestamp format (`YYYY-MM-DDT00:00:00.000000000Z`). Integers and unsigned integers are interpreted as seconds since the unix epoch (`1970-01-01T00:00:00Z`) return the corresponding timestamp.", syntax_example = "from_unixtime(expression[, timezone])", sql_example = r#"```sql > select from_unixtime(1599572549, 'America/New_York'); @@ -82,12 +81,12 @@ impl ScalarUDFImpl for FromUnixtimeFunc { &self.signature } - fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { // Length check handled in the signature debug_assert!(matches!(args.scalar_arguments.len(), 1 | 2)); if args.scalar_arguments.len() == 1 { - Ok(ReturnInfo::new_nullable(Timestamp(Second, None))) + Ok(Field::new(self.name(), Timestamp(Second, None), true).into()) } else { args.scalar_arguments[1] .and_then(|sv| { @@ -95,12 +94,14 @@ impl ScalarUDFImpl for FromUnixtimeFunc { .flatten() .filter(|s| !s.is_empty()) .map(|tz| { - ReturnInfo::new_nullable(Timestamp( - Second, - Some(Arc::from(tz.to_string())), - )) + Field::new( + self.name(), + Timestamp(Second, Some(Arc::from(tz.to_string()))), + true, + ) }) }) + .map(Arc::new) .map_or_else( || { exec_err!( @@ -114,7 +115,7 @@ impl ScalarUDFImpl for FromUnixtimeFunc { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - internal_err!("call return_type_from_args instead") + internal_err!("call return_field_from_args instead") } fn invoke_with_args( @@ -161,8 +162,8 @@ impl ScalarUDFImpl for FromUnixtimeFunc { #[cfg(test)] mod test { use crate::datetime::from_unixtime::FromUnixtimeFunc; - use arrow::datatypes::DataType; use arrow::datatypes::TimeUnit::Second; + use arrow::datatypes::{DataType, Field}; use datafusion_common::ScalarValue; use datafusion_common::ScalarValue::Int64; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; @@ -170,10 +171,12 @@ mod test { #[test] fn test_without_timezone() { + let arg_field = Arc::new(Field::new("a", DataType::Int64, true)); let args = datafusion_expr::ScalarFunctionArgs { args: vec![ColumnarValue::Scalar(Int64(Some(1729900800)))], + arg_fields: vec![arg_field], number_rows: 1, - return_type: &DataType::Timestamp(Second, None), + return_field: Field::new("f", DataType::Timestamp(Second, None), true).into(), }; let result = FromUnixtimeFunc::new().invoke_with_args(args).unwrap(); @@ -187,6 +190,10 @@ mod test { #[test] fn test_with_timezone() { + let arg_fields = vec![ + Field::new("a", DataType::Int64, true).into(), + Field::new("a", DataType::Utf8, true).into(), + ]; let args = datafusion_expr::ScalarFunctionArgs { args: vec![ ColumnarValue::Scalar(Int64(Some(1729900800))), @@ -194,11 +201,14 @@ mod test { "America/New_York".to_string(), ))), ], + arg_fields, number_rows: 2, - return_type: &DataType::Timestamp( - Second, - Some(Arc::from("America/New_York")), - ), + return_field: Field::new( + "f", + DataType::Timestamp(Second, Some(Arc::from("America/New_York"))), + true, + ) + .into(), }; let result = FromUnixtimeFunc::new().invoke_with_args(args).unwrap(); diff --git a/datafusion/functions/src/datetime/make_date.rs b/datafusion/functions/src/datetime/make_date.rs index 929fa601f107..daa9bd83971f 100644 --- a/datafusion/functions/src/datetime/make_date.rs +++ b/datafusion/functions/src/datetime/make_date.rs @@ -223,25 +223,39 @@ fn make_date_inner( mod tests { use crate::datetime::make_date::MakeDateFunc; use arrow::array::{Array, Date32Array, Int32Array, Int64Array, UInt32Array}; - use arrow::datatypes::DataType; - use datafusion_common::ScalarValue; + use arrow::datatypes::{DataType, Field}; + use datafusion_common::{DataFusionError, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; use std::sync::Arc; + fn invoke_make_date_with_args( + args: Vec, + number_rows: usize, + ) -> Result { + let arg_fields = args + .iter() + .map(|arg| Field::new("a", arg.data_type(), true).into()) + .collect::>(); + let args = datafusion_expr::ScalarFunctionArgs { + args, + arg_fields, + number_rows, + return_field: Field::new("f", DataType::Date32, true).into(), + }; + MakeDateFunc::new().invoke_with_args(args) + } + #[test] fn test_make_date() { - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ + let res = invoke_make_date_with_args( + vec![ ColumnarValue::Scalar(ScalarValue::Int32(Some(2024))), ColumnarValue::Scalar(ScalarValue::Int64(Some(1))), ColumnarValue::Scalar(ScalarValue::UInt32(Some(14))), ], - number_rows: 1, - return_type: &DataType::Date32, - }; - let res = MakeDateFunc::new() - .invoke_with_args(args) - .expect("that make_date parsed values without error"); + 1, + ) + .expect("that make_date parsed values without error"); if let ColumnarValue::Scalar(ScalarValue::Date32(date)) = res { assert_eq!(19736, date.unwrap()); @@ -249,18 +263,15 @@ mod tests { panic!("Expected a scalar value") } - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ + let res = invoke_make_date_with_args( + vec![ ColumnarValue::Scalar(ScalarValue::Int64(Some(2024))), ColumnarValue::Scalar(ScalarValue::UInt64(Some(1))), ColumnarValue::Scalar(ScalarValue::UInt32(Some(14))), ], - number_rows: 1, - return_type: &DataType::Date32, - }; - let res = MakeDateFunc::new() - .invoke_with_args(args) - .expect("that make_date parsed values without error"); + 1, + ) + .expect("that make_date parsed values without error"); if let ColumnarValue::Scalar(ScalarValue::Date32(date)) = res { assert_eq!(19736, date.unwrap()); @@ -268,18 +279,15 @@ mod tests { panic!("Expected a scalar value") } - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ + let res = invoke_make_date_with_args( + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(Some("2024".to_string()))), ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some("1".to_string()))), ColumnarValue::Scalar(ScalarValue::Utf8(Some("14".to_string()))), ], - number_rows: 1, - return_type: &DataType::Date32, - }; - let res = MakeDateFunc::new() - .invoke_with_args(args) - .expect("that make_date parsed values without error"); + 1, + ) + .expect("that make_date parsed values without error"); if let ColumnarValue::Scalar(ScalarValue::Date32(date)) = res { assert_eq!(19736, date.unwrap()); @@ -291,18 +299,15 @@ mod tests { let months = Arc::new((1..5).map(Some).collect::()); let days = Arc::new((11..15).map(Some).collect::()); let batch_len = years.len(); - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ + let res = invoke_make_date_with_args( + vec![ ColumnarValue::Array(years), ColumnarValue::Array(months), ColumnarValue::Array(days), ], - number_rows: batch_len, - return_type: &DataType::Date32, - }; - let res = MakeDateFunc::new() - .invoke_with_args(args) - .expect("that make_date parsed values without error"); + batch_len, + ) + .unwrap(); if let ColumnarValue::Array(array) = res { assert_eq!(array.len(), 4); @@ -321,60 +326,52 @@ mod tests { // // invalid number of arguments - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ColumnarValue::Scalar(ScalarValue::Int32(Some(1)))], - number_rows: 1, - return_type: &DataType::Date32, - }; - let res = MakeDateFunc::new().invoke_with_args(args); + let res = invoke_make_date_with_args( + vec![ColumnarValue::Scalar(ScalarValue::Int32(Some(1)))], + 1, + ); assert_eq!( res.err().unwrap().strip_backtrace(), "Execution error: make_date function requires 3 arguments, got 1" ); // invalid type - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ + let res = invoke_make_date_with_args( + vec![ ColumnarValue::Scalar(ScalarValue::IntervalYearMonth(Some(1))), ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ], - number_rows: 1, - return_type: &DataType::Date32, - }; - let res = MakeDateFunc::new().invoke_with_args(args); + 1, + ); assert_eq!( res.err().unwrap().strip_backtrace(), "Arrow error: Cast error: Casting from Interval(YearMonth) to Int32 not supported" ); // overflow of month - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ + let res = invoke_make_date_with_args( + vec![ ColumnarValue::Scalar(ScalarValue::Int32(Some(2023))), ColumnarValue::Scalar(ScalarValue::UInt64(Some(u64::MAX))), ColumnarValue::Scalar(ScalarValue::Int32(Some(22))), ], - number_rows: 1, - return_type: &DataType::Date32, - }; - let res = MakeDateFunc::new().invoke_with_args(args); + 1, + ); assert_eq!( res.err().unwrap().strip_backtrace(), "Arrow error: Cast error: Can't cast value 18446744073709551615 to type Int32" ); // overflow of day - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ + let res = invoke_make_date_with_args( + vec![ ColumnarValue::Scalar(ScalarValue::Int32(Some(2023))), ColumnarValue::Scalar(ScalarValue::Int32(Some(22))), ColumnarValue::Scalar(ScalarValue::UInt32(Some(u32::MAX))), ], - number_rows: 1, - return_type: &DataType::Date32, - }; - let res = MakeDateFunc::new().invoke_with_args(args); + 1, + ); assert_eq!( res.err().unwrap().strip_backtrace(), "Arrow error: Cast error: Can't cast value 4294967295 to type Int32" diff --git a/datafusion/functions/src/datetime/now.rs b/datafusion/functions/src/datetime/now.rs index b26dc52cee4d..ffb3aed5a960 100644 --- a/datafusion/functions/src/datetime/now.rs +++ b/datafusion/functions/src/datetime/now.rs @@ -15,16 +15,16 @@ // specific language governing permissions and limitations // under the License. -use arrow::datatypes::DataType; use arrow::datatypes::DataType::Timestamp; use arrow::datatypes::TimeUnit::Nanosecond; +use arrow::datatypes::{DataType, Field, FieldRef}; use std::any::Any; use datafusion_common::{internal_err, Result, ScalarValue}; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::{ - ColumnarValue, Documentation, Expr, ReturnInfo, ReturnTypeArgs, ScalarUDFImpl, - Signature, Volatility, + ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarUDFImpl, Signature, + Volatility, }; use datafusion_macros::user_doc; @@ -77,15 +77,17 @@ impl ScalarUDFImpl for NowFunc { &self.signature } - fn return_type_from_args(&self, _args: ReturnTypeArgs) -> Result { - Ok(ReturnInfo::new_non_nullable(Timestamp( - Nanosecond, - Some("+00:00".into()), - ))) + fn return_field_from_args(&self, _args: ReturnFieldArgs) -> Result { + Ok(Field::new( + self.name(), + Timestamp(Nanosecond, Some("+00:00".into())), + false, + ) + .into()) } fn return_type(&self, _arg_types: &[DataType]) -> Result { - internal_err!("return_type_from_args should be called instead") + internal_err!("return_field_from_args should be called instead") } fn invoke_with_args( @@ -106,6 +108,7 @@ impl ScalarUDFImpl for NowFunc { .timestamp_nanos_opt(); Ok(ExprSimplifyResult::Simplified(Expr::Literal( ScalarValue::TimestampNanosecond(now_ts, Some("+00:00".into())), + None, ))) } diff --git a/datafusion/functions/src/datetime/to_char.rs b/datafusion/functions/src/datetime/to_char.rs index 8b2e5ad87471..3e89242aba26 100644 --- a/datafusion/functions/src/datetime/to_char.rs +++ b/datafusion/functions/src/datetime/to_char.rs @@ -303,7 +303,7 @@ mod tests { TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, }; - use arrow::datatypes::DataType; + use arrow::datatypes::{DataType, Field, TimeUnit}; use chrono::{NaiveDateTime, Timelike}; use datafusion_common::ScalarValue; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; @@ -385,10 +385,15 @@ mod tests { ]; for (value, format, expected) in scalar_data { + let arg_fields = vec![ + Field::new("a", value.data_type(), false).into(), + Field::new("a", format.data_type(), false).into(), + ]; let args = datafusion_expr::ScalarFunctionArgs { args: vec![ColumnarValue::Scalar(value), ColumnarValue::Scalar(format)], + arg_fields, number_rows: 1, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), }; let result = ToCharFunc::new() .invoke_with_args(args) @@ -465,13 +470,18 @@ mod tests { for (value, format, expected) in scalar_array_data { let batch_len = format.len(); + let arg_fields = vec![ + Field::new("a", value.data_type(), false).into(), + Field::new("a", format.data_type().to_owned(), false).into(), + ]; let args = datafusion_expr::ScalarFunctionArgs { args: vec![ ColumnarValue::Scalar(value), ColumnarValue::Array(Arc::new(format) as ArrayRef), ], + arg_fields, number_rows: batch_len, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), }; let result = ToCharFunc::new() .invoke_with_args(args) @@ -596,13 +606,18 @@ mod tests { for (value, format, expected) in array_scalar_data { let batch_len = value.len(); + let arg_fields = vec![ + Field::new("a", value.data_type().clone(), false).into(), + Field::new("a", format.data_type(), false).into(), + ]; let args = datafusion_expr::ScalarFunctionArgs { args: vec![ ColumnarValue::Array(value as ArrayRef), ColumnarValue::Scalar(format), ], + arg_fields, number_rows: batch_len, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), }; let result = ToCharFunc::new() .invoke_with_args(args) @@ -618,13 +633,18 @@ mod tests { for (value, format, expected) in array_array_data { let batch_len = value.len(); + let arg_fields = vec![ + Field::new("a", value.data_type().clone(), false).into(), + Field::new("a", format.data_type().clone(), false).into(), + ]; let args = datafusion_expr::ScalarFunctionArgs { args: vec![ ColumnarValue::Array(value), ColumnarValue::Array(Arc::new(format) as ArrayRef), ], + arg_fields, number_rows: batch_len, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), }; let result = ToCharFunc::new() .invoke_with_args(args) @@ -643,10 +663,12 @@ mod tests { // // invalid number of arguments + let arg_field = Field::new("a", DataType::Int32, true).into(); let args = datafusion_expr::ScalarFunctionArgs { args: vec![ColumnarValue::Scalar(ScalarValue::Int32(Some(1)))], + arg_fields: vec![arg_field], number_rows: 1, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), }; let result = ToCharFunc::new().invoke_with_args(args); assert_eq!( @@ -655,13 +677,18 @@ mod tests { ); // invalid type + let arg_fields = vec![ + Field::new("a", DataType::Utf8, true).into(), + Field::new("a", DataType::Timestamp(TimeUnit::Nanosecond, None), true).into(), + ]; let args = datafusion_expr::ScalarFunctionArgs { args: vec![ ColumnarValue::Scalar(ScalarValue::Int32(Some(1))), ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ], + arg_fields, number_rows: 1, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), }; let result = ToCharFunc::new().invoke_with_args(args); assert_eq!( diff --git a/datafusion/functions/src/datetime/to_date.rs b/datafusion/functions/src/datetime/to_date.rs index 91740b2c31c1..c9fd17dbef11 100644 --- a/datafusion/functions/src/datetime/to_date.rs +++ b/datafusion/functions/src/datetime/to_date.rs @@ -163,14 +163,32 @@ impl ScalarUDFImpl for ToDateFunc { #[cfg(test)] mod tests { use arrow::array::{Array, Date32Array, GenericStringArray, StringViewArray}; - use arrow::datatypes::DataType; + use arrow::datatypes::{DataType, Field}; use arrow::{compute::kernels::cast_utils::Parser, datatypes::Date32Type}; - use datafusion_common::ScalarValue; + use datafusion_common::{DataFusionError, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; use std::sync::Arc; use super::ToDateFunc; + fn invoke_to_date_with_args( + args: Vec, + number_rows: usize, + ) -> Result { + let arg_fields = args + .iter() + .map(|arg| Field::new("a", arg.data_type(), true).into()) + .collect::>(); + + let args = datafusion_expr::ScalarFunctionArgs { + args, + arg_fields, + number_rows, + return_field: Field::new("f", DataType::Date32, true).into(), + }; + ToDateFunc::new().invoke_with_args(args) + } + #[test] fn test_to_date_without_format() { struct TestCase { @@ -208,12 +226,8 @@ mod tests { } fn test_scalar(sv: ScalarValue, tc: &TestCase) { - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ColumnarValue::Scalar(sv)], - number_rows: 1, - return_type: &DataType::Date32, - }; - let to_date_result = ToDateFunc::new().invoke_with_args(args); + let to_date_result = + invoke_to_date_with_args(vec![ColumnarValue::Scalar(sv)], 1); match to_date_result { Ok(ColumnarValue::Scalar(ScalarValue::Date32(date_val))) => { @@ -234,12 +248,10 @@ mod tests { { let date_array = A::from(vec![tc.date_str]); let batch_len = date_array.len(); - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ColumnarValue::Array(Arc::new(date_array))], - number_rows: batch_len, - return_type: &DataType::Date32, - }; - let to_date_result = ToDateFunc::new().invoke_with_args(args); + let to_date_result = invoke_to_date_with_args( + vec![ColumnarValue::Array(Arc::new(date_array))], + batch_len, + ); match to_date_result { Ok(ColumnarValue::Array(a)) => { @@ -328,15 +340,13 @@ mod tests { fn test_scalar(sv: ScalarValue, tc: &TestCase) { let format_scalar = ScalarValue::Utf8(Some(tc.format_str.to_string())); - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ + let to_date_result = invoke_to_date_with_args( + vec![ ColumnarValue::Scalar(sv), ColumnarValue::Scalar(format_scalar), ], - number_rows: 1, - return_type: &DataType::Date32, - }; - let to_date_result = ToDateFunc::new().invoke_with_args(args); + 1, + ); match to_date_result { Ok(ColumnarValue::Scalar(ScalarValue::Date32(date_val))) => { @@ -358,15 +368,13 @@ mod tests { let format_array = A::from(vec![tc.format_str]); let batch_len = date_array.len(); - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ + let to_date_result = invoke_to_date_with_args( + vec![ ColumnarValue::Array(Arc::new(date_array)), ColumnarValue::Array(Arc::new(format_array)), ], - number_rows: batch_len, - return_type: &DataType::Date32, - }; - let to_date_result = ToDateFunc::new().invoke_with_args(args); + batch_len, + ); match to_date_result { Ok(ColumnarValue::Array(a)) => { @@ -398,16 +406,14 @@ mod tests { let format1_scalar = ScalarValue::Utf8(Some("%Y-%m-%d".into())); let format2_scalar = ScalarValue::Utf8(Some("%Y/%m/%d".into())); - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ + let to_date_result = invoke_to_date_with_args( + vec![ ColumnarValue::Scalar(formatted_date_scalar), ColumnarValue::Scalar(format1_scalar), ColumnarValue::Scalar(format2_scalar), ], - number_rows: 1, - return_type: &DataType::Date32, - }; - let to_date_result = ToDateFunc::new().invoke_with_args(args); + 1, + ); match to_date_result { Ok(ColumnarValue::Scalar(ScalarValue::Date32(date_val))) => { @@ -431,19 +437,17 @@ mod tests { for date_str in test_cases { let formatted_date_scalar = ScalarValue::Utf8(Some(date_str.into())); - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ColumnarValue::Scalar(formatted_date_scalar)], - number_rows: 1, - return_type: &DataType::Date32, - }; - let to_date_result = ToDateFunc::new().invoke_with_args(args); + let to_date_result = invoke_to_date_with_args( + vec![ColumnarValue::Scalar(formatted_date_scalar)], + 1, + ); match to_date_result { Ok(ColumnarValue::Scalar(ScalarValue::Date32(date_val))) => { let expected = Date32Type::parse_formatted("2020-09-08", "%Y-%m-%d"); assert_eq!(date_val, expected, "to_date created wrong value"); } - _ => panic!("Conversion of {} failed", date_str), + _ => panic!("Conversion of {date_str} failed"), } } } @@ -453,23 +457,18 @@ mod tests { let date_str = "20241231"; let date_scalar = ScalarValue::Utf8(Some(date_str.into())); - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ColumnarValue::Scalar(date_scalar)], - number_rows: 1, - return_type: &DataType::Date32, - }; - let to_date_result = ToDateFunc::new().invoke_with_args(args); + let to_date_result = + invoke_to_date_with_args(vec![ColumnarValue::Scalar(date_scalar)], 1); match to_date_result { Ok(ColumnarValue::Scalar(ScalarValue::Date32(date_val))) => { let expected = Date32Type::parse_formatted("2024-12-31", "%Y-%m-%d"); assert_eq!( date_val, expected, - "to_date created wrong value for {}", - date_str + "to_date created wrong value for {date_str}" ); } - _ => panic!("Conversion of {} failed", date_str), + _ => panic!("Conversion of {date_str} failed"), } } @@ -478,18 +477,11 @@ mod tests { let date_str = "202412311"; let date_scalar = ScalarValue::Utf8(Some(date_str.into())); - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ColumnarValue::Scalar(date_scalar)], - number_rows: 1, - return_type: &DataType::Date32, - }; - let to_date_result = ToDateFunc::new().invoke_with_args(args); + let to_date_result = + invoke_to_date_with_args(vec![ColumnarValue::Scalar(date_scalar)], 1); if let Ok(ColumnarValue::Scalar(ScalarValue::Date32(_))) = to_date_result { - panic!( - "Conversion of {} succeeded, but should have failed, ", - date_str - ); + panic!("Conversion of {date_str} succeeded, but should have failed. "); } } } diff --git a/datafusion/functions/src/datetime/to_local_time.rs b/datafusion/functions/src/datetime/to_local_time.rs index 8dbef90cdc3f..b9ebe537d459 100644 --- a/datafusion/functions/src/datetime/to_local_time.rs +++ b/datafusion/functions/src/datetime/to_local_time.rs @@ -407,9 +407,9 @@ impl ScalarUDFImpl for ToLocalTimeFunc { mod tests { use std::sync::Arc; - use arrow::array::{types::TimestampNanosecondType, TimestampNanosecondArray}; + use arrow::array::{types::TimestampNanosecondType, Array, TimestampNanosecondArray}; use arrow::compute::kernels::cast_utils::string_to_timestamp_nanos; - use arrow::datatypes::{DataType, TimeUnit}; + use arrow::datatypes::{DataType, Field, TimeUnit}; use chrono::NaiveDateTime; use datafusion_common::ScalarValue; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; @@ -538,11 +538,13 @@ mod tests { } fn test_to_local_time_helper(input: ScalarValue, expected: ScalarValue) { + let arg_field = Field::new("a", input.data_type(), true).into(); let res = ToLocalTimeFunc::new() .invoke_with_args(ScalarFunctionArgs { args: vec![ColumnarValue::Scalar(input)], + arg_fields: vec![arg_field], number_rows: 1, - return_type: &expected.data_type(), + return_field: Field::new("f", expected.data_type(), true).into(), }) .unwrap(); match res { @@ -602,10 +604,17 @@ mod tests { .map(|s| Some(string_to_timestamp_nanos(s).unwrap())) .collect::(); let batch_size = input.len(); + let arg_field = Field::new("a", input.data_type().clone(), true).into(); let args = ScalarFunctionArgs { args: vec![ColumnarValue::Array(Arc::new(input))], + arg_fields: vec![arg_field], number_rows: batch_size, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), + return_field: Field::new( + "f", + DataType::Timestamp(TimeUnit::Nanosecond, None), + true, + ) + .into(), }; let result = ToLocalTimeFunc::new().invoke_with_args(args).unwrap(); if let ColumnarValue::Array(result) = result { diff --git a/datafusion/functions/src/datetime/to_timestamp.rs b/datafusion/functions/src/datetime/to_timestamp.rs index 52c86733f332..8b26a1c25950 100644 --- a/datafusion/functions/src/datetime/to_timestamp.rs +++ b/datafusion/functions/src/datetime/to_timestamp.rs @@ -639,7 +639,7 @@ mod tests { TimestampNanosecondArray, TimestampSecondArray, }; use arrow::array::{ArrayRef, Int64Array, StringBuilder}; - use arrow::datatypes::TimeUnit; + use arrow::datatypes::{Field, TimeUnit}; use chrono::Utc; use datafusion_common::{assert_contains, DataFusionError, ScalarValue}; use datafusion_expr::ScalarFunctionImplementation; @@ -1012,11 +1012,13 @@ mod tests { for udf in &udfs { for array in arrays { let rt = udf.return_type(&[array.data_type()]).unwrap(); + let arg_field = Field::new("arg", array.data_type().clone(), true).into(); assert!(matches!(rt, Timestamp(_, Some(_)))); let args = datafusion_expr::ScalarFunctionArgs { args: vec![array.clone()], + arg_fields: vec![arg_field], number_rows: 4, - return_type: &rt, + return_field: Field::new("f", rt, true).into(), }; let res = udf .invoke_with_args(args) @@ -1060,10 +1062,12 @@ mod tests { for array in arrays { let rt = udf.return_type(&[array.data_type()]).unwrap(); assert!(matches!(rt, Timestamp(_, None))); + let arg_field = Field::new("arg", array.data_type().clone(), true).into(); let args = datafusion_expr::ScalarFunctionArgs { args: vec![array.clone()], + arg_fields: vec![arg_field], number_rows: 5, - return_type: &rt, + return_field: Field::new("f", rt, true).into(), }; let res = udf .invoke_with_args(args) diff --git a/datafusion/functions/src/encoding/inner.rs b/datafusion/functions/src/encoding/inner.rs index 51e8c6968866..9a7b49105743 100644 --- a/datafusion/functions/src/encoding/inner.rs +++ b/datafusion/functions/src/encoding/inner.rs @@ -310,7 +310,7 @@ fn hex_decode(input: &[u8], buf: &mut [u8]) -> Result { let out_len = input.len() / 2; let buf = &mut buf[..out_len]; hex::decode_to_slice(input, buf).map_err(|e| { - DataFusionError::Internal(format!("Failed to decode from hex: {}", e)) + DataFusionError::Internal(format!("Failed to decode from hex: {e}")) })?; Ok(out_len) } @@ -319,7 +319,7 @@ fn base64_decode(input: &[u8], buf: &mut [u8]) -> Result { general_purpose::STANDARD_NO_PAD .decode_slice(input, buf) .map_err(|e| { - DataFusionError::Internal(format!("Failed to decode from base64: {}", e)) + DataFusionError::Internal(format!("Failed to decode from base64: {e}")) }) } @@ -419,15 +419,13 @@ impl Encoding { .decode(value) .map_err(|e| { DataFusionError::Internal(format!( - "Failed to decode value using base64: {}", - e + "Failed to decode value using base64: {e}" )) })? } Self::Hex => hex::decode(value).map_err(|e| { DataFusionError::Internal(format!( - "Failed to decode value using hex: {}", - e + "Failed to decode value using hex: {e}" )) })?, }; @@ -447,15 +445,13 @@ impl Encoding { .decode(value) .map_err(|e| { DataFusionError::Internal(format!( - "Failed to decode value using base64: {}", - e + "Failed to decode value using base64: {e}" )) })? } Self::Hex => hex::decode(value).map_err(|e| { DataFusionError::Internal(format!( - "Failed to decode value using hex: {}", - e + "Failed to decode value using hex: {e}" )) })?, }; diff --git a/datafusion/functions/src/lib.rs b/datafusion/functions/src/lib.rs index b65c4c543242..51cd5df8060d 100644 --- a/datafusion/functions/src/lib.rs +++ b/datafusion/functions/src/lib.rs @@ -209,8 +209,7 @@ mod tests { for alias in func.aliases() { assert!( names.insert(alias.to_string().to_lowercase()), - "duplicate function name: {}", - alias + "duplicate function name: {alias}" ); } } diff --git a/datafusion/functions/src/macros.rs b/datafusion/functions/src/macros.rs index d2849c3abba0..30ebf8654ea0 100644 --- a/datafusion/functions/src/macros.rs +++ b/datafusion/functions/src/macros.rs @@ -40,6 +40,7 @@ /// Exported functions accept: /// - `Vec` argument (single argument followed by a comma) /// - Variable number of `Expr` arguments (zero or more arguments, must be without commas) +#[macro_export] macro_rules! export_functions { ($(($FUNC:ident, $DOC:expr, $($arg:tt)*)),*) => { $( @@ -69,6 +70,7 @@ macro_rules! export_functions { /// named `$NAME` which returns that singleton. /// /// This is used to ensure creating the list of `ScalarUDF` only happens once. +#[macro_export] macro_rules! make_udf_function { ($UDF:ty, $NAME:ident) => { #[doc = concat!("Return a [`ScalarUDF`](datafusion_expr::ScalarUDF) implementation of ", stringify!($NAME))] diff --git a/datafusion/functions/src/math/log.rs b/datafusion/functions/src/math/log.rs index fd135f4c5ec0..23e267a323b9 100644 --- a/datafusion/functions/src/math/log.rs +++ b/datafusion/functions/src/math/log.rs @@ -210,7 +210,9 @@ impl ScalarUDFImpl for LogFunc { }; match number { - Expr::Literal(value) if value == ScalarValue::new_one(&number_datatype)? => { + Expr::Literal(value, _) + if value == ScalarValue::new_one(&number_datatype)? => + { Ok(ExprSimplifyResult::Simplified(lit(ScalarValue::new_zero( &info.get_data_type(&base)?, )?))) @@ -256,6 +258,7 @@ mod tests { use arrow::array::{Float32Array, Float64Array, Int64Array}; use arrow::compute::SortOptions; + use arrow::datatypes::Field; use datafusion_common::cast::{as_float32_array, as_float64_array}; use datafusion_common::DFSchema; use datafusion_expr::execution_props::ExecutionProps; @@ -264,6 +267,10 @@ mod tests { #[test] #[should_panic] fn test_log_invalid_base_type() { + let arg_fields = vec![ + Field::new("a", DataType::Float64, false).into(), + Field::new("a", DataType::Int64, false).into(), + ]; let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Array(Arc::new(Float64Array::from(vec![ @@ -271,20 +278,23 @@ mod tests { ]))), // num ColumnarValue::Array(Arc::new(Int64Array::from(vec![5, 10, 15, 20]))), ], + arg_fields, number_rows: 4, - return_type: &DataType::Float64, + return_field: Field::new("f", DataType::Float64, true).into(), }; let _ = LogFunc::new().invoke_with_args(args); } #[test] fn test_log_invalid_value() { + let arg_field = Field::new("a", DataType::Int64, false).into(); let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Array(Arc::new(Int64Array::from(vec![10]))), // num ], + arg_fields: vec![arg_field], number_rows: 1, - return_type: &DataType::Float64, + return_field: Field::new("f", DataType::Float64, true).into(), }; let result = LogFunc::new().invoke_with_args(args); @@ -293,12 +303,14 @@ mod tests { #[test] fn test_log_scalar_f32_unary() { + let arg_field = Field::new("a", DataType::Float32, false).into(); let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Scalar(ScalarValue::Float32(Some(10.0))), // num ], + arg_fields: vec![arg_field], number_rows: 1, - return_type: &DataType::Float32, + return_field: Field::new("f", DataType::Float32, true).into(), }; let result = LogFunc::new() .invoke_with_args(args) @@ -320,12 +332,14 @@ mod tests { #[test] fn test_log_scalar_f64_unary() { + let arg_field = Field::new("a", DataType::Float64, false).into(); let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Scalar(ScalarValue::Float64(Some(10.0))), // num ], + arg_fields: vec![arg_field], number_rows: 1, - return_type: &DataType::Float64, + return_field: Field::new("f", DataType::Float64, true).into(), }; let result = LogFunc::new() .invoke_with_args(args) @@ -347,13 +361,18 @@ mod tests { #[test] fn test_log_scalar_f32() { + let arg_fields = vec![ + Field::new("a", DataType::Float32, false).into(), + Field::new("a", DataType::Float32, false).into(), + ]; let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Scalar(ScalarValue::Float32(Some(2.0))), // num ColumnarValue::Scalar(ScalarValue::Float32(Some(32.0))), // num ], + arg_fields, number_rows: 1, - return_type: &DataType::Float32, + return_field: Field::new("f", DataType::Float32, true).into(), }; let result = LogFunc::new() .invoke_with_args(args) @@ -375,13 +394,18 @@ mod tests { #[test] fn test_log_scalar_f64() { + let arg_fields = vec![ + Field::new("a", DataType::Float64, false).into(), + Field::new("a", DataType::Float64, false).into(), + ]; let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Scalar(ScalarValue::Float64(Some(2.0))), // num ColumnarValue::Scalar(ScalarValue::Float64(Some(64.0))), // num ], + arg_fields, number_rows: 1, - return_type: &DataType::Float64, + return_field: Field::new("f", DataType::Float64, true).into(), }; let result = LogFunc::new() .invoke_with_args(args) @@ -403,14 +427,16 @@ mod tests { #[test] fn test_log_f64_unary() { + let arg_field = Field::new("a", DataType::Float64, false).into(); let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Array(Arc::new(Float64Array::from(vec![ 10.0, 100.0, 1000.0, 10000.0, ]))), // num ], + arg_fields: vec![arg_field], number_rows: 4, - return_type: &DataType::Float64, + return_field: Field::new("f", DataType::Float64, true).into(), }; let result = LogFunc::new() .invoke_with_args(args) @@ -435,14 +461,16 @@ mod tests { #[test] fn test_log_f32_unary() { + let arg_field = Field::new("a", DataType::Float32, false).into(); let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Array(Arc::new(Float32Array::from(vec![ 10.0, 100.0, 1000.0, 10000.0, ]))), // num ], + arg_fields: vec![arg_field], number_rows: 4, - return_type: &DataType::Float32, + return_field: Field::new("f", DataType::Float32, true).into(), }; let result = LogFunc::new() .invoke_with_args(args) @@ -467,6 +495,10 @@ mod tests { #[test] fn test_log_f64() { + let arg_fields = vec![ + Field::new("a", DataType::Float64, false).into(), + Field::new("a", DataType::Float64, false).into(), + ]; let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Array(Arc::new(Float64Array::from(vec![ @@ -476,8 +508,9 @@ mod tests { 8.0, 4.0, 81.0, 625.0, ]))), // num ], + arg_fields, number_rows: 4, - return_type: &DataType::Float64, + return_field: Field::new("f", DataType::Float64, true).into(), }; let result = LogFunc::new() .invoke_with_args(args) @@ -502,6 +535,10 @@ mod tests { #[test] fn test_log_f32() { + let arg_fields = vec![ + Field::new("a", DataType::Float32, false).into(), + Field::new("a", DataType::Float32, false).into(), + ]; let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Array(Arc::new(Float32Array::from(vec![ @@ -511,8 +548,9 @@ mod tests { 8.0, 4.0, 81.0, 625.0, ]))), // num ], + arg_fields, number_rows: 4, - return_type: &DataType::Float32, + return_field: Field::new("f", DataType::Float32, true).into(), }; let result = LogFunc::new() .invoke_with_args(args) diff --git a/datafusion/functions/src/math/power.rs b/datafusion/functions/src/math/power.rs index 028ec2fef793..465844704f59 100644 --- a/datafusion/functions/src/math/power.rs +++ b/datafusion/functions/src/math/power.rs @@ -156,12 +156,15 @@ impl ScalarUDFImpl for PowerFunc { let exponent_type = info.get_data_type(&exponent)?; match exponent { - Expr::Literal(value) if value == ScalarValue::new_zero(&exponent_type)? => { + Expr::Literal(value, _) + if value == ScalarValue::new_zero(&exponent_type)? => + { Ok(ExprSimplifyResult::Simplified(Expr::Literal( ScalarValue::new_one(&info.get_data_type(&base)?)?, + None, ))) } - Expr::Literal(value) if value == ScalarValue::new_one(&exponent_type)? => { + Expr::Literal(value, _) if value == ScalarValue::new_one(&exponent_type)? => { Ok(ExprSimplifyResult::Simplified(base)) } Expr::ScalarFunction(ScalarFunction { func, mut args }) @@ -187,12 +190,17 @@ fn is_log(func: &ScalarUDF) -> bool { #[cfg(test)] mod tests { use arrow::array::Float64Array; + use arrow::datatypes::Field; use datafusion_common::cast::{as_float64_array, as_int64_array}; use super::*; #[test] fn test_power_f64() { + let arg_fields = vec![ + Field::new("a", DataType::Float64, true).into(), + Field::new("a", DataType::Float64, true).into(), + ]; let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Array(Arc::new(Float64Array::from(vec![ @@ -202,8 +210,9 @@ mod tests { 3.0, 2.0, 4.0, 4.0, ]))), // exponent ], + arg_fields, number_rows: 4, - return_type: &DataType::Float64, + return_field: Field::new("f", DataType::Float64, true).into(), }; let result = PowerFunc::new() .invoke_with_args(args) @@ -227,13 +236,18 @@ mod tests { #[test] fn test_power_i64() { + let arg_fields = vec![ + Field::new("a", DataType::Int64, true).into(), + Field::new("a", DataType::Int64, true).into(), + ]; let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Array(Arc::new(Int64Array::from(vec![2, 2, 3, 5]))), // base ColumnarValue::Array(Arc::new(Int64Array::from(vec![3, 2, 4, 4]))), // exponent ], + arg_fields, number_rows: 4, - return_type: &DataType::Int64, + return_field: Field::new("f", DataType::Int64, true).into(), }; let result = PowerFunc::new() .invoke_with_args(args) diff --git a/datafusion/functions/src/math/random.rs b/datafusion/functions/src/math/random.rs index 607f9fb09f2a..92b6ed1895ed 100644 --- a/datafusion/functions/src/math/random.rs +++ b/datafusion/functions/src/math/random.rs @@ -21,7 +21,7 @@ use std::sync::Arc; use arrow::array::Float64Array; use arrow::datatypes::DataType; use arrow::datatypes::DataType::Float64; -use rand::{thread_rng, Rng}; +use rand::{rng, Rng}; use datafusion_common::{internal_err, Result}; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; @@ -74,9 +74,9 @@ impl ScalarUDFImpl for RandomFunc { if !args.args.is_empty() { return internal_err!("{} function does not accept arguments", self.name()); } - let mut rng = thread_rng(); + let mut rng = rng(); let mut values = vec![0.0; args.number_rows]; - // Equivalent to set each element with rng.gen_range(0.0..1.0), but more efficient + // Equivalent to set each element with rng.random_range(0.0..1.0), but more efficient rng.fill(&mut values[..]); let array = Float64Array::from(values); diff --git a/datafusion/functions/src/math/signum.rs b/datafusion/functions/src/math/signum.rs index ba5422afa768..ec6ef5a78c6a 100644 --- a/datafusion/functions/src/math/signum.rs +++ b/datafusion/functions/src/math/signum.rs @@ -138,7 +138,7 @@ mod test { use std::sync::Arc; use arrow::array::{ArrayRef, Float32Array, Float64Array}; - use arrow::datatypes::DataType; + use arrow::datatypes::{DataType, Field}; use datafusion_common::cast::{as_float32_array, as_float64_array}; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; @@ -157,10 +157,12 @@ mod test { f32::INFINITY, f32::NEG_INFINITY, ])); + let arg_fields = vec![Field::new("a", DataType::Float32, false).into()]; let args = ScalarFunctionArgs { args: vec![ColumnarValue::Array(Arc::clone(&array) as ArrayRef)], + arg_fields, number_rows: array.len(), - return_type: &DataType::Float32, + return_field: Field::new("f", DataType::Float32, true).into(), }; let result = SignumFunc::new() .invoke_with_args(args) @@ -201,10 +203,12 @@ mod test { f64::INFINITY, f64::NEG_INFINITY, ])); + let arg_fields = vec![Field::new("a", DataType::Float64, false).into()]; let args = ScalarFunctionArgs { args: vec![ColumnarValue::Array(Arc::clone(&array) as ArrayRef)], + arg_fields, number_rows: array.len(), - return_type: &DataType::Float64, + return_field: Field::new("f", DataType::Float64, true).into(), }; let result = SignumFunc::new() .invoke_with_args(args) diff --git a/datafusion/functions/src/regex/regexpcount.rs b/datafusion/functions/src/regex/regexpcount.rs index 8cb1a4ff3d60..52ab3d489ee3 100644 --- a/datafusion/functions/src/regex/regexpcount.rs +++ b/datafusion/functions/src/regex/regexpcount.rs @@ -577,15 +577,12 @@ fn compile_regex(regex: &str, flags: Option<&str>) -> Result "regexp_count() does not support global flag".to_string(), )); } - format!("(?{}){}", flags, regex) + format!("(?{flags}){regex}") } }; Regex::new(&pattern).map_err(|_| { - ArrowError::ComputeError(format!( - "Regular expression did not compile: {}", - pattern - )) + ArrowError::ComputeError(format!("Regular expression did not compile: {pattern}")) }) } @@ -619,6 +616,7 @@ fn count_matches( mod tests { use super::*; use arrow::array::{GenericStringArray, StringViewArray}; + use arrow::datatypes::Field; use datafusion_expr::ScalarFunctionArgs; #[test] @@ -647,6 +645,26 @@ mod tests { test_case_regexp_count_cache_check::>(); } + fn regexp_count_with_scalar_values(args: &[ScalarValue]) -> Result { + let args_values = args + .iter() + .map(|sv| ColumnarValue::Scalar(sv.clone())) + .collect(); + + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, a)| Field::new(format!("arg_{idx}"), a.data_type(), true).into()) + .collect::>(); + + RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { + args: args_values, + arg_fields, + number_rows: args.len(), + return_field: Field::new("f", Int64, true).into(), + }) + } + fn test_case_sensitive_regexp_count_scalar() { let values = ["", "aabca", "abcabc", "abcAbcab", "abcabcabc"]; let regex = "abc"; @@ -657,11 +675,7 @@ mod tests { let v_sv = ScalarValue::Utf8(Some(v.to_string())); let regex_sv = ScalarValue::Utf8(Some(regex.to_string())); let expected = expected.get(pos).cloned(); - let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { - args: vec![ColumnarValue::Scalar(v_sv), ColumnarValue::Scalar(regex_sv)], - number_rows: 2, - return_type: &Int64, - }); + let re = regexp_count_with_scalar_values(&[v_sv, regex_sv]); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -672,11 +686,7 @@ mod tests { // largeutf8 let v_sv = ScalarValue::LargeUtf8(Some(v.to_string())); let regex_sv = ScalarValue::LargeUtf8(Some(regex.to_string())); - let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { - args: vec![ColumnarValue::Scalar(v_sv), ColumnarValue::Scalar(regex_sv)], - number_rows: 2, - return_type: &Int64, - }); + let re = regexp_count_with_scalar_values(&[v_sv, regex_sv]); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -687,11 +697,7 @@ mod tests { // utf8view let v_sv = ScalarValue::Utf8View(Some(v.to_string())); let regex_sv = ScalarValue::Utf8View(Some(regex.to_string())); - let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { - args: vec![ColumnarValue::Scalar(v_sv), ColumnarValue::Scalar(regex_sv)], - number_rows: 2, - return_type: &Int64, - }); + let re = regexp_count_with_scalar_values(&[v_sv, regex_sv]); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -713,15 +719,7 @@ mod tests { let regex_sv = ScalarValue::Utf8(Some(regex.to_string())); let start_sv = ScalarValue::Int64(Some(start)); let expected = expected.get(pos).cloned(); - let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(v_sv), - ColumnarValue::Scalar(regex_sv), - ColumnarValue::Scalar(start_sv.clone()), - ], - number_rows: 3, - return_type: &Int64, - }); + let re = regexp_count_with_scalar_values(&[v_sv, regex_sv, start_sv.clone()]); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -732,15 +730,7 @@ mod tests { // largeutf8 let v_sv = ScalarValue::LargeUtf8(Some(v.to_string())); let regex_sv = ScalarValue::LargeUtf8(Some(regex.to_string())); - let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(v_sv), - ColumnarValue::Scalar(regex_sv), - ColumnarValue::Scalar(start_sv.clone()), - ], - number_rows: 3, - return_type: &Int64, - }); + let re = regexp_count_with_scalar_values(&[v_sv, regex_sv, start_sv.clone()]); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -751,15 +741,7 @@ mod tests { // utf8view let v_sv = ScalarValue::Utf8View(Some(v.to_string())); let regex_sv = ScalarValue::Utf8View(Some(regex.to_string())); - let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(v_sv), - ColumnarValue::Scalar(regex_sv), - ColumnarValue::Scalar(start_sv.clone()), - ], - number_rows: 3, - return_type: &Int64, - }); + let re = regexp_count_with_scalar_values(&[v_sv, regex_sv, start_sv.clone()]); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -783,16 +765,13 @@ mod tests { let start_sv = ScalarValue::Int64(Some(start)); let flags_sv = ScalarValue::Utf8(Some(flags.to_string())); let expected = expected.get(pos).cloned(); - let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(v_sv), - ColumnarValue::Scalar(regex_sv), - ColumnarValue::Scalar(start_sv.clone()), - ColumnarValue::Scalar(flags_sv.clone()), - ], - number_rows: 4, - return_type: &Int64, - }); + + let re = regexp_count_with_scalar_values(&[ + v_sv, + regex_sv, + start_sv.clone(), + flags_sv.clone(), + ]); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -804,16 +783,13 @@ mod tests { let v_sv = ScalarValue::LargeUtf8(Some(v.to_string())); let regex_sv = ScalarValue::LargeUtf8(Some(regex.to_string())); let flags_sv = ScalarValue::LargeUtf8(Some(flags.to_string())); - let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(v_sv), - ColumnarValue::Scalar(regex_sv), - ColumnarValue::Scalar(start_sv.clone()), - ColumnarValue::Scalar(flags_sv.clone()), - ], - number_rows: 4, - return_type: &Int64, - }); + + let re = regexp_count_with_scalar_values(&[ + v_sv, + regex_sv, + start_sv.clone(), + flags_sv.clone(), + ]); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -825,16 +801,13 @@ mod tests { let v_sv = ScalarValue::Utf8View(Some(v.to_string())); let regex_sv = ScalarValue::Utf8View(Some(regex.to_string())); let flags_sv = ScalarValue::Utf8View(Some(flags.to_string())); - let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(v_sv), - ColumnarValue::Scalar(regex_sv), - ColumnarValue::Scalar(start_sv.clone()), - ColumnarValue::Scalar(flags_sv.clone()), - ], - number_rows: 4, - return_type: &Int64, - }); + + let re = regexp_count_with_scalar_values(&[ + v_sv, + regex_sv, + start_sv.clone(), + flags_sv.clone(), + ]); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -907,16 +880,12 @@ mod tests { let start_sv = ScalarValue::Int64(Some(start)); let flags_sv = ScalarValue::Utf8(flags.get(pos).map(|f| f.to_string())); let expected = expected.get(pos).cloned(); - let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(v_sv), - ColumnarValue::Scalar(regex_sv), - ColumnarValue::Scalar(start_sv.clone()), - ColumnarValue::Scalar(flags_sv.clone()), - ], - number_rows: 4, - return_type: &Int64, - }); + let re = regexp_count_with_scalar_values(&[ + v_sv, + regex_sv, + start_sv.clone(), + flags_sv.clone(), + ]); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -928,16 +897,12 @@ mod tests { let v_sv = ScalarValue::LargeUtf8(Some(v.to_string())); let regex_sv = ScalarValue::LargeUtf8(regex.get(pos).map(|s| s.to_string())); let flags_sv = ScalarValue::LargeUtf8(flags.get(pos).map(|f| f.to_string())); - let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(v_sv), - ColumnarValue::Scalar(regex_sv), - ColumnarValue::Scalar(start_sv.clone()), - ColumnarValue::Scalar(flags_sv.clone()), - ], - number_rows: 4, - return_type: &Int64, - }); + let re = regexp_count_with_scalar_values(&[ + v_sv, + regex_sv, + start_sv.clone(), + flags_sv.clone(), + ]); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -949,16 +914,12 @@ mod tests { let v_sv = ScalarValue::Utf8View(Some(v.to_string())); let regex_sv = ScalarValue::Utf8View(regex.get(pos).map(|s| s.to_string())); let flags_sv = ScalarValue::Utf8View(flags.get(pos).map(|f| f.to_string())); - let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(v_sv), - ColumnarValue::Scalar(regex_sv), - ColumnarValue::Scalar(start_sv.clone()), - ColumnarValue::Scalar(flags_sv.clone()), - ], - number_rows: 4, - return_type: &Int64, - }); + let re = regexp_count_with_scalar_values(&[ + v_sv, + regex_sv, + start_sv.clone(), + flags_sv.clone(), + ]); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); diff --git a/datafusion/functions/src/string/ascii.rs b/datafusion/functions/src/string/ascii.rs index 006492a0e07a..63c987906b0f 100644 --- a/datafusion/functions/src/string/ascii.rs +++ b/datafusion/functions/src/string/ascii.rs @@ -16,7 +16,7 @@ // under the License. use crate::utils::make_scalar_function; -use arrow::array::{ArrayAccessor, ArrayIter, ArrayRef, AsArray, Int32Array}; +use arrow::array::{ArrayRef, AsArray, Int32Array, StringArrayType}; use arrow::datatypes::DataType; use arrow::error::ArrowError; use datafusion_common::types::logical_string; @@ -103,19 +103,22 @@ impl ScalarUDFImpl for AsciiFunc { fn calculate_ascii<'a, V>(array: V) -> Result where - V: ArrayAccessor, + V: StringArrayType<'a, Item = &'a str>, { - let iter = ArrayIter::new(array); - let result = iter - .map(|string| { - string.map(|s| { - let mut chars = s.chars(); - chars.next().map_or(0, |v| v as i32) - }) + let values: Vec<_> = (0..array.len()) + .map(|i| { + if array.is_null(i) { + 0 + } else { + let s = array.value(i); + s.chars().next().map_or(0, |c| c as i32) + } }) - .collect::(); + .collect(); - Ok(Arc::new(result) as ArrayRef) + let array = Int32Array::new(values.into(), array.nulls().cloned()); + + Ok(Arc::new(array)) } /// Returns the numeric code of the first character of the argument. @@ -182,6 +185,7 @@ mod tests { test_ascii!(Some(String::from("x")), Ok(Some(120))); test_ascii!(Some(String::from("a")), Ok(Some(97))); test_ascii!(Some(String::from("")), Ok(Some(0))); + test_ascii!(Some(String::from("🚀")), Ok(Some(128640))); test_ascii!(None, Ok(None)); Ok(()) } diff --git a/datafusion/functions/src/string/concat.rs b/datafusion/functions/src/string/concat.rs index c47d08d579e4..64a527eac198 100644 --- a/datafusion/functions/src/string/concat.rs +++ b/datafusion/functions/src/string/concat.rs @@ -295,7 +295,7 @@ pub fn simplify_concat(args: Vec) -> Result { let data_types: Vec<_> = args .iter() .filter_map(|expr| match expr { - Expr::Literal(l) => Some(l.data_type()), + Expr::Literal(l, _) => Some(l.data_type()), _ => None, }) .collect(); @@ -304,25 +304,25 @@ pub fn simplify_concat(args: Vec) -> Result { for arg in args.clone() { match arg { - Expr::Literal(ScalarValue::Utf8(None)) => {} - Expr::Literal(ScalarValue::LargeUtf8(None)) => { + Expr::Literal(ScalarValue::Utf8(None), _) => {} + Expr::Literal(ScalarValue::LargeUtf8(None), _) => { } - Expr::Literal(ScalarValue::Utf8View(None)) => { } + Expr::Literal(ScalarValue::Utf8View(None), _) => { } // filter out `null` args // All literals have been converted to Utf8 or LargeUtf8 in type_coercion. // Concatenate it with the `contiguous_scalar`. - Expr::Literal(ScalarValue::Utf8(Some(v))) => { + Expr::Literal(ScalarValue::Utf8(Some(v)), _) => { contiguous_scalar += &v; } - Expr::Literal(ScalarValue::LargeUtf8(Some(v))) => { + Expr::Literal(ScalarValue::LargeUtf8(Some(v)), _) => { contiguous_scalar += &v; } - Expr::Literal(ScalarValue::Utf8View(Some(v))) => { + Expr::Literal(ScalarValue::Utf8View(Some(v)), _) => { contiguous_scalar += &v; } - Expr::Literal(x) => { + Expr::Literal(x, _) => { return internal_err!( "The scalar {x} should be casted to string type during the type coercion." ) @@ -376,6 +376,7 @@ mod tests { use crate::utils::test::test_function; use arrow::array::{Array, LargeStringArray, StringViewArray}; use arrow::array::{ArrayRef, StringArray}; + use arrow::datatypes::Field; use DataType::*; #[test] @@ -468,11 +469,22 @@ mod tests { None, Some("b"), ]))); + let arg_fields = vec![ + Field::new("a", Utf8, true), + Field::new("a", Utf8, true), + Field::new("a", Utf8, true), + Field::new("a", Utf8View, true), + Field::new("a", Utf8View, true), + ] + .into_iter() + .map(Arc::new) + .collect::>(); let args = ScalarFunctionArgs { args: vec![c0, c1, c2, c3, c4], + arg_fields, number_rows: 3, - return_type: &Utf8, + return_field: Field::new("f", Utf8, true).into(), }; let result = ConcatFunc::new().invoke_with_args(args)?; diff --git a/datafusion/functions/src/string/concat_ws.rs b/datafusion/functions/src/string/concat_ws.rs index c2bad206db15..1f45f8501e1f 100644 --- a/datafusion/functions/src/string/concat_ws.rs +++ b/datafusion/functions/src/string/concat_ws.rs @@ -312,6 +312,7 @@ fn simplify_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result { match delimiter { // when the delimiter is an empty string, @@ -336,8 +337,8 @@ fn simplify_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result {} - Expr::Literal(ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v)) | ScalarValue::Utf8View(Some(v))) => { + Expr::Literal(ScalarValue::Utf8(None) | ScalarValue::LargeUtf8(None) | ScalarValue::Utf8View(None), _) => {} + Expr::Literal(ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v)) | ScalarValue::Utf8View(Some(v)), _) => { match contiguous_scalar { None => contiguous_scalar = Some(v.to_string()), Some(mut pre) => { @@ -347,7 +348,7 @@ fn simplify_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result return internal_err!("The scalar {s} should be casted to string type during the type coercion."), + Expr::Literal(s, _) => return internal_err!("The scalar {s} should be casted to string type during the type coercion."), // If the arg is not a literal, we should first push the current `contiguous_scalar` // to the `new_args` and reset it to None. // Then pushing this arg to the `new_args`. @@ -374,10 +375,11 @@ fn simplify_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result Ok(ExprSimplifyResult::Simplified(Expr::Literal( ScalarValue::Utf8(None), + None, ))), } } - Expr::Literal(d) => internal_err!( + Expr::Literal(d, _) => internal_err!( "The scalar {d} should be casted to string type during the type coercion." ), _ => { @@ -394,7 +396,7 @@ fn simplify_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result bool { match expr { - Expr::Literal(v) => v.is_null(), + Expr::Literal(v, _) => v.is_null(), _ => false, } } @@ -403,10 +405,10 @@ fn is_null(expr: &Expr) -> bool { mod tests { use std::sync::Arc; + use crate::string::concat_ws::ConcatWsFunc; use arrow::array::{Array, ArrayRef, StringArray}; use arrow::datatypes::DataType::Utf8; - - use crate::string::concat_ws::ConcatWsFunc; + use arrow::datatypes::Field; use datafusion_common::Result; use datafusion_common::ScalarValue; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; @@ -481,10 +483,16 @@ mod tests { Some("z"), ]))); + let arg_fields = vec![ + Field::new("a", Utf8, true).into(), + Field::new("a", Utf8, true).into(), + Field::new("a", Utf8, true).into(), + ]; let args = ScalarFunctionArgs { args: vec![c0, c1, c2], + arg_fields, number_rows: 3, - return_type: &Utf8, + return_field: Field::new("f", Utf8, true).into(), }; let result = ConcatWsFunc::new().invoke_with_args(args)?; @@ -511,10 +519,16 @@ mod tests { Some("z"), ]))); + let arg_fields = vec![ + Field::new("a", Utf8, true).into(), + Field::new("a", Utf8, true).into(), + Field::new("a", Utf8, true).into(), + ]; let args = ScalarFunctionArgs { args: vec![c0, c1, c2], + arg_fields, number_rows: 3, - return_type: &Utf8, + return_field: Field::new("f", Utf8, true).into(), }; let result = ConcatWsFunc::new().invoke_with_args(args)?; diff --git a/datafusion/functions/src/string/contains.rs b/datafusion/functions/src/string/contains.rs index 05a3edf61c5a..215f8f7a25b9 100644 --- a/datafusion/functions/src/string/contains.rs +++ b/datafusion/functions/src/string/contains.rs @@ -150,10 +150,11 @@ fn contains(args: &[ArrayRef]) -> Result { #[cfg(test)] mod test { use super::ContainsFunc; + use crate::expr_fn::contains; use arrow::array::{BooleanArray, StringArray}; - use arrow::datatypes::DataType; + use arrow::datatypes::{DataType, Field}; use datafusion_common::ScalarValue; - use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; + use datafusion_expr::{ColumnarValue, Expr, ScalarFunctionArgs, ScalarUDFImpl}; use std::sync::Arc; #[test] @@ -164,11 +165,16 @@ mod test { Some("yyy?()"), ]))); let scalar = ColumnarValue::Scalar(ScalarValue::Utf8(Some("x?(".to_string()))); + let arg_fields = vec![ + Field::new("a", DataType::Utf8, true).into(), + Field::new("a", DataType::Utf8, true).into(), + ]; let args = ScalarFunctionArgs { args: vec![array, scalar], + arg_fields, number_rows: 2, - return_type: &DataType::Boolean, + return_field: Field::new("f", DataType::Boolean, true).into(), }; let actual = udf.invoke_with_args(args).unwrap(); @@ -181,4 +187,19 @@ mod test { *expect.into_array(2).unwrap() ); } + + #[test] + fn test_contains_api() { + let expr = contains( + Expr::Literal( + ScalarValue::Utf8(Some("the quick brown fox".to_string())), + None, + ), + Expr::Literal(ScalarValue::Utf8(Some("row".to_string())), None), + ); + assert_eq!( + expr.to_string(), + "contains(Utf8(\"the quick brown fox\"), Utf8(\"row\"))" + ); + } } diff --git a/datafusion/functions/src/string/lower.rs b/datafusion/functions/src/string/lower.rs index 226275b13999..536c29a7cb25 100644 --- a/datafusion/functions/src/string/lower.rs +++ b/datafusion/functions/src/string/lower.rs @@ -98,15 +98,19 @@ impl ScalarUDFImpl for LowerFunc { mod tests { use super::*; use arrow::array::{Array, ArrayRef, StringArray}; + use arrow::datatypes::DataType::Utf8; + use arrow::datatypes::Field; use std::sync::Arc; fn to_lower(input: ArrayRef, expected: ArrayRef) -> Result<()> { let func = LowerFunc::new(); + let arg_fields = vec![Field::new("a", input.data_type().clone(), true).into()]; let args = ScalarFunctionArgs { number_rows: input.len(), args: vec![ColumnarValue::Array(input)], - return_type: &DataType::Utf8, + arg_fields, + return_field: Field::new("f", Utf8, true).into(), }; let result = match func.invoke_with_args(args)? { diff --git a/datafusion/functions/src/string/mod.rs b/datafusion/functions/src/string/mod.rs index 4c59e2644456..b4a026db9f89 100644 --- a/datafusion/functions/src/string/mod.rs +++ b/datafusion/functions/src/string/mod.rs @@ -140,7 +140,8 @@ pub mod expr_fn { "returns uuid v4 as a string value", ), ( contains, - "Return true if search_string is found within string.", + "Return true if `search_string` is found within `string`.", + string search_string )); #[doc = "Removes all characters, spaces by default, from both sides of a string"] diff --git a/datafusion/functions/src/string/starts_with.rs b/datafusion/functions/src/string/starts_with.rs index 71df83352f96..ecab1af132e0 100644 --- a/datafusion/functions/src/string/starts_with.rs +++ b/datafusion/functions/src/string/starts_with.rs @@ -130,7 +130,7 @@ impl ScalarUDFImpl for StartsWithFunc { args: Vec, info: &dyn SimplifyInfo, ) -> Result { - if let Expr::Literal(scalar_value) = &args[1] { + if let Expr::Literal(scalar_value, _) = &args[1] { // Convert starts_with(col, 'prefix') to col LIKE 'prefix%' with proper escaping // Example: starts_with(col, 'ja%') -> col LIKE 'ja\%%' // 1. 'ja%' (input pattern) @@ -141,8 +141,8 @@ impl ScalarUDFImpl for StartsWithFunc { | ScalarValue::LargeUtf8(Some(pattern)) | ScalarValue::Utf8View(Some(pattern)) => { let escaped_pattern = pattern.replace("%", "\\%"); - let like_pattern = format!("{}%", escaped_pattern); - Expr::Literal(ScalarValue::Utf8(Some(like_pattern))) + let like_pattern = format!("{escaped_pattern}%"); + Expr::Literal(ScalarValue::Utf8(Some(like_pattern)), None) } _ => return Ok(ExprSimplifyResult::Original(args)), }; diff --git a/datafusion/functions/src/string/upper.rs b/datafusion/functions/src/string/upper.rs index 2fec7305d183..882fb45eda4a 100644 --- a/datafusion/functions/src/string/upper.rs +++ b/datafusion/functions/src/string/upper.rs @@ -97,15 +97,19 @@ impl ScalarUDFImpl for UpperFunc { mod tests { use super::*; use arrow::array::{Array, ArrayRef, StringArray}; + use arrow::datatypes::DataType::Utf8; + use arrow::datatypes::Field; use std::sync::Arc; fn to_upper(input: ArrayRef, expected: ArrayRef) -> Result<()> { let func = UpperFunc::new(); + let arg_field = Field::new("a", input.data_type().clone(), true).into(); let args = ScalarFunctionArgs { number_rows: input.len(), args: vec![ColumnarValue::Array(input)], - return_type: &DataType::Utf8, + arg_fields: vec![arg_field], + return_field: Field::new("f", Utf8, true).into(), }; let result = match func.invoke_with_args(args)? { diff --git a/datafusion/functions/src/string/uuid.rs b/datafusion/functions/src/string/uuid.rs index d1f43d548066..29415a9b2080 100644 --- a/datafusion/functions/src/string/uuid.rs +++ b/datafusion/functions/src/string/uuid.rs @@ -86,7 +86,7 @@ impl ScalarUDFImpl for UuidFunc { } // Generate random u128 values - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let mut randoms = vec![0u128; args.number_rows]; rng.fill(&mut randoms[..]); diff --git a/datafusion/functions/src/unicode/character_length.rs b/datafusion/functions/src/unicode/character_length.rs index c2db253dc741..4ee5995f0a6b 100644 --- a/datafusion/functions/src/unicode/character_length.rs +++ b/datafusion/functions/src/unicode/character_length.rs @@ -17,7 +17,7 @@ use crate::utils::{make_scalar_function, utf8_to_int_type}; use arrow::array::{ - Array, ArrayRef, ArrowPrimitiveType, AsArray, OffsetSizeTrait, PrimitiveBuilder, + Array, ArrayRef, ArrowPrimitiveType, AsArray, OffsetSizeTrait, PrimitiveArray, StringArrayType, }; use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type}; @@ -131,46 +131,64 @@ where T::Native: OffsetSizeTrait, V: StringArrayType<'a>, { - let mut builder = PrimitiveBuilder::::with_capacity(array.len()); - // String characters are variable length encoded in UTF-8, counting the // number of chars requires expensive decoding, however checking if the // string is ASCII only is relatively cheap. // If strings are ASCII only, count bytes instead. let is_array_ascii_only = array.is_ascii(); - if array.null_count() == 0 { + let array = if array.null_count() == 0 { if is_array_ascii_only { - for i in 0..array.len() { - let value = array.value(i); - builder.append_value(T::Native::usize_as(value.len())); - } + let values: Vec<_> = (0..array.len()) + .map(|i| { + let value = array.value(i); + T::Native::usize_as(value.len()) + }) + .collect(); + PrimitiveArray::::new(values.into(), None) } else { - for i in 0..array.len() { - let value = array.value(i); - builder.append_value(T::Native::usize_as(value.chars().count())); - } + let values: Vec<_> = (0..array.len()) + .map(|i| { + let value = array.value(i); + if value.is_ascii() { + T::Native::usize_as(value.len()) + } else { + T::Native::usize_as(value.chars().count()) + } + }) + .collect(); + PrimitiveArray::::new(values.into(), None) } } else if is_array_ascii_only { - for i in 0..array.len() { - if array.is_null(i) { - builder.append_null(); - } else { - let value = array.value(i); - builder.append_value(T::Native::usize_as(value.len())); - } - } + let values: Vec<_> = (0..array.len()) + .map(|i| { + if array.is_null(i) { + T::default_value() + } else { + let value = array.value(i); + T::Native::usize_as(value.len()) + } + }) + .collect(); + PrimitiveArray::::new(values.into(), array.nulls().cloned()) } else { - for i in 0..array.len() { - if array.is_null(i) { - builder.append_null(); - } else { - let value = array.value(i); - builder.append_value(T::Native::usize_as(value.chars().count())); - } - } - } + let values: Vec<_> = (0..array.len()) + .map(|i| { + if array.is_null(i) { + T::default_value() + } else { + let value = array.value(i); + if value.is_ascii() { + T::Native::usize_as(value.len()) + } else { + T::Native::usize_as(value.chars().count()) + } + } + }) + .collect(); + PrimitiveArray::::new(values.into(), array.nulls().cloned()) + }; - Ok(Arc::new(builder.finish()) as ArrayRef) + Ok(Arc::new(array)) } #[cfg(test)] diff --git a/datafusion/functions/src/unicode/find_in_set.rs b/datafusion/functions/src/unicode/find_in_set.rs index c4a9f067e9f4..8b00c7be1ccf 100644 --- a/datafusion/functions/src/unicode/find_in_set.rs +++ b/datafusion/functions/src/unicode/find_in_set.rs @@ -348,7 +348,7 @@ mod tests { use crate::unicode::find_in_set::FindInSetFunc; use crate::utils::test::test_function; use arrow::array::{Array, Int32Array, StringArray}; - use arrow::datatypes::DataType::Int32; + use arrow::datatypes::{DataType::Int32, Field}; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; use std::sync::Arc; @@ -471,10 +471,18 @@ mod tests { }) .unwrap_or(1); let return_type = fis.return_type(&type_array)?; + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, a)| { + Field::new(format!("arg_{idx}"), a.data_type(), true).into() + }) + .collect::>(); let result = fis.invoke_with_args(ScalarFunctionArgs { args, + arg_fields, number_rows: cardinality, - return_type: &return_type, + return_field: Field::new("f", return_type, true).into(), }); assert!(result.is_ok()); diff --git a/datafusion/functions/src/unicode/strpos.rs b/datafusion/functions/src/unicode/strpos.rs index b3bc73a29585..1c81b46ec78e 100644 --- a/datafusion/functions/src/unicode/strpos.rs +++ b/datafusion/functions/src/unicode/strpos.rs @@ -22,7 +22,9 @@ use crate::utils::{make_scalar_function, utf8_to_int_type}; use arrow::array::{ ArrayRef, ArrowPrimitiveType, AsArray, PrimitiveArray, StringArrayType, }; -use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type}; +use arrow::datatypes::{ + ArrowNativeType, DataType, Field, FieldRef, Int32Type, Int64Type, +}; use datafusion_common::types::logical_string; use datafusion_common::{exec_err, internal_err, Result}; use datafusion_expr::{ @@ -88,16 +90,23 @@ impl ScalarUDFImpl for StrposFunc { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - internal_err!("return_type_from_args should be used instead") + internal_err!("return_field_from_args should be used instead") } - fn return_type_from_args( + fn return_field_from_args( &self, - args: datafusion_expr::ReturnTypeArgs, - ) -> Result { - utf8_to_int_type(&args.arg_types[0], "strpos/instr/position").map(|data_type| { - datafusion_expr::ReturnInfo::new(data_type, args.nullables.iter().any(|x| *x)) - }) + args: datafusion_expr::ReturnFieldArgs, + ) -> Result { + utf8_to_int_type(args.arg_fields[0].data_type(), "strpos/instr/position").map( + |data_type| { + Field::new( + self.name(), + data_type, + args.arg_fields.iter().any(|x| x.is_nullable()), + ) + .into() + }, + ) } fn invoke_with_args( @@ -228,7 +237,7 @@ mod tests { use arrow::array::{Array, Int32Array, Int64Array}; use arrow::datatypes::DataType::{Int32, Int64}; - use arrow::datatypes::DataType; + use arrow::datatypes::{DataType, Field}; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; @@ -321,15 +330,15 @@ mod tests { fn nullable_return_type() { fn get_nullable(string_array_nullable: bool, substring_nullable: bool) -> bool { let strpos = StrposFunc::new(); - let args = datafusion_expr::ReturnTypeArgs { - arg_types: &[DataType::Utf8, DataType::Utf8], - nullables: &[string_array_nullable, substring_nullable], + let args = datafusion_expr::ReturnFieldArgs { + arg_fields: &[ + Field::new("f1", DataType::Utf8, string_array_nullable).into(), + Field::new("f2", DataType::Utf8, substring_nullable).into(), + ], scalar_arguments: &[None::<&ScalarValue>, None::<&ScalarValue>], }; - let (_, nullable) = strpos.return_type_from_args(args).unwrap().into_parts(); - - nullable + strpos.return_field_from_args(args).unwrap().is_nullable() } assert!(!get_nullable(false, false)); diff --git a/datafusion/functions/src/utils.rs b/datafusion/functions/src/utils.rs index 47f3121ba2ce..583ff48bff39 100644 --- a/datafusion/functions/src/utils.rs +++ b/datafusion/functions/src/utils.rs @@ -75,7 +75,7 @@ get_optimal_return_type!(utf8_to_int_type, DataType::Int64, DataType::Int32); /// Creates a scalar function implementation for the given function. /// * `inner` - the function to be executed /// * `hints` - hints to be used when expanding scalars to arrays -pub(super) fn make_scalar_function( +pub fn make_scalar_function( inner: F, hints: Vec, ) -> impl Fn(&[ColumnarValue]) -> Result @@ -133,7 +133,7 @@ pub mod test { let expected: Result> = $EXPECTED; let func = $FUNC; - let type_array = $ARGS.iter().map(|arg| arg.data_type()).collect::>(); + let data_array = $ARGS.iter().map(|arg| arg.data_type()).collect::>(); let cardinality = $ARGS .iter() .fold(Option::::None, |acc, arg| match arg { @@ -153,19 +153,28 @@ pub mod test { ColumnarValue::Array(a) => a.null_count() > 0, }).collect::>(); - let return_info = func.return_type_from_args(datafusion_expr::ReturnTypeArgs { - arg_types: &type_array, + let field_array = data_array.into_iter().zip(nullables).enumerate() + .map(|(idx, (data_type, nullable))| arrow::datatypes::Field::new(format!("field_{idx}"), data_type, nullable)) + .map(std::sync::Arc::new) + .collect::>(); + + let return_field = func.return_field_from_args(datafusion_expr::ReturnFieldArgs { + arg_fields: &field_array, scalar_arguments: &scalar_arguments_refs, - nullables: &nullables }); + let arg_fields = $ARGS.iter() + .enumerate() + .map(|(idx, arg)| arrow::datatypes::Field::new(format!("f_{idx}"), arg.data_type(), true).into()) + .collect::>(); match expected { Ok(expected) => { - assert_eq!(return_info.is_ok(), true); - let (return_type, _nullable) = return_info.unwrap().into_parts(); - assert_eq!(return_type, $EXPECTED_DATA_TYPE); + assert_eq!(return_field.is_ok(), true); + let return_field = return_field.unwrap(); + let return_type = return_field.data_type(); + assert_eq!(return_type, &$EXPECTED_DATA_TYPE); - let result = func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{args: $ARGS, number_rows: cardinality, return_type: &return_type}); + let result = func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{args: $ARGS, arg_fields, number_rows: cardinality, return_field}); assert_eq!(result.is_ok(), true, "function returned an error: {}", result.unwrap_err()); let result = result.unwrap().to_array(cardinality).expect("Failed to convert to array"); @@ -179,17 +188,17 @@ pub mod test { }; } Err(expected_error) => { - if return_info.is_err() { - match return_info { + if return_field.is_err() { + match return_field { Ok(_) => assert!(false, "expected error"), Err(error) => { datafusion_common::assert_contains!(expected_error.strip_backtrace(), error.strip_backtrace()); } } } else { - let (return_type, _nullable) = return_info.unwrap().into_parts(); + let return_field = return_field.unwrap(); // invoke is expected error - cannot use .expect_err() due to Debug not being implemented - match func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{args: $ARGS, number_rows: cardinality, return_type: &return_type}) { + match func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{args: $ARGS, arg_fields, number_rows: cardinality, return_field}) { Ok(_) => assert!(false, "expected error"), Err(error) => { assert!(expected_error.strip_backtrace().starts_with(&error.strip_backtrace())); diff --git a/datafusion/macros/src/user_doc.rs b/datafusion/macros/src/user_doc.rs index c6510c156423..31cf9bb1b750 100644 --- a/datafusion/macros/src/user_doc.rs +++ b/datafusion/macros/src/user_doc.rs @@ -206,7 +206,7 @@ pub fn user_doc(args: TokenStream, input: TokenStream) -> TokenStream { }; let doc_section_description = doc_section_desc .map(|desc| quote! { Some(#desc)}) - .unwrap_or(quote! { None }); + .unwrap_or_else(|| quote! { None }); let sql_example = sql_example.map(|ex| { quote! { diff --git a/datafusion/optimizer/Cargo.toml b/datafusion/optimizer/Cargo.toml index 61d101aab3f8..60358d20e2a1 100644 --- a/datafusion/optimizer/Cargo.toml +++ b/datafusion/optimizer/Cargo.toml @@ -55,6 +55,7 @@ regex-syntax = "0.8.0" [dev-dependencies] async-trait = { workspace = true } +criterion = { workspace = true } ctor = { workspace = true } datafusion-functions-aggregate = { workspace = true } datafusion-functions-window = { workspace = true } @@ -62,3 +63,7 @@ datafusion-functions-window-common = { workspace = true } datafusion-sql = { workspace = true } env_logger = { workspace = true } insta = { workspace = true } + +[[bench]] +name = "projection_unnecessary" +harness = false diff --git a/datafusion/optimizer/benches/projection_unnecessary.rs b/datafusion/optimizer/benches/projection_unnecessary.rs new file mode 100644 index 000000000000..c9f248fe49b5 --- /dev/null +++ b/datafusion/optimizer/benches/projection_unnecessary.rs @@ -0,0 +1,79 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::{DataType, Field, Schema}; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_common::ToDFSchema; +use datafusion_common::{Column, TableReference}; +use datafusion_expr::{logical_plan::LogicalPlan, projection_schema, Expr}; +use datafusion_optimizer::optimize_projections::is_projection_unnecessary; +use std::sync::Arc; + +fn is_projection_unnecessary_old( + input: &LogicalPlan, + proj_exprs: &[Expr], +) -> datafusion_common::Result { + // First check if all expressions are trivial (cheaper operation than `projection_schema`) + if !proj_exprs + .iter() + .all(|expr| matches!(expr, Expr::Column(_) | Expr::Literal(_, _))) + { + return Ok(false); + } + let proj_schema = projection_schema(input, proj_exprs)?; + Ok(&proj_schema == input.schema()) +} + +fn create_plan_with_many_exprs(num_exprs: usize) -> (LogicalPlan, Vec) { + // Create schema with many fields + let fields = (0..num_exprs) + .map(|i| Field::new(format!("col{i}"), DataType::Int32, false)) + .collect::>(); + let schema = Schema::new(fields); + + // Create table scan + let table_scan = LogicalPlan::EmptyRelation(datafusion_expr::EmptyRelation { + produce_one_row: true, + schema: Arc::new(schema.clone().to_dfschema().unwrap()), + }); + + // Create projection expressions (just column references) + let exprs = (0..num_exprs) + .map(|i| Expr::Column(Column::new(None::, format!("col{i}")))) + .collect(); + + (table_scan, exprs) +} + +fn benchmark_is_projection_unnecessary(c: &mut Criterion) { + let (plan, exprs) = create_plan_with_many_exprs(1000); + + let mut group = c.benchmark_group("projection_unnecessary_comparison"); + + group.bench_function("is_projection_unnecessary_new", |b| { + b.iter(|| black_box(is_projection_unnecessary(&plan, &exprs).unwrap())) + }); + + group.bench_function("is_projection_unnecessary_old", |b| { + b.iter(|| black_box(is_projection_unnecessary_old(&plan, &exprs).unwrap())) + }); + + group.finish(); +} + +criterion_group!(benches, benchmark_is_projection_unnecessary); +criterion_main!(benches); diff --git a/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs b/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs index f8a818563609..fa7ff1b8b19d 100644 --- a/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs +++ b/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs @@ -189,19 +189,19 @@ fn grouping_function_on_id( // Postgres allows grouping function for group by without grouping sets, the result is then // always 0 if !is_grouping_set { - return Ok(Expr::Literal(ScalarValue::from(0i32))); + return Ok(Expr::Literal(ScalarValue::from(0i32), None)); } let group_by_expr_count = group_by_expr.len(); let literal = |value: usize| { if group_by_expr_count < 8 { - Expr::Literal(ScalarValue::from(value as u8)) + Expr::Literal(ScalarValue::from(value as u8), None) } else if group_by_expr_count < 16 { - Expr::Literal(ScalarValue::from(value as u16)) + Expr::Literal(ScalarValue::from(value as u16), None) } else if group_by_expr_count < 32 { - Expr::Literal(ScalarValue::from(value as u32)) + Expr::Literal(ScalarValue::from(value as u32), None) } else { - Expr::Literal(ScalarValue::from(value as u64)) + Expr::Literal(ScalarValue::from(value as u64), None) } }; diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index d47f7ea6ce68..b5a3e9a2d585 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -41,7 +41,7 @@ use datafusion_expr::expr_schema::cast_subquery; use datafusion_expr::logical_plan::Subquery; use datafusion_expr::type_coercion::binary::{comparison_coercion, like_coercion}; use datafusion_expr::type_coercion::functions::{ - data_types_with_aggregate_udf, data_types_with_scalar_udf, + data_types_with_scalar_udf, fields_with_aggregate_udf, }; use datafusion_expr::type_coercion::other::{ get_coerce_type_for_case_expression, get_coerce_type_for_list, @@ -539,17 +539,18 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { ), ))) } - Expr::WindowFunction(WindowFunction { - fun, - params: - expr::WindowFunctionParams { - args, - partition_by, - order_by, - window_frame, - null_treatment, - }, - }) => { + Expr::WindowFunction(window_fun) => { + let WindowFunction { + fun, + params: + expr::WindowFunctionParams { + args, + partition_by, + order_by, + window_frame, + null_treatment, + }, + } = *window_fun; let window_frame = coerce_window_frame(window_frame, self.schema, &order_by)?; @@ -565,7 +566,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { }; Ok(Transformed::yes( - Expr::WindowFunction(WindowFunction::new(fun, args)) + Expr::from(WindowFunction::new(fun, args)) .partition_by(partition_by) .order_by(order_by) .window_frame(window_frame) @@ -578,7 +579,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { Expr::Alias(_) | Expr::Column(_) | Expr::ScalarVariable(_, _) - | Expr::Literal(_) + | Expr::Literal(_, _) | Expr::SimilarTo(_) | Expr::IsNotNull(_) | Expr::IsNull(_) @@ -718,6 +719,9 @@ fn coerce_frame_bound( fn extract_window_frame_target_type(col_type: &DataType) -> Result { if col_type.is_numeric() || is_utf8_or_utf8view_or_large_utf8(col_type) + || matches!(col_type, DataType::List(_)) + || matches!(col_type, DataType::LargeList(_)) + || matches!(col_type, DataType::FixedSizeList(_, _)) || matches!(col_type, DataType::Null) || matches!(col_type, DataType::Boolean) { @@ -808,12 +812,15 @@ fn coerce_arguments_for_signature_with_aggregate_udf( return Ok(expressions); } - let current_types = expressions + let current_fields = expressions .iter() - .map(|e| e.get_type(schema)) + .map(|e| e.to_field(schema).map(|(_, f)| f)) .collect::>>()?; - let new_types = data_types_with_aggregate_udf(¤t_types, func)?; + let new_types = fields_with_aggregate_udf(¤t_fields, func)? + .into_iter() + .map(|f| f.data_type().clone()) + .collect::>(); expressions .into_iter() @@ -1055,12 +1062,13 @@ mod test { use arrow::datatypes::DataType::Utf8; use arrow::datatypes::{DataType, Field, Schema, SchemaBuilder, TimeUnit}; + use insta::assert_snapshot; use crate::analyzer::type_coercion::{ coerce_case_expression, TypeCoercion, TypeCoercionRewriter, }; use crate::analyzer::Analyzer; - use crate::test::{assert_analyzed_plan_eq, assert_analyzed_plan_with_config_eq}; + use crate::assert_analyzed_plan_with_config_eq_snapshot; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{TransformedResult, TreeNode}; use datafusion_common::{DFSchema, DFSchemaRef, Result, ScalarValue, Spans}; @@ -1096,13 +1104,80 @@ mod test { })) } + macro_rules! assert_analyzed_plan_eq { + ( + $plan: expr, + @ $expected: literal $(,)? + ) => {{ + let options = ConfigOptions::default(); + let rule = Arc::new(TypeCoercion::new()); + assert_analyzed_plan_with_config_eq_snapshot!( + options, + rule, + $plan, + @ $expected, + ) + }}; + } + + macro_rules! coerce_on_output_if_viewtype { + ( + $is_viewtype: expr, + $plan: expr, + @ $expected: literal $(,)? + ) => {{ + let mut options = ConfigOptions::default(); + // coerce on output + if $is_viewtype {options.optimizer.expand_views_at_output = true;} + let rule = Arc::new(TypeCoercion::new()); + + assert_analyzed_plan_with_config_eq_snapshot!( + options, + rule, + $plan, + @ $expected, + ) + }}; + } + + fn assert_type_coercion_error( + plan: LogicalPlan, + expected_substr: &str, + ) -> Result<()> { + let options = ConfigOptions::default(); + let analyzer = Analyzer::with_rules(vec![Arc::new(TypeCoercion::new())]); + + match analyzer.execute_and_check(plan, &options, |_, _| {}) { + Ok(succeeded_plan) => { + panic!( + "Expected a type coercion error, but analysis succeeded: \n{succeeded_plan:#?}" + ); + } + Err(e) => { + let msg = e.to_string(); + assert!( + msg.contains(expected_substr), + "Error did not contain expected substring.\n expected to find: `{expected_substr}`\n actual error: `{msg}`" + ); + } + } + + Ok(()) + } + #[test] fn simple_case() -> Result<()> { let expr = col("a").lt(lit(2_u32)); let empty = empty_with_type(DataType::Float64); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); - let expected = "Projection: a < CAST(UInt32(2) AS Float64)\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected) + + assert_analyzed_plan_eq!( + plan, + @r" + Projection: a < CAST(UInt32(2) AS Float64) + EmptyRelation + " + ) } #[test] @@ -1137,28 +1212,15 @@ mod test { Arc::new(analyzed_union), )?); - let expected = "Projection: a\n Union\n Projection: CAST(datafusion.test.foo.a AS Int64) AS a\n EmptyRelation\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), top_level_plan, expected) - } - - fn coerce_on_output_if_viewtype(plan: LogicalPlan, expected: &str) -> Result<()> { - let mut options = ConfigOptions::default(); - options.optimizer.expand_views_at_output = true; - - assert_analyzed_plan_with_config_eq( - options, - Arc::new(TypeCoercion::new()), - plan.clone(), - expected, - ) - } - - fn do_not_coerce_on_output(plan: LogicalPlan, expected: &str) -> Result<()> { - assert_analyzed_plan_with_config_eq( - ConfigOptions::default(), - Arc::new(TypeCoercion::new()), - plan.clone(), - expected, + assert_analyzed_plan_eq!( + top_level_plan, + @r" + Projection: a + Union + Projection: CAST(datafusion.test.foo.a AS Int64) AS a + EmptyRelation + EmptyRelation + " ) } @@ -1172,12 +1234,26 @@ mod test { vec![expr.clone()], Arc::clone(&empty), )?); + // Plan A: no coerce - let if_not_coerced = "Projection: a\n EmptyRelation"; - do_not_coerce_on_output(plan.clone(), if_not_coerced)?; + coerce_on_output_if_viewtype!( + false, + plan.clone(), + @r" + Projection: a + EmptyRelation + " + )?; + // Plan A: coerce requested: Utf8View => LargeUtf8 - let if_coerced = "Projection: CAST(a AS LargeUtf8)\n EmptyRelation"; - coerce_on_output_if_viewtype(plan.clone(), if_coerced)?; + coerce_on_output_if_viewtype!( + true, + plan.clone(), + @r" + Projection: CAST(a AS LargeUtf8) + EmptyRelation + " + )?; // Plan B // scenario: outermost bool projection @@ -1187,12 +1263,33 @@ mod test { Arc::clone(&empty), )?); // Plan B: no coerce - let if_not_coerced = - "Projection: a < CAST(Utf8(\"foo\") AS Utf8View)\n EmptyRelation"; - do_not_coerce_on_output(bool_plan.clone(), if_not_coerced)?; + coerce_on_output_if_viewtype!( + false, + bool_plan.clone(), + @r#" + Projection: a < CAST(Utf8("foo") AS Utf8View) + EmptyRelation + "# + )?; + + coerce_on_output_if_viewtype!( + false, + plan.clone(), + @r" + Projection: a + EmptyRelation + " + )?; + // Plan B: coerce requested: no coercion applied - let if_coerced = if_not_coerced; - coerce_on_output_if_viewtype(bool_plan, if_coerced)?; + coerce_on_output_if_viewtype!( + true, + plan.clone(), + @r" + Projection: CAST(a AS LargeUtf8) + EmptyRelation + " + )?; // Plan C // scenario: with a non-projection root logical plan node @@ -1202,13 +1299,29 @@ mod test { input: Arc::new(plan), fetch: None, }); + // Plan C: no coerce - let if_not_coerced = - "Sort: a ASC NULLS FIRST\n Projection: a\n EmptyRelation"; - do_not_coerce_on_output(sort_plan.clone(), if_not_coerced)?; + coerce_on_output_if_viewtype!( + false, + sort_plan.clone(), + @r" + Sort: a ASC NULLS FIRST + Projection: a + EmptyRelation + " + )?; + // Plan C: coerce requested: Utf8View => LargeUtf8 - let if_coerced = "Projection: CAST(a AS LargeUtf8)\n Sort: a ASC NULLS FIRST\n Projection: a\n EmptyRelation"; - coerce_on_output_if_viewtype(sort_plan.clone(), if_coerced)?; + coerce_on_output_if_viewtype!( + true, + sort_plan.clone(), + @r" + Projection: CAST(a AS LargeUtf8) + Sort: a ASC NULLS FIRST + Projection: a + EmptyRelation + " + )?; // Plan D // scenario: two layers of projections with view types @@ -1217,11 +1330,27 @@ mod test { Arc::new(sort_plan), )?); // Plan D: no coerce - let if_not_coerced = "Projection: a\n Sort: a ASC NULLS FIRST\n Projection: a\n EmptyRelation"; - do_not_coerce_on_output(plan.clone(), if_not_coerced)?; + coerce_on_output_if_viewtype!( + false, + plan.clone(), + @r" + Projection: a + Sort: a ASC NULLS FIRST + Projection: a + EmptyRelation + " + )?; // Plan B: coerce requested: Utf8View => LargeUtf8 only on outermost - let if_coerced = "Projection: CAST(a AS LargeUtf8)\n Sort: a ASC NULLS FIRST\n Projection: a\n EmptyRelation"; - coerce_on_output_if_viewtype(plan.clone(), if_coerced)?; + coerce_on_output_if_viewtype!( + true, + plan.clone(), + @r" + Projection: CAST(a AS LargeUtf8) + Sort: a ASC NULLS FIRST + Projection: a + EmptyRelation + " + )?; Ok(()) } @@ -1236,12 +1365,26 @@ mod test { vec![expr.clone()], Arc::clone(&empty), )?); + // Plan A: no coerce - let if_not_coerced = "Projection: a\n EmptyRelation"; - do_not_coerce_on_output(plan.clone(), if_not_coerced)?; + coerce_on_output_if_viewtype!( + false, + plan.clone(), + @r" + Projection: a + EmptyRelation + " + )?; + // Plan A: coerce requested: BinaryView => LargeBinary - let if_coerced = "Projection: CAST(a AS LargeBinary)\n EmptyRelation"; - coerce_on_output_if_viewtype(plan.clone(), if_coerced)?; + coerce_on_output_if_viewtype!( + true, + plan.clone(), + @r" + Projection: CAST(a AS LargeBinary) + EmptyRelation + " + )?; // Plan B // scenario: outermost bool projection @@ -1250,13 +1393,26 @@ mod test { vec![bool_expr], Arc::clone(&empty), )?); + // Plan B: no coerce - let if_not_coerced = - "Projection: a < CAST(Binary(\"8,1,8,1\") AS BinaryView)\n EmptyRelation"; - do_not_coerce_on_output(bool_plan.clone(), if_not_coerced)?; + coerce_on_output_if_viewtype!( + false, + bool_plan.clone(), + @r#" + Projection: a < CAST(Binary("8,1,8,1") AS BinaryView) + EmptyRelation + "# + )?; + // Plan B: coerce requested: no coercion applied - let if_coerced = if_not_coerced; - coerce_on_output_if_viewtype(bool_plan, if_coerced)?; + coerce_on_output_if_viewtype!( + true, + bool_plan.clone(), + @r#" + Projection: a < CAST(Binary("8,1,8,1") AS BinaryView) + EmptyRelation + "# + )?; // Plan C // scenario: with a non-projection root logical plan node @@ -1266,13 +1422,28 @@ mod test { input: Arc::new(plan), fetch: None, }); + // Plan C: no coerce - let if_not_coerced = - "Sort: a ASC NULLS FIRST\n Projection: a\n EmptyRelation"; - do_not_coerce_on_output(sort_plan.clone(), if_not_coerced)?; + coerce_on_output_if_viewtype!( + false, + sort_plan.clone(), + @r" + Sort: a ASC NULLS FIRST + Projection: a + EmptyRelation + " + )?; // Plan C: coerce requested: BinaryView => LargeBinary - let if_coerced = "Projection: CAST(a AS LargeBinary)\n Sort: a ASC NULLS FIRST\n Projection: a\n EmptyRelation"; - coerce_on_output_if_viewtype(sort_plan.clone(), if_coerced)?; + coerce_on_output_if_viewtype!( + true, + sort_plan.clone(), + @r" + Projection: CAST(a AS LargeBinary) + Sort: a ASC NULLS FIRST + Projection: a + EmptyRelation + " + )?; // Plan D // scenario: two layers of projections with view types @@ -1280,12 +1451,30 @@ mod test { vec![col("a")], Arc::new(sort_plan), )?); + // Plan D: no coerce - let if_not_coerced = "Projection: a\n Sort: a ASC NULLS FIRST\n Projection: a\n EmptyRelation"; - do_not_coerce_on_output(plan.clone(), if_not_coerced)?; + coerce_on_output_if_viewtype!( + false, + plan.clone(), + @r" + Projection: a + Sort: a ASC NULLS FIRST + Projection: a + EmptyRelation + " + )?; + // Plan B: coerce requested: BinaryView => LargeBinary only on outermost - let if_coerced = "Projection: CAST(a AS LargeBinary)\n Sort: a ASC NULLS FIRST\n Projection: a\n EmptyRelation"; - coerce_on_output_if_viewtype(plan.clone(), if_coerced)?; + coerce_on_output_if_viewtype!( + true, + plan.clone(), + @r" + Projection: CAST(a AS LargeBinary) + Sort: a ASC NULLS FIRST + Projection: a + EmptyRelation + " + )?; Ok(()) } @@ -1299,9 +1488,14 @@ mod test { vec![expr.clone().or(expr)], empty, )?); - let expected = "Projection: a < CAST(UInt32(2) AS Float64) OR a < CAST(UInt32(2) AS Float64)\ - \n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected) + + assert_analyzed_plan_eq!( + plan, + @r" + Projection: a < CAST(UInt32(2) AS Float64) OR a < CAST(UInt32(2) AS Float64) + EmptyRelation + " + ) } #[derive(Debug, Clone)] @@ -1340,9 +1534,14 @@ mod test { }) .call(vec![lit(123_i32)]); let plan = LogicalPlan::Projection(Projection::try_new(vec![udf], empty)?); - let expected = - "Projection: TestScalarUDF(CAST(Int32(123) AS Float32))\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected) + + assert_analyzed_plan_eq!( + plan, + @r" + Projection: TestScalarUDF(CAST(Int32(123) AS Float32)) + EmptyRelation + " + ) } #[test] @@ -1372,9 +1571,14 @@ mod test { vec![scalar_function_expr], empty, )?); - let expected = - "Projection: TestScalarUDF(CAST(Int64(10) AS Float32))\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected) + + assert_analyzed_plan_eq!( + plan, + @r" + Projection: TestScalarUDF(CAST(Int64(10) AS Float32)) + EmptyRelation + " + ) } #[test] @@ -1397,8 +1601,14 @@ mod test { None, )); let plan = LogicalPlan::Projection(Projection::try_new(vec![udaf], empty)?); - let expected = "Projection: MY_AVG(CAST(Int64(10) AS Float64))\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected) + + assert_analyzed_plan_eq!( + plan, + @r" + Projection: MY_AVG(CAST(Int64(10) AS Float64)) + EmptyRelation + " + ) } #[test] @@ -1413,8 +1623,8 @@ mod test { return_type, accumulator, vec![ - Field::new("count", DataType::UInt64, true), - Field::new("avg", DataType::Float64, true), + Field::new("count", DataType::UInt64, true).into(), + Field::new("avg", DataType::Float64, true).into(), ], )); let udaf = Expr::AggregateFunction(expr::AggregateFunction::new_udf( @@ -1445,8 +1655,14 @@ mod test { None, )); let plan = LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty)?); - let expected = "Projection: avg(Float64(12))\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; + + assert_analyzed_plan_eq!( + plan, + @r" + Projection: avg(Float64(12)) + EmptyRelation + " + )?; let empty = empty_with_type(DataType::Int32); let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf( @@ -1458,9 +1674,14 @@ mod test { None, )); let plan = LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty)?); - let expected = "Projection: avg(CAST(a AS Float64))\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; - Ok(()) + + assert_analyzed_plan_eq!( + plan, + @r" + Projection: avg(CAST(a AS Float64)) + EmptyRelation + " + ) } #[test] @@ -1489,10 +1710,14 @@ mod test { + lit(ScalarValue::new_interval_dt(123, 456)); let empty = empty(); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); - let expected = - "Projection: CAST(Utf8(\"1998-03-18\") AS Date32) + IntervalDayTime(\"IntervalDayTime { days: 123, milliseconds: 456 }\")\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; - Ok(()) + + assert_analyzed_plan_eq!( + plan, + @r#" + Projection: CAST(Utf8("1998-03-18") AS Date32) + IntervalDayTime("IntervalDayTime { days: 123, milliseconds: 456 }") + EmptyRelation + "# + ) } #[test] @@ -1501,8 +1726,12 @@ mod test { let expr = col("a").in_list(vec![lit(1_i32), lit(4_i8), lit(8_i64)], false); let empty = empty_with_type(DataType::Int64); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); - let expected = "Projection: a IN ([CAST(Int32(1) AS Int64), CAST(Int8(4) AS Int64), Int64(8)])\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; + assert_analyzed_plan_eq!( + plan, + @r" + Projection: a IN ([CAST(Int32(1) AS Int64), CAST(Int8(4) AS Int64), Int64(8)]) + EmptyRelation + ")?; // a in (1,4,8), a is decimal let expr = col("a").in_list(vec![lit(1_i32), lit(4_i8), lit(8_i64)], false); @@ -1514,8 +1743,12 @@ mod test { )?), })); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); - let expected = "Projection: CAST(a AS Decimal128(24, 4)) IN ([CAST(Int32(1) AS Decimal128(24, 4)), CAST(Int8(4) AS Decimal128(24, 4)), CAST(Int64(8) AS Decimal128(24, 4))])\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected) + assert_analyzed_plan_eq!( + plan, + @r" + Projection: CAST(a AS Decimal128(24, 4)) IN ([CAST(Int32(1) AS Decimal128(24, 4)), CAST(Int8(4) AS Decimal128(24, 4)), CAST(Int64(8) AS Decimal128(24, 4))]) + EmptyRelation + ") } #[test] @@ -1528,10 +1761,14 @@ mod test { ); let empty = empty_with_type(Utf8); let plan = LogicalPlan::Filter(Filter::try_new(expr, empty)?); - let expected = - "Filter: CAST(a AS Date32) BETWEEN CAST(Utf8(\"2002-05-08\") AS Date32) AND CAST(Utf8(\"2002-05-08\") AS Date32) + IntervalYearMonth(\"1\")\ - \n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected) + + assert_analyzed_plan_eq!( + plan, + @r#" + Filter: CAST(a AS Date32) BETWEEN CAST(Utf8("2002-05-08") AS Date32) AND CAST(Utf8("2002-05-08") AS Date32) + IntervalYearMonth("1") + EmptyRelation + "# + ) } #[test] @@ -1544,11 +1781,15 @@ mod test { ); let empty = empty_with_type(Utf8); let plan = LogicalPlan::Filter(Filter::try_new(expr, empty)?); + // TODO: we should cast col(a). - let expected = - "Filter: CAST(a AS Date32) BETWEEN CAST(Utf8(\"2002-05-08\") AS Date32) + IntervalYearMonth(\"1\") AND CAST(Utf8(\"2002-12-08\") AS Date32)\ - \n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected) + assert_analyzed_plan_eq!( + plan, + @r#" + Filter: CAST(a AS Date32) BETWEEN CAST(Utf8("2002-05-08") AS Date32) + IntervalYearMonth("1") AND CAST(Utf8("2002-12-08") AS Date32) + EmptyRelation + "# + ) } #[test] @@ -1556,10 +1797,14 @@ mod test { let expr = lit(ScalarValue::Null).between(lit(ScalarValue::Null), lit(2i64)); let empty = empty(); let plan = LogicalPlan::Filter(Filter::try_new(expr, empty)?); - let expected = - "Filter: CAST(NULL AS Int64) BETWEEN CAST(NULL AS Int64) AND Int64(2)\ - \n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected) + + assert_analyzed_plan_eq!( + plan, + @r" + Filter: CAST(NULL AS Int64) BETWEEN CAST(NULL AS Int64) AND Int64(2) + EmptyRelation + " + ) } #[test] @@ -1569,37 +1814,60 @@ mod test { let empty = empty_with_type(DataType::Boolean); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr.clone()], empty)?); - let expected = "Projection: a IS TRUE\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; + + assert_analyzed_plan_eq!( + plan, + @r" + Projection: a IS TRUE + EmptyRelation + " + )?; let empty = empty_with_type(DataType::Int64); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); - let ret = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, ""); - let err = ret.unwrap_err().to_string(); - assert!(err.contains("Cannot infer common argument type for comparison operation Int64 IS DISTINCT FROM Boolean"), "{err}"); + assert_type_coercion_error( + plan, + "Cannot infer common argument type for comparison operation Int64 IS DISTINCT FROM Boolean" + )?; // is not true let expr = col("a").is_not_true(); let empty = empty_with_type(DataType::Boolean); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); - let expected = "Projection: a IS NOT TRUE\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; + + assert_analyzed_plan_eq!( + plan, + @r" + Projection: a IS NOT TRUE + EmptyRelation + " + )?; // is false let expr = col("a").is_false(); let empty = empty_with_type(DataType::Boolean); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); - let expected = "Projection: a IS FALSE\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; + + assert_analyzed_plan_eq!( + plan, + @r" + Projection: a IS FALSE + EmptyRelation + " + )?; // is not false let expr = col("a").is_not_false(); let empty = empty_with_type(DataType::Boolean); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); - let expected = "Projection: a IS NOT FALSE\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; - Ok(()) + assert_analyzed_plan_eq!( + plan, + @r" + Projection: a IS NOT FALSE + EmptyRelation + " + ) } #[test] @@ -1610,27 +1878,38 @@ mod test { let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false)); let empty = empty_with_type(Utf8); let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?); - let expected = "Projection: a LIKE Utf8(\"abc\")\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; + + assert_analyzed_plan_eq!( + plan, + @r#" + Projection: a LIKE Utf8("abc") + EmptyRelation + "# + )?; let expr = Box::new(col("a")); let pattern = Box::new(lit(ScalarValue::Null)); let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false)); let empty = empty_with_type(Utf8); let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?); - let expected = "Projection: a LIKE CAST(NULL AS Utf8)\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; + + assert_analyzed_plan_eq!( + plan, + @r" + Projection: a LIKE CAST(NULL AS Utf8) + EmptyRelation + " + )?; let expr = Box::new(col("a")); let pattern = Box::new(lit(ScalarValue::new_utf8("abc"))); let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false)); let empty = empty_with_type(DataType::Int64); let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?); - let err = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected); - assert!(err.is_err()); - assert!(err.unwrap_err().to_string().contains( - "There isn't a common type to coerce Int64 and Utf8 in LIKE expression" - )); + assert_type_coercion_error( + plan, + "There isn't a common type to coerce Int64 and Utf8 in LIKE expression", + )?; // ilike let expr = Box::new(col("a")); @@ -1638,27 +1917,39 @@ mod test { let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true)); let empty = empty_with_type(Utf8); let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?); - let expected = "Projection: a ILIKE Utf8(\"abc\")\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; + + assert_analyzed_plan_eq!( + plan, + @r#" + Projection: a ILIKE Utf8("abc") + EmptyRelation + "# + )?; let expr = Box::new(col("a")); let pattern = Box::new(lit(ScalarValue::Null)); let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true)); let empty = empty_with_type(Utf8); let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?); - let expected = "Projection: a ILIKE CAST(NULL AS Utf8)\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; + + assert_analyzed_plan_eq!( + plan, + @r" + Projection: a ILIKE CAST(NULL AS Utf8) + EmptyRelation + " + )?; let expr = Box::new(col("a")); let pattern = Box::new(lit(ScalarValue::new_utf8("abc"))); let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true)); let empty = empty_with_type(DataType::Int64); let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?); - let err = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected); - assert!(err.is_err()); - assert!(err.unwrap_err().to_string().contains( - "There isn't a common type to coerce Int64 and Utf8 in ILIKE expression" - )); + assert_type_coercion_error( + plan, + "There isn't a common type to coerce Int64 and Utf8 in ILIKE expression", + )?; + Ok(()) } @@ -1669,23 +1960,34 @@ mod test { let empty = empty_with_type(DataType::Boolean); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr.clone()], empty)?); - let expected = "Projection: a IS UNKNOWN\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; + + assert_analyzed_plan_eq!( + plan, + @r" + Projection: a IS UNKNOWN + EmptyRelation + " + )?; let empty = empty_with_type(Utf8); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); - let ret = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected); - let err = ret.unwrap_err().to_string(); - assert!(err.contains("Cannot infer common argument type for comparison operation Utf8 IS DISTINCT FROM Boolean"), "{err}"); + assert_type_coercion_error( + plan, + "Cannot infer common argument type for comparison operation Utf8 IS DISTINCT FROM Boolean" + )?; // is not unknown let expr = col("a").is_not_unknown(); let empty = empty_with_type(DataType::Boolean); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); - let expected = "Projection: a IS NOT UNKNOWN\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; - Ok(()) + assert_analyzed_plan_eq!( + plan, + @r" + Projection: a IS NOT UNKNOWN + EmptyRelation + " + ) } #[test] @@ -1694,21 +1996,19 @@ mod test { let args = [col("a"), lit("b"), lit(true), lit(false), lit(13)]; // concat-type signature - { - let expr = ScalarUDF::new_from_impl(TestScalarUDF { - signature: Signature::variadic(vec![Utf8], Volatility::Immutable), - }) - .call(args.to_vec()); - let plan = LogicalPlan::Projection(Projection::try_new( - vec![expr], - Arc::clone(&empty), - )?); - let expected = - "Projection: TestScalarUDF(a, Utf8(\"b\"), CAST(Boolean(true) AS Utf8), CAST(Boolean(false) AS Utf8), CAST(Int32(13) AS Utf8))\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; - } - - Ok(()) + let expr = ScalarUDF::new_from_impl(TestScalarUDF { + signature: Signature::variadic(vec![Utf8], Volatility::Immutable), + }) + .call(args.to_vec()); + let plan = + LogicalPlan::Projection(Projection::try_new(vec![expr], Arc::clone(&empty))?); + assert_analyzed_plan_eq!( + plan, + @r#" + Projection: TestScalarUDF(a, Utf8("b"), CAST(Boolean(true) AS Utf8), CAST(Boolean(false) AS Utf8), CAST(Int32(13) AS Utf8)) + EmptyRelation + "# + ) } #[test] @@ -1758,10 +2058,14 @@ mod test { .eq(cast(lit("1998-03-18"), DataType::Date32)); let empty = empty(); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); - let expected = - "Projection: CAST(Utf8(\"1998-03-18\") AS Timestamp(Nanosecond, None)) = CAST(CAST(Utf8(\"1998-03-18\") AS Date32) AS Timestamp(Nanosecond, None))\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; - Ok(()) + + assert_analyzed_plan_eq!( + plan, + @r#" + Projection: CAST(Utf8("1998-03-18") AS Timestamp(Nanosecond, None)) = CAST(CAST(Utf8("1998-03-18") AS Date32) AS Timestamp(Nanosecond, None)) + EmptyRelation + "# + ) } fn cast_if_not_same_type( @@ -1882,12 +2186,9 @@ mod test { else_expr: Some(Box::new(col("string"))), }; let err = coerce_case_expression(case, &schema).unwrap_err(); - assert_eq!( + assert_snapshot!( err.strip_backtrace(), - "Error during planning: \ - Failed to coerce case (Interval(MonthDayNano)) and \ - when ([Float32, Binary, Utf8]) to common types in \ - CASE WHEN expression" + @"Error during planning: Failed to coerce case (Interval(MonthDayNano)) and when ([Float32, Binary, Utf8]) to common types in CASE WHEN expression" ); let case = Case { @@ -1900,12 +2201,9 @@ mod test { else_expr: Some(Box::new(col("timestamp"))), }; let err = coerce_case_expression(case, &schema).unwrap_err(); - assert_eq!( + assert_snapshot!( err.strip_backtrace(), - "Error during planning: \ - Failed to coerce then ([Date32, Float32, Binary]) and \ - else (Some(Timestamp(Nanosecond, None))) to common types \ - in CASE WHEN expression" + @"Error during planning: Failed to coerce then ([Date32, Float32, Binary]) and else (Some(Timestamp(Nanosecond, None))) to common types in CASE WHEN expression" ); Ok(()) @@ -2108,12 +2406,14 @@ mod test { let expr = col("a").eq(cast(col("a"), may_type_cutsom)); let empty = empty_with_type(map_type_entries); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); - let expected = "Projection: a = CAST(CAST(a AS Map(Field { name: \"key_value\", data_type: Struct([Field { name: \"key\", data_type: Utf8, \ - nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: \"value\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]), \ - nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, false)) AS Map(Field { name: \"entries\", data_type: Struct([Field { name: \"key\", data_type: Utf8, nullable: false, \ - dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: \"value\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]), nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, false))\n \ - EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected) + + assert_analyzed_plan_eq!( + plan, + @r#" + Projection: a = CAST(CAST(a AS Map(Field { name: "key_value", data_type: Struct([Field { name: "key", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "value", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]), nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, false)) AS Map(Field { name: "entries", data_type: Struct([Field { name: "key", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "value", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]), nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, false)) + EmptyRelation + "# + ) } #[test] @@ -2129,9 +2429,14 @@ mod test { )); let empty = empty(); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); - let expected = "Projection: IntervalYearMonth(\"12\") + CAST(Utf8(\"2000-01-01T00:00:00\") AS Timestamp(Nanosecond, None))\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; - Ok(()) + + assert_analyzed_plan_eq!( + plan, + @r#" + Projection: IntervalYearMonth("12") + CAST(Utf8("2000-01-01T00:00:00") AS Timestamp(Nanosecond, None)) + EmptyRelation + "# + ) } #[test] @@ -2149,10 +2454,14 @@ mod test { )); let empty = empty(); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); - let expected = - "Projection: CAST(Utf8(\"1998-03-18\") AS Timestamp(Nanosecond, None)) - CAST(Utf8(\"1998-03-18\") AS Timestamp(Nanosecond, None))\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; - Ok(()) + + assert_analyzed_plan_eq!( + plan, + @r#" + Projection: CAST(Utf8("1998-03-18") AS Timestamp(Nanosecond, None)) - CAST(Utf8("1998-03-18") AS Timestamp(Nanosecond, None)) + EmptyRelation + "# + ) } #[test] @@ -2171,14 +2480,17 @@ mod test { )); let plan = LogicalPlan::Filter(Filter::try_new(in_subquery_expr, empty_int64)?); // add cast for subquery - let expected = "\ - Filter: a IN ()\ - \n Subquery:\ - \n Projection: CAST(a AS Int64)\ - \n EmptyRelation\ - \n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; - Ok(()) + + assert_analyzed_plan_eq!( + plan, + @r" + Filter: a IN () + Subquery: + Projection: CAST(a AS Int64) + EmptyRelation + EmptyRelation + " + ) } #[test] @@ -2196,14 +2508,17 @@ mod test { false, )); let plan = LogicalPlan::Filter(Filter::try_new(in_subquery_expr, empty_int32)?); + // add cast for subquery - let expected = "\ - Filter: CAST(a AS Int64) IN ()\ - \n Subquery:\ - \n EmptyRelation\ - \n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; - Ok(()) + assert_analyzed_plan_eq!( + plan, + @r" + Filter: CAST(a AS Int64) IN () + Subquery: + EmptyRelation + EmptyRelation + " + ) } #[test] @@ -2221,13 +2536,17 @@ mod test { false, )); let plan = LogicalPlan::Filter(Filter::try_new(in_subquery_expr, empty_outside)?); + // add cast for subquery - let expected = "Filter: CAST(a AS Decimal128(13, 8)) IN ()\ - \n Subquery:\ - \n Projection: CAST(a AS Decimal128(13, 8))\ - \n EmptyRelation\ - \n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; - Ok(()) + assert_analyzed_plan_eq!( + plan, + @r" + Filter: CAST(a AS Decimal128(13, 8)) IN () + Subquery: + Projection: CAST(a AS Decimal128(13, 8)) + EmptyRelation + EmptyRelation + " + ) } } diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 69b5fbb9f8c0..6a49e5d22087 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -803,23 +803,39 @@ mod test { use datafusion_expr::{lit, logical_plan::builder::LogicalPlanBuilder}; use super::*; + use crate::assert_optimized_plan_eq_snapshot; use crate::optimizer::OptimizerContext; use crate::test::*; - use crate::Optimizer; use datafusion_expr::test::function_stub::{avg, sum}; - fn assert_optimized_plan_eq( - expected: &str, - plan: LogicalPlan, - config: Option<&dyn OptimizerConfig>, - ) { - let optimizer = - Optimizer::with_rules(vec![Arc::new(CommonSubexprEliminate::new())]); - let default_config = OptimizerContext::new(); - let config = config.unwrap_or(&default_config); - let optimized_plan = optimizer.optimize(plan, config, |_, _| ()).unwrap(); - let formatted_plan = format!("{optimized_plan}"); - assert_eq!(expected, formatted_plan); + macro_rules! assert_optimized_plan_equal { + ( + $config:expr, + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let rules: Vec> = vec![Arc::new(CommonSubexprEliminate::new())]; + assert_optimized_plan_eq_snapshot!( + $config, + rules, + $plan, + @ $expected, + ) + }}; + + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let rules: Vec> = vec![Arc::new(CommonSubexprEliminate::new())]; + let optimizer_ctx = OptimizerContext::new(); + assert_optimized_plan_eq_snapshot!( + optimizer_ctx, + rules, + $plan, + @ $expected, + ) + }}; } #[test] @@ -844,13 +860,14 @@ mod test { )? .build()?; - let expected = "Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_1 AS test.a * Int32(1) - test.b), sum(__common_expr_1 AS test.a * Int32(1) - test.b * (Int32(1) + test.c))]]\ - \n Projection: test.a * (Int32(1) - test.b) AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_1 AS test.a * Int32(1) - test.b), sum(__common_expr_1 AS test.a * Int32(1) - test.b * (Int32(1) + test.c))]] + Projection: test.a * (Int32(1) - test.b) AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -864,13 +881,14 @@ mod test { ])? .build()?; - let expected = "Projection: __common_expr_1 - test.c AS alias1 * __common_expr_1 AS test.a + test.b, __common_expr_1 AS test.a + test.b\ - \n Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: __common_expr_1 - test.c AS alias1 * __common_expr_1 AS test.a + test.b, __common_expr_1 AS test.a + test.b + Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -886,7 +904,7 @@ mod test { Signature::exact(vec![DataType::UInt32], Volatility::Stable), return_type.clone(), Arc::clone(&accumulator), - vec![Field::new("value", DataType::UInt32, true)], + vec![Field::new("value", DataType::UInt32, true).into()], ))), vec![inner], false, @@ -917,11 +935,14 @@ mod test { )? .build()?; - let expected = "Projection: __common_expr_1 AS col1, __common_expr_1 AS col2, col3, __common_expr_3 AS avg(test.c), __common_expr_2 AS col4, __common_expr_2 AS col5, col6, __common_expr_4 AS my_agg(test.c)\ - \n Aggregate: groupBy=[[]], aggr=[[avg(test.a) AS __common_expr_1, my_agg(test.a) AS __common_expr_2, avg(test.b) AS col3, avg(test.c) AS __common_expr_3, my_agg(test.b) AS col6, my_agg(test.c) AS __common_expr_4]]\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, None); + assert_optimized_plan_equal!( + plan, + @ r" + Projection: __common_expr_1 AS col1, __common_expr_1 AS col2, col3, __common_expr_3 AS avg(test.c), __common_expr_2 AS col4, __common_expr_2 AS col5, col6, __common_expr_4 AS my_agg(test.c) + Aggregate: groupBy=[[]], aggr=[[avg(test.a) AS __common_expr_1, my_agg(test.a) AS __common_expr_2, avg(test.b) AS col3, avg(test.c) AS __common_expr_3, my_agg(test.b) AS col6, my_agg(test.c) AS __common_expr_4]] + TableScan: test + " + )?; // test: trafo after aggregate let plan = LogicalPlanBuilder::from(table_scan.clone()) @@ -936,11 +957,14 @@ mod test { )? .build()?; - let expected = "Projection: Int32(1) + __common_expr_1 AS avg(test.a), Int32(1) - __common_expr_1 AS avg(test.a), Int32(1) + __common_expr_2 AS my_agg(test.a), Int32(1) - __common_expr_2 AS my_agg(test.a)\ - \n Aggregate: groupBy=[[]], aggr=[[avg(test.a) AS __common_expr_1, my_agg(test.a) AS __common_expr_2]]\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, None); + assert_optimized_plan_equal!( + plan, + @ r" + Projection: Int32(1) + __common_expr_1 AS avg(test.a), Int32(1) - __common_expr_1 AS avg(test.a), Int32(1) + __common_expr_2 AS my_agg(test.a), Int32(1) - __common_expr_2 AS my_agg(test.a) + Aggregate: groupBy=[[]], aggr=[[avg(test.a) AS __common_expr_1, my_agg(test.a) AS __common_expr_2]] + TableScan: test + " + )?; // test: transformation before aggregate let plan = LogicalPlanBuilder::from(table_scan.clone()) @@ -953,11 +977,14 @@ mod test { )? .build()?; - let expected = "Aggregate: groupBy=[[]], aggr=[[avg(__common_expr_1) AS col1, my_agg(__common_expr_1) AS col2]]\ - \n Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, None); + assert_optimized_plan_equal!( + plan, + @ r" + Aggregate: groupBy=[[]], aggr=[[avg(__common_expr_1) AS col1, my_agg(__common_expr_1) AS col2]] + Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + )?; // test: common between agg and group let plan = LogicalPlanBuilder::from(table_scan.clone()) @@ -970,11 +997,14 @@ mod test { )? .build()?; - let expected = "Aggregate: groupBy=[[__common_expr_1 AS UInt32(1) + test.a]], aggr=[[avg(__common_expr_1) AS col1, my_agg(__common_expr_1) AS col2]]\ - \n Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, None); + assert_optimized_plan_equal!( + plan, + @ r" + Aggregate: groupBy=[[__common_expr_1 AS UInt32(1) + test.a]], aggr=[[avg(__common_expr_1) AS col1, my_agg(__common_expr_1) AS col2]] + Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + )?; // test: all mixed let plan = LogicalPlanBuilder::from(table_scan) @@ -991,14 +1021,15 @@ mod test { )? .build()?; - let expected = "Projection: UInt32(1) + test.a, UInt32(1) + __common_expr_2 AS col1, UInt32(1) - __common_expr_2 AS col2, __common_expr_4 AS avg(UInt32(1) + test.a), UInt32(1) + __common_expr_3 AS col3, UInt32(1) - __common_expr_3 AS col4, __common_expr_5 AS my_agg(UInt32(1) + test.a)\ - \n Aggregate: groupBy=[[__common_expr_1 AS UInt32(1) + test.a]], aggr=[[avg(__common_expr_1) AS __common_expr_2, my_agg(__common_expr_1) AS __common_expr_3, avg(__common_expr_1 AS UInt32(1) + test.a) AS __common_expr_4, my_agg(__common_expr_1 AS UInt32(1) + test.a) AS __common_expr_5]]\ - \n Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: UInt32(1) + test.a, UInt32(1) + __common_expr_2 AS col1, UInt32(1) - __common_expr_2 AS col2, __common_expr_4 AS avg(UInt32(1) + test.a), UInt32(1) + __common_expr_3 AS col3, UInt32(1) - __common_expr_3 AS col4, __common_expr_5 AS my_agg(UInt32(1) + test.a) + Aggregate: groupBy=[[__common_expr_1 AS UInt32(1) + test.a]], aggr=[[avg(__common_expr_1) AS __common_expr_2, my_agg(__common_expr_1) AS __common_expr_3, avg(__common_expr_1 AS UInt32(1) + test.a) AS __common_expr_4, my_agg(__common_expr_1 AS UInt32(1) + test.a) AS __common_expr_5]] + Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -1018,14 +1049,15 @@ mod test { )? .build()?; - let expected = "Projection: table.test.col.a, UInt32(1) + __common_expr_2 AS avg(UInt32(1) + table.test.col.a), __common_expr_2 AS avg(UInt32(1) + table.test.col.a)\ - \n Aggregate: groupBy=[[table.test.col.a]], aggr=[[avg(__common_expr_1 AS UInt32(1) + table.test.col.a) AS __common_expr_2]]\ - \n Projection: UInt32(1) + table.test.col.a AS __common_expr_1, table.test.col.a\ - \n TableScan: table.test"; - - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: table.test.col.a, UInt32(1) + __common_expr_2 AS avg(UInt32(1) + table.test.col.a), __common_expr_2 AS avg(UInt32(1) + table.test.col.a) + Aggregate: groupBy=[[table.test.col.a]], aggr=[[avg(__common_expr_1 AS UInt32(1) + table.test.col.a) AS __common_expr_2]] + Projection: UInt32(1) + table.test.col.a AS __common_expr_1, table.test.col.a + TableScan: table.test + " + ) } #[test] @@ -1039,13 +1071,14 @@ mod test { ])? .build()?; - let expected = "Projection: __common_expr_1 AS first, __common_expr_1 AS second\ - \n Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: __common_expr_1 AS first, __common_expr_1 AS second + Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -1056,13 +1089,14 @@ mod test { .project(vec![lit(1) + col("a"), col("a") + lit(1)])? .build()?; - let expected = "Projection: __common_expr_1 AS Int32(1) + test.a, __common_expr_1 AS test.a + Int32(1)\ - \n Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: __common_expr_1 AS Int32(1) + test.a, __common_expr_1 AS test.a + Int32(1) + Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -1074,12 +1108,14 @@ mod test { .project(vec![lit(1) + col("a")])? .build()?; - let expected = "Projection: Int32(1) + test.a\ - \n Projection: Int32(1) + test.a, test.a\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, None); - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: Int32(1) + test.a + Projection: Int32(1) + test.a, test.a + TableScan: test + " + ) } #[test] @@ -1193,14 +1229,15 @@ mod test { .filter((lit(1) + col("a") - lit(10)).gt(lit(1) + col("a")))? .build()?; - let expected = "Projection: test.a, test.b, test.c\ - \n Filter: __common_expr_1 - Int32(10) > __common_expr_1\ - \n Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: test.a, test.b, test.c + Filter: __common_expr_1 - Int32(10) > __common_expr_1 + Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -1226,7 +1263,7 @@ mod test { fn test_alias_collision() -> Result<()> { let table_scan = test_table_scan()?; - let config = &OptimizerContext::new(); + let config = OptimizerContext::new(); let common_expr_1 = config.alias_generator().next(CSE_PREFIX); let plan = LogicalPlanBuilder::from(table_scan.clone()) .project(vec![ @@ -1241,14 +1278,18 @@ mod test { ])? .build()?; - let expected = "Projection: __common_expr_1 AS c1, __common_expr_1 AS c2, __common_expr_2 AS c3, __common_expr_2 AS c4\ - \n Projection: test.c + Int32(2) AS __common_expr_2, __common_expr_1, test.c\ - \n Projection: test.a + test.b AS __common_expr_1, test.c\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, Some(config)); - - let config = &OptimizerContext::new(); + assert_optimized_plan_equal!( + config, + plan, + @ r" + Projection: __common_expr_1 AS c1, __common_expr_1 AS c2, __common_expr_2 AS c3, __common_expr_2 AS c4 + Projection: test.c + Int32(2) AS __common_expr_2, __common_expr_1, test.c + Projection: test.a + test.b AS __common_expr_1, test.c + TableScan: test + " + )?; + + let config = OptimizerContext::new(); let _common_expr_1 = config.alias_generator().next(CSE_PREFIX); let common_expr_2 = config.alias_generator().next(CSE_PREFIX); let plan = LogicalPlanBuilder::from(table_scan) @@ -1264,12 +1305,16 @@ mod test { ])? .build()?; - let expected = "Projection: __common_expr_2 AS c1, __common_expr_2 AS c2, __common_expr_3 AS c3, __common_expr_3 AS c4\ - \n Projection: test.c + Int32(2) AS __common_expr_3, __common_expr_2, test.c\ - \n Projection: test.a + test.b AS __common_expr_2, test.c\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, Some(config)); + assert_optimized_plan_equal!( + config, + plan, + @ r" + Projection: __common_expr_2 AS c1, __common_expr_2 AS c2, __common_expr_3 AS c3, __common_expr_3 AS c4 + Projection: test.c + Int32(2) AS __common_expr_3, __common_expr_2, test.c + Projection: test.a + test.b AS __common_expr_2, test.c + TableScan: test + " + )?; Ok(()) } @@ -1308,13 +1353,14 @@ mod test { ])? .build()?; - let expected = "Projection: __common_expr_1 AS c1, __common_expr_1 AS c2, __common_expr_2 OR test.a - test.b = Int32(0) AS c3, __common_expr_2 AND test.a - test.b = Int32(0) AS c4, __common_expr_3 OR __common_expr_3 AS c5\ - \n Projection: test.a = Int32(0) OR test.b = Int32(0) AS __common_expr_1, test.a + test.b = Int32(0) AS __common_expr_2, test.a * test.b = Int32(0) AS __common_expr_3, test.a, test.b, test.c\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: __common_expr_1 AS c1, __common_expr_1 AS c2, __common_expr_2 OR test.a - test.b = Int32(0) AS c3, __common_expr_2 AND test.a - test.b = Int32(0) AS c4, __common_expr_3 OR __common_expr_3 AS c5 + Projection: test.a = Int32(0) OR test.b = Int32(0) AS __common_expr_1, test.a + test.b = Int32(0) AS __common_expr_2, test.a * test.b = Int32(0) AS __common_expr_3, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -1331,13 +1377,14 @@ mod test { ])? .build()?; - let expected = "Projection: __common_expr_1 + random() AS c1, __common_expr_1 + random() AS c2\ - \n Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: __common_expr_1 + random() AS c1, __common_expr_1 + random() AS c2 + Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -1360,13 +1407,14 @@ mod test { ])? .build()?; - let expected = "Projection: __common_expr_1 OR random() = Int32(0) AS c1, __common_expr_1 OR random() = Int32(0) AS c2, random() = Int32(0) OR test.b = Int32(0) AS c3, random() = Int32(0) OR test.b = Int32(0) AS c4\ - \n Projection: test.a = Int32(0) AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: __common_expr_1 OR random() = Int32(0) AS c1, __common_expr_1 OR random() = Int32(0) AS c2, random() = Int32(0) OR test.b = Int32(0) AS c3, random() = Int32(0) OR test.b = Int32(0) AS c4 + Projection: test.a = Int32(0) AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -1382,14 +1430,15 @@ mod test { .project(vec![col("c1"), col("c2")])? .build()?; - let expected = "Projection: c1, c2\ - \n Projection: __common_expr_1 AS c1, __common_expr_1 AS c2\ - \n Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: c1, c2 + Projection: __common_expr_1 AS c1, __common_expr_1 AS c2 + Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -1405,14 +1454,15 @@ mod test { ])? .build()?; - let expected = "Projection: __common_expr_1 AS c1, __common_expr_1 AS c2\ - \n Projection: __common_expr_2 * __common_expr_2 AS __common_expr_1, test.a, test.b, test.c\ - \n Projection: test.a + test.b AS __common_expr_2, test.a, test.b, test.c\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: __common_expr_1 AS c1, __common_expr_1 AS c2 + Projection: __common_expr_2 * __common_expr_2 AS __common_expr_1, test.a, test.b, test.c + Projection: test.a + test.b AS __common_expr_2, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -1422,13 +1472,15 @@ mod test { let expr = ((col("a") + col("b")) * (col("b") + col("a"))).eq(lit(30)); let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?; - let expected = "Projection: test.a, test.b, test.c\ - \n Filter: __common_expr_1 * __common_expr_1 = Int32(30)\ - \n Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: test.a, test.b, test.c + Filter: __common_expr_1 * __common_expr_1 = Int32(30) + Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -1438,13 +1490,15 @@ mod test { let expr = ((col("a") * col("b")) + (col("b") * col("a"))).eq(lit(30)); let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?; - let expected = "Projection: test.a, test.b, test.c\ - \n Filter: __common_expr_1 + __common_expr_1 = Int32(30)\ - \n Projection: test.a * test.b AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: test.a, test.b, test.c + Filter: __common_expr_1 + __common_expr_1 = Int32(30) + Projection: test.a * test.b AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -1454,13 +1508,15 @@ mod test { let expr = ((col("a") & col("b")) + (col("b") & col("a"))).eq(lit(30)); let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?; - let expected = "Projection: test.a, test.b, test.c\ - \n Filter: __common_expr_1 + __common_expr_1 = Int32(30)\ - \n Projection: test.a & test.b AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: test.a, test.b, test.c + Filter: __common_expr_1 + __common_expr_1 = Int32(30) + Projection: test.a & test.b AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -1470,13 +1526,15 @@ mod test { let expr = ((col("a") | col("b")) + (col("b") | col("a"))).eq(lit(30)); let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?; - let expected = "Projection: test.a, test.b, test.c\ - \n Filter: __common_expr_1 + __common_expr_1 = Int32(30)\ - \n Projection: test.a | test.b AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: test.a, test.b, test.c + Filter: __common_expr_1 + __common_expr_1 = Int32(30) + Projection: test.a | test.b AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -1486,13 +1544,15 @@ mod test { let expr = ((col("a") ^ col("b")) + (col("b") ^ col("a"))).eq(lit(30)); let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?; - let expected = "Projection: test.a, test.b, test.c\ - \n Filter: __common_expr_1 + __common_expr_1 = Int32(30)\ - \n Projection: test.a BIT_XOR test.b AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: test.a, test.b, test.c + Filter: __common_expr_1 + __common_expr_1 = Int32(30) + Projection: test.a BIT_XOR test.b AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -1502,13 +1562,15 @@ mod test { let expr = (col("a").eq(col("b"))).and(col("b").eq(col("a"))); let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?; - let expected = "Projection: test.a, test.b, test.c\ - \n Filter: __common_expr_1 AND __common_expr_1\ - \n Projection: test.a = test.b AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: test.a, test.b, test.c + Filter: __common_expr_1 AND __common_expr_1 + Projection: test.a = test.b AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -1518,13 +1580,15 @@ mod test { let expr = (col("a").not_eq(col("b"))).and(col("b").not_eq(col("a"))); let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?; - let expected = "Projection: test.a, test.b, test.c\ - \n Filter: __common_expr_1 AND __common_expr_1\ - \n Projection: test.a != test.b AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: test.a, test.b, test.c + Filter: __common_expr_1 AND __common_expr_1 + Projection: test.a != test.b AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -1535,11 +1599,15 @@ mod test { .eq(lit(30)); let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?; - let expected = "Projection: test.a, test.b, test.c\ - \n Filter: __common_expr_1 - __common_expr_1 = Int32(30)\ - \n Projection: test.a + test.b * test.c AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - assert_optimized_plan_eq(expected, plan, None); + assert_optimized_plan_equal!( + plan, + @ r" + Projection: test.a, test.b, test.c + Filter: __common_expr_1 - __common_expr_1 = Int32(30) + Projection: test.a + test.b * test.c AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + )?; // ((c1 + c2 / c3) * c3 <=> c3 * (c2 / c3 + c1)) let table_scan = test_table_scan()?; @@ -1548,11 +1616,16 @@ mod test { + col("a")) .eq(lit(30)); let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?; - let expected = "Projection: test.a, test.b, test.c\ - \n Filter: __common_expr_1 / __common_expr_1 + test.a = Int32(30)\ - \n Projection: (test.a + test.b / test.c) * test.c AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - assert_optimized_plan_eq(expected, plan, None); + + assert_optimized_plan_equal!( + plan, + @ r" + Projection: test.a, test.b, test.c + Filter: __common_expr_1 / __common_expr_1 + test.a = Int32(30) + Projection: (test.a + test.b / test.c) * test.c AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + )?; // c2 / (c1 + c3) <=> c2 / (c3 + c1) let table_scan = test_table_scan()?; @@ -1560,11 +1633,15 @@ mod test { * (col("b") / (col("c") + col("a")))) .eq(lit(30)); let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?; - let expected = "Projection: test.a, test.b, test.c\ - \n Filter: __common_expr_1 * __common_expr_1 = Int32(30)\ - \n Projection: test.b / (test.a + test.c) AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - assert_optimized_plan_eq(expected, plan, None); + assert_optimized_plan_equal!( + plan, + @ r" + Projection: test.a, test.b, test.c + Filter: __common_expr_1 * __common_expr_1 = Int32(30) + Projection: test.b / (test.a + test.c) AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + )?; Ok(()) } @@ -1612,10 +1689,14 @@ mod test { let plan = LogicalPlanBuilder::from(table_scan) .project(vec![expr1, expr2])? .build()?; - let expected = "Projection: __common_expr_1 AS NOT test.a = test.b, __common_expr_1 AS NOT test.b = test.a\ - \n Projection: NOT test.a = test.b AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - assert_optimized_plan_eq(expected, plan, None); + assert_optimized_plan_equal!( + plan, + @ r" + Projection: __common_expr_1 AS NOT test.a = test.b, __common_expr_1 AS NOT test.b = test.a + Projection: NOT test.a = test.b AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + )?; // is_null(a == b) <=> is_null(b == a) let table_scan = test_table_scan()?; @@ -1624,10 +1705,14 @@ mod test { let plan = LogicalPlanBuilder::from(table_scan) .project(vec![expr1, expr2])? .build()?; - let expected = "Projection: __common_expr_1 AS test.a = test.b IS NULL, __common_expr_1 AS test.b = test.a IS NULL\ - \n Projection: test.a = test.b IS NULL AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - assert_optimized_plan_eq(expected, plan, None); + assert_optimized_plan_equal!( + plan, + @ r" + Projection: __common_expr_1 AS test.a = test.b IS NULL, __common_expr_1 AS test.b = test.a IS NULL + Projection: test.a = test.b IS NULL AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + )?; // a + b between 0 and 10 <=> b + a between 0 and 10 let table_scan = test_table_scan()?; @@ -1636,10 +1721,14 @@ mod test { let plan = LogicalPlanBuilder::from(table_scan) .project(vec![expr1, expr2])? .build()?; - let expected = "Projection: __common_expr_1 AS test.a + test.b BETWEEN Int32(0) AND Int32(10), __common_expr_1 AS test.b + test.a BETWEEN Int32(0) AND Int32(10)\ - \n Projection: test.a + test.b BETWEEN Int32(0) AND Int32(10) AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - assert_optimized_plan_eq(expected, plan, None); + assert_optimized_plan_equal!( + plan, + @ r" + Projection: __common_expr_1 AS test.a + test.b BETWEEN Int32(0) AND Int32(10), __common_expr_1 AS test.b + test.a BETWEEN Int32(0) AND Int32(10) + Projection: test.a + test.b BETWEEN Int32(0) AND Int32(10) AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + )?; // c between a + b and 10 <=> c between b + a and 10 let table_scan = test_table_scan()?; @@ -1648,10 +1737,14 @@ mod test { let plan = LogicalPlanBuilder::from(table_scan) .project(vec![expr1, expr2])? .build()?; - let expected = "Projection: __common_expr_1 AS test.c BETWEEN test.a + test.b AND Int32(10), __common_expr_1 AS test.c BETWEEN test.b + test.a AND Int32(10)\ - \n Projection: test.c BETWEEN test.a + test.b AND Int32(10) AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - assert_optimized_plan_eq(expected, plan, None); + assert_optimized_plan_equal!( + plan, + @ r" + Projection: __common_expr_1 AS test.c BETWEEN test.a + test.b AND Int32(10), __common_expr_1 AS test.c BETWEEN test.b + test.a AND Int32(10) + Projection: test.c BETWEEN test.a + test.b AND Int32(10) AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + )?; // function call with argument <=> function call with argument let udf = ScalarUDF::from(TestUdf::new()); @@ -1661,11 +1754,14 @@ mod test { let plan = LogicalPlanBuilder::from(table_scan) .project(vec![expr1, expr2])? .build()?; - let expected = "Projection: __common_expr_1 AS my_udf(test.a + test.b), __common_expr_1 AS my_udf(test.b + test.a)\ - \n Projection: my_udf(test.a + test.b) AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - assert_optimized_plan_eq(expected, plan, None); - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: __common_expr_1 AS my_udf(test.a + test.b), __common_expr_1 AS my_udf(test.b + test.a) + Projection: my_udf(test.a + test.b) AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + ) } /// returns a "random" function that is marked volatile (aka each invocation diff --git a/datafusion/optimizer/src/decorrelate.rs b/datafusion/optimizer/src/decorrelate.rs index 418619c8399e..63236787743a 100644 --- a/datafusion/optimizer/src/decorrelate.rs +++ b/datafusion/optimizer/src/decorrelate.rs @@ -71,6 +71,9 @@ pub struct PullUpCorrelatedExpr { pub collected_count_expr_map: HashMap, /// pull up having expr, which must be evaluated after the Join pub pull_up_having_expr: Option, + /// whether we have converted a scalar aggregation into a group aggregation. When unnesting + /// lateral joins, we need to produce a left outer join in such cases. + pub pulled_up_scalar_agg: bool, } impl Default for PullUpCorrelatedExpr { @@ -91,6 +94,7 @@ impl PullUpCorrelatedExpr { need_handle_count_bug: false, collected_count_expr_map: HashMap::new(), pull_up_having_expr: None, + pulled_up_scalar_agg: false, } } @@ -313,6 +317,11 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { missing_exprs.push(un_matched_row); } } + if aggregate.group_expr.is_empty() { + // TODO: how do we handle the case where we have pulled multiple aggregations? For example, + // a group agg with a scalar agg as child. + self.pulled_up_scalar_agg = true; + } let new_plan = LogicalPlanBuilder::from((*aggregate.input).clone()) .aggregate(missing_exprs, aggregate.aggr_expr.to_vec())? .build()?; @@ -485,9 +494,12 @@ fn agg_exprs_evaluation_result_on_empty_batch( let new_expr = match expr { Expr::AggregateFunction(expr::AggregateFunction { func, .. }) => { if func.name() == "count" { - Transformed::yes(Expr::Literal(ScalarValue::Int64(Some(0)))) + Transformed::yes(Expr::Literal( + ScalarValue::Int64(Some(0)), + None, + )) } else { - Transformed::yes(Expr::Literal(ScalarValue::Null)) + Transformed::yes(Expr::Literal(ScalarValue::Null, None)) } } _ => Transformed::no(expr), @@ -578,10 +590,10 @@ fn filter_exprs_evaluation_result_on_empty_batch( let result_expr = simplifier.simplify(result_expr)?; match &result_expr { // evaluate to false or null on empty batch, no need to pull up - Expr::Literal(ScalarValue::Null) - | Expr::Literal(ScalarValue::Boolean(Some(false))) => None, + Expr::Literal(ScalarValue::Null, _) + | Expr::Literal(ScalarValue::Boolean(Some(false)), _) => None, // evaluate to true on empty batch, need to pull up the expr - Expr::Literal(ScalarValue::Boolean(Some(true))) => { + Expr::Literal(ScalarValue::Boolean(Some(true)), _) => { for (name, exprs) in input_expr_result_map_for_count_bug { expr_result_map_for_count_bug.insert(name.clone(), exprs.clone()); } @@ -596,7 +608,7 @@ fn filter_exprs_evaluation_result_on_empty_batch( Box::new(result_expr.clone()), Box::new(input_expr.clone()), )], - else_expr: Some(Box::new(Expr::Literal(ScalarValue::Null))), + else_expr: Some(Box::new(Expr::Literal(ScalarValue::Null, None))), }); let expr_key = new_expr.schema_name().to_string(); expr_result_map_for_count_bug.insert(expr_key, new_expr); diff --git a/datafusion/optimizer/src/decorrelate_lateral_join.rs b/datafusion/optimizer/src/decorrelate_lateral_join.rs new file mode 100644 index 000000000000..7d2072ad1ce9 --- /dev/null +++ b/datafusion/optimizer/src/decorrelate_lateral_join.rs @@ -0,0 +1,143 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`DecorrelateLateralJoin`] decorrelates logical plans produced by lateral joins. + +use std::collections::BTreeSet; + +use crate::decorrelate::PullUpCorrelatedExpr; +use crate::optimizer::ApplyOrder; +use crate::{OptimizerConfig, OptimizerRule}; +use datafusion_expr::{lit, Join}; + +use datafusion_common::tree_node::{ + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, +}; +use datafusion_common::Result; +use datafusion_expr::logical_plan::JoinType; +use datafusion_expr::utils::conjunction; +use datafusion_expr::{LogicalPlan, LogicalPlanBuilder}; + +/// Optimizer rule for rewriting lateral joins to joins +#[derive(Default, Debug)] +pub struct DecorrelateLateralJoin {} + +impl DecorrelateLateralJoin { + #[allow(missing_docs)] + pub fn new() -> Self { + Self::default() + } +} + +impl OptimizerRule for DecorrelateLateralJoin { + fn supports_rewrite(&self) -> bool { + true + } + + fn rewrite( + &self, + plan: LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> Result> { + // Find cross joins with outer column references on the right side (i.e., the apply operator). + let LogicalPlan::Join(join) = plan else { + return Ok(Transformed::no(plan)); + }; + + rewrite_internal(join) + } + + fn name(&self) -> &str { + "decorrelate_lateral_join" + } + + fn apply_order(&self) -> Option { + Some(ApplyOrder::TopDown) + } +} + +// Build the decorrelated join based on the original lateral join query. For now, we only support cross/inner +// lateral joins. +fn rewrite_internal(join: Join) -> Result> { + if join.join_type != JoinType::Inner { + return Ok(Transformed::no(LogicalPlan::Join(join))); + } + + match join.right.apply_with_subqueries(|p| { + // TODO: support outer joins + if p.contains_outer_reference() { + Ok(TreeNodeRecursion::Stop) + } else { + Ok(TreeNodeRecursion::Continue) + } + })? { + TreeNodeRecursion::Stop => {} + TreeNodeRecursion::Continue => { + // The left side contains outer references, we need to decorrelate it. + return Ok(Transformed::new( + LogicalPlan::Join(join), + false, + TreeNodeRecursion::Jump, + )); + } + TreeNodeRecursion::Jump => { + unreachable!("") + } + } + + let LogicalPlan::Subquery(subquery) = join.right.as_ref() else { + return Ok(Transformed::no(LogicalPlan::Join(join))); + }; + + if join.join_type != JoinType::Inner { + return Ok(Transformed::no(LogicalPlan::Join(join))); + } + let subquery_plan = subquery.subquery.as_ref(); + let mut pull_up = PullUpCorrelatedExpr::new().with_need_handle_count_bug(true); + let rewritten_subquery = subquery_plan.clone().rewrite(&mut pull_up).data()?; + if !pull_up.can_pull_up { + return Ok(Transformed::no(LogicalPlan::Join(join))); + } + + let mut all_correlated_cols = BTreeSet::new(); + pull_up + .correlated_subquery_cols_map + .values() + .for_each(|cols| all_correlated_cols.extend(cols.clone())); + let join_filter_opt = conjunction(pull_up.join_filters); + let join_filter = match join_filter_opt { + Some(join_filter) => join_filter, + None => lit(true), + }; + // -- inner join but the right side always has one row, we need to rewrite it to a left join + // SELECT * FROM t0, LATERAL (SELECT sum(v1) FROM t1 WHERE t0.v0 = t1.v0); + // -- inner join but the right side number of rows is related to the filter (join) condition, so keep inner join. + // SELECT * FROM t0, LATERAL (SELECT * FROM t1 WHERE t0.v0 = t1.v0); + let new_plan = LogicalPlanBuilder::from(join.left) + .join_on( + rewritten_subquery, + if pull_up.pulled_up_scalar_agg { + JoinType::Left + } else { + JoinType::Inner + }, + Some(join_filter), + )? + .build()?; + // TODO: handle count(*) bug + Ok(Transformed::new(new_plan, true, TreeNodeRecursion::Jump)) +} diff --git a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs index c18c48251daa..a72657bf689d 100644 --- a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs +++ b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs @@ -427,17 +427,23 @@ mod tests { use super::*; use crate::test::*; + use crate::assert_optimized_plan_eq_display_indent_snapshot; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_expr::builder::table_source; use datafusion_expr::{and, binary_expr, col, lit, not, out_ref_col, table_scan}; - fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelatePredicateSubquery::new()), - plan, - expected, - ); - Ok(()) + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let rule: Arc = Arc::new(DecorrelatePredicateSubquery::new()); + assert_optimized_plan_eq_display_indent_snapshot!( + rule, + $plan, + @ $expected, + ) + }}; } fn test_subquery_with_name(name: &str) -> Result> { @@ -461,17 +467,21 @@ mod tests { .project(vec![col("test.b")])? .build()?; - let expected = "Projection: test.b [b:UInt32]\ - \n LeftSemi Join: Filter: test.b = __correlated_sq_2.c [a:UInt32, b:UInt32, c:UInt32]\ - \n LeftSemi Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [c:UInt32]\ - \n Projection: sq_1.c [c:UInt32]\ - \n TableScan: sq_1 [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_2 [c:UInt32]\ - \n Projection: sq_2.c [c:UInt32]\ - \n TableScan: sq_2 [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.b [b:UInt32] + LeftSemi Join: Filter: test.b = __correlated_sq_2.c [a:UInt32, b:UInt32, c:UInt32] + LeftSemi Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [c:UInt32] + Projection: sq_1.c [c:UInt32] + TableScan: sq_1 [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_2 [c:UInt32] + Projection: sq_2.c [c:UInt32] + TableScan: sq_2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } /// Test for IN subquery with additional AND filter @@ -489,15 +499,18 @@ mod tests { .project(vec![col("test.b")])? .build()?; - let expected = "Projection: test.b [b:UInt32]\ - \n Filter: test.a = UInt32(1) AND test.b < UInt32(30) [a:UInt32, b:UInt32, c:UInt32]\ - \n LeftSemi Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [c:UInt32]\ - \n Projection: sq.c [c:UInt32]\ - \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.b [b:UInt32] + Filter: test.a = UInt32(1) AND test.b < UInt32(30) [a:UInt32, b:UInt32, c:UInt32] + LeftSemi Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [c:UInt32] + Projection: sq.c [c:UInt32] + TableScan: sq [a:UInt32, b:UInt32, c:UInt32] + " + ) } /// Test for nested IN subqueries @@ -515,18 +528,21 @@ mod tests { .project(vec![col("test.b")])? .build()?; - let expected = "Projection: test.b [b:UInt32]\ - \n LeftSemi Join: Filter: test.b = __correlated_sq_2.a [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_2 [a:UInt32]\ - \n Projection: sq.a [a:UInt32]\ - \n LeftSemi Join: Filter: sq.a = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [c:UInt32]\ - \n Projection: sq_nested.c [c:UInt32]\ - \n TableScan: sq_nested [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.b [b:UInt32] + LeftSemi Join: Filter: test.b = __correlated_sq_2.a [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_2 [a:UInt32] + Projection: sq.a [a:UInt32] + LeftSemi Join: Filter: sq.a = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32] + TableScan: sq [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [c:UInt32] + Projection: sq_nested.c [c:UInt32] + TableScan: sq_nested [a:UInt32, b:UInt32, c:UInt32] + " + ) } /// Test multiple correlated subqueries @@ -551,23 +567,21 @@ mod tests { .build()?; debug!("plan to optimize:\n{}", plan.display_indent()); - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_2.o_custkey [c_custkey:Int64, c_name:Utf8]\ - \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n SubqueryAlias: __correlated_sq_2 [o_custkey:Int64]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelatePredicateSubquery::new()), - plan, - expected, - ); - Ok(()) + assert_optimized_plan_equal!( + plan, + @r###" + Projection: customer.c_custkey [c_custkey:Int64] + LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_2.o_custkey [c_custkey:Int64, c_name:Utf8] + LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __correlated_sq_1 [o_custkey:Int64] + Projection: orders.o_custkey [o_custkey:Int64] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + SubqueryAlias: __correlated_sq_2 [o_custkey:Int64] + Projection: orders.o_custkey [o_custkey:Int64] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + "### + ) } /// Test recursive correlated subqueries @@ -601,23 +615,21 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_2.o_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_2 [o_custkey:Int64]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n LeftSemi Join: Filter: orders.o_orderkey = __correlated_sq_1.l_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n SubqueryAlias: __correlated_sq_1 [l_orderkey:Int64]\ - \n Projection: lineitem.l_orderkey [l_orderkey:Int64]\ - \n TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]"; - - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelatePredicateSubquery::new()), + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_2.o_custkey [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __correlated_sq_2 [o_custkey:Int64] + Projection: orders.o_custkey [o_custkey:Int64] + LeftSemi Join: Filter: orders.o_orderkey = __correlated_sq_1.l_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + SubqueryAlias: __correlated_sq_1 [l_orderkey:Int64] + Projection: lineitem.l_orderkey [l_orderkey:Int64] + TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64] + " + ) } /// Test for correlated IN subquery filter with additional subquery filters @@ -639,20 +651,18 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n Filter: orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelatePredicateSubquery::new()), + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __correlated_sq_1 [o_custkey:Int64] + Projection: orders.o_custkey [o_custkey:Int64] + Filter: orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test for correlated IN subquery with no columns in schema @@ -673,19 +683,17 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelatePredicateSubquery::new()), + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __correlated_sq_1 [o_custkey:Int64] + Projection: orders.o_custkey [o_custkey:Int64] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test for IN subquery with both columns in schema @@ -703,20 +711,18 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n Filter: orders.o_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelatePredicateSubquery::new()), + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __correlated_sq_1 [o_custkey:Int64] + Projection: orders.o_custkey [o_custkey:Int64] + Filter: orders.o_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test for correlated IN subquery not equal @@ -737,19 +743,17 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey AND customer.c_custkey != __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelatePredicateSubquery::new()), + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey AND customer.c_custkey != __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __correlated_sq_1 [o_custkey:Int64] + Projection: orders.o_custkey [o_custkey:Int64] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test for correlated IN subquery less than @@ -770,19 +774,17 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey AND customer.c_custkey < __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelatePredicateSubquery::new()), + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey AND customer.c_custkey < __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __correlated_sq_1 [o_custkey:Int64] + Projection: orders.o_custkey [o_custkey:Int64] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test for correlated IN subquery filter with subquery disjunction @@ -804,20 +806,17 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey AND (customer.c_custkey = __correlated_sq_1.o_custkey OR __correlated_sq_1.o_orderkey = Int32(1)) [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64, o_orderkey:Int64]\ - \n Projection: orders.o_custkey, orders.o_orderkey [o_custkey:Int64, o_orderkey:Int64]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelatePredicateSubquery::new()), + assert_optimized_plan_equal!( plan, - expected, - ); - - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey AND (customer.c_custkey = __correlated_sq_1.o_custkey OR __correlated_sq_1.o_orderkey = Int32(1)) [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __correlated_sq_1 [o_custkey:Int64, o_orderkey:Int64] + Projection: orders.o_custkey, orders.o_orderkey [o_custkey:Int64, o_orderkey:Int64] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test for correlated IN without projection @@ -861,19 +860,17 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: customer.c_custkey + Int32(1) = __correlated_sq_1.o_custkey AND customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelatePredicateSubquery::new()), + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + LeftSemi Join: Filter: customer.c_custkey + Int32(1) = __correlated_sq_1.o_custkey AND customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __correlated_sq_1 [o_custkey:Int64] + Projection: orders.o_custkey [o_custkey:Int64] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test for correlated IN expressions @@ -894,19 +891,17 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.orders.o_custkey + Int32(1) AND customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [orders.o_custkey + Int32(1):Int64, o_custkey:Int64]\ - \n Projection: orders.o_custkey + Int32(1), orders.o_custkey [orders.o_custkey + Int32(1):Int64, o_custkey:Int64]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelatePredicateSubquery::new()), + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.orders.o_custkey + Int32(1) AND customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __correlated_sq_1 [orders.o_custkey + Int32(1):Int64, o_custkey:Int64] + Projection: orders.o_custkey + Int32(1), orders.o_custkey [orders.o_custkey + Int32(1):Int64, o_custkey:Int64] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test for correlated IN subquery multiple projected columns @@ -959,20 +954,18 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8]\ - \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelatePredicateSubquery::new()), + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + Filter: customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8] + LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __correlated_sq_1 [o_custkey:Int64] + Projection: orders.o_custkey [o_custkey:Int64] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test for correlated IN subquery filter @@ -990,19 +983,17 @@ mod tests { .project(vec![col("test.b")])? .build()?; - let expected = "Projection: test.b [b:UInt32]\ - \n LeftSemi Join: Filter: test.c = __correlated_sq_1.c AND test.a = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [c:UInt32, a:UInt32]\ - \n Projection: sq.c, sq.a [c:UInt32, a:UInt32]\ - \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelatePredicateSubquery::new()), + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: test.b [b:UInt32] + LeftSemi Join: Filter: test.c = __correlated_sq_1.c AND test.a = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [c:UInt32, a:UInt32] + Projection: sq.c, sq.a [c:UInt32, a:UInt32] + TableScan: sq [a:UInt32, b:UInt32, c:UInt32] + " + ) } /// Test for single IN subquery filter @@ -1014,19 +1005,17 @@ mod tests { .project(vec![col("test.b")])? .build()?; - let expected = "Projection: test.b [b:UInt32]\ - \n LeftSemi Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [c:UInt32]\ - \n Projection: sq.c [c:UInt32]\ - \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelatePredicateSubquery::new()), + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: test.b [b:UInt32] + LeftSemi Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [c:UInt32] + Projection: sq.c [c:UInt32] + TableScan: sq [a:UInt32, b:UInt32, c:UInt32] + " + ) } /// Test for single NOT IN subquery filter @@ -1038,19 +1027,17 @@ mod tests { .project(vec![col("test.b")])? .build()?; - let expected = "Projection: test.b [b:UInt32]\ - \n LeftAnti Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [c:UInt32]\ - \n Projection: sq.c [c:UInt32]\ - \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelatePredicateSubquery::new()), + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: test.b [b:UInt32] + LeftAnti Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [c:UInt32] + Projection: sq.c [c:UInt32] + TableScan: sq [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -1061,19 +1048,17 @@ mod tests { .project(vec![col("test.b")])? .build()?; - let expected = "Projection: test.b [b:UInt32]\ - \n LeftAnti Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [c:UInt32]\ - \n Projection: sq.c [c:UInt32]\ - \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelatePredicateSubquery::new()), + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: test.b [b:UInt32] + LeftAnti Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [c:UInt32] + Projection: sq.c [c:UInt32] + TableScan: sq [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -1087,19 +1072,17 @@ mod tests { .project(vec![col("test.b")])? .build()?; - let expected = "Projection: test.b [b:UInt32]\ - \n LeftSemi Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [c:UInt32]\ - \n Projection: sq.c [c:UInt32]\ - \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelatePredicateSubquery::new()), + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: test.b [b:UInt32] + LeftSemi Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [c:UInt32] + Projection: sq.c [c:UInt32] + TableScan: sq [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -1116,19 +1099,17 @@ mod tests { .project(vec![col("test.b")])? .build()?; - let expected = "Projection: test.b [b:UInt32]\ - \n LeftSemi Join: Filter: test.c + UInt32(1) = __correlated_sq_1.sq.c * UInt32(2) [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [sq.c * UInt32(2):UInt32]\ - \n Projection: sq.c * UInt32(2) [sq.c * UInt32(2):UInt32]\ - \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelatePredicateSubquery::new()), + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: test.b [b:UInt32] + LeftSemi Join: Filter: test.c + UInt32(1) = __correlated_sq_1.sq.c * UInt32(2) [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [sq.c * UInt32(2):UInt32] + Projection: sq.c * UInt32(2) [sq.c * UInt32(2):UInt32] + TableScan: sq [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -1150,20 +1131,18 @@ mod tests { .project(vec![col("test.b")])? .build()?; - let expected = "Projection: test.b [b:UInt32]\ - \n LeftSemi Join: Filter: test.c + UInt32(1) = __correlated_sq_1.sq.c * UInt32(2) AND test.a = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [sq.c * UInt32(2):UInt32, a:UInt32]\ - \n Projection: sq.c * UInt32(2), sq.a [sq.c * UInt32(2):UInt32, a:UInt32]\ - \n Filter: sq.a + UInt32(1) = sq.b [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelatePredicateSubquery::new()), + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: test.b [b:UInt32] + LeftSemi Join: Filter: test.c + UInt32(1) = __correlated_sq_1.sq.c * UInt32(2) AND test.a = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [sq.c * UInt32(2):UInt32, a:UInt32] + Projection: sq.c * UInt32(2), sq.a [sq.c * UInt32(2):UInt32, a:UInt32] + Filter: sq.a + UInt32(1) = sq.b [a:UInt32, b:UInt32, c:UInt32] + TableScan: sq [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -1186,20 +1165,18 @@ mod tests { .project(vec![col("test.b")])? .build()?; - let expected = "Projection: test.b [b:UInt32]\ - \n LeftSemi Join: Filter: test.c + UInt32(1) = __correlated_sq_1.sq.c * UInt32(2) AND test.a + test.b = __correlated_sq_1.a + __correlated_sq_1.b [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [sq.c * UInt32(2):UInt32, a:UInt32, b:UInt32]\ - \n Projection: sq.c * UInt32(2), sq.a, sq.b [sq.c * UInt32(2):UInt32, a:UInt32, b:UInt32]\ - \n Filter: sq.a + UInt32(1) = sq.b [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelatePredicateSubquery::new()), + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: test.b [b:UInt32] + LeftSemi Join: Filter: test.c + UInt32(1) = __correlated_sq_1.sq.c * UInt32(2) AND test.a + test.b = __correlated_sq_1.a + __correlated_sq_1.b [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [sq.c * UInt32(2):UInt32, a:UInt32, b:UInt32] + Projection: sq.c * UInt32(2), sq.a, sq.b [sq.c * UInt32(2):UInt32, a:UInt32, b:UInt32] + Filter: sq.a + UInt32(1) = sq.b [a:UInt32, b:UInt32, c:UInt32] + TableScan: sq [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -1228,24 +1205,22 @@ mod tests { .project(vec![col("test.b")])? .build()?; - let expected = "Projection: test.b [b:UInt32]\ - \n Filter: test.c > UInt32(1) [a:UInt32, b:UInt32, c:UInt32]\ - \n LeftSemi Join: Filter: test.c * UInt32(2) = __correlated_sq_2.sq2.c * UInt32(2) AND test.a > __correlated_sq_2.a [a:UInt32, b:UInt32, c:UInt32]\ - \n LeftSemi Join: Filter: test.c + UInt32(1) = __correlated_sq_1.sq1.c * UInt32(2) AND test.a > __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [sq1.c * UInt32(2):UInt32, a:UInt32]\ - \n Projection: sq1.c * UInt32(2), sq1.a [sq1.c * UInt32(2):UInt32, a:UInt32]\ - \n TableScan: sq1 [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_2 [sq2.c * UInt32(2):UInt32, a:UInt32]\ - \n Projection: sq2.c * UInt32(2), sq2.a [sq2.c * UInt32(2):UInt32, a:UInt32]\ - \n TableScan: sq2 [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelatePredicateSubquery::new()), + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: test.b [b:UInt32] + Filter: test.c > UInt32(1) [a:UInt32, b:UInt32, c:UInt32] + LeftSemi Join: Filter: test.c * UInt32(2) = __correlated_sq_2.sq2.c * UInt32(2) AND test.a > __correlated_sq_2.a [a:UInt32, b:UInt32, c:UInt32] + LeftSemi Join: Filter: test.c + UInt32(1) = __correlated_sq_1.sq1.c * UInt32(2) AND test.a > __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [sq1.c * UInt32(2):UInt32, a:UInt32] + Projection: sq1.c * UInt32(2), sq1.a [sq1.c * UInt32(2):UInt32, a:UInt32] + TableScan: sq1 [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_2 [sq2.c * UInt32(2):UInt32, a:UInt32] + Projection: sq2.c * UInt32(2), sq2.a [sq2.c * UInt32(2):UInt32, a:UInt32] + TableScan: sq2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -1263,20 +1238,18 @@ mod tests { .build()?; // Subquery and outer query refer to the same table. - let expected = "Projection: test.b [b:UInt32]\ - \n LeftSemi Join: Filter: test.a = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [c:UInt32]\ - \n Projection: test.c [c:UInt32]\ - \n Filter: test.a > test.b [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelatePredicateSubquery::new()), + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: test.b [b:UInt32] + LeftSemi Join: Filter: test.a = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [c:UInt32] + Projection: test.c [c:UInt32] + Filter: test.a > test.b [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + " + ) } /// Test for multiple exists subqueries in the same filter expression @@ -1297,17 +1270,21 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: __correlated_sq_2.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8]\ - \n LeftSemi Join: Filter: __correlated_sq_1.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n SubqueryAlias: __correlated_sq_2 [o_custkey:Int64]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: customer.c_custkey [c_custkey:Int64] + LeftSemi Join: Filter: __correlated_sq_2.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8] + LeftSemi Join: Filter: __correlated_sq_1.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __correlated_sq_1 [o_custkey:Int64] + Projection: orders.o_custkey [o_custkey:Int64] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + SubqueryAlias: __correlated_sq_2 [o_custkey:Int64] + Projection: orders.o_custkey [o_custkey:Int64] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test recursive correlated subqueries @@ -1340,17 +1317,21 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: __correlated_sq_2.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_2 [o_custkey:Int64]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n LeftSemi Join: Filter: __correlated_sq_1.l_orderkey = orders.o_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n SubqueryAlias: __correlated_sq_1 [l_orderkey:Int64]\ - \n Projection: lineitem.l_orderkey [l_orderkey:Int64]\ - \n TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: customer.c_custkey [c_custkey:Int64] + LeftSemi Join: Filter: __correlated_sq_2.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __correlated_sq_2 [o_custkey:Int64] + Projection: orders.o_custkey [o_custkey:Int64] + LeftSemi Join: Filter: __correlated_sq_1.l_orderkey = orders.o_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + SubqueryAlias: __correlated_sq_1 [l_orderkey:Int64] + Projection: lineitem.l_orderkey [l_orderkey:Int64] + TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64] + " + ) } /// Test for correlated exists subquery filter with additional subquery filters @@ -1372,15 +1353,18 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n Filter: orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: customer.c_custkey [c_custkey:Int64] + LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __correlated_sq_1 [o_custkey:Int64] + Projection: orders.o_custkey [o_custkey:Int64] + Filter: orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } #[test] @@ -1398,14 +1382,17 @@ mod tests { .build()?; // Other rule will pushdown `customer.c_custkey = 1`, - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: customer.c_custkey = UInt32(1) [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: customer.c_custkey [c_custkey:Int64] + LeftSemi Join: Filter: customer.c_custkey = UInt32(1) [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __correlated_sq_1 [o_custkey:Int64] + Projection: orders.o_custkey [o_custkey:Int64] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test for exists subquery with both columns in schema @@ -1423,14 +1410,18 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n Filter: orders.o_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: customer.c_custkey [c_custkey:Int64] + LeftSemi Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __correlated_sq_1 [o_custkey:Int64] + Projection: orders.o_custkey [o_custkey:Int64] + Filter: orders.o_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test for correlated exists subquery not equal @@ -1451,14 +1442,17 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: customer.c_custkey != __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: customer.c_custkey [c_custkey:Int64] + LeftSemi Join: Filter: customer.c_custkey != __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __correlated_sq_1 [o_custkey:Int64] + Projection: orders.o_custkey [o_custkey:Int64] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test for correlated exists subquery less than @@ -1479,14 +1473,17 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: customer.c_custkey < __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: customer.c_custkey [c_custkey:Int64] + LeftSemi Join: Filter: customer.c_custkey < __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __correlated_sq_1 [o_custkey:Int64] + Projection: orders.o_custkey [o_custkey:Int64] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test for correlated exists subquery filter with subquery disjunction @@ -1508,14 +1505,17 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey OR __correlated_sq_1.o_orderkey = Int32(1) [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64, o_orderkey:Int64]\ - \n Projection: orders.o_custkey, orders.o_orderkey [o_custkey:Int64, o_orderkey:Int64]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: customer.c_custkey [c_custkey:Int64] + LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey OR __correlated_sq_1.o_orderkey = Int32(1) [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __correlated_sq_1 [o_custkey:Int64, o_orderkey:Int64] + Projection: orders.o_custkey, orders.o_orderkey [o_custkey:Int64, o_orderkey:Int64] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test for correlated exists without projection @@ -1535,13 +1535,16 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: customer.c_custkey [c_custkey:Int64] + LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __correlated_sq_1 [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test for correlated exists expressions @@ -1562,14 +1565,17 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [orders.o_custkey + Int32(1):Int64, o_custkey:Int64]\ - \n Projection: orders.o_custkey + Int32(1), orders.o_custkey [orders.o_custkey + Int32(1):Int64, o_custkey:Int64]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: customer.c_custkey [c_custkey:Int64] + LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __correlated_sq_1 [orders.o_custkey + Int32(1):Int64, o_custkey:Int64] + Projection: orders.o_custkey + Int32(1), orders.o_custkey [orders.o_custkey + Int32(1):Int64, o_custkey:Int64] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test for correlated exists subquery filter with additional filters @@ -1589,15 +1595,18 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8]\ - \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: customer.c_custkey [c_custkey:Int64] + Filter: customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8] + LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __correlated_sq_1 [o_custkey:Int64] + Projection: orders.o_custkey [o_custkey:Int64] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test for correlated exists subquery filter with disjunctions @@ -1615,16 +1624,19 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: __correlated_sq_1.mark OR customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, mark:Boolean]\ - \n LeftMark Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, mark:Boolean]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n Filter: customer.c_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: customer.c_custkey [c_custkey:Int64] + Filter: __correlated_sq_1.mark OR customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, mark:Boolean] + LeftMark Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, mark:Boolean] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __correlated_sq_1 [o_custkey:Int64] + Projection: orders.o_custkey [o_custkey:Int64] + Filter: customer.c_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test for correlated EXISTS subquery filter @@ -1642,14 +1654,17 @@ mod tests { .project(vec![col("test.c")])? .build()?; - let expected = "Projection: test.c [c:UInt32]\ - \n LeftSemi Join: Filter: test.a = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [c:UInt32, a:UInt32]\ - \n Projection: sq.c, sq.a [c:UInt32, a:UInt32]\ - \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.c [c:UInt32] + LeftSemi Join: Filter: test.a = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [c:UInt32, a:UInt32] + Projection: sq.c, sq.a [c:UInt32, a:UInt32] + TableScan: sq [a:UInt32, b:UInt32, c:UInt32] + " + ) } /// Test for single exists subquery filter @@ -1661,13 +1676,17 @@ mod tests { .project(vec![col("test.b")])? .build()?; - let expected = "Projection: test.b [b:UInt32]\ - \n LeftSemi Join: Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [c:UInt32]\ - \n Projection: sq.c [c:UInt32]\ - \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.b [b:UInt32] + LeftSemi Join: Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [c:UInt32] + Projection: sq.c [c:UInt32] + TableScan: sq [a:UInt32, b:UInt32, c:UInt32] + " + ) } /// Test for single NOT exists subquery filter @@ -1679,13 +1698,17 @@ mod tests { .project(vec![col("test.b")])? .build()?; - let expected = "Projection: test.b [b:UInt32]\ - \n LeftAnti Join: Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [c:UInt32]\ - \n Projection: sq.c [c:UInt32]\ - \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.b [b:UInt32] + LeftAnti Join: Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [c:UInt32] + Projection: sq.c [c:UInt32] + TableScan: sq [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -1712,19 +1735,22 @@ mod tests { .project(vec![col("test.b")])? .build()?; - let expected = "Projection: test.b [b:UInt32]\ - \n Filter: test.c > UInt32(1) [a:UInt32, b:UInt32, c:UInt32]\ - \n LeftSemi Join: Filter: test.a = __correlated_sq_2.a [a:UInt32, b:UInt32, c:UInt32]\ - \n LeftSemi Join: Filter: test.a = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [c:UInt32, a:UInt32]\ - \n Projection: sq1.c, sq1.a [c:UInt32, a:UInt32]\ - \n TableScan: sq1 [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_2 [c:UInt32, a:UInt32]\ - \n Projection: sq2.c, sq2.a [c:UInt32, a:UInt32]\ - \n TableScan: sq2 [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.b [b:UInt32] + Filter: test.c > UInt32(1) [a:UInt32, b:UInt32, c:UInt32] + LeftSemi Join: Filter: test.a = __correlated_sq_2.a [a:UInt32, b:UInt32, c:UInt32] + LeftSemi Join: Filter: test.a = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [c:UInt32, a:UInt32] + Projection: sq1.c, sq1.a [c:UInt32, a:UInt32] + TableScan: sq1 [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_2 [c:UInt32, a:UInt32] + Projection: sq2.c, sq2.a [c:UInt32, a:UInt32] + TableScan: sq2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -1743,14 +1769,17 @@ mod tests { .project(vec![col("test.b")])? .build()?; - let expected = "Projection: test.b [b:UInt32]\ - \n LeftSemi Join: Filter: UInt32(1) + __correlated_sq_1.a > test.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [UInt32(1):UInt32, a:UInt32]\ - \n Projection: UInt32(1), sq.a [UInt32(1):UInt32, a:UInt32]\ - \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.b [b:UInt32] + LeftSemi Join: Filter: UInt32(1) + __correlated_sq_1.a > test.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [UInt32(1):UInt32, a:UInt32] + Projection: UInt32(1), sq.a [UInt32(1):UInt32, a:UInt32] + TableScan: sq [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -1768,15 +1797,18 @@ mod tests { .build()?; // Subquery and outer query refer to the same table. - let expected = "Projection: test.b [b:UInt32]\ - \n LeftSemi Join: Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [c:UInt32]\ - \n Projection: test.c [c:UInt32]\ - \n Filter: test.a > test.b [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.b [b:UInt32] + LeftSemi Join: Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [c:UInt32] + Projection: test.c [c:UInt32] + Filter: test.a > test.b [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -1796,15 +1828,18 @@ mod tests { .project(vec![col("test.b")])? .build()?; - let expected = "Projection: test.b [b:UInt32]\ - \n LeftSemi Join: Filter: UInt32(1) + __correlated_sq_1.a > test.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [c:UInt32, a:UInt32]\ - \n Distinct: [c:UInt32, a:UInt32]\ - \n Projection: sq.c, sq.a [c:UInt32, a:UInt32]\ - \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.b [b:UInt32] + LeftSemi Join: Filter: UInt32(1) + __correlated_sq_1.a > test.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [c:UInt32, a:UInt32] + Distinct: [c:UInt32, a:UInt32] + Projection: sq.c, sq.a [c:UInt32, a:UInt32] + TableScan: sq [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -1824,15 +1859,18 @@ mod tests { .project(vec![col("test.b")])? .build()?; - let expected = "Projection: test.b [b:UInt32]\ - \n LeftSemi Join: Filter: UInt32(1) + __correlated_sq_1.a > test.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [sq.b + sq.c:UInt32, a:UInt32]\ - \n Distinct: [sq.b + sq.c:UInt32, a:UInt32]\ - \n Projection: sq.b + sq.c, sq.a [sq.b + sq.c:UInt32, a:UInt32]\ - \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.b [b:UInt32] + LeftSemi Join: Filter: UInt32(1) + __correlated_sq_1.a > test.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [sq.b + sq.c:UInt32, a:UInt32] + Distinct: [sq.b + sq.c:UInt32, a:UInt32] + Projection: sq.b + sq.c, sq.a [sq.b + sq.c:UInt32, a:UInt32] + TableScan: sq [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -1852,15 +1890,18 @@ mod tests { .project(vec![col("test.b")])? .build()?; - let expected = "Projection: test.b [b:UInt32]\ - \n LeftSemi Join: Filter: UInt32(1) + __correlated_sq_1.a > test.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [UInt32(1):UInt32, c:UInt32, a:UInt32]\ - \n Distinct: [UInt32(1):UInt32, c:UInt32, a:UInt32]\ - \n Projection: UInt32(1), sq.c, sq.a [UInt32(1):UInt32, c:UInt32, a:UInt32]\ - \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.b [b:UInt32] + LeftSemi Join: Filter: UInt32(1) + __correlated_sq_1.a > test.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [UInt32(1):UInt32, c:UInt32, a:UInt32] + Distinct: [UInt32(1):UInt32, c:UInt32, a:UInt32] + Projection: UInt32(1), sq.c, sq.a [UInt32(1):UInt32, c:UInt32, a:UInt32] + TableScan: sq [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -1884,13 +1925,17 @@ mod tests { .project(vec![col("test.b")])? .build()?; - let expected = "Projection: test.b [b:UInt32]\ - \n LeftSemi Join: Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [arr:Int32;N]\ - \n Unnest: lists[sq.arr|depth=1] structs[] [arr:Int32;N]\ - \n TableScan: sq [arr:List(Field { name: \"item\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} });N]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r#" + Projection: test.b [b:UInt32] + LeftSemi Join: Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [arr:Int32;N] + Unnest: lists[sq.arr|depth=1] structs[] [arr:Int32;N] + TableScan: sq [arr:List(Field { name: "item", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} });N] + "# + ) } #[test] @@ -1915,14 +1960,17 @@ mod tests { .project(vec![col("test.b")])? .build()?; - let expected = "Projection: test.b [b:UInt32]\ - \n LeftSemi Join: Filter: __correlated_sq_1.a = test.b [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [a:UInt32;N]\ - \n Unnest: lists[sq.a|depth=1] structs[] [a:UInt32;N]\ - \n TableScan: sq [a:List(Field { name: \"item\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} });N]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r#" + Projection: test.b [b:UInt32] + LeftSemi Join: Filter: __correlated_sq_1.a = test.b [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [a:UInt32;N] + Unnest: lists[sq.a|depth=1] structs[] [a:UInt32;N] + TableScan: sq [a:List(Field { name: "item", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} });N] + "# + ) } #[test] @@ -1946,13 +1994,16 @@ mod tests { .project(vec![col("\"TEST_A\".\"B\"")])? .build()?; - let expected = "Projection: TEST_A.B [B:UInt32]\ - \n LeftSemi Join: Filter: __correlated_sq_1.A = TEST_A.A [A:UInt32, B:UInt32]\ - \n TableScan: TEST_A [A:UInt32, B:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [Int32(1):Int32, A:UInt32]\ - \n Projection: Int32(1), TEST_B.A [Int32(1):Int32, A:UInt32]\ - \n TableScan: TEST_B [A:UInt32, B:UInt32]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: TEST_A.B [B:UInt32] + LeftSemi Join: Filter: __correlated_sq_1.A = TEST_A.A [A:UInt32, B:UInt32] + TableScan: TEST_A [A:UInt32, B:UInt32] + SubqueryAlias: __correlated_sq_1 [Int32(1):Int32, A:UInt32] + Projection: Int32(1), TEST_B.A [Int32(1):Int32, A:UInt32] + TableScan: TEST_B [A:UInt32, B:UInt32] + " + ) } } diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs b/datafusion/optimizer/src/eliminate_cross_join.rs index d35572e6d34a..d465faf0c5e8 100644 --- a/datafusion/optimizer/src/eliminate_cross_join.rs +++ b/datafusion/optimizer/src/eliminate_cross_join.rs @@ -440,22 +440,28 @@ mod tests { logical_plan::builder::LogicalPlanBuilder, Operator::{And, Or}, }; + use insta::assert_snapshot; + + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let starting_schema = Arc::clone($plan.schema()); + let rule = EliminateCrossJoin::new(); + let Transformed {transformed: is_plan_transformed, data: optimized_plan, ..} = rule.rewrite($plan, &OptimizerContext::new()).unwrap(); + let formatted_plan = optimized_plan.display_indent_schema(); + // Ensure the rule was actually applied + assert!(is_plan_transformed, "failed to optimize plan"); + // Verify the schema remains unchanged + assert_eq!(&starting_schema, optimized_plan.schema()); + assert_snapshot!( + formatted_plan, + @ $expected, + ); - fn assert_optimized_plan_eq(plan: LogicalPlan, expected: Vec<&str>) { - let starting_schema = Arc::clone(plan.schema()); - let rule = EliminateCrossJoin::new(); - let transformed_plan = rule.rewrite(plan, &OptimizerContext::new()).unwrap(); - assert!(transformed_plan.transformed, "failed to optimize plan"); - let optimized_plan = transformed_plan.data; - let formatted = optimized_plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - - assert_eq!(&starting_schema, optimized_plan.schema()) + Ok(()) + }}; } #[test] @@ -473,16 +479,15 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: t2.c < UInt32(20) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - ]; - - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: t2.c < UInt32(20) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -501,16 +506,15 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: t1.a = t2.a OR t2.b = t1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - ]; - - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: t1.a = t2.a OR t2.b = t1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -528,16 +532,15 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: t2.c < UInt32(20) AND t2.c = UInt32(10) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - ]; - - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: t2.c < UInt32(20) AND t2.c = UInt32(10) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -559,15 +562,15 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - ]; - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -589,15 +592,15 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: t1.a = t2.a AND t2.c < UInt32(15) OR t1.b = t2.b AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - ]; - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: t1.a = t2.a AND t2.c < UInt32(15) OR t1.b = t2.b AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -615,15 +618,15 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: t1.a = t2.a AND t2.c < UInt32(15) OR t1.a = t2.a OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - ]; - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: t1.a = t2.a AND t2.c < UInt32(15) OR t1.a = t2.a OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -644,19 +647,18 @@ mod tests { .filter(col("t1.a").gt(lit(15u32)))? .build()?; - let expected = vec![ - "Filter: t1.a > UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Filter: t1.a > UInt32(20) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]" - ]; - - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: t1.a > UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Filter: t1.a > UInt32(20) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t3 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -691,19 +693,18 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: t3.c < UInt32(15) AND t3.b < UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Projection: t1.a, t1.b, t1.c, t2.a, t2.b, t2.c, t3.a, t3.b, t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t3.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - ]; - - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: t3.c < UInt32(15) AND t3.b < UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Projection: t1.a, t1.b, t1.c, t2.a, t2.b, t2.c, t3.a, t3.b, t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t3.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t3 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -765,22 +766,21 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - " Filter: t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]", - ]; - - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + Filter: t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t3 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t4 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -840,22 +840,21 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: t3.a = t1.a AND t4.c < UInt32(15) OR t3.a = t1.a OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - " Filter: t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]", - ]; - - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: t3.a = t1.a AND t4.c < UInt32(15) OR t3.a = t1.a OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + Filter: t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t3 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t4 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -915,22 +914,21 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - " Filter: t3.a = t4.a AND t4.c < UInt32(15) OR t3.a = t4.a AND t3.c = UInt32(688) OR t3.a = t4.a OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]", - ]; - - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + Filter: t3.a = t4.a AND t4.c < UInt32(15) OR t3.a = t4.a AND t3.c = UInt32(688) OR t3.a = t4.a OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t3 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t4 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -994,22 +992,21 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Filter: t1.a = t2.a OR t2.c < UInt32(15) OR t1.a = t2.a AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - " Filter: t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]", - ]; - - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Filter: t1.a = t2.a OR t2.c < UInt32(15) OR t1.a = t2.a AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + Filter: t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t3 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t4 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -1083,21 +1080,20 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: (t4.c < UInt32(15) OR t4.c = UInt32(688)) AND (t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Filter: t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]", - ]; - - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: (t4.c < UInt32(15) OR t4.c = UInt32(688)) AND (t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Filter: t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t3 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t4 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -1177,20 +1173,19 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: (t4.c < UInt32(15) OR t4.c = UInt32(688)) AND (t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b) AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]", - ]; - - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: (t4.c < UInt32(15) OR t4.c = UInt32(688)) AND (t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b) AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t3 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t4 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -1208,15 +1203,15 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: t2.c < UInt32(20) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]"]; - - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: t2.c < UInt32(20) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -1235,16 +1230,15 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: t1.a + UInt32(100) = t2.a * UInt32(2) OR t2.b = t1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - ]; - - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: t1.a + UInt32(100) = t2.a * UInt32(2) OR t2.b = t1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -1263,16 +1257,15 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: t2.c < UInt32(20) AND t2.c = UInt32(10) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - ]; - - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: t2.c < UInt32(20) AND t2.c = UInt32(10) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -1291,16 +1284,15 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - ]; - - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -1328,18 +1320,17 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: t3.c < UInt32(15) AND t3.b < UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Projection: t1.a, t1.b, t1.c, t2.a, t2.b, t2.c, t3.a, t3.b, t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t3.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a * UInt32(2) = t3.a + UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - ]; - - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: t3.c < UInt32(15) AND t3.b < UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Projection: t1.a, t1.b, t1.c, t2.a, t2.b, t2.c, t3.a, t3.b, t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t3.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a * UInt32(2) = t3.a + UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t3 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } } diff --git a/datafusion/optimizer/src/eliminate_duplicated_expr.rs b/datafusion/optimizer/src/eliminate_duplicated_expr.rs index 466950092095..a6651df938a7 100644 --- a/datafusion/optimizer/src/eliminate_duplicated_expr.rs +++ b/datafusion/optimizer/src/eliminate_duplicated_expr.rs @@ -118,16 +118,26 @@ impl OptimizerRule for EliminateDuplicatedExpr { #[cfg(test)] mod tests { use super::*; + use crate::assert_optimized_plan_eq_snapshot; use crate::test::*; + use crate::OptimizerContext; use datafusion_expr::{col, logical_plan::builder::LogicalPlanBuilder}; use std::sync::Arc; - fn assert_optimized_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> { - crate::test::assert_optimized_plan_eq( - Arc::new(EliminateDuplicatedExpr::new()), - plan, - expected, - ) + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let optimizer_ctx = OptimizerContext::new().with_max_passes(1); + let rules: Vec> = vec![Arc::new(EliminateDuplicatedExpr::new())]; + assert_optimized_plan_eq_snapshot!( + optimizer_ctx, + rules, + $plan, + @ $expected, + ) + }}; } #[test] @@ -137,10 +147,12 @@ mod tests { .sort_by(vec![col("a"), col("a"), col("b"), col("c")])? .limit(5, Some(10))? .build()?; - let expected = "Limit: skip=5, fetch=10\ - \n Sort: test.a ASC NULLS LAST, test.b ASC NULLS LAST, test.c ASC NULLS LAST\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + + assert_optimized_plan_equal!(plan, @r" + Limit: skip=5, fetch=10 + Sort: test.a ASC NULLS LAST, test.b ASC NULLS LAST, test.c ASC NULLS LAST + TableScan: test + ") } #[test] @@ -156,9 +168,11 @@ mod tests { .sort(sort_exprs)? .limit(5, Some(10))? .build()?; - let expected = "Limit: skip=5, fetch=10\ - \n Sort: test.a ASC NULLS FIRST, test.b ASC NULLS LAST\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + + assert_optimized_plan_equal!(plan, @r" + Limit: skip=5, fetch=10 + Sort: test.a ASC NULLS FIRST, test.b ASC NULLS LAST + TableScan: test + ") } } diff --git a/datafusion/optimizer/src/eliminate_filter.rs b/datafusion/optimizer/src/eliminate_filter.rs index 4ed2ac8ba1a4..e28771be548b 100644 --- a/datafusion/optimizer/src/eliminate_filter.rs +++ b/datafusion/optimizer/src/eliminate_filter.rs @@ -60,7 +60,7 @@ impl OptimizerRule for EliminateFilter { ) -> Result> { match plan { LogicalPlan::Filter(Filter { - predicate: Expr::Literal(ScalarValue::Boolean(v)), + predicate: Expr::Literal(ScalarValue::Boolean(v), _), input, .. }) => match v { @@ -81,17 +81,29 @@ impl OptimizerRule for EliminateFilter { mod tests { use std::sync::Arc; + use crate::assert_optimized_plan_eq_snapshot; + use crate::OptimizerContext; use datafusion_common::{Result, ScalarValue}; - use datafusion_expr::{ - col, lit, logical_plan::builder::LogicalPlanBuilder, Expr, LogicalPlan, - }; + use datafusion_expr::{col, lit, logical_plan::builder::LogicalPlanBuilder, Expr}; use crate::eliminate_filter::EliminateFilter; use crate::test::*; use datafusion_expr::test::function_stub::sum; - fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { - assert_optimized_plan_eq(Arc::new(EliminateFilter::new()), plan, expected) + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let optimizer_ctx = OptimizerContext::new().with_max_passes(1); + let rules: Vec> = vec![Arc::new(EliminateFilter::new())]; + assert_optimized_plan_eq_snapshot!( + optimizer_ctx, + rules, + $plan, + @ $expected, + ) + }}; } #[test] @@ -105,13 +117,12 @@ mod tests { .build()?; // No aggregate / scan / limit - let expected = "EmptyRelation"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!(plan, @"EmptyRelation") } #[test] fn filter_null() -> Result<()> { - let filter_expr = Expr::Literal(ScalarValue::Boolean(None)); + let filter_expr = Expr::Literal(ScalarValue::Boolean(None), None); let table_scan = test_table_scan().unwrap(); let plan = LogicalPlanBuilder::from(table_scan) @@ -120,8 +131,7 @@ mod tests { .build()?; // No aggregate / scan / limit - let expected = "EmptyRelation"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!(plan, @"EmptyRelation") } #[test] @@ -139,11 +149,12 @@ mod tests { .build()?; // Left side is removed - let expected = "Union\ - \n EmptyRelation\ - \n Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\ - \n TableScan: test"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!(plan, @r" + Union + EmptyRelation + Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]] + TableScan: test + ") } #[test] @@ -156,9 +167,10 @@ mod tests { .filter(filter_expr)? .build()?; - let expected = "Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\ - \n TableScan: test"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!(plan, @r" + Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]] + TableScan: test + ") } #[test] @@ -176,12 +188,13 @@ mod tests { .build()?; // Filter is removed - let expected = "Union\ - \n Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\ - \n TableScan: test\ - \n Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\ - \n TableScan: test"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!(plan, @r" + Union + Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]] + TableScan: test + Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]] + TableScan: test + ") } #[test] @@ -202,8 +215,9 @@ mod tests { .build()?; // Filter is removed - let expected = "Projection: test.a\ - \n EmptyRelation"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!(plan, @r" + Projection: test.a + EmptyRelation + ") } } diff --git a/datafusion/optimizer/src/eliminate_group_by_constant.rs b/datafusion/optimizer/src/eliminate_group_by_constant.rs index 7e252d6dcea0..9c47ce024f91 100644 --- a/datafusion/optimizer/src/eliminate_group_by_constant.rs +++ b/datafusion/optimizer/src/eliminate_group_by_constant.rs @@ -101,7 +101,7 @@ fn is_constant_expression(expr: &Expr) -> bool { Expr::BinaryExpr(e) => { is_constant_expression(&e.left) && is_constant_expression(&e.right) } - Expr::Literal(_) => true, + Expr::Literal(_, _) => true, Expr::ScalarFunction(e) => { matches!( e.func.signature().volatility, @@ -115,7 +115,9 @@ fn is_constant_expression(expr: &Expr) -> bool { #[cfg(test)] mod tests { use super::*; + use crate::assert_optimized_plan_eq_snapshot; use crate::test::*; + use crate::OptimizerContext; use arrow::datatypes::DataType; use datafusion_common::Result; @@ -129,6 +131,22 @@ mod tests { use std::sync::Arc; + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let optimizer_ctx = OptimizerContext::new().with_max_passes(1); + let rules: Vec> = vec![Arc::new(EliminateGroupByConstant::new())]; + assert_optimized_plan_eq_snapshot!( + optimizer_ctx, + rules, + $plan, + @ $expected, + ) + }}; + } + #[derive(Debug)] struct ScalarUDFMock { signature: Signature, @@ -167,17 +185,11 @@ mod tests { .aggregate(vec![col("a"), lit(1u32)], vec![count(col("c"))])? .build()?; - let expected = "\ - Projection: test.a, UInt32(1), count(test.c)\ - \n Aggregate: groupBy=[[test.a]], aggr=[[count(test.c)]]\ - \n TableScan: test\ - "; - - assert_optimized_plan_eq( - Arc::new(EliminateGroupByConstant::new()), - plan, - expected, - ) + assert_optimized_plan_equal!(plan, @r" + Projection: test.a, UInt32(1), count(test.c) + Aggregate: groupBy=[[test.a]], aggr=[[count(test.c)]] + TableScan: test + ") } #[test] @@ -187,17 +199,11 @@ mod tests { .aggregate(vec![lit("test"), lit(123u32)], vec![count(col("c"))])? .build()?; - let expected = "\ - Projection: Utf8(\"test\"), UInt32(123), count(test.c)\ - \n Aggregate: groupBy=[[]], aggr=[[count(test.c)]]\ - \n TableScan: test\ - "; - - assert_optimized_plan_eq( - Arc::new(EliminateGroupByConstant::new()), - plan, - expected, - ) + assert_optimized_plan_equal!(plan, @r#" + Projection: Utf8("test"), UInt32(123), count(test.c) + Aggregate: groupBy=[[]], aggr=[[count(test.c)]] + TableScan: test + "#) } #[test] @@ -207,16 +213,10 @@ mod tests { .aggregate(vec![col("a"), col("b")], vec![count(col("c"))])? .build()?; - let expected = "\ - Aggregate: groupBy=[[test.a, test.b]], aggr=[[count(test.c)]]\ - \n TableScan: test\ - "; - - assert_optimized_plan_eq( - Arc::new(EliminateGroupByConstant::new()), - plan, - expected, - ) + assert_optimized_plan_equal!(plan, @r" + Aggregate: groupBy=[[test.a, test.b]], aggr=[[count(test.c)]] + TableScan: test + ") } #[test] @@ -226,16 +226,10 @@ mod tests { .aggregate(vec![lit(123u32)], Vec::::new())? .build()?; - let expected = "\ - Aggregate: groupBy=[[UInt32(123)]], aggr=[[]]\ - \n TableScan: test\ - "; - - assert_optimized_plan_eq( - Arc::new(EliminateGroupByConstant::new()), - plan, - expected, - ) + assert_optimized_plan_equal!(plan, @r" + Aggregate: groupBy=[[UInt32(123)]], aggr=[[]] + TableScan: test + ") } #[test] @@ -248,17 +242,11 @@ mod tests { )? .build()?; - let expected = "\ - Projection: UInt32(123) AS const, test.a, count(test.c)\ - \n Aggregate: groupBy=[[test.a]], aggr=[[count(test.c)]]\ - \n TableScan: test\ - "; - - assert_optimized_plan_eq( - Arc::new(EliminateGroupByConstant::new()), - plan, - expected, - ) + assert_optimized_plan_equal!(plan, @r" + Projection: UInt32(123) AS const, test.a, count(test.c) + Aggregate: groupBy=[[test.a]], aggr=[[count(test.c)]] + TableScan: test + ") } #[test] @@ -273,17 +261,11 @@ mod tests { .aggregate(vec![udf_expr, col("a")], vec![count(col("c"))])? .build()?; - let expected = "\ - Projection: scalar_fn_mock(UInt32(123)), test.a, count(test.c)\ - \n Aggregate: groupBy=[[test.a]], aggr=[[count(test.c)]]\ - \n TableScan: test\ - "; - - assert_optimized_plan_eq( - Arc::new(EliminateGroupByConstant::new()), - plan, - expected, - ) + assert_optimized_plan_equal!(plan, @r" + Projection: scalar_fn_mock(UInt32(123)), test.a, count(test.c) + Aggregate: groupBy=[[test.a]], aggr=[[count(test.c)]] + TableScan: test + ") } #[test] @@ -298,15 +280,9 @@ mod tests { .aggregate(vec![udf_expr, col("a")], vec![count(col("c"))])? .build()?; - let expected = "\ - Aggregate: groupBy=[[scalar_fn_mock(UInt32(123)), test.a]], aggr=[[count(test.c)]]\ - \n TableScan: test\ - "; - - assert_optimized_plan_eq( - Arc::new(EliminateGroupByConstant::new()), - plan, - expected, - ) + assert_optimized_plan_equal!(plan, @r" + Aggregate: groupBy=[[scalar_fn_mock(UInt32(123)), test.a]], aggr=[[count(test.c)]] + TableScan: test + ") } } diff --git a/datafusion/optimizer/src/eliminate_join.rs b/datafusion/optimizer/src/eliminate_join.rs index 789235595dab..dfc3a220d0f9 100644 --- a/datafusion/optimizer/src/eliminate_join.rs +++ b/datafusion/optimizer/src/eliminate_join.rs @@ -54,7 +54,7 @@ impl OptimizerRule for EliminateJoin { match plan { LogicalPlan::Join(join) if join.join_type == Inner && join.on.is_empty() => { match join.filter { - Some(Expr::Literal(ScalarValue::Boolean(Some(false)))) => Ok( + Some(Expr::Literal(ScalarValue::Boolean(Some(false)), _)) => Ok( Transformed::yes(LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row: false, schema: join.schema, @@ -74,15 +74,28 @@ impl OptimizerRule for EliminateJoin { #[cfg(test)] mod tests { + use crate::assert_optimized_plan_eq_snapshot; use crate::eliminate_join::EliminateJoin; - use crate::test::*; + use crate::OptimizerContext; use datafusion_common::Result; use datafusion_expr::JoinType::Inner; - use datafusion_expr::{lit, logical_plan::builder::LogicalPlanBuilder, LogicalPlan}; + use datafusion_expr::{lit, logical_plan::builder::LogicalPlanBuilder}; use std::sync::Arc; - fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { - assert_optimized_plan_eq(Arc::new(EliminateJoin::new()), plan, expected) + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let optimizer_ctx = OptimizerContext::new().with_max_passes(1); + let rules: Vec> = vec![Arc::new(EliminateJoin::new())]; + assert_optimized_plan_eq_snapshot!( + optimizer_ctx, + rules, + $plan, + @ $expected, + ) + }}; } #[test] @@ -95,7 +108,6 @@ mod tests { )? .build()?; - let expected = "EmptyRelation"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!(plan, @"EmptyRelation") } } diff --git a/datafusion/optimizer/src/eliminate_limit.rs b/datafusion/optimizer/src/eliminate_limit.rs index 5d3a1b223b7a..2007e0c82045 100644 --- a/datafusion/optimizer/src/eliminate_limit.rs +++ b/datafusion/optimizer/src/eliminate_limit.rs @@ -90,7 +90,6 @@ impl OptimizerRule for EliminateLimit { #[cfg(test)] mod tests { use super::*; - use crate::optimizer::Optimizer; use crate::test::*; use crate::OptimizerContext; use datafusion_common::Column; @@ -100,36 +99,43 @@ mod tests { }; use std::sync::Arc; + use crate::assert_optimized_plan_eq_snapshot; use crate::push_down_limit::PushDownLimit; use datafusion_expr::test::function_stub::sum; - fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} - fn assert_optimized_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> { - let optimizer = Optimizer::with_rules(vec![Arc::new(EliminateLimit::new())]); - let optimized_plan = - optimizer.optimize(plan, &OptimizerContext::new(), observe)?; - - let formatted_plan = format!("{optimized_plan}"); - assert_eq!(formatted_plan, expected); - Ok(()) + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let rules: Vec> = vec![Arc::new(EliminateLimit::new())]; + let optimizer_ctx = OptimizerContext::new(); + assert_optimized_plan_eq_snapshot!( + optimizer_ctx, + rules, + $plan, + @ $expected, + ) + }}; } - fn assert_optimized_plan_eq_with_pushdown( - plan: LogicalPlan, - expected: &str, - ) -> Result<()> { - fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} - let config = OptimizerContext::new().with_max_passes(1); - let optimizer = Optimizer::with_rules(vec![ - Arc::new(PushDownLimit::new()), - Arc::new(EliminateLimit::new()), - ]); - let optimized_plan = optimizer - .optimize(plan, &config, observe) - .expect("failed to optimize plan"); - let formatted_plan = format!("{optimized_plan}"); - assert_eq!(formatted_plan, expected); - Ok(()) + macro_rules! assert_optimized_plan_eq_with_pushdown { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let optimizer_ctx = OptimizerContext::new().with_max_passes(1); + let rules: Vec> = vec![ + Arc::new(PushDownLimit::new()), + Arc::new(EliminateLimit::new()) + ]; + assert_optimized_plan_eq_snapshot!( + optimizer_ctx, + rules, + $plan, + @ $expected, + ) + }}; } #[test] @@ -140,8 +146,10 @@ mod tests { .limit(0, Some(0))? .build()?; // No aggregate / scan / limit - let expected = "EmptyRelation"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r"EmptyRelation" + ) } #[test] @@ -157,11 +165,15 @@ mod tests { .build()?; // Left side is removed - let expected = "Union\ - \n EmptyRelation\ - \n Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Union + EmptyRelation + Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]] + TableScan: test + " + ) } #[test] @@ -174,8 +186,10 @@ mod tests { .build()?; // No aggregate / scan / limit - let expected = "EmptyRelation"; - assert_optimized_plan_eq_with_pushdown(plan, expected) + assert_optimized_plan_eq_with_pushdown!( + plan, + @ "EmptyRelation" + ) } #[test] @@ -190,12 +204,16 @@ mod tests { // After remove global-state, we don't record the parent // So, bottom don't know parent info, so can't eliminate. - let expected = "Limit: skip=2, fetch=1\ - \n Sort: test.a ASC NULLS LAST, fetch=3\ - \n Limit: skip=0, fetch=2\ - \n Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\ - \n TableScan: test"; - assert_optimized_plan_eq_with_pushdown(plan, expected) + assert_optimized_plan_eq_with_pushdown!( + plan, + @ r" + Limit: skip=2, fetch=1 + Sort: test.a ASC NULLS LAST, fetch=3 + Limit: skip=0, fetch=2 + Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]] + TableScan: test + " + ) } #[test] @@ -208,12 +226,16 @@ mod tests { .limit(0, Some(1))? .build()?; - let expected = "Limit: skip=0, fetch=1\ - \n Sort: test.a ASC NULLS LAST\ - \n Limit: skip=0, fetch=2\ - \n Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Limit: skip=0, fetch=1 + Sort: test.a ASC NULLS LAST + Limit: skip=0, fetch=2 + Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]] + TableScan: test + " + ) } #[test] @@ -226,12 +248,16 @@ mod tests { .limit(3, Some(1))? .build()?; - let expected = "Limit: skip=3, fetch=1\ - \n Sort: test.a ASC NULLS LAST\ - \n Limit: skip=2, fetch=1\ - \n Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Limit: skip=3, fetch=1 + Sort: test.a ASC NULLS LAST + Limit: skip=2, fetch=1 + Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]] + TableScan: test + " + ) } #[test] @@ -248,12 +274,16 @@ mod tests { .limit(3, Some(1))? .build()?; - let expected = "Limit: skip=3, fetch=1\ - \n Inner Join: Using test.a = test1.a\ - \n Limit: skip=2, fetch=1\ - \n TableScan: test\ - \n TableScan: test1"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Limit: skip=3, fetch=1 + Inner Join: Using test.a = test1.a + Limit: skip=2, fetch=1 + TableScan: test + TableScan: test1 + " + ) } #[test] @@ -264,8 +294,12 @@ mod tests { .limit(0, None)? .build()?; - let expected = "Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]] + TableScan: test + " + ) } } diff --git a/datafusion/optimizer/src/eliminate_nested_union.rs b/datafusion/optimizer/src/eliminate_nested_union.rs index 94da08243d78..f8f93727cd9b 100644 --- a/datafusion/optimizer/src/eliminate_nested_union.rs +++ b/datafusion/optimizer/src/eliminate_nested_union.rs @@ -116,7 +116,8 @@ mod tests { use super::*; use crate::analyzer::type_coercion::TypeCoercion; use crate::analyzer::Analyzer; - use crate::test::*; + use crate::assert_optimized_plan_eq_snapshot; + use crate::OptimizerContext; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::config::ConfigOptions; use datafusion_expr::{col, logical_plan::table_scan}; @@ -129,15 +130,23 @@ mod tests { ]) } - fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { - let options = ConfigOptions::default(); - let analyzed_plan = Analyzer::with_rules(vec![Arc::new(TypeCoercion::new())]) - .execute_and_check(plan, &options, |_, _| {})?; - assert_optimized_plan_eq( - Arc::new(EliminateNestedUnion::new()), - analyzed_plan, - expected, - ) + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let options = ConfigOptions::default(); + let analyzed_plan = Analyzer::with_rules(vec![Arc::new(TypeCoercion::new())]) + .execute_and_check($plan, &options, |_, _| {})?; + let optimizer_ctx = OptimizerContext::new().with_max_passes(1); + let rules: Vec> = vec![Arc::new(EliminateNestedUnion::new())]; + assert_optimized_plan_eq_snapshot!( + optimizer_ctx, + rules, + analyzed_plan, + @ $expected, + ) + }}; } #[test] @@ -146,11 +155,11 @@ mod tests { let plan = plan_builder.clone().union(plan_builder.build()?)?.build()?; - let expected = "\ - Union\ - \n TableScan: table\ - \n TableScan: table"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!(plan, @r" + Union + TableScan: table + TableScan: table + ") } #[test] @@ -162,11 +171,12 @@ mod tests { .union_distinct(plan_builder.build()?)? .build()?; - let expected = "Distinct:\ - \n Union\ - \n TableScan: table\ - \n TableScan: table"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!(plan, @r" + Distinct: + Union + TableScan: table + TableScan: table + ") } #[test] @@ -180,13 +190,13 @@ mod tests { .union(plan_builder.build()?)? .build()?; - let expected = "\ - Union\ - \n TableScan: table\ - \n TableScan: table\ - \n TableScan: table\ - \n TableScan: table"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!(plan, @r" + Union + TableScan: table + TableScan: table + TableScan: table + TableScan: table + ") } #[test] @@ -200,14 +210,15 @@ mod tests { .union(plan_builder.build()?)? .build()?; - let expected = "Union\ - \n Distinct:\ - \n Union\ - \n TableScan: table\ - \n TableScan: table\ - \n TableScan: table\ - \n TableScan: table"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!(plan, @r" + Union + Distinct: + Union + TableScan: table + TableScan: table + TableScan: table + TableScan: table + ") } #[test] @@ -222,14 +233,15 @@ mod tests { .union_distinct(plan_builder.build()?)? .build()?; - let expected = "Distinct:\ - \n Union\ - \n TableScan: table\ - \n TableScan: table\ - \n TableScan: table\ - \n TableScan: table\ - \n TableScan: table"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!(plan, @r" + Distinct: + Union + TableScan: table + TableScan: table + TableScan: table + TableScan: table + TableScan: table + ") } #[test] @@ -243,13 +255,14 @@ mod tests { .union_distinct(plan_builder.build()?)? .build()?; - let expected = "Distinct:\ - \n Union\ - \n TableScan: table\ - \n TableScan: table\ - \n TableScan: table\ - \n TableScan: table"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!(plan, @r" + Distinct: + Union + TableScan: table + TableScan: table + TableScan: table + TableScan: table + ") } // We don't need to use project_with_column_index in logical optimizer, @@ -273,13 +286,14 @@ mod tests { )? .build()?; - let expected = "Union\ - \n TableScan: table\ - \n Projection: table.id AS id, table.key, table.value\ - \n TableScan: table\ - \n Projection: table.id AS id, table.key, table.value\ - \n TableScan: table"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!(plan, @r" + Union + TableScan: table + Projection: table.id AS id, table.key, table.value + TableScan: table + Projection: table.id AS id, table.key, table.value + TableScan: table + ") } #[test] @@ -301,14 +315,15 @@ mod tests { )? .build()?; - let expected = "Distinct:\ - \n Union\ - \n TableScan: table\ - \n Projection: table.id AS id, table.key, table.value\ - \n TableScan: table\ - \n Projection: table.id AS id, table.key, table.value\ - \n TableScan: table"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!(plan, @r" + Distinct: + Union + TableScan: table + Projection: table.id AS id, table.key, table.value + TableScan: table + Projection: table.id AS id, table.key, table.value + TableScan: table + ") } #[test] @@ -348,13 +363,14 @@ mod tests { .union(table_3.build()?)? .build()?; - let expected = "Union\ - \n TableScan: table_1\ - \n Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value\ - \n TableScan: table_1\ - \n Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value\ - \n TableScan: table_1"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!(plan, @r" + Union + TableScan: table_1 + Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value + TableScan: table_1 + Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value + TableScan: table_1 + ") } #[test] @@ -394,13 +410,14 @@ mod tests { .union_distinct(table_3.build()?)? .build()?; - let expected = "Distinct:\ - \n Union\ - \n TableScan: table_1\ - \n Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value\ - \n TableScan: table_1\ - \n Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value\ - \n TableScan: table_1"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!(plan, @r" + Distinct: + Union + TableScan: table_1 + Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value + TableScan: table_1 + Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value + TableScan: table_1 + ") } } diff --git a/datafusion/optimizer/src/eliminate_outer_join.rs b/datafusion/optimizer/src/eliminate_outer_join.rs index 1ecb32ca2a43..621086e4a28a 100644 --- a/datafusion/optimizer/src/eliminate_outer_join.rs +++ b/datafusion/optimizer/src/eliminate_outer_join.rs @@ -304,7 +304,9 @@ fn extract_non_nullable_columns( #[cfg(test)] mod tests { use super::*; + use crate::assert_optimized_plan_eq_snapshot; use crate::test::*; + use crate::OptimizerContext; use arrow::datatypes::DataType; use datafusion_expr::{ binary_expr, cast, col, lit, @@ -313,8 +315,20 @@ mod tests { Operator::{And, Or}, }; - fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { - assert_optimized_plan_eq(Arc::new(EliminateOuterJoin::new()), plan, expected) + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let optimizer_ctx = OptimizerContext::new().with_max_passes(1); + let rules: Vec> = vec![Arc::new(EliminateOuterJoin::new())]; + assert_optimized_plan_eq_snapshot!( + optimizer_ctx, + rules, + $plan, + @ $expected, + ) + }}; } #[test] @@ -332,12 +346,13 @@ mod tests { )? .filter(col("t2.b").is_null())? .build()?; - let expected = "\ - Filter: t2.b IS NULL\ - \n Left Join: t1.a = t2.a\ - \n TableScan: t1\ - \n TableScan: t2"; - assert_optimized_plan_equal(plan, expected) + + assert_optimized_plan_equal!(plan, @r" + Filter: t2.b IS NULL + Left Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + ") } #[test] @@ -355,12 +370,13 @@ mod tests { )? .filter(col("t2.b").is_not_null())? .build()?; - let expected = "\ - Filter: t2.b IS NOT NULL\ - \n Inner Join: t1.a = t2.a\ - \n TableScan: t1\ - \n TableScan: t2"; - assert_optimized_plan_equal(plan, expected) + + assert_optimized_plan_equal!(plan, @r" + Filter: t2.b IS NOT NULL + Inner Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + ") } #[test] @@ -382,12 +398,13 @@ mod tests { col("t1.c").lt(lit(20u32)), ))? .build()?; - let expected = "\ - Filter: t1.b > UInt32(10) OR t1.c < UInt32(20)\ - \n Inner Join: t1.a = t2.a\ - \n TableScan: t1\ - \n TableScan: t2"; - assert_optimized_plan_equal(plan, expected) + + assert_optimized_plan_equal!(plan, @r" + Filter: t1.b > UInt32(10) OR t1.c < UInt32(20) + Inner Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + ") } #[test] @@ -409,12 +426,13 @@ mod tests { col("t2.c").lt(lit(20u32)), ))? .build()?; - let expected = "\ - Filter: t1.b > UInt32(10) AND t2.c < UInt32(20)\ - \n Inner Join: t1.a = t2.a\ - \n TableScan: t1\ - \n TableScan: t2"; - assert_optimized_plan_equal(plan, expected) + + assert_optimized_plan_equal!(plan, @r" + Filter: t1.b > UInt32(10) AND t2.c < UInt32(20) + Inner Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + ") } #[test] @@ -436,11 +454,12 @@ mod tests { try_cast(col("t2.c"), DataType::Int64).lt(lit(20u32)), ))? .build()?; - let expected = "\ - Filter: CAST(t1.b AS Int64) > UInt32(10) AND TRY_CAST(t2.c AS Int64) < UInt32(20)\ - \n Inner Join: t1.a = t2.a\ - \n TableScan: t1\ - \n TableScan: t2"; - assert_optimized_plan_equal(plan, expected) + + assert_optimized_plan_equal!(plan, @r" + Filter: CAST(t1.b AS Int64) > UInt32(10) AND TRY_CAST(t2.c AS Int64) < UInt32(20) + Inner Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + ") } } diff --git a/datafusion/optimizer/src/extract_equijoin_predicate.rs b/datafusion/optimizer/src/extract_equijoin_predicate.rs index 48191ec20631..a07b50ade5b8 100644 --- a/datafusion/optimizer/src/extract_equijoin_predicate.rs +++ b/datafusion/optimizer/src/extract_equijoin_predicate.rs @@ -155,6 +155,7 @@ fn split_eq_and_noneq_join_predicate( #[cfg(test)] mod tests { use super::*; + use crate::assert_optimized_plan_eq_display_indent_snapshot; use crate::test::*; use arrow::datatypes::DataType; use datafusion_expr::{ @@ -162,14 +163,18 @@ mod tests { }; use std::sync::Arc; - fn assert_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> { - assert_optimized_plan_eq_display_indent( - Arc::new(ExtractEquijoinPredicate {}), - plan, - expected, - ); - - Ok(()) + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let rule: Arc = Arc::new(ExtractEquijoinPredicate {}); + assert_optimized_plan_eq_display_indent_snapshot!( + rule, + $plan, + @ $expected, + ) + }}; } #[test] @@ -180,11 +185,15 @@ mod tests { let plan = LogicalPlanBuilder::from(t1) .join_on(t2, JoinType::Left, Some(col("t1.a").eq(col("t2.a"))))? .build()?; - let expected = "Left Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]\ - \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Left Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -199,11 +208,15 @@ mod tests { Some((col("t1.a") + lit(10i64)).eq(col("t2.a") * lit(2u32))), )? .build()?; - let expected = "Left Join: t1.a + Int64(10) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]\ - \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Left Join: t1.a + Int64(10) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -222,11 +235,15 @@ mod tests { ), )? .build()?; - let expected = "Left Join: Filter: t1.a + Int64(10) >= t2.a * UInt32(2) AND t1.b < Int32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]\ - \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Left Join: Filter: t1.a + Int64(10) >= t2.a * UInt32(2) AND t1.b < Int32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -249,11 +266,15 @@ mod tests { ), )? .build()?; - let expected = "Left Join: t1.a + UInt32(11) = t2.a * UInt32(2), t1.a + Int64(10) = t2.a * UInt32(2) Filter: t1.b < Int32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]\ - \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Left Join: t1.a + UInt32(11) = t2.a * UInt32(2), t1.a + Int64(10) = t2.a * UInt32(2) Filter: t1.b < Int32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -275,11 +296,15 @@ mod tests { ), )? .build()?; - let expected = "Left Join: t1.a = t2.a, t1.b = t2.b Filter: t1.c = t2.c OR t1.a + t1.b > t2.b + t2.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]\ - \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Left Join: t1.a = t2.a, t1.b = t2.b Filter: t1.c = t2.c OR t1.a + t1.b > t2.b + t2.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -310,13 +335,17 @@ mod tests { ), )? .build()?; - let expected = "Left Join: t1.a = t2.a Filter: t1.c + t2.c + t3.c < UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N, a:UInt32;N, b:UInt32;N, c:UInt32;N]\ - \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\ - \n Left Join: t2.a = t3.a Filter: t2.a + t3.b > UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]\ - \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Left Join: t1.a = t2.a Filter: t1.c + t2.c + t3.c < UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N, a:UInt32;N, b:UInt32;N, c:UInt32;N] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + Left Join: t2.a = t3.a Filter: t2.a + t3.b > UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t3 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -343,13 +372,17 @@ mod tests { Some(col("t1.a").eq(col("t2.a")).and(col("t2.c").eq(col("t3.c")))), )? .build()?; - let expected = "Left Join: t1.a = t2.a Filter: t2.c = t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N, a:UInt32;N, b:UInt32;N, c:UInt32;N]\ - \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\ - \n Left Join: t2.a = t3.a Filter: t2.a + t3.b > UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]\ - \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Left Join: t1.a = t2.a Filter: t2.c = t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N, a:UInt32;N, b:UInt32;N, c:UInt32;N] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + Left Join: t2.a = t3.a Filter: t2.a + t3.b > UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t3 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -369,10 +402,14 @@ mod tests { let plan = LogicalPlanBuilder::from(t1) .join_on(t2, JoinType::Left, Some(filter))? .build()?; - let expected = "Left Join: t1.a + CAST(Int64(1) AS UInt32) = t2.a + CAST(Int32(2) AS UInt32) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]\ - \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Left Join: t1.a + CAST(Int64(1) AS UInt32) = t2.a + CAST(Int32(2) AS UInt32) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } } diff --git a/datafusion/optimizer/src/filter_null_join_keys.rs b/datafusion/optimizer/src/filter_null_join_keys.rs index 2e7a751ca4c5..14a424b32687 100644 --- a/datafusion/optimizer/src/filter_null_join_keys.rs +++ b/datafusion/optimizer/src/filter_null_join_keys.rs @@ -107,35 +107,52 @@ fn create_not_null_predicate(filters: Vec) -> Expr { #[cfg(test)] mod tests { use super::*; - use crate::test::assert_optimized_plan_eq; + use crate::assert_optimized_plan_eq_snapshot; + use crate::OptimizerContext; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::Column; use datafusion_expr::logical_plan::table_scan; use datafusion_expr::{col, lit, JoinType, LogicalPlanBuilder}; - fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { - assert_optimized_plan_eq(Arc::new(FilterNullJoinKeys {}), plan, expected) + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let optimizer_ctx = OptimizerContext::new().with_max_passes(1); + let rules: Vec> = vec![Arc::new(FilterNullJoinKeys {})]; + assert_optimized_plan_eq_snapshot!( + optimizer_ctx, + rules, + $plan, + @ $expected, + ) + }}; } #[test] fn left_nullable() -> Result<()> { let (t1, t2) = test_tables()?; let plan = build_plan(t1, t2, "t1.optional_id", "t2.id", JoinType::Inner)?; - let expected = "Inner Join: t1.optional_id = t2.id\ - \n Filter: t1.optional_id IS NOT NULL\ - \n TableScan: t1\ - \n TableScan: t2"; - assert_optimized_plan_equal(plan, expected) + + assert_optimized_plan_equal!(plan, @r" + Inner Join: t1.optional_id = t2.id + Filter: t1.optional_id IS NOT NULL + TableScan: t1 + TableScan: t2 + ") } #[test] fn left_nullable_left_join() -> Result<()> { let (t1, t2) = test_tables()?; let plan = build_plan(t1, t2, "t1.optional_id", "t2.id", JoinType::Left)?; - let expected = "Left Join: t1.optional_id = t2.id\ - \n TableScan: t1\ - \n TableScan: t2"; - assert_optimized_plan_equal(plan, expected) + + assert_optimized_plan_equal!(plan, @r" + Left Join: t1.optional_id = t2.id + TableScan: t1 + TableScan: t2 + ") } #[test] @@ -144,22 +161,26 @@ mod tests { // Note: order of tables is reversed let plan = build_plan(t_right, t_left, "t2.id", "t1.optional_id", JoinType::Left)?; - let expected = "Left Join: t2.id = t1.optional_id\ - \n TableScan: t2\ - \n Filter: t1.optional_id IS NOT NULL\ - \n TableScan: t1"; - assert_optimized_plan_equal(plan, expected) + + assert_optimized_plan_equal!(plan, @r" + Left Join: t2.id = t1.optional_id + TableScan: t2 + Filter: t1.optional_id IS NOT NULL + TableScan: t1 + ") } #[test] fn left_nullable_on_condition_reversed() -> Result<()> { let (t1, t2) = test_tables()?; let plan = build_plan(t1, t2, "t2.id", "t1.optional_id", JoinType::Inner)?; - let expected = "Inner Join: t1.optional_id = t2.id\ - \n Filter: t1.optional_id IS NOT NULL\ - \n TableScan: t1\ - \n TableScan: t2"; - assert_optimized_plan_equal(plan, expected) + + assert_optimized_plan_equal!(plan, @r" + Inner Join: t1.optional_id = t2.id + Filter: t1.optional_id IS NOT NULL + TableScan: t1 + TableScan: t2 + ") } #[test] @@ -189,14 +210,16 @@ mod tests { None, )? .build()?; - let expected = "Inner Join: t3.t1_id = t1.id, t3.t2_id = t2.id\ - \n Filter: t3.t1_id IS NOT NULL AND t3.t2_id IS NOT NULL\ - \n TableScan: t3\ - \n Inner Join: t1.optional_id = t2.id\ - \n Filter: t1.optional_id IS NOT NULL\ - \n TableScan: t1\ - \n TableScan: t2"; - assert_optimized_plan_equal(plan, expected) + + assert_optimized_plan_equal!(plan, @r" + Inner Join: t3.t1_id = t1.id, t3.t2_id = t2.id + Filter: t3.t1_id IS NOT NULL AND t3.t2_id IS NOT NULL + TableScan: t3 + Inner Join: t1.optional_id = t2.id + Filter: t1.optional_id IS NOT NULL + TableScan: t1 + TableScan: t2 + ") } #[test] @@ -213,11 +236,13 @@ mod tests { None, )? .build()?; - let expected = "Inner Join: t1.optional_id + UInt32(1) = t2.id + UInt32(1)\ - \n Filter: t1.optional_id + UInt32(1) IS NOT NULL\ - \n TableScan: t1\ - \n TableScan: t2"; - assert_optimized_plan_equal(plan, expected) + + assert_optimized_plan_equal!(plan, @r" + Inner Join: t1.optional_id + UInt32(1) = t2.id + UInt32(1) + Filter: t1.optional_id + UInt32(1) IS NOT NULL + TableScan: t1 + TableScan: t2 + ") } #[test] @@ -234,11 +259,13 @@ mod tests { None, )? .build()?; - let expected = "Inner Join: t1.id + UInt32(1) = t2.optional_id + UInt32(1)\ - \n TableScan: t1\ - \n Filter: t2.optional_id + UInt32(1) IS NOT NULL\ - \n TableScan: t2"; - assert_optimized_plan_equal(plan, expected) + + assert_optimized_plan_equal!(plan, @r" + Inner Join: t1.id + UInt32(1) = t2.optional_id + UInt32(1) + TableScan: t1 + Filter: t2.optional_id + UInt32(1) IS NOT NULL + TableScan: t2 + ") } #[test] @@ -255,13 +282,14 @@ mod tests { None, )? .build()?; - let expected = - "Inner Join: t1.optional_id + UInt32(1) = t2.optional_id + UInt32(1)\ - \n Filter: t1.optional_id + UInt32(1) IS NOT NULL\ - \n TableScan: t1\ - \n Filter: t2.optional_id + UInt32(1) IS NOT NULL\ - \n TableScan: t2"; - assert_optimized_plan_equal(plan, expected) + + assert_optimized_plan_equal!(plan, @r" + Inner Join: t1.optional_id + UInt32(1) = t2.optional_id + UInt32(1) + Filter: t1.optional_id + UInt32(1) IS NOT NULL + TableScan: t1 + Filter: t2.optional_id + UInt32(1) IS NOT NULL + TableScan: t2 + ") } #[test] @@ -283,13 +311,22 @@ mod tests { None, )? .build()?; - let expected = "Inner Join: t1.optional_id = t2.optional_id\ - \n Filter: t1.optional_id IS NOT NULL\ - \n TableScan: t1\ - \n Filter: t2.optional_id IS NOT NULL\ - \n TableScan: t2"; - assert_optimized_plan_equal(plan_from_cols, expected)?; - assert_optimized_plan_equal(plan_from_exprs, expected) + + assert_optimized_plan_equal!(plan_from_cols, @r" + Inner Join: t1.optional_id = t2.optional_id + Filter: t1.optional_id IS NOT NULL + TableScan: t1 + Filter: t2.optional_id IS NOT NULL + TableScan: t2 + ")?; + + assert_optimized_plan_equal!(plan_from_exprs, @r" + Inner Join: t1.optional_id = t2.optional_id + Filter: t1.optional_id IS NOT NULL + TableScan: t1 + Filter: t2.optional_id IS NOT NULL + TableScan: t2 + ") } fn build_plan( diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index 893cb249a2a8..280010e3d92c 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -40,6 +40,7 @@ pub mod analyzer; pub mod common_subexpr_eliminate; pub mod decorrelate; +pub mod decorrelate_lateral_join; pub mod decorrelate_predicate_subquery; pub mod eliminate_cross_join; pub mod eliminate_duplicated_expr; diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index b3a09e2dcbcc..d0457e709026 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -31,8 +31,7 @@ use datafusion_common::{ use datafusion_expr::expr::Alias; use datafusion_expr::Unnest; use datafusion_expr::{ - logical_plan::LogicalPlan, projection_schema, Aggregate, Distinct, Expr, Projection, - TableScan, Window, + logical_plan::LogicalPlan, Aggregate, Distinct, Expr, Projection, TableScan, Window, }; use crate::optimize_projections::required_indices::RequiredIndices; @@ -455,6 +454,17 @@ fn merge_consecutive_projections(proj: Projection) -> Result::new(); expr.iter() @@ -523,7 +533,7 @@ fn merge_consecutive_projections(proj: Projection) -> Result bool { - matches!(expr, Expr::Column(_) | Expr::Literal(_)) + matches!(expr, Expr::Column(_) | Expr::Literal(_, _)) } /// Rewrites a projection expression using the projection before it (i.e. its input) @@ -573,8 +583,18 @@ fn is_expr_trivial(expr: &Expr) -> bool { fn rewrite_expr(expr: Expr, input: &Projection) -> Result> { expr.transform_up(|expr| { match expr { - // remove any intermediate aliases - Expr::Alias(alias) => Ok(Transformed::yes(*alias.expr)), + // remove any intermediate aliases if they do not carry metadata + Expr::Alias(alias) => { + match alias + .metadata + .as_ref() + .map(|h| h.is_empty()) + .unwrap_or(true) + { + true => Ok(Transformed::yes(*alias.expr)), + false => Ok(Transformed::no(Expr::Alias(alias))), + } + } Expr::Column(col) => { // Find index of column: let idx = input.schema.index_of_column(&col)?; @@ -774,9 +794,24 @@ fn rewrite_projection_given_requirements( /// Projection is unnecessary, when /// - input schema of the projection, output schema of the projection are same, and /// - all projection expressions are either Column or Literal -fn is_projection_unnecessary(input: &LogicalPlan, proj_exprs: &[Expr]) -> Result { - let proj_schema = projection_schema(input, proj_exprs)?; - Ok(&proj_schema == input.schema() && proj_exprs.iter().all(is_expr_trivial)) +pub fn is_projection_unnecessary( + input: &LogicalPlan, + proj_exprs: &[Expr], +) -> Result { + // First check if the number of expressions is equal to the number of fields in the input schema. + if proj_exprs.len() != input.schema().fields().len() { + return Ok(false); + } + Ok(input.schema().iter().zip(proj_exprs.iter()).all( + |((field_relation, field_name), expr)| { + // Check if the expression is a column and if it matches the field name + if let Expr::Column(col) = expr { + col.relation.as_ref() == field_relation && col.name.eq(field_name.name()) + } else { + false + } + }, + )) } #[cfg(test)] @@ -791,8 +826,8 @@ mod tests { use crate::optimize_projections::OptimizeProjections; use crate::optimizer::Optimizer; use crate::test::{ - assert_fields_eq, assert_optimized_plan_eq, scan_empty, test_table_scan, - test_table_scan_fields, test_table_scan_with_name, + assert_fields_eq, scan_empty, test_table_scan, test_table_scan_fields, + test_table_scan_with_name, }; use crate::{OptimizerContext, OptimizerRule}; use arrow::datatypes::{DataType, Field, Schema}; @@ -810,13 +845,27 @@ mod tests { not, try_cast, when, BinaryExpr, Expr, Extension, Like, LogicalPlan, Operator, Projection, UserDefinedLogicalNodeCore, WindowFunctionDefinition, }; + use insta::assert_snapshot; + use crate::assert_optimized_plan_eq_snapshot; use datafusion_functions_aggregate::count::count_udaf; use datafusion_functions_aggregate::expr_fn::{count, max, min}; use datafusion_functions_aggregate::min_max::max_udaf; - fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { - assert_optimized_plan_eq(Arc::new(OptimizeProjections::new()), plan, expected) + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let optimizer_ctx = OptimizerContext::new().with_max_passes(1); + let rules: Vec> = vec![Arc::new(OptimizeProjections::new())]; + assert_optimized_plan_eq_snapshot!( + optimizer_ctx, + rules, + $plan, + @ $expected, + ) + }}; } #[derive(Debug, Hash, PartialEq, Eq)] @@ -1005,9 +1054,13 @@ mod tests { .project(vec![binary_expr(lit(1), Operator::Plus, col("a"))])? .build()?; - let expected = "Projection: Int32(1) + test.a\ - \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: Int32(1) + test.a + TableScan: test projection=[a] + " + ) } #[test] @@ -1019,9 +1072,13 @@ mod tests { .project(vec![binary_expr(lit(1), Operator::Plus, col("a"))])? .build()?; - let expected = "Projection: Int32(1) + test.a\ - \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: Int32(1) + test.a + TableScan: test projection=[a] + " + ) } #[test] @@ -1032,9 +1089,13 @@ mod tests { .project(vec![col("a").alias("alias")])? .build()?; - let expected = "Projection: test.a AS alias\ - \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a AS alias + TableScan: test projection=[a] + " + ) } #[test] @@ -1045,9 +1106,13 @@ mod tests { .project(vec![col("alias2").alias("alias")])? .build()?; - let expected = "Projection: test.a AS alias\ - \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a AS alias + TableScan: test projection=[a] + " + ) } #[test] @@ -1065,11 +1130,15 @@ mod tests { .build() .unwrap(); - let expected = "Aggregate: groupBy=[[]], aggr=[[count(Int32(1))]]\ - \n Projection: \ - \n Aggregate: groupBy=[[]], aggr=[[count(Int32(1))]]\ - \n TableScan: ?table? projection=[]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Aggregate: groupBy=[[]], aggr=[[count(Int32(1))]] + Projection: + Aggregate: groupBy=[[]], aggr=[[count(Int32(1))]] + TableScan: ?table? projection=[] + " + ) } #[test] @@ -1079,9 +1148,13 @@ mod tests { .project(vec![-col("a")])? .build()?; - let expected = "Projection: (- test.a)\ - \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: (- test.a) + TableScan: test projection=[a] + " + ) } #[test] @@ -1091,9 +1164,13 @@ mod tests { .project(vec![col("a").is_null()])? .build()?; - let expected = "Projection: test.a IS NULL\ - \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a IS NULL + TableScan: test projection=[a] + " + ) } #[test] @@ -1103,9 +1180,13 @@ mod tests { .project(vec![col("a").is_not_null()])? .build()?; - let expected = "Projection: test.a IS NOT NULL\ - \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a IS NOT NULL + TableScan: test projection=[a] + " + ) } #[test] @@ -1115,9 +1196,13 @@ mod tests { .project(vec![col("a").is_true()])? .build()?; - let expected = "Projection: test.a IS TRUE\ - \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a IS TRUE + TableScan: test projection=[a] + " + ) } #[test] @@ -1127,9 +1212,13 @@ mod tests { .project(vec![col("a").is_not_true()])? .build()?; - let expected = "Projection: test.a IS NOT TRUE\ - \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a IS NOT TRUE + TableScan: test projection=[a] + " + ) } #[test] @@ -1139,9 +1228,13 @@ mod tests { .project(vec![col("a").is_false()])? .build()?; - let expected = "Projection: test.a IS FALSE\ - \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a IS FALSE + TableScan: test projection=[a] + " + ) } #[test] @@ -1151,9 +1244,13 @@ mod tests { .project(vec![col("a").is_not_false()])? .build()?; - let expected = "Projection: test.a IS NOT FALSE\ - \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a IS NOT FALSE + TableScan: test projection=[a] + " + ) } #[test] @@ -1163,9 +1260,13 @@ mod tests { .project(vec![col("a").is_unknown()])? .build()?; - let expected = "Projection: test.a IS UNKNOWN\ - \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a IS UNKNOWN + TableScan: test projection=[a] + " + ) } #[test] @@ -1175,9 +1276,13 @@ mod tests { .project(vec![col("a").is_not_unknown()])? .build()?; - let expected = "Projection: test.a IS NOT UNKNOWN\ - \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a IS NOT UNKNOWN + TableScan: test projection=[a] + " + ) } #[test] @@ -1187,9 +1292,13 @@ mod tests { .project(vec![not(col("a"))])? .build()?; - let expected = "Projection: NOT test.a\ - \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: NOT test.a + TableScan: test projection=[a] + " + ) } #[test] @@ -1199,9 +1308,13 @@ mod tests { .project(vec![try_cast(col("a"), DataType::Float64)])? .build()?; - let expected = "Projection: TRY_CAST(test.a AS Float64)\ - \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: TRY_CAST(test.a AS Float64) + TableScan: test projection=[a] + " + ) } #[test] @@ -1215,9 +1328,13 @@ mod tests { .project(vec![similar_to_expr])? .build()?; - let expected = "Projection: test.a SIMILAR TO Utf8(\"[0-9]\")\ - \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r#" + Projection: test.a SIMILAR TO Utf8("[0-9]") + TableScan: test projection=[a] + "# + ) } #[test] @@ -1227,9 +1344,13 @@ mod tests { .project(vec![col("a").between(lit(1), lit(3))])? .build()?; - let expected = "Projection: test.a BETWEEN Int32(1) AND Int32(3)\ - \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a BETWEEN Int32(1) AND Int32(3) + TableScan: test projection=[a] + " + ) } // Test Case expression @@ -1246,9 +1367,13 @@ mod tests { ])? .build()?; - let expected = "Projection: test.a, CASE WHEN test.a = Int32(1) THEN Int32(10) ELSE Int32(0) END AS d\ - \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a, CASE WHEN test.a = Int32(1) THEN Int32(10) ELSE Int32(0) END AS d + TableScan: test projection=[a] + " + ) } // Test outer projection isn't discarded despite the same schema as inner @@ -1266,11 +1391,14 @@ mod tests { ])? .build()?; - let expected = - "Projection: a, CASE WHEN a = Int32(1) THEN Int32(10) ELSE d END AS d\ - \n Projection: test.a + Int32(1) AS a, Int32(0) AS d\ - \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: a, CASE WHEN a = Int32(1) THEN Int32(10) ELSE d END AS d + Projection: test.a + Int32(1) AS a, Int32(0) AS d + TableScan: test projection=[a] + " + ) } // Since only column `a` is referred at the output. Scan should only contain projection=[a]. @@ -1288,10 +1416,14 @@ mod tests { .project(vec![col("a"), lit(0).alias("d")])? .build()?; - let expected = "Projection: test.a, Int32(0) AS d\ - \n NoOpUserDefined\ - \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a, Int32(0) AS d + NoOpUserDefined + TableScan: test projection=[a] + " + ) } // Only column `a` is referred at the output. However, User defined node itself uses column `b` @@ -1315,10 +1447,14 @@ mod tests { .project(vec![col("a"), lit(0).alias("d")])? .build()?; - let expected = "Projection: test.a, Int32(0) AS d\ - \n NoOpUserDefined\ - \n TableScan: test projection=[a, b]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a, Int32(0) AS d + NoOpUserDefined + TableScan: test projection=[a, b] + " + ) } // Only column `a` is referred at the output. However, User defined node itself uses expression `b+c` @@ -1350,10 +1486,14 @@ mod tests { .project(vec![col("a"), lit(0).alias("d")])? .build()?; - let expected = "Projection: test.a, Int32(0) AS d\ - \n NoOpUserDefined\ - \n TableScan: test projection=[a, b, c]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a, Int32(0) AS d + NoOpUserDefined + TableScan: test projection=[a, b, c] + " + ) } // Columns `l.a`, `l.c`, `r.a` is referred at the output. @@ -1374,11 +1514,15 @@ mod tests { .project(vec![col("l.a"), col("l.c"), col("r.a"), lit(0).alias("d")])? .build()?; - let expected = "Projection: l.a, l.c, r.a, Int32(0) AS d\ - \n UserDefinedCrossJoin\ - \n TableScan: l projection=[a, c]\ - \n TableScan: r projection=[a]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: l.a, l.c, r.a, Int32(0) AS d + UserDefinedCrossJoin + TableScan: l projection=[a, c] + TableScan: r projection=[a] + " + ) } #[test] @@ -1389,10 +1533,13 @@ mod tests { .aggregate(Vec::::new(), vec![max(col("b"))])? .build()?; - let expected = "Aggregate: groupBy=[[]], aggr=[[max(test.b)]]\ - \n TableScan: test projection=[b]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Aggregate: groupBy=[[]], aggr=[[max(test.b)]] + TableScan: test projection=[b] + " + ) } #[test] @@ -1403,10 +1550,13 @@ mod tests { .aggregate(vec![col("c")], vec![max(col("b"))])? .build()?; - let expected = "Aggregate: groupBy=[[test.c]], aggr=[[max(test.b)]]\ - \n TableScan: test projection=[b, c]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Aggregate: groupBy=[[test.c]], aggr=[[max(test.b)]] + TableScan: test projection=[b, c] + " + ) } #[test] @@ -1418,11 +1568,14 @@ mod tests { .aggregate(vec![col("c")], vec![max(col("b"))])? .build()?; - let expected = "Aggregate: groupBy=[[a.c]], aggr=[[max(a.b)]]\ - \n SubqueryAlias: a\ - \n TableScan: test projection=[b, c]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Aggregate: groupBy=[[a.c]], aggr=[[max(a.b)]] + SubqueryAlias: a + TableScan: test projection=[b, c] + " + ) } #[test] @@ -1434,12 +1587,15 @@ mod tests { .aggregate(Vec::::new(), vec![max(col("b"))])? .build()?; - let expected = "Aggregate: groupBy=[[]], aggr=[[max(test.b)]]\ - \n Projection: test.b\ - \n Filter: test.c > Int32(1)\ - \n TableScan: test projection=[b, c]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Aggregate: groupBy=[[]], aggr=[[max(test.b)]] + Projection: test.b + Filter: test.c > Int32(1) + TableScan: test projection=[b, c] + " + ) } #[test] @@ -1460,11 +1616,13 @@ mod tests { .project([col(Column::new_unqualified("tag.one"))])? .build()?; - let expected = "\ - Aggregate: groupBy=[[]], aggr=[[max(m4.tag.one) AS tag.one]]\ - \n TableScan: m4 projection=[tag.one]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Aggregate: groupBy=[[]], aggr=[[max(m4.tag.one) AS tag.one]] + TableScan: m4 projection=[tag.one] + " + ) } #[test] @@ -1475,10 +1633,13 @@ mod tests { .project(vec![col("a"), col("b"), col("c")])? .project(vec![col("a"), col("c"), col("b")])? .build()?; - let expected = "Projection: test.a, test.c, test.b\ - \n TableScan: test projection=[a, b, c]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a, test.c, test.b + TableScan: test projection=[a, b, c] + " + ) } #[test] @@ -1486,9 +1647,10 @@ mod tests { let schema = Schema::new(test_table_scan_fields()); let plan = table_scan(Some("test"), &schema, Some(vec![1, 0, 2]))?.build()?; - let expected = "TableScan: test projection=[b, a, c]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @"TableScan: test projection=[b, a, c]" + ) } #[test] @@ -1498,10 +1660,13 @@ mod tests { let plan = table_scan(Some("test"), &schema, Some(vec![1, 0, 2]))? .project(vec![col("a"), col("b")])? .build()?; - let expected = "Projection: test.a, test.b\ - \n TableScan: test projection=[b, a]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a, test.b + TableScan: test projection=[b, a] + " + ) } #[test] @@ -1511,10 +1676,13 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .project(vec![col("c"), col("b"), col("a")])? .build()?; - let expected = "Projection: test.c, test.b, test.a\ - \n TableScan: test projection=[a, b, c]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.c, test.b, test.a + TableScan: test projection=[a, b, c] + " + ) } #[test] @@ -1529,14 +1697,18 @@ mod tests { .filter(col("a").gt(lit(1)))? .project(vec![col("a"), col("c"), col("b")])? .build()?; - let expected = "Projection: test.a, test.c, test.b\ - \n Filter: test.a > Int32(1)\ - \n Filter: test.b > Int32(1)\ - \n Projection: test.c, test.a, test.b\ - \n Filter: test.c > Int32(1)\ - \n Projection: test.c, test.b, test.a\ - \n TableScan: test projection=[a, b, c]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a, test.c, test.b + Filter: test.a > Int32(1) + Filter: test.b > Int32(1) + Projection: test.c, test.a, test.b + Filter: test.c > Int32(1) + Projection: test.c, test.b, test.a + TableScan: test projection=[a, b, c] + " + ) } #[test] @@ -1551,14 +1723,17 @@ mod tests { .project(vec![col("a"), col("b"), col("c1")])? .build()?; - // make sure projections are pushed down to both table scans - let expected = "Left Join: test.a = test2.c1\ - \n TableScan: test projection=[a, b]\ - \n TableScan: test2 projection=[c1]"; - let optimized_plan = optimize(plan)?; - let formatted_plan = format!("{optimized_plan}"); - assert_eq!(formatted_plan, expected); + + // make sure projections are pushed down to both table scans + assert_snapshot!( + optimized_plan.clone(), + @r" + Left Join: test.a = test2.c1 + TableScan: test projection=[a, b] + TableScan: test2 projection=[c1] + " + ); // make sure schema for join node include both join columns let optimized_join = optimized_plan; @@ -1602,15 +1777,18 @@ mod tests { .project(vec![col("a"), col("b")])? .build()?; - // make sure projections are pushed down to both table scans - let expected = "Projection: test.a, test.b\ - \n Left Join: test.a = test2.c1\ - \n TableScan: test projection=[a, b]\ - \n TableScan: test2 projection=[c1]"; - let optimized_plan = optimize(plan)?; - let formatted_plan = format!("{optimized_plan}"); - assert_eq!(formatted_plan, expected); + + // make sure projections are pushed down to both table scans + assert_snapshot!( + optimized_plan.clone(), + @r" + Projection: test.a, test.b + Left Join: test.a = test2.c1 + TableScan: test projection=[a, b] + TableScan: test2 projection=[c1] + " + ); // make sure schema for join node include both join columns let optimized_join = optimized_plan.inputs()[0]; @@ -1652,15 +1830,18 @@ mod tests { .project(vec![col("a"), col("b")])? .build()?; - // make sure projections are pushed down to table scan - let expected = "Projection: test.a, test.b\ - \n Left Join: Using test.a = test2.a\ - \n TableScan: test projection=[a, b]\ - \n TableScan: test2 projection=[a]"; - let optimized_plan = optimize(plan)?; - let formatted_plan = format!("{optimized_plan}"); - assert_eq!(formatted_plan, expected); + + // make sure projections are pushed down to table scan + assert_snapshot!( + optimized_plan.clone(), + @r" + Projection: test.a, test.b + Left Join: Using test.a = test2.a + TableScan: test projection=[a, b] + TableScan: test2 projection=[a] + " + ); // make sure schema for join node include both join columns let optimized_join = optimized_plan.inputs()[0]; @@ -1692,17 +1873,20 @@ mod tests { fn cast() -> Result<()> { let table_scan = test_table_scan()?; - let projection = LogicalPlanBuilder::from(table_scan) + let plan = LogicalPlanBuilder::from(table_scan) .project(vec![Expr::Cast(Cast::new( Box::new(col("c")), DataType::Float64, ))])? .build()?; - let expected = "Projection: CAST(test.c AS Float64)\ - \n TableScan: test projection=[c]"; - - assert_optimized_plan_equal(projection, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: CAST(test.c AS Float64) + TableScan: test projection=[c] + " + ) } #[test] @@ -1716,9 +1900,10 @@ mod tests { assert_fields_eq(&table_scan, vec!["a", "b", "c"]); assert_fields_eq(&plan, vec!["a", "b"]); - let expected = "TableScan: test projection=[a, b]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @"TableScan: test projection=[a, b]" + ) } #[test] @@ -1737,9 +1922,10 @@ mod tests { assert_fields_eq(&plan, vec!["a", "b"]); - let expected = "TableScan: test projection=[a, b]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @"TableScan: test projection=[a, b]" + ) } #[test] @@ -1755,11 +1941,14 @@ mod tests { assert_fields_eq(&plan, vec!["c", "a"]); - let expected = "Limit: skip=0, fetch=5\ - \n Projection: test.c, test.a\ - \n TableScan: test projection=[a, c]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=0, fetch=5 + Projection: test.c, test.a + TableScan: test projection=[a, c] + " + ) } #[test] @@ -1767,8 +1956,10 @@ mod tests { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(table_scan).build()?; // should expand projection to all columns without projection - let expected = "TableScan: test projection=[a, b, c]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @"TableScan: test projection=[a, b, c]" + ) } #[test] @@ -1777,9 +1968,13 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .project(vec![lit(1_i64), lit(2_i64)])? .build()?; - let expected = "Projection: Int64(1), Int64(2)\ - \n TableScan: test projection=[]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: Int64(1), Int64(2) + TableScan: test projection=[] + " + ) } /// tests that it removes unused columns in projections @@ -1799,13 +1994,15 @@ mod tests { assert_fields_eq(&plan, vec!["c", "max(test.a)"]); let plan = optimize(plan).expect("failed to optimize plan"); - let expected = "\ - Aggregate: groupBy=[[test.c]], aggr=[[max(test.a)]]\ - \n Filter: test.c > Int32(1)\ - \n Projection: test.c, test.a\ - \n TableScan: test projection=[a, c]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Aggregate: groupBy=[[test.c]], aggr=[[max(test.a)]] + Filter: test.c > Int32(1) + Projection: test.c, test.a + TableScan: test projection=[a, c] + " + ) } /// tests that it removes un-needed projections @@ -1823,11 +2020,13 @@ mod tests { assert_fields_eq(&plan, vec!["a"]); - let expected = "\ - Projection: Int32(1) AS a\ - \n TableScan: test projection=[]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: Int32(1) AS a + TableScan: test projection=[] + " + ) } #[test] @@ -1852,11 +2051,13 @@ mod tests { assert_fields_eq(&plan, vec!["a"]); - let expected = "\ - Projection: Int32(1) AS a\ - \n TableScan: test projection=[], full_filters=[b = Int32(1)]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: Int32(1) AS a + TableScan: test projection=[], full_filters=[b = Int32(1)] + " + ) } /// tests that optimizing twice yields same plan @@ -1895,12 +2096,15 @@ mod tests { assert_fields_eq(&plan, vec!["c", "a", "max(test.b)"]); - let expected = "Projection: test.c, test.a, max(test.b)\ - \n Filter: test.c > Int32(1)\ - \n Aggregate: groupBy=[[test.a, test.c]], aggr=[[max(test.b)]]\ - \n TableScan: test projection=[a, b, c]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.c, test.a, max(test.b) + Filter: test.c > Int32(1) + Aggregate: groupBy=[[test.a, test.c]], aggr=[[max(test.b)]] + TableScan: test projection=[a, b, c] + " + ) } #[test] @@ -1917,10 +2121,13 @@ mod tests { )? .build()?; - let expected = "Aggregate: groupBy=[[test.a]], aggr=[[count(test.b), count(test.b) FILTER (WHERE test.c > Int32(42)) AS count2]]\ - \n TableScan: test projection=[a, b, c]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Aggregate: groupBy=[[test.a]], aggr=[[count(test.b), count(test.b) FILTER (WHERE test.c > Int32(42)) AS count2]] + TableScan: test projection=[a, b, c] + " + ) } #[test] @@ -1933,18 +2140,21 @@ mod tests { .project(vec![col("a")])? .build()?; - let expected = "Projection: test.a\ - \n Distinct:\ - \n TableScan: test projection=[a, b]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a + Distinct: + TableScan: test projection=[a, b] + " + ) } #[test] fn test_window() -> Result<()> { let table_scan = test_table_scan()?; - let max1 = Expr::WindowFunction(expr::WindowFunction::new( + let max1 = Expr::from(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("test.a")], )) @@ -1952,7 +2162,7 @@ mod tests { .build() .unwrap(); - let max2 = Expr::WindowFunction(expr::WindowFunction::new( + let max2 = Expr::from(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("test.b")], )); @@ -1965,13 +2175,16 @@ mod tests { .project(vec![col1, col2])? .build()?; - let expected = "Projection: max(test.a) PARTITION BY [test.b] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, max(test.b) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING\ - \n WindowAggr: windowExpr=[[max(test.b) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\ - \n Projection: test.b, max(test.a) PARTITION BY [test.b] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING\ - \n WindowAggr: windowExpr=[[max(test.a) PARTITION BY [test.b] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\ - \n TableScan: test projection=[a, b]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: max(test.a) PARTITION BY [test.b] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, max(test.b) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING + WindowAggr: windowExpr=[[max(test.b) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] + Projection: test.b, max(test.a) PARTITION BY [test.b] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING + WindowAggr: windowExpr=[[max(test.a) PARTITION BY [test.b] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] + TableScan: test projection=[a, b] + " + ) } fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index b40121dbfeb7..4d2c2c7c79cd 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -33,6 +33,7 @@ use datafusion_common::{internal_err, DFSchema, DataFusionError, HashSet, Result use datafusion_expr::logical_plan::LogicalPlan; use crate::common_subexpr_eliminate::CommonSubexprEliminate; +use crate::decorrelate_lateral_join::DecorrelateLateralJoin; use crate::decorrelate_predicate_subquery::DecorrelatePredicateSubquery; use crate::eliminate_cross_join::EliminateCrossJoin; use crate::eliminate_duplicated_expr::EliminateDuplicatedExpr; @@ -226,6 +227,7 @@ impl Optimizer { Arc::new(EliminateJoin::new()), Arc::new(DecorrelatePredicateSubquery::new()), Arc::new(ScalarSubqueryToJoin::new()), + Arc::new(DecorrelateLateralJoin::new()), Arc::new(ExtractEquijoinPredicate::new()), Arc::new(EliminateDuplicatedExpr::new()), Arc::new(EliminateFilter::new()), @@ -413,7 +415,7 @@ impl Optimizer { previous_plans.insert(LogicalPlanSignature::new(&new_plan)); if !plan_is_fresh { // plan did not change, so no need to continue trying to optimize - debug!("optimizer pass {} did not make changes", i); + debug!("optimizer pass {i} did not make changes"); break; } i += 1; diff --git a/datafusion/optimizer/src/propagate_empty_relation.rs b/datafusion/optimizer/src/propagate_empty_relation.rs index 344707ae8dbe..4fb9e117e2af 100644 --- a/datafusion/optimizer/src/propagate_empty_relation.rs +++ b/datafusion/optimizer/src/propagate_empty_relation.rs @@ -242,17 +242,31 @@ mod tests { binary_expr, col, lit, logical_plan::builder::LogicalPlanBuilder, Operator, }; + use crate::assert_optimized_plan_eq_snapshot; use crate::eliminate_filter::EliminateFilter; use crate::eliminate_nested_union::EliminateNestedUnion; use crate::test::{ - assert_optimized_plan_eq, assert_optimized_plan_with_rules, test_table_scan, - test_table_scan_fields, test_table_scan_with_name, + assert_optimized_plan_with_rules, test_table_scan, test_table_scan_fields, + test_table_scan_with_name, }; + use crate::OptimizerContext; use super::*; - fn assert_eq(plan: LogicalPlan, expected: &str) -> Result<()> { - assert_optimized_plan_eq(Arc::new(PropagateEmptyRelation::new()), plan, expected) + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let optimizer_ctx = OptimizerContext::new().with_max_passes(1); + let rules: Vec> = vec![Arc::new(PropagateEmptyRelation::new())]; + assert_optimized_plan_eq_snapshot!( + optimizer_ctx, + rules, + $plan, + @ $expected, + ) + }}; } fn assert_together_optimized_plan( @@ -280,8 +294,7 @@ mod tests { .project(vec![binary_expr(lit(1), Operator::Plus, lit(1))])? .build()?; - let expected = "EmptyRelation"; - assert_eq(plan, expected) + assert_optimized_plan_equal!(plan, @"EmptyRelation") } #[test] diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index c9617514e453..1c1996d6a241 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -254,7 +254,7 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { let mut is_evaluate = true; predicate.apply(|expr| match expr { Expr::Column(_) - | Expr::Literal(_) + | Expr::Literal(_, _) | Expr::Placeholder(_) | Expr::ScalarVariable(_, _) => Ok(TreeNodeRecursion::Jump), Expr::Exists { .. } @@ -1391,7 +1391,7 @@ mod tests { use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use async_trait::async_trait; - use datafusion_common::{DFSchemaRef, ScalarValue}; + use datafusion_common::{DFSchemaRef, DataFusionError, ScalarValue}; use datafusion_expr::expr::{ScalarFunction, WindowFunction}; use datafusion_expr::logical_plan::table_scan; use datafusion_expr::{ @@ -1401,38 +1401,47 @@ mod tests { WindowFunctionDefinition, }; + use crate::assert_optimized_plan_eq_snapshot; use crate::optimizer::Optimizer; use crate::simplify_expressions::SimplifyExpressions; use crate::test::*; use crate::OptimizerContext; use datafusion_expr::test::function_stub::sum; + use insta::assert_snapshot; use super::*; fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} - fn assert_optimized_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> { - crate::test::assert_optimized_plan_eq( - Arc::new(PushDownFilter::new()), - plan, - expected, - ) + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let optimizer_ctx = OptimizerContext::new().with_max_passes(1); + let rules: Vec> = vec![Arc::new(PushDownFilter::new())]; + assert_optimized_plan_eq_snapshot!( + optimizer_ctx, + rules, + $plan, + @ $expected, + ) + }}; } - fn assert_optimized_plan_eq_with_rewrite_predicate( - plan: LogicalPlan, - expected: &str, - ) -> Result<()> { - let optimizer = Optimizer::with_rules(vec![ - Arc::new(SimplifyExpressions::new()), - Arc::new(PushDownFilter::new()), - ]); - let optimized_plan = - optimizer.optimize(plan, &OptimizerContext::new(), observe)?; - - let formatted_plan = format!("{optimized_plan}"); - assert_eq!(expected, formatted_plan); - Ok(()) + macro_rules! assert_optimized_plan_eq_with_rewrite_predicate { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let optimizer = Optimizer::with_rules(vec![ + Arc::new(SimplifyExpressions::new()), + Arc::new(PushDownFilter::new()), + ]); + let optimized_plan = optimizer.optimize($plan, &OptimizerContext::new(), observe)?; + assert_snapshot!(optimized_plan, @ $expected); + Ok::<(), DataFusionError>(()) + }}; } #[test] @@ -1443,10 +1452,13 @@ mod tests { .filter(col("a").eq(lit(1i64)))? .build()?; // filter is before projection - let expected = "\ - Projection: test.a, test.b\ - \n TableScan: test, full_filters=[test.a = Int64(1)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a, test.b + TableScan: test, full_filters=[test.a = Int64(1)] + " + ) } #[test] @@ -1458,12 +1470,15 @@ mod tests { .filter(col("a").eq(lit(1i64)))? .build()?; // filter is before single projection - let expected = "\ - Filter: test.a = Int64(1)\ - \n Limit: skip=0, fetch=10\ - \n Projection: test.a, test.b\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Filter: test.a = Int64(1) + Limit: skip=0, fetch=10 + Projection: test.a, test.b + TableScan: test + " + ) } #[test] @@ -1472,8 +1487,10 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .filter(lit(0i64).eq(lit(1i64)))? .build()?; - let expected = "TableScan: test, full_filters=[Int64(0) = Int64(1)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @"TableScan: test, full_filters=[Int64(0) = Int64(1)]" + ) } #[test] @@ -1485,11 +1502,14 @@ mod tests { .filter(col("a").eq(lit(1i64)))? .build()?; // filter is before double projection - let expected = "\ - Projection: test.c, test.b\ - \n Projection: test.a, test.b, test.c\ - \n TableScan: test, full_filters=[test.a = Int64(1)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.c, test.b + Projection: test.a, test.b, test.c + TableScan: test, full_filters=[test.a = Int64(1)] + " + ) } #[test] @@ -1500,10 +1520,13 @@ mod tests { .filter(col("a").gt(lit(10i64)))? .build()?; // filter of key aggregation is commutative - let expected = "\ - Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b) AS total_salary]]\ - \n TableScan: test, full_filters=[test.a > Int64(10)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b) AS total_salary]] + TableScan: test, full_filters=[test.a > Int64(10)] + " + ) } #[test] @@ -1513,10 +1536,14 @@ mod tests { .aggregate(vec![add(col("b"), col("a"))], vec![sum(col("a")), col("b")])? .filter(col("b").gt(lit(10i64)))? .build()?; - let expected = "Filter: test.b > Int64(10)\ - \n Aggregate: groupBy=[[test.b + test.a]], aggr=[[sum(test.a), test.b]]\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Filter: test.b > Int64(10) + Aggregate: groupBy=[[test.b + test.a]], aggr=[[sum(test.a), test.b]] + TableScan: test + " + ) } #[test] @@ -1525,10 +1552,13 @@ mod tests { .aggregate(vec![add(col("b"), col("a"))], vec![sum(col("a")), col("b")])? .filter(col("test.b + test.a").gt(lit(10i64)))? .build()?; - let expected = - "Aggregate: groupBy=[[test.b + test.a]], aggr=[[sum(test.a), test.b]]\ - \n TableScan: test, full_filters=[test.b + test.a > Int64(10)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Aggregate: groupBy=[[test.b + test.a]], aggr=[[sum(test.a), test.b]] + TableScan: test, full_filters=[test.b + test.a > Int64(10)] + " + ) } #[test] @@ -1539,11 +1569,14 @@ mod tests { .filter(col("b").gt(lit(10i64)))? .build()?; // filter of aggregate is after aggregation since they are non-commutative - let expected = "\ - Filter: b > Int64(10)\ - \n Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b) AS b]]\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Filter: b > Int64(10) + Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b) AS b]] + TableScan: test + " + ) } /// verifies that when partitioning by 'a' and 'b', and filtering by 'b', 'b' is pushed @@ -1551,7 +1584,7 @@ mod tests { fn filter_move_window() -> Result<()> { let table_scan = test_table_scan()?; - let window = Expr::WindowFunction(WindowFunction::new( + let window = Expr::from(WindowFunction::new( WindowFunctionDefinition::WindowUDF( datafusion_functions_window::rank::rank_udwf(), ), @@ -1567,10 +1600,13 @@ mod tests { .filter(col("b").gt(lit(10i64)))? .build()?; - let expected = "\ - WindowAggr: windowExpr=[[rank() PARTITION BY [test.a, test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ - \n TableScan: test, full_filters=[test.b > Int64(10)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + WindowAggr: windowExpr=[[rank() PARTITION BY [test.a, test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] + TableScan: test, full_filters=[test.b > Int64(10)] + " + ) } /// verifies that when partitioning by 'a' and 'b', and filtering by 'a' and 'b', both 'a' and @@ -1579,7 +1615,7 @@ mod tests { fn filter_move_complex_window() -> Result<()> { let table_scan = test_table_scan()?; - let window = Expr::WindowFunction(WindowFunction::new( + let window = Expr::from(WindowFunction::new( WindowFunctionDefinition::WindowUDF( datafusion_functions_window::rank::rank_udwf(), ), @@ -1595,10 +1631,13 @@ mod tests { .filter(and(col("a").gt(lit(10i64)), col("b").eq(lit(1i64))))? .build()?; - let expected = "\ - WindowAggr: windowExpr=[[rank() PARTITION BY [test.a, test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ - \n TableScan: test, full_filters=[test.a > Int64(10), test.b = Int64(1)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + WindowAggr: windowExpr=[[rank() PARTITION BY [test.a, test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] + TableScan: test, full_filters=[test.a > Int64(10), test.b = Int64(1)] + " + ) } /// verifies that when partitioning by 'a' and filtering by 'a' and 'b', only 'a' is pushed @@ -1606,7 +1645,7 @@ mod tests { fn filter_move_partial_window() -> Result<()> { let table_scan = test_table_scan()?; - let window = Expr::WindowFunction(WindowFunction::new( + let window = Expr::from(WindowFunction::new( WindowFunctionDefinition::WindowUDF( datafusion_functions_window::rank::rank_udwf(), ), @@ -1622,11 +1661,14 @@ mod tests { .filter(and(col("a").gt(lit(10i64)), col("b").eq(lit(1i64))))? .build()?; - let expected = "\ - Filter: test.b = Int64(1)\ - \n WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ - \n TableScan: test, full_filters=[test.a > Int64(10)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Filter: test.b = Int64(1) + WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] + TableScan: test, full_filters=[test.a > Int64(10)] + " + ) } /// verifies that filters on partition expressions are not pushed, as the single expression @@ -1635,7 +1677,7 @@ mod tests { fn filter_expression_keep_window() -> Result<()> { let table_scan = test_table_scan()?; - let window = Expr::WindowFunction(WindowFunction::new( + let window = Expr::from(WindowFunction::new( WindowFunctionDefinition::WindowUDF( datafusion_functions_window::rank::rank_udwf(), ), @@ -1653,11 +1695,14 @@ mod tests { .filter(add(col("a"), col("b")).gt(lit(10i64)))? .build()?; - let expected = "\ - Filter: test.a + test.b > Int64(10)\ - \n WindowAggr: windowExpr=[[rank() PARTITION BY [test.a + test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Filter: test.a + test.b > Int64(10) + WindowAggr: windowExpr=[[rank() PARTITION BY [test.a + test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] + TableScan: test + " + ) } /// verifies that filters are not pushed on order by columns (that are not used in partitioning) @@ -1665,7 +1710,7 @@ mod tests { fn filter_order_keep_window() -> Result<()> { let table_scan = test_table_scan()?; - let window = Expr::WindowFunction(WindowFunction::new( + let window = Expr::from(WindowFunction::new( WindowFunctionDefinition::WindowUDF( datafusion_functions_window::rank::rank_udwf(), ), @@ -1681,11 +1726,14 @@ mod tests { .filter(col("c").gt(lit(10i64)))? .build()?; - let expected = "\ - Filter: test.c > Int64(10)\ - \n WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Filter: test.c > Int64(10) + WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] + TableScan: test + " + ) } /// verifies that when we use multiple window functions with a common partition key, the filter @@ -1694,7 +1742,7 @@ mod tests { fn filter_multiple_windows_common_partitions() -> Result<()> { let table_scan = test_table_scan()?; - let window1 = Expr::WindowFunction(WindowFunction::new( + let window1 = Expr::from(WindowFunction::new( WindowFunctionDefinition::WindowUDF( datafusion_functions_window::rank::rank_udwf(), ), @@ -1705,7 +1753,7 @@ mod tests { .build() .unwrap(); - let window2 = Expr::WindowFunction(WindowFunction::new( + let window2 = Expr::from(WindowFunction::new( WindowFunctionDefinition::WindowUDF( datafusion_functions_window::rank::rank_udwf(), ), @@ -1721,10 +1769,13 @@ mod tests { .filter(col("a").gt(lit(10i64)))? // a appears in both window functions .build()?; - let expected = "\ - WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, rank() PARTITION BY [test.b, test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ - \n TableScan: test, full_filters=[test.a > Int64(10)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, rank() PARTITION BY [test.b, test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] + TableScan: test, full_filters=[test.a > Int64(10)] + " + ) } /// verifies that when we use multiple window functions with different partitions keys, the @@ -1733,7 +1784,7 @@ mod tests { fn filter_multiple_windows_disjoint_partitions() -> Result<()> { let table_scan = test_table_scan()?; - let window1 = Expr::WindowFunction(WindowFunction::new( + let window1 = Expr::from(WindowFunction::new( WindowFunctionDefinition::WindowUDF( datafusion_functions_window::rank::rank_udwf(), ), @@ -1744,7 +1795,7 @@ mod tests { .build() .unwrap(); - let window2 = Expr::WindowFunction(WindowFunction::new( + let window2 = Expr::from(WindowFunction::new( WindowFunctionDefinition::WindowUDF( datafusion_functions_window::rank::rank_udwf(), ), @@ -1760,11 +1811,14 @@ mod tests { .filter(col("b").gt(lit(10i64)))? // b only appears in one window function .build()?; - let expected = "\ - Filter: test.b > Int64(10)\ - \n WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, rank() PARTITION BY [test.b, test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Filter: test.b > Int64(10) + WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, rank() PARTITION BY [test.b, test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] + TableScan: test + " + ) } /// verifies that a filter is pushed to before a projection, the filter expression is correctly re-written @@ -1776,10 +1830,13 @@ mod tests { .filter(col("b").eq(lit(1i64)))? .build()?; // filter is before projection - let expected = "\ - Projection: test.a AS b, test.c\ - \n TableScan: test, full_filters=[test.a = Int64(1)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a AS b, test.c + TableScan: test, full_filters=[test.a = Int64(1)] + " + ) } fn add(left: Expr, right: Expr) -> Expr { @@ -1811,19 +1868,21 @@ mod tests { .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "\ - Filter: b = Int64(1)\ - \n Projection: test.a * Int32(2) + test.c AS b, test.c\ - \n TableScan: test" + assert_snapshot!(plan, + @r" + Filter: b = Int64(1) + Projection: test.a * Int32(2) + test.c AS b, test.c + TableScan: test + ", ); - // filter is before projection - let expected = "\ - Projection: test.a * Int32(2) + test.c AS b, test.c\ - \n TableScan: test, full_filters=[test.a * Int32(2) + test.c = Int64(1)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a * Int32(2) + test.c AS b, test.c + TableScan: test, full_filters=[test.a * Int32(2) + test.c = Int64(1)] + " + ) } /// verifies that when a filter is pushed to after 2 projections, the filter expression is correctly re-written @@ -1841,21 +1900,23 @@ mod tests { .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "\ - Filter: a = Int64(1)\ - \n Projection: b * Int32(3) AS a, test.c\ - \n Projection: test.a * Int32(2) + test.c AS b, test.c\ - \n TableScan: test" + assert_snapshot!(plan, + @r" + Filter: a = Int64(1) + Projection: b * Int32(3) AS a, test.c + Projection: test.a * Int32(2) + test.c AS b, test.c + TableScan: test + ", ); - // filter is before the projections - let expected = "\ - Projection: b * Int32(3) AS a, test.c\ - \n Projection: test.a * Int32(2) + test.c AS b, test.c\ - \n TableScan: test, full_filters=[(test.a * Int32(2) + test.c) * Int32(3) = Int64(1)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: b * Int32(3) AS a, test.c + Projection: test.a * Int32(2) + test.c AS b, test.c + TableScan: test, full_filters=[(test.a * Int32(2) + test.c) * Int32(3) = Int64(1)] + " + ) } #[derive(Debug, PartialEq, Eq, Hash)] @@ -1930,10 +1991,13 @@ mod tests { .build()?; // Push filter below NoopPlan - let expected = "\ - NoopPlan\ - \n TableScan: test, full_filters=[test.a = Int64(1)]"; - assert_optimized_plan_eq(plan, expected)?; + assert_optimized_plan_equal!( + plan, + @r" + NoopPlan + TableScan: test, full_filters=[test.a = Int64(1)] + " + )?; let custom_plan = LogicalPlan::Extension(Extension { node: Arc::new(NoopPlan { @@ -1946,11 +2010,14 @@ mod tests { .build()?; // Push only predicate on `a` below NoopPlan - let expected = "\ - Filter: test.c = Int64(2)\ - \n NoopPlan\ - \n TableScan: test, full_filters=[test.a = Int64(1)]"; - assert_optimized_plan_eq(plan, expected)?; + assert_optimized_plan_equal!( + plan, + @r" + Filter: test.c = Int64(2) + NoopPlan + TableScan: test, full_filters=[test.a = Int64(1)] + " + )?; let custom_plan = LogicalPlan::Extension(Extension { node: Arc::new(NoopPlan { @@ -1963,11 +2030,14 @@ mod tests { .build()?; // Push filter below NoopPlan for each child branch - let expected = "\ - NoopPlan\ - \n TableScan: test, full_filters=[test.a = Int64(1)]\ - \n TableScan: test, full_filters=[test.a = Int64(1)]"; - assert_optimized_plan_eq(plan, expected)?; + assert_optimized_plan_equal!( + plan, + @r" + NoopPlan + TableScan: test, full_filters=[test.a = Int64(1)] + TableScan: test, full_filters=[test.a = Int64(1)] + " + )?; let custom_plan = LogicalPlan::Extension(Extension { node: Arc::new(NoopPlan { @@ -1980,12 +2050,15 @@ mod tests { .build()?; // Push only predicate on `a` below NoopPlan - let expected = "\ - Filter: test.c = Int64(2)\ - \n NoopPlan\ - \n TableScan: test, full_filters=[test.a = Int64(1)]\ - \n TableScan: test, full_filters=[test.a = Int64(1)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Filter: test.c = Int64(2) + NoopPlan + TableScan: test, full_filters=[test.a = Int64(1)] + TableScan: test, full_filters=[test.a = Int64(1)] + " + ) } /// verifies that when two filters apply after an aggregation that only allows one to be pushed, one is pushed @@ -2002,23 +2075,25 @@ mod tests { .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "\ - Filter: sum(test.c) > Int64(10)\ - \n Filter: b > Int64(10)\ - \n Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]\ - \n Projection: test.a AS b, test.c\ - \n TableScan: test" + assert_snapshot!(plan, + @r" + Filter: sum(test.c) > Int64(10) + Filter: b > Int64(10) + Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]] + Projection: test.a AS b, test.c + TableScan: test + ", ); - // filter is before the projections - let expected = "\ - Filter: sum(test.c) > Int64(10)\ - \n Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]\ - \n Projection: test.a AS b, test.c\ - \n TableScan: test, full_filters=[test.a > Int64(10)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Filter: sum(test.c) > Int64(10) + Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]] + Projection: test.a AS b, test.c + TableScan: test, full_filters=[test.a > Int64(10)] + " + ) } /// verifies that when a filter with two predicates is applied after an aggregation that only allows one to be pushed, one is pushed @@ -2037,22 +2112,24 @@ mod tests { .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "\ - Filter: sum(test.c) > Int64(10) AND b > Int64(10) AND sum(test.c) < Int64(20)\ - \n Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]\ - \n Projection: test.a AS b, test.c\ - \n TableScan: test" + assert_snapshot!(plan, + @r" + Filter: sum(test.c) > Int64(10) AND b > Int64(10) AND sum(test.c) < Int64(20) + Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]] + Projection: test.a AS b, test.c + TableScan: test + ", ); - // filter is before the projections - let expected = "\ - Filter: sum(test.c) > Int64(10) AND sum(test.c) < Int64(20)\ - \n Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]\ - \n Projection: test.a AS b, test.c\ - \n TableScan: test, full_filters=[test.a > Int64(10)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Filter: sum(test.c) > Int64(10) AND sum(test.c) < Int64(20) + Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]] + Projection: test.a AS b, test.c + TableScan: test, full_filters=[test.a > Int64(10)] + " + ) } /// verifies that when two limits are in place, we jump neither @@ -2067,14 +2144,17 @@ mod tests { .filter(col("a").eq(lit(1i64)))? .build()?; // filter does not just any of the limits - let expected = "\ - Projection: test.a, test.b\ - \n Filter: test.a = Int64(1)\ - \n Limit: skip=0, fetch=10\ - \n Limit: skip=0, fetch=20\ - \n Projection: test.a, test.b\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a, test.b + Filter: test.a = Int64(1) + Limit: skip=0, fetch=10 + Limit: skip=0, fetch=20 + Projection: test.a, test.b + TableScan: test + " + ) } #[test] @@ -2086,10 +2166,14 @@ mod tests { .filter(col("a").eq(lit(1i64)))? .build()?; // filter appears below Union - let expected = "Union\ - \n TableScan: test, full_filters=[test.a = Int64(1)]\ - \n TableScan: test2, full_filters=[test2.a = Int64(1)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Union + TableScan: test, full_filters=[test.a = Int64(1)] + TableScan: test2, full_filters=[test2.a = Int64(1)] + " + ) } #[test] @@ -2106,13 +2190,18 @@ mod tests { .build()?; // filter appears below Union - let expected = "Union\n SubqueryAlias: test2\ - \n Projection: test.a AS b\ - \n TableScan: test, full_filters=[test.a = Int64(1)]\ - \n SubqueryAlias: test2\ - \n Projection: test.a AS b\ - \n TableScan: test, full_filters=[test.a = Int64(1)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Union + SubqueryAlias: test2 + Projection: test.a AS b + TableScan: test, full_filters=[test.a = Int64(1)] + SubqueryAlias: test2 + Projection: test.a AS b + TableScan: test, full_filters=[test.a = Int64(1)] + " + ) } #[test] @@ -2136,14 +2225,17 @@ mod tests { .filter(filter)? .build()?; - let expected = "Projection: test.a, test1.d\ - \n Cross Join: \ - \n Projection: test.a, test.b, test.c\ - \n TableScan: test, full_filters=[test.a = Int32(1)]\ - \n Projection: test1.d, test1.e, test1.f\ - \n TableScan: test1, full_filters=[test1.d > Int32(2)]"; - - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a, test1.d + Cross Join: + Projection: test.a, test.b, test.c + TableScan: test, full_filters=[test.a = Int32(1)] + Projection: test1.d, test1.e, test1.f + TableScan: test1, full_filters=[test1.d > Int32(2)] + " + ) } #[test] @@ -2163,13 +2255,17 @@ mod tests { .filter(filter)? .build()?; - let expected = "Projection: test.a, test1.a\ - \n Cross Join: \ - \n Projection: test.a, test.b, test.c\ - \n TableScan: test, full_filters=[test.a = Int32(1)]\ - \n Projection: test1.a, test1.b, test1.c\ - \n TableScan: test1, full_filters=[test1.a > Int32(2)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a, test1.a + Cross Join: + Projection: test.a, test.b, test.c + TableScan: test, full_filters=[test.a = Int32(1)] + Projection: test1.a, test1.b, test1.c + TableScan: test1, full_filters=[test1.a > Int32(2)] + " + ) } /// verifies that filters with the same columns are correctly placed @@ -2186,24 +2282,26 @@ mod tests { // Should be able to move both filters below the projections // not part of the test - assert_eq!( - format!("{plan}"), - "Filter: test.a >= Int64(1)\ - \n Projection: test.a\ - \n Limit: skip=0, fetch=1\ - \n Filter: test.a <= Int64(1)\ - \n Projection: test.a\ - \n TableScan: test" + assert_snapshot!(plan, + @r" + Filter: test.a >= Int64(1) + Projection: test.a + Limit: skip=0, fetch=1 + Filter: test.a <= Int64(1) + Projection: test.a + TableScan: test + ", ); - - let expected = "\ - Projection: test.a\ - \n Filter: test.a >= Int64(1)\ - \n Limit: skip=0, fetch=1\ - \n Projection: test.a\ - \n TableScan: test, full_filters=[test.a <= Int64(1)]"; - - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a + Filter: test.a >= Int64(1) + Limit: skip=0, fetch=1 + Projection: test.a + TableScan: test, full_filters=[test.a <= Int64(1)] + " + ) } /// verifies that filters to be placed on the same depth are ANDed @@ -2218,22 +2316,24 @@ mod tests { .build()?; // not part of the test - assert_eq!( - format!("{plan}"), - "Projection: test.a\ - \n Filter: test.a >= Int64(1)\ - \n Filter: test.a <= Int64(1)\ - \n Limit: skip=0, fetch=1\ - \n TableScan: test" + assert_snapshot!(plan, + @r" + Projection: test.a + Filter: test.a >= Int64(1) + Filter: test.a <= Int64(1) + Limit: skip=0, fetch=1 + TableScan: test + ", ); - - let expected = "\ - Projection: test.a\ - \n Filter: test.a >= Int64(1) AND test.a <= Int64(1)\ - \n Limit: skip=0, fetch=1\ - \n TableScan: test"; - - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a + Filter: test.a >= Int64(1) AND test.a <= Int64(1) + Limit: skip=0, fetch=1 + TableScan: test + " + ) } /// verifies that filters on a plan with user nodes are not lost @@ -2247,19 +2347,21 @@ mod tests { let plan = user_defined::new(plan); - let expected = "\ - TestUserDefined\ - \n Filter: test.a <= Int64(1)\ - \n TableScan: test"; - // not part of the test - assert_eq!(format!("{plan}"), expected); - - let expected = "\ - TestUserDefined\ - \n TableScan: test, full_filters=[test.a <= Int64(1)]"; - - assert_optimized_plan_eq(plan, expected) + assert_snapshot!(plan, + @r" + TestUserDefined + Filter: test.a <= Int64(1) + TableScan: test + ", + ); + assert_optimized_plan_equal!( + plan, + @r" + TestUserDefined + TableScan: test, full_filters=[test.a <= Int64(1)] + " + ) } /// post-on-join predicates on a column common to both sides is pushed to both sides @@ -2282,22 +2384,25 @@ mod tests { .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "Filter: test.a <= Int64(1)\ - \n Inner Join: test.a = test2.a\ - \n TableScan: test\ - \n Projection: test2.a\ - \n TableScan: test2" + assert_snapshot!(plan, + @r" + Filter: test.a <= Int64(1) + Inner Join: test.a = test2.a + TableScan: test + Projection: test2.a + TableScan: test2 + ", ); - // filter sent to side before the join - let expected = "\ - Inner Join: test.a = test2.a\ - \n TableScan: test, full_filters=[test.a <= Int64(1)]\ - \n Projection: test2.a\ - \n TableScan: test2, full_filters=[test2.a <= Int64(1)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Inner Join: test.a = test2.a + TableScan: test, full_filters=[test.a <= Int64(1)] + Projection: test2.a + TableScan: test2, full_filters=[test2.a <= Int64(1)] + " + ) } /// post-using-join predicates on a column common to both sides is pushed to both sides @@ -2319,22 +2424,25 @@ mod tests { .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "Filter: test.a <= Int64(1)\ - \n Inner Join: Using test.a = test2.a\ - \n TableScan: test\ - \n Projection: test2.a\ - \n TableScan: test2" + assert_snapshot!(plan, + @r" + Filter: test.a <= Int64(1) + Inner Join: Using test.a = test2.a + TableScan: test + Projection: test2.a + TableScan: test2 + ", ); - // filter sent to side before the join - let expected = "\ - Inner Join: Using test.a = test2.a\ - \n TableScan: test, full_filters=[test.a <= Int64(1)]\ - \n Projection: test2.a\ - \n TableScan: test2, full_filters=[test2.a <= Int64(1)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Inner Join: Using test.a = test2.a + TableScan: test, full_filters=[test.a <= Int64(1)] + Projection: test2.a + TableScan: test2, full_filters=[test2.a <= Int64(1)] + " + ) } /// post-join predicates with columns from both sides are converted to join filters @@ -2359,24 +2467,27 @@ mod tests { .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "Filter: test.c <= test2.b\ - \n Inner Join: test.a = test2.a\ - \n Projection: test.a, test.c\ - \n TableScan: test\ - \n Projection: test2.a, test2.b\ - \n TableScan: test2" + assert_snapshot!(plan, + @r" + Filter: test.c <= test2.b + Inner Join: test.a = test2.a + Projection: test.a, test.c + TableScan: test + Projection: test2.a, test2.b + TableScan: test2 + ", ); - // Filter is converted to Join Filter - let expected = "\ - Inner Join: test.a = test2.a Filter: test.c <= test2.b\ - \n Projection: test.a, test.c\ - \n TableScan: test\ - \n Projection: test2.a, test2.b\ - \n TableScan: test2"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Inner Join: test.a = test2.a Filter: test.c <= test2.b + Projection: test.a, test.c + TableScan: test + Projection: test2.a, test2.b + TableScan: test2 + " + ) } /// post-join predicates with columns from one side of a join are pushed only to that side @@ -2402,23 +2513,26 @@ mod tests { .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "Filter: test.b <= Int64(1)\ - \n Inner Join: test.a = test2.a\ - \n Projection: test.a, test.b\ - \n TableScan: test\ - \n Projection: test2.a, test2.c\ - \n TableScan: test2" + assert_snapshot!(plan, + @r" + Filter: test.b <= Int64(1) + Inner Join: test.a = test2.a + Projection: test.a, test.b + TableScan: test + Projection: test2.a, test2.c + TableScan: test2 + ", ); - - let expected = "\ - Inner Join: test.a = test2.a\ - \n Projection: test.a, test.b\ - \n TableScan: test, full_filters=[test.b <= Int64(1)]\ - \n Projection: test2.a, test2.c\ - \n TableScan: test2"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Inner Join: test.a = test2.a + Projection: test.a, test.b + TableScan: test, full_filters=[test.b <= Int64(1)] + Projection: test2.a, test2.c + TableScan: test2 + " + ) } /// post-join predicates on the right side of a left join are not duplicated @@ -2441,23 +2555,26 @@ mod tests { .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "Filter: test2.a <= Int64(1)\ - \n Left Join: Using test.a = test2.a\ - \n TableScan: test\ - \n Projection: test2.a\ - \n TableScan: test2" + assert_snapshot!(plan, + @r" + Filter: test2.a <= Int64(1) + Left Join: Using test.a = test2.a + TableScan: test + Projection: test2.a + TableScan: test2 + ", ); - // filter not duplicated nor pushed down - i.e. noop - let expected = "\ - Filter: test2.a <= Int64(1)\ - \n Left Join: Using test.a = test2.a\ - \n TableScan: test, full_filters=[test.a <= Int64(1)]\ - \n Projection: test2.a\ - \n TableScan: test2"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Filter: test2.a <= Int64(1) + Left Join: Using test.a = test2.a + TableScan: test, full_filters=[test.a <= Int64(1)] + Projection: test2.a + TableScan: test2 + " + ) } /// post-join predicates on the left side of a right join are not duplicated @@ -2479,23 +2596,26 @@ mod tests { .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "Filter: test.a <= Int64(1)\ - \n Right Join: Using test.a = test2.a\ - \n TableScan: test\ - \n Projection: test2.a\ - \n TableScan: test2" + assert_snapshot!(plan, + @r" + Filter: test.a <= Int64(1) + Right Join: Using test.a = test2.a + TableScan: test + Projection: test2.a + TableScan: test2 + ", ); - // filter not duplicated nor pushed down - i.e. noop - let expected = "\ - Filter: test.a <= Int64(1)\ - \n Right Join: Using test.a = test2.a\ - \n TableScan: test\ - \n Projection: test2.a\ - \n TableScan: test2, full_filters=[test2.a <= Int64(1)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Filter: test.a <= Int64(1) + Right Join: Using test.a = test2.a + TableScan: test + Projection: test2.a + TableScan: test2, full_filters=[test2.a <= Int64(1)] + " + ) } /// post-left-join predicate on a column common to both sides is only pushed to the left side @@ -2518,22 +2638,25 @@ mod tests { .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "Filter: test.a <= Int64(1)\ - \n Left Join: Using test.a = test2.a\ - \n TableScan: test\ - \n Projection: test2.a\ - \n TableScan: test2" + assert_snapshot!(plan, + @r" + Filter: test.a <= Int64(1) + Left Join: Using test.a = test2.a + TableScan: test + Projection: test2.a + TableScan: test2 + ", ); - // filter sent to left side of the join, not the right - let expected = "\ - Left Join: Using test.a = test2.a\ - \n TableScan: test, full_filters=[test.a <= Int64(1)]\ - \n Projection: test2.a\ - \n TableScan: test2"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Left Join: Using test.a = test2.a + TableScan: test, full_filters=[test.a <= Int64(1)] + Projection: test2.a + TableScan: test2 + " + ) } /// post-right-join predicate on a column common to both sides is only pushed to the right side @@ -2556,22 +2679,25 @@ mod tests { .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "Filter: test2.a <= Int64(1)\ - \n Right Join: Using test.a = test2.a\ - \n TableScan: test\ - \n Projection: test2.a\ - \n TableScan: test2" + assert_snapshot!(plan, + @r" + Filter: test2.a <= Int64(1) + Right Join: Using test.a = test2.a + TableScan: test + Projection: test2.a + TableScan: test2 + ", ); - // filter sent to right side of join, not duplicated to the left - let expected = "\ - Right Join: Using test.a = test2.a\ - \n TableScan: test\ - \n Projection: test2.a\ - \n TableScan: test2, full_filters=[test2.a <= Int64(1)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Right Join: Using test.a = test2.a + TableScan: test + Projection: test2.a + TableScan: test2, full_filters=[test2.a <= Int64(1)] + " + ) } /// single table predicate parts of ON condition should be pushed to both inputs @@ -2599,22 +2725,25 @@ mod tests { .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "Inner Join: test.a = test2.a Filter: test.c > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)\ - \n Projection: test.a, test.b, test.c\ - \n TableScan: test\ - \n Projection: test2.a, test2.b, test2.c\ - \n TableScan: test2" + assert_snapshot!(plan, + @r" + Inner Join: test.a = test2.a Filter: test.c > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4) + Projection: test.a, test.b, test.c + TableScan: test + Projection: test2.a, test2.b, test2.c + TableScan: test2 + ", ); - - let expected = "\ - Inner Join: test.a = test2.a Filter: test.b < test2.b\ - \n Projection: test.a, test.b, test.c\ - \n TableScan: test, full_filters=[test.c > UInt32(1)]\ - \n Projection: test2.a, test2.b, test2.c\ - \n TableScan: test2, full_filters=[test2.c > UInt32(4)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Inner Join: test.a = test2.a Filter: test.b < test2.b + Projection: test.a, test.b, test.c + TableScan: test, full_filters=[test.c > UInt32(1)] + Projection: test2.a, test2.b, test2.c + TableScan: test2, full_filters=[test2.c > UInt32(4)] + " + ) } /// join filter should be completely removed after pushdown @@ -2641,22 +2770,25 @@ mod tests { .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "Inner Join: test.a = test2.a Filter: test.b > UInt32(1) AND test2.c > UInt32(4)\ - \n Projection: test.a, test.b, test.c\ - \n TableScan: test\ - \n Projection: test2.a, test2.b, test2.c\ - \n TableScan: test2" + assert_snapshot!(plan, + @r" + Inner Join: test.a = test2.a Filter: test.b > UInt32(1) AND test2.c > UInt32(4) + Projection: test.a, test.b, test.c + TableScan: test + Projection: test2.a, test2.b, test2.c + TableScan: test2 + ", ); - - let expected = "\ - Inner Join: test.a = test2.a\ - \n Projection: test.a, test.b, test.c\ - \n TableScan: test, full_filters=[test.b > UInt32(1)]\ - \n Projection: test2.a, test2.b, test2.c\ - \n TableScan: test2, full_filters=[test2.c > UInt32(4)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Inner Join: test.a = test2.a + Projection: test.a, test.b, test.c + TableScan: test, full_filters=[test.b > UInt32(1)] + Projection: test2.a, test2.b, test2.c + TableScan: test2, full_filters=[test2.c > UInt32(4)] + " + ) } /// predicate on join key in filter expression should be pushed down to both inputs @@ -2681,22 +2813,25 @@ mod tests { .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "Inner Join: test.a = test2.b Filter: test.a > UInt32(1)\ - \n Projection: test.a\ - \n TableScan: test\ - \n Projection: test2.b\ - \n TableScan: test2" + assert_snapshot!(plan, + @r" + Inner Join: test.a = test2.b Filter: test.a > UInt32(1) + Projection: test.a + TableScan: test + Projection: test2.b + TableScan: test2 + ", ); - - let expected = "\ - Inner Join: test.a = test2.b\ - \n Projection: test.a\ - \n TableScan: test, full_filters=[test.a > UInt32(1)]\ - \n Projection: test2.b\ - \n TableScan: test2, full_filters=[test2.b > UInt32(1)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Inner Join: test.a = test2.b + Projection: test.a + TableScan: test, full_filters=[test.a > UInt32(1)] + Projection: test2.b + TableScan: test2, full_filters=[test2.b > UInt32(1)] + " + ) } /// single table predicate parts of ON condition should be pushed to right input @@ -2724,22 +2859,25 @@ mod tests { .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "Left Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)\ - \n Projection: test.a, test.b, test.c\ - \n TableScan: test\ - \n Projection: test2.a, test2.b, test2.c\ - \n TableScan: test2" + assert_snapshot!(plan, + @r" + Left Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4) + Projection: test.a, test.b, test.c + TableScan: test + Projection: test2.a, test2.b, test2.c + TableScan: test2 + ", ); - - let expected = "\ - Left Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b\ - \n Projection: test.a, test.b, test.c\ - \n TableScan: test\ - \n Projection: test2.a, test2.b, test2.c\ - \n TableScan: test2, full_filters=[test2.c > UInt32(4)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Left Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b + Projection: test.a, test.b, test.c + TableScan: test + Projection: test2.a, test2.b, test2.c + TableScan: test2, full_filters=[test2.c > UInt32(4)] + " + ) } /// single table predicate parts of ON condition should be pushed to left input @@ -2767,22 +2905,25 @@ mod tests { .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "Right Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)\ - \n Projection: test.a, test.b, test.c\ - \n TableScan: test\ - \n Projection: test2.a, test2.b, test2.c\ - \n TableScan: test2" + assert_snapshot!(plan, + @r" + Right Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4) + Projection: test.a, test.b, test.c + TableScan: test + Projection: test2.a, test2.b, test2.c + TableScan: test2 + ", ); - - let expected = "\ - Right Join: test.a = test2.a Filter: test.b < test2.b AND test2.c > UInt32(4)\ - \n Projection: test.a, test.b, test.c\ - \n TableScan: test, full_filters=[test.a > UInt32(1)]\ - \n Projection: test2.a, test2.b, test2.c\ - \n TableScan: test2"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Right Join: test.a = test2.a Filter: test.b < test2.b AND test2.c > UInt32(4) + Projection: test.a, test.b, test.c + TableScan: test, full_filters=[test.a > UInt32(1)] + Projection: test2.a, test2.b, test2.c + TableScan: test2 + " + ) } /// single table predicate parts of ON condition should not be pushed @@ -2810,17 +2951,25 @@ mod tests { .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "Full Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)\ - \n Projection: test.a, test.b, test.c\ - \n TableScan: test\ - \n Projection: test2.a, test2.b, test2.c\ - \n TableScan: test2" + assert_snapshot!(plan, + @r" + Full Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4) + Projection: test.a, test.b, test.c + TableScan: test + Projection: test2.a, test2.b, test2.c + TableScan: test2 + ", ); - - let expected = &format!("{plan}"); - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Full Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4) + Projection: test.a, test.b, test.c + TableScan: test + Projection: test2.a, test2.b, test2.c + TableScan: test2 + " + ) } struct PushDownProvider { @@ -2887,9 +3036,10 @@ mod tests { fn filter_with_table_provider_exact() -> Result<()> { let plan = table_scan_with_pushdown_provider(TableProviderFilterPushDown::Exact)?; - let expected = "\ - TableScan: test, full_filters=[a = Int64(1)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @"TableScan: test, full_filters=[a = Int64(1)]" + ) } #[test] @@ -2897,10 +3047,13 @@ mod tests { let plan = table_scan_with_pushdown_provider(TableProviderFilterPushDown::Inexact)?; - let expected = "\ - Filter: a = Int64(1)\ - \n TableScan: test, partial_filters=[a = Int64(1)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Filter: a = Int64(1) + TableScan: test, partial_filters=[a = Int64(1)] + " + ) } #[test] @@ -2913,13 +3066,15 @@ mod tests { .expect("failed to optimize plan") .data; - let expected = "\ - Filter: a = Int64(1)\ - \n TableScan: test, partial_filters=[a = Int64(1)]"; - // Optimizing the same plan multiple times should produce the same plan // each time. - assert_optimized_plan_eq(optimized_plan, expected) + assert_optimized_plan_equal!( + optimized_plan, + @r" + Filter: a = Int64(1) + TableScan: test, partial_filters=[a = Int64(1)] + " + ) } #[test] @@ -2927,10 +3082,13 @@ mod tests { let plan = table_scan_with_pushdown_provider(TableProviderFilterPushDown::Unsupported)?; - let expected = "\ - Filter: a = Int64(1)\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Filter: a = Int64(1) + TableScan: test + " + ) } #[test] @@ -2944,11 +3102,14 @@ mod tests { .project(vec![col("a"), col("b")])? .build()?; - let expected = "Projection: a, b\ - \n Filter: a = Int64(10) AND b > Int64(11)\ - \n TableScan: test projection=[a], partial_filters=[a = Int64(10), b > Int64(11)]"; - - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: a, b + Filter: a = Int64(10) AND b > Int64(11) + TableScan: test projection=[a], partial_filters=[a = Int64(10), b > Int64(11)] + " + ) } #[test] @@ -2962,13 +3123,13 @@ mod tests { .project(vec![col("a"), col("b")])? .build()?; - let expected = r#" -Projection: a, b - TableScan: test projection=[a], full_filters=[a = Int64(10), b > Int64(11)] - "# - .trim(); - - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: a, b + TableScan: test projection=[a], full_filters=[a = Int64(10), b > Int64(11)] + " + ) } #[test] @@ -2983,20 +3144,21 @@ Projection: a, b .build()?; // filter on col b - assert_eq!( - format!("{plan}"), - "Filter: b > Int64(10) AND test.c > Int64(10)\ - \n Projection: test.a AS b, test.c\ - \n TableScan: test" + assert_snapshot!(plan, + @r" + Filter: b > Int64(10) AND test.c > Int64(10) + Projection: test.a AS b, test.c + TableScan: test + ", ); - // rewrite filter col b to test.a - let expected = "\ - Projection: test.a AS b, test.c\ - \n TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]\ - "; - - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a AS b, test.c + TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)] + " + ) } #[test] @@ -3012,23 +3174,23 @@ Projection: a, b .build()?; // filter on col b - assert_eq!( - format!("{plan}"), - "Filter: b > Int64(10) AND test.c > Int64(10)\ - \n Projection: b, test.c\ - \n Projection: test.a AS b, test.c\ - \n TableScan: test\ - " + assert_snapshot!(plan, + @r" + Filter: b > Int64(10) AND test.c > Int64(10) + Projection: b, test.c + Projection: test.a AS b, test.c + TableScan: test + ", ); - // rewrite filter col b to test.a - let expected = "\ - Projection: b, test.c\ - \n Projection: test.a AS b, test.c\ - \n TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]\ - "; - - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: b, test.c + Projection: test.a AS b, test.c + TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)] + " + ) } #[test] @@ -3040,20 +3202,21 @@ Projection: a, b .build()?; // filter on col b and d - assert_eq!( - format!("{plan}"), - "Filter: b > Int64(10) AND d > Int64(10)\ - \n Projection: test.a AS b, test.c AS d\ - \n TableScan: test\ - " + assert_snapshot!(plan, + @r" + Filter: b > Int64(10) AND d > Int64(10) + Projection: test.a AS b, test.c AS d + TableScan: test + ", ); - // rewrite filter col b to test.a, col d to test.c - let expected = "\ - Projection: test.a AS b, test.c AS d\ - \n TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]"; - - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a AS b, test.c AS d + TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)] + " + ) } /// predicate on join key in filter expression should be pushed down to both inputs @@ -3077,23 +3240,26 @@ Projection: a, b )? .build()?; - assert_eq!( - format!("{plan}"), - "Inner Join: c = d Filter: c > UInt32(1)\ - \n Projection: test.a AS c\ - \n TableScan: test\ - \n Projection: test2.b AS d\ - \n TableScan: test2" + assert_snapshot!(plan, + @r" + Inner Join: c = d Filter: c > UInt32(1) + Projection: test.a AS c + TableScan: test + Projection: test2.b AS d + TableScan: test2 + ", ); - // Change filter on col `c`, 'd' to `test.a`, 'test.b' - let expected = "\ - Inner Join: c = d\ - \n Projection: test.a AS c\ - \n TableScan: test, full_filters=[test.a > UInt32(1)]\ - \n Projection: test2.b AS d\ - \n TableScan: test2, full_filters=[test2.b > UInt32(1)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Inner Join: c = d + Projection: test.a AS c + TableScan: test, full_filters=[test.a > UInt32(1)] + Projection: test2.b AS d + TableScan: test2, full_filters=[test2.b > UInt32(1)] + " + ) } #[test] @@ -3109,20 +3275,21 @@ Projection: a, b .build()?; // filter on col b - assert_eq!( - format!("{plan}"), - "Filter: b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])\ - \n Projection: test.a AS b, test.c\ - \n TableScan: test\ - " + assert_snapshot!(plan, + @r" + Filter: b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)]) + Projection: test.a AS b, test.c + TableScan: test + ", ); - // rewrite filter col b to test.a - let expected = "\ - Projection: test.a AS b, test.c\ - \n TableScan: test, full_filters=[test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])]"; - - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a AS b, test.c + TableScan: test, full_filters=[test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])] + " + ) } #[test] @@ -3139,22 +3306,23 @@ Projection: a, b .build()?; // filter on col b - assert_eq!( - format!("{plan}"), - "Filter: b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])\ - \n Projection: b, test.c\ - \n Projection: test.a AS b, test.c\ - \n TableScan: test\ - " + assert_snapshot!(plan, + @r" + Filter: b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)]) + Projection: b, test.c + Projection: test.a AS b, test.c + TableScan: test + ", ); - // rewrite filter col b to test.a - let expected = "\ - Projection: b, test.c\ - \n Projection: test.a AS b, test.c\ - \n TableScan: test, full_filters=[test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])]"; - - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: b, test.c + Projection: test.a AS b, test.c + TableScan: test, full_filters=[test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])] + " + ) } #[test] @@ -3174,23 +3342,27 @@ Projection: a, b .build()?; // filter on col b in subquery - let expected_before = "\ - Filter: b IN ()\ - \n Subquery:\ - \n Projection: sq.c\ - \n TableScan: sq\ - \n Projection: test.a AS b, test.c\ - \n TableScan: test"; - assert_eq!(format!("{plan}"), expected_before); - + assert_snapshot!(plan, + @r" + Filter: b IN () + Subquery: + Projection: sq.c + TableScan: sq + Projection: test.a AS b, test.c + TableScan: test + ", + ); // rewrite filter col b to test.a - let expected_after = "\ - Projection: test.a AS b, test.c\ - \n TableScan: test, full_filters=[test.a IN ()]\ - \n Subquery:\ - \n Projection: sq.c\ - \n TableScan: sq"; - assert_optimized_plan_eq(plan, expected_after) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a AS b, test.c + TableScan: test, full_filters=[test.a IN ()] + Subquery: + Projection: sq.c + TableScan: sq + " + ) } #[test] @@ -3205,25 +3377,31 @@ Projection: a, b .project(vec![col("b.a")])? .build()?; - let expected_before = "Projection: b.a\ - \n Filter: b.a = Int64(1)\ - \n SubqueryAlias: b\ - \n Projection: b.a\ - \n SubqueryAlias: b\ - \n Projection: Int64(0) AS a\ - \n EmptyRelation"; - assert_eq!(format!("{plan}"), expected_before); - + assert_snapshot!(plan, + @r" + Projection: b.a + Filter: b.a = Int64(1) + SubqueryAlias: b + Projection: b.a + SubqueryAlias: b + Projection: Int64(0) AS a + EmptyRelation + ", + ); // Ensure that the predicate without any columns (0 = 1) is // still there. - let expected_after = "Projection: b.a\ - \n SubqueryAlias: b\ - \n Projection: b.a\ - \n SubqueryAlias: b\ - \n Projection: Int64(0) AS a\ - \n Filter: Int64(0) = Int64(1)\ - \n EmptyRelation"; - assert_optimized_plan_eq(plan, expected_after) + assert_optimized_plan_equal!( + plan, + @r" + Projection: b.a + SubqueryAlias: b + Projection: b.a + SubqueryAlias: b + Projection: Int64(0) AS a + Filter: Int64(0) = Int64(1) + EmptyRelation + " + ) } #[test] @@ -3245,13 +3423,14 @@ Projection: a, b .cross_join(right)? .filter(filter)? .build()?; - let expected = "\ - Inner Join: Filter: test.a = d AND test.b > UInt32(1) OR test.b = e AND test.c < UInt32(10)\ - \n Projection: test.a, test.b, test.c\ - \n TableScan: test, full_filters=[test.b > UInt32(1) OR test.c < UInt32(10)]\ - \n Projection: test1.a AS d, test1.a AS e\ - \n TableScan: test1"; - assert_optimized_plan_eq_with_rewrite_predicate(plan.clone(), expected)?; + + assert_optimized_plan_eq_with_rewrite_predicate!(plan.clone(), @r" + Inner Join: Filter: test.a = d AND test.b > UInt32(1) OR test.b = e AND test.c < UInt32(10) + Projection: test.a, test.b, test.c + TableScan: test, full_filters=[test.b > UInt32(1) OR test.c < UInt32(10)] + Projection: test1.a AS d, test1.a AS e + TableScan: test1 + ")?; // Originally global state which can help to avoid duplicate Filters been generated and pushed down. // Now the global state is removed. Need to double confirm that avoid duplicate Filters. @@ -3259,7 +3438,16 @@ Projection: a, b .rewrite(plan, &OptimizerContext::new()) .expect("failed to optimize plan") .data; - assert_optimized_plan_eq(optimized_plan, expected) + assert_optimized_plan_equal!( + optimized_plan, + @r" + Inner Join: Filter: test.a = d AND test.b > UInt32(1) OR test.b = e AND test.c < UInt32(10) + Projection: test.a, test.b, test.c + TableScan: test, full_filters=[test.b > UInt32(1) OR test.c < UInt32(10)] + Projection: test1.a AS d, test1.a AS e + TableScan: test1 + " + ) } #[test] @@ -3283,23 +3471,26 @@ Projection: a, b .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "Filter: test2.a <= Int64(1)\ - \n LeftSemi Join: test1.a = test2.a\ - \n TableScan: test1\ - \n Projection: test2.a, test2.b\ - \n TableScan: test2" + assert_snapshot!(plan, + @r" + Filter: test2.a <= Int64(1) + LeftSemi Join: test1.a = test2.a + TableScan: test1 + Projection: test2.a, test2.b + TableScan: test2 + ", ); - // Inferred the predicate `test1.a <= Int64(1)` and push it down to the left side. - let expected = "\ - Filter: test2.a <= Int64(1)\ - \n LeftSemi Join: test1.a = test2.a\ - \n TableScan: test1, full_filters=[test1.a <= Int64(1)]\ - \n Projection: test2.a, test2.b\ - \n TableScan: test2"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Filter: test2.a <= Int64(1) + LeftSemi Join: test1.a = test2.a + TableScan: test1, full_filters=[test1.a <= Int64(1)] + Projection: test2.a, test2.b + TableScan: test2 + " + ) } #[test] @@ -3326,21 +3517,24 @@ Projection: a, b .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "LeftSemi Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)\ - \n TableScan: test1\ - \n Projection: test2.a, test2.b\ - \n TableScan: test2", + assert_snapshot!(plan, + @r" + LeftSemi Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2) + TableScan: test1 + Projection: test2.a, test2.b + TableScan: test2 + ", ); - // Both side will be pushed down. - let expected = "\ - LeftSemi Join: test1.a = test2.a\ - \n TableScan: test1, full_filters=[test1.b > UInt32(1)]\ - \n Projection: test2.a, test2.b\ - \n TableScan: test2, full_filters=[test2.b > UInt32(2)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + LeftSemi Join: test1.a = test2.a + TableScan: test1, full_filters=[test1.b > UInt32(1)] + Projection: test2.a, test2.b + TableScan: test2, full_filters=[test2.b > UInt32(2)] + " + ) } #[test] @@ -3364,23 +3558,26 @@ Projection: a, b .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "Filter: test1.a <= Int64(1)\ - \n RightSemi Join: test1.a = test2.a\ - \n TableScan: test1\ - \n Projection: test2.a, test2.b\ - \n TableScan: test2", + assert_snapshot!(plan, + @r" + Filter: test1.a <= Int64(1) + RightSemi Join: test1.a = test2.a + TableScan: test1 + Projection: test2.a, test2.b + TableScan: test2 + ", ); - // Inferred the predicate `test2.a <= Int64(1)` and push it down to the right side. - let expected = "\ - Filter: test1.a <= Int64(1)\ - \n RightSemi Join: test1.a = test2.a\ - \n TableScan: test1\ - \n Projection: test2.a, test2.b\ - \n TableScan: test2, full_filters=[test2.a <= Int64(1)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Filter: test1.a <= Int64(1) + RightSemi Join: test1.a = test2.a + TableScan: test1 + Projection: test2.a, test2.b + TableScan: test2, full_filters=[test2.a <= Int64(1)] + " + ) } #[test] @@ -3407,21 +3604,24 @@ Projection: a, b .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "RightSemi Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)\ - \n TableScan: test1\ - \n Projection: test2.a, test2.b\ - \n TableScan: test2", + assert_snapshot!(plan, + @r" + RightSemi Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2) + TableScan: test1 + Projection: test2.a, test2.b + TableScan: test2 + ", ); - // Both side will be pushed down. - let expected = "\ - RightSemi Join: test1.a = test2.a\ - \n TableScan: test1, full_filters=[test1.b > UInt32(1)]\ - \n Projection: test2.a, test2.b\ - \n TableScan: test2, full_filters=[test2.b > UInt32(2)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + RightSemi Join: test1.a = test2.a + TableScan: test1, full_filters=[test1.b > UInt32(1)] + Projection: test2.a, test2.b + TableScan: test2, full_filters=[test2.b > UInt32(2)] + " + ) } #[test] @@ -3448,25 +3648,28 @@ Projection: a, b .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "Filter: test2.a > UInt32(2)\ - \n LeftAnti Join: test1.a = test2.a\ - \n Projection: test1.a, test1.b\ - \n TableScan: test1\ - \n Projection: test2.a, test2.b\ - \n TableScan: test2", + assert_snapshot!(plan, + @r" + Filter: test2.a > UInt32(2) + LeftAnti Join: test1.a = test2.a + Projection: test1.a, test1.b + TableScan: test1 + Projection: test2.a, test2.b + TableScan: test2 + ", ); - // For left anti, filter of the right side filter can be pushed down. - let expected = "\ - Filter: test2.a > UInt32(2)\ - \n LeftAnti Join: test1.a = test2.a\ - \n Projection: test1.a, test1.b\ - \n TableScan: test1, full_filters=[test1.a > UInt32(2)]\ - \n Projection: test2.a, test2.b\ - \n TableScan: test2"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Filter: test2.a > UInt32(2) + LeftAnti Join: test1.a = test2.a + Projection: test1.a, test1.b + TableScan: test1, full_filters=[test1.a > UInt32(2)] + Projection: test2.a, test2.b + TableScan: test2 + " + ) } #[test] @@ -3496,23 +3699,26 @@ Projection: a, b .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "LeftAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)\ - \n Projection: test1.a, test1.b\ - \n TableScan: test1\ - \n Projection: test2.a, test2.b\ - \n TableScan: test2", + assert_snapshot!(plan, + @r" + LeftAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2) + Projection: test1.a, test1.b + TableScan: test1 + Projection: test2.a, test2.b + TableScan: test2 + ", ); - // For left anti, filter of the right side filter can be pushed down. - let expected = "\ - LeftAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1)\ - \n Projection: test1.a, test1.b\ - \n TableScan: test1\ - \n Projection: test2.a, test2.b\ - \n TableScan: test2, full_filters=[test2.b > UInt32(2)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + LeftAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1) + Projection: test1.a, test1.b + TableScan: test1 + Projection: test2.a, test2.b + TableScan: test2, full_filters=[test2.b > UInt32(2)] + " + ) } #[test] @@ -3539,25 +3745,28 @@ Projection: a, b .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "Filter: test1.a > UInt32(2)\ - \n RightAnti Join: test1.a = test2.a\ - \n Projection: test1.a, test1.b\ - \n TableScan: test1\ - \n Projection: test2.a, test2.b\ - \n TableScan: test2", + assert_snapshot!(plan, + @r" + Filter: test1.a > UInt32(2) + RightAnti Join: test1.a = test2.a + Projection: test1.a, test1.b + TableScan: test1 + Projection: test2.a, test2.b + TableScan: test2 + ", ); - // For right anti, filter of the left side can be pushed down. - let expected = "\ - Filter: test1.a > UInt32(2)\ - \n RightAnti Join: test1.a = test2.a\ - \n Projection: test1.a, test1.b\ - \n TableScan: test1\ - \n Projection: test2.a, test2.b\ - \n TableScan: test2, full_filters=[test2.a > UInt32(2)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Filter: test1.a > UInt32(2) + RightAnti Join: test1.a = test2.a + Projection: test1.a, test1.b + TableScan: test1 + Projection: test2.a, test2.b + TableScan: test2, full_filters=[test2.a > UInt32(2)] + " + ) } #[test] @@ -3587,22 +3796,26 @@ Projection: a, b .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "RightAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)\ - \n Projection: test1.a, test1.b\ - \n TableScan: test1\ - \n Projection: test2.a, test2.b\ - \n TableScan: test2", + assert_snapshot!(plan, + @r" + RightAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2) + Projection: test1.a, test1.b + TableScan: test1 + Projection: test2.a, test2.b + TableScan: test2 + ", ); - // For right anti, filter of the left side can be pushed down. - let expected = "RightAnti Join: test1.a = test2.a Filter: test2.b > UInt32(2)\ - \n Projection: test1.a, test1.b\ - \n TableScan: test1, full_filters=[test1.b > UInt32(1)]\ - \n Projection: test2.a, test2.b\ - \n TableScan: test2"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + RightAnti Join: test1.a = test2.a Filter: test2.b > UInt32(2) + Projection: test1.a, test1.b + TableScan: test1, full_filters=[test1.b > UInt32(1)] + Projection: test2.a, test2.b + TableScan: test2 + " + ) } #[derive(Debug)] @@ -3648,21 +3861,27 @@ Projection: a, b .project(vec![col("t.a"), col("t.r")])? .build()?; - let expected_before = "Projection: t.a, t.r\ - \n Filter: t.a > Int32(5) AND t.r > Float64(0.5)\ - \n SubqueryAlias: t\ - \n Projection: test1.a, sum(test1.b), TestScalarUDF() + Int32(1) AS r\ - \n Aggregate: groupBy=[[test1.a]], aggr=[[sum(test1.b)]]\ - \n TableScan: test1"; - assert_eq!(format!("{plan}"), expected_before); - - let expected_after = "Projection: t.a, t.r\ - \n SubqueryAlias: t\ - \n Filter: r > Float64(0.5)\ - \n Projection: test1.a, sum(test1.b), TestScalarUDF() + Int32(1) AS r\ - \n Aggregate: groupBy=[[test1.a]], aggr=[[sum(test1.b)]]\ - \n TableScan: test1, full_filters=[test1.a > Int32(5)]"; - assert_optimized_plan_eq(plan, expected_after) + assert_snapshot!(plan, + @r" + Projection: t.a, t.r + Filter: t.a > Int32(5) AND t.r > Float64(0.5) + SubqueryAlias: t + Projection: test1.a, sum(test1.b), TestScalarUDF() + Int32(1) AS r + Aggregate: groupBy=[[test1.a]], aggr=[[sum(test1.b)]] + TableScan: test1 + ", + ); + assert_optimized_plan_equal!( + plan, + @r" + Projection: t.a, t.r + SubqueryAlias: t + Filter: r > Float64(0.5) + Projection: test1.a, sum(test1.b), TestScalarUDF() + Int32(1) AS r + Aggregate: groupBy=[[test1.a]], aggr=[[sum(test1.b)]] + TableScan: test1, full_filters=[test1.a > Int32(5)] + " + ) } #[test] @@ -3692,23 +3911,29 @@ Projection: a, b .project(vec![col("t.a"), col("t.r")])? .build()?; - let expected_before = "Projection: t.a, t.r\ - \n Filter: t.r > Float64(0.8)\ - \n SubqueryAlias: t\ - \n Projection: test1.a AS a, TestScalarUDF() AS r\ - \n Inner Join: test1.a = test2.a\ - \n TableScan: test1\ - \n TableScan: test2"; - assert_eq!(format!("{plan}"), expected_before); - - let expected = "Projection: t.a, t.r\ - \n SubqueryAlias: t\ - \n Filter: r > Float64(0.8)\ - \n Projection: test1.a AS a, TestScalarUDF() AS r\ - \n Inner Join: test1.a = test2.a\ - \n TableScan: test1\ - \n TableScan: test2"; - assert_optimized_plan_eq(plan, expected) + assert_snapshot!(plan, + @r" + Projection: t.a, t.r + Filter: t.r > Float64(0.8) + SubqueryAlias: t + Projection: test1.a AS a, TestScalarUDF() AS r + Inner Join: test1.a = test2.a + TableScan: test1 + TableScan: test2 + ", + ); + assert_optimized_plan_equal!( + plan, + @r" + Projection: t.a, t.r + SubqueryAlias: t + Filter: r > Float64(0.8) + Projection: test1.a AS a, TestScalarUDF() AS r + Inner Join: test1.a = test2.a + TableScan: test1 + TableScan: test2 + " + ) } #[test] @@ -3724,15 +3949,21 @@ Projection: a, b .filter(expr.gt(lit(0.1)))? .build()?; - let expected_before = "Filter: TestScalarUDF() > Float64(0.1)\ - \n Projection: test.a, test.b\ - \n TableScan: test"; - assert_eq!(format!("{plan}"), expected_before); - - let expected_after = "Projection: test.a, test.b\ - \n Filter: TestScalarUDF() > Float64(0.1)\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected_after) + assert_snapshot!(plan, + @r" + Filter: TestScalarUDF() > Float64(0.1) + Projection: test.a, test.b + TableScan: test + ", + ); + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a, test.b + Filter: TestScalarUDF() > Float64(0.1) + TableScan: test + " + ) } #[test] @@ -3752,15 +3983,21 @@ Projection: a, b )? .build()?; - let expected_before = "Filter: TestScalarUDF() > Float64(0.1) AND t.a > Int32(5) AND t.b > Int32(10)\ - \n Projection: test.a, test.b\ - \n TableScan: test"; - assert_eq!(format!("{plan}"), expected_before); - - let expected_after = "Projection: test.a, test.b\ - \n Filter: TestScalarUDF() > Float64(0.1)\ - \n TableScan: test, full_filters=[t.a > Int32(5), t.b > Int32(10)]"; - assert_optimized_plan_eq(plan, expected_after) + assert_snapshot!(plan, + @r" + Filter: TestScalarUDF() > Float64(0.1) AND t.a > Int32(5) AND t.b > Int32(10) + Projection: test.a, test.b + TableScan: test + ", + ); + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a, test.b + Filter: TestScalarUDF() > Float64(0.1) + TableScan: test, full_filters=[t.a > Int32(5), t.b > Int32(10)] + " + ) } #[test] @@ -3783,15 +4020,21 @@ Projection: a, b )? .build()?; - let expected_before = "Filter: TestScalarUDF() > Float64(0.1) AND t.a > Int32(5) AND t.b > Int32(10)\ - \n Projection: a, b\ - \n TableScan: test"; - assert_eq!(format!("{plan}"), expected_before); - - let expected_after = "Projection: a, b\ - \n Filter: t.a > Int32(5) AND t.b > Int32(10) AND TestScalarUDF() > Float64(0.1)\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected_after) + assert_snapshot!(plan, + @r" + Filter: TestScalarUDF() > Float64(0.1) AND t.a > Int32(5) AND t.b > Int32(10) + Projection: a, b + TableScan: test + ", + ); + assert_optimized_plan_equal!( + plan, + @r" + Projection: a, b + Filter: t.a > Int32(5) AND t.b > Int32(10) AND TestScalarUDF() > Float64(0.1) + TableScan: test + " + ) } #[test] @@ -3864,12 +4107,19 @@ Projection: a, b let plan = LogicalPlanBuilder::from(node).filter(lit(false))?.build()?; // Check the original plan format (not part of the test assertions) - let expected_before = "Filter: Boolean(false)\ - \n TestUserNode"; - assert_eq!(format!("{plan}"), expected_before); - + assert_snapshot!(plan, + @r" + Filter: Boolean(false) + TestUserNode + ", + ); // Check that the filter is pushed down to the user-defined node - let expected_after = "Filter: Boolean(false)\n TestUserNode"; - assert_optimized_plan_eq(plan, expected_after) + assert_optimized_plan_equal!( + plan, + @r" + Filter: Boolean(false) + TestUserNode + " + ) } } diff --git a/datafusion/optimizer/src/push_down_limit.rs b/datafusion/optimizer/src/push_down_limit.rs index 1e9ef16bde67..ec042dd350ca 100644 --- a/datafusion/optimizer/src/push_down_limit.rs +++ b/datafusion/optimizer/src/push_down_limit.rs @@ -276,8 +276,10 @@ mod test { use std::vec; use super::*; + use crate::assert_optimized_plan_eq_snapshot; use crate::test::*; + use crate::OptimizerContext; use datafusion_common::DFSchemaRef; use datafusion_expr::{ col, exists, logical_plan::builder::LogicalPlanBuilder, Expr, Extension, @@ -285,8 +287,20 @@ mod test { }; use datafusion_functions_aggregate::expr_fn::max; - fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { - assert_optimized_plan_eq(Arc::new(PushDownLimit::new()), plan, expected) + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let optimizer_ctx = OptimizerContext::new().with_max_passes(1); + let rules: Vec> = vec![Arc::new(PushDownLimit::new())]; + assert_optimized_plan_eq_snapshot!( + optimizer_ctx, + rules, + $plan, + @ $expected, + ) + }}; } #[derive(Debug, PartialEq, Eq, Hash)] @@ -408,12 +422,15 @@ mod test { .limit(0, Some(1000))? .build()?; - let expected = "Limit: skip=0, fetch=1000\ - \n NoopPlan\ - \n Limit: skip=0, fetch=1000\ - \n TableScan: test, fetch=1000"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=0, fetch=1000 + NoopPlan + Limit: skip=0, fetch=1000 + TableScan: test, fetch=1000 + " + ) } #[test] @@ -430,12 +447,15 @@ mod test { .limit(10, Some(1000))? .build()?; - let expected = "Limit: skip=10, fetch=1000\ - \n NoopPlan\ - \n Limit: skip=0, fetch=1010\ - \n TableScan: test, fetch=1010"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=10, fetch=1000 + NoopPlan + Limit: skip=0, fetch=1010 + TableScan: test, fetch=1010 + " + ) } #[test] @@ -453,12 +473,15 @@ mod test { .limit(20, Some(500))? .build()?; - let expected = "Limit: skip=30, fetch=500\ - \n NoopPlan\ - \n Limit: skip=0, fetch=530\ - \n TableScan: test, fetch=530"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=30, fetch=500 + NoopPlan + Limit: skip=0, fetch=530 + TableScan: test, fetch=530 + " + ) } #[test] @@ -475,14 +498,17 @@ mod test { .limit(0, Some(1000))? .build()?; - let expected = "Limit: skip=0, fetch=1000\ - \n NoopPlan\ - \n Limit: skip=0, fetch=1000\ - \n TableScan: test, fetch=1000\ - \n Limit: skip=0, fetch=1000\ - \n TableScan: test, fetch=1000"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=0, fetch=1000 + NoopPlan + Limit: skip=0, fetch=1000 + TableScan: test, fetch=1000 + Limit: skip=0, fetch=1000 + TableScan: test, fetch=1000 + " + ) } #[test] @@ -499,11 +525,14 @@ mod test { .limit(0, Some(1000))? .build()?; - let expected = "Limit: skip=0, fetch=1000\ - \n NoLimitNoopPlan\ - \n TableScan: test"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=0, fetch=1000 + NoLimitNoopPlan + TableScan: test + " + ) } #[test] @@ -517,11 +546,14 @@ mod test { // Should push the limit down to table provider // When it has a select - let expected = "Projection: test.a\ - \n Limit: skip=0, fetch=1000\ - \n TableScan: test, fetch=1000"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a + Limit: skip=0, fetch=1000 + TableScan: test, fetch=1000 + " + ) } #[test] @@ -536,10 +568,13 @@ mod test { // Should push down the smallest limit // Towards table scan // This rule doesn't replace multiple limits - let expected = "Limit: skip=0, fetch=10\ - \n TableScan: test, fetch=10"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=0, fetch=10 + TableScan: test, fetch=10 + " + ) } #[test] @@ -552,11 +587,14 @@ mod test { .build()?; // Limit should *not* push down aggregate node - let expected = "Limit: skip=0, fetch=1000\ - \n Aggregate: groupBy=[[test.a]], aggr=[[max(test.b)]]\ - \n TableScan: test"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=0, fetch=1000 + Aggregate: groupBy=[[test.a]], aggr=[[max(test.b)]] + TableScan: test + " + ) } #[test] @@ -569,14 +607,17 @@ mod test { .build()?; // Limit should push down through union - let expected = "Limit: skip=0, fetch=1000\ - \n Union\ - \n Limit: skip=0, fetch=1000\ - \n TableScan: test, fetch=1000\ - \n Limit: skip=0, fetch=1000\ - \n TableScan: test, fetch=1000"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=0, fetch=1000 + Union + Limit: skip=0, fetch=1000 + TableScan: test, fetch=1000 + Limit: skip=0, fetch=1000 + TableScan: test, fetch=1000 + " + ) } #[test] @@ -589,11 +630,14 @@ mod test { .build()?; // Should push down limit to sort - let expected = "Limit: skip=0, fetch=10\ - \n Sort: test.a ASC NULLS LAST, fetch=10\ - \n TableScan: test"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=0, fetch=10 + Sort: test.a ASC NULLS LAST, fetch=10 + TableScan: test + " + ) } #[test] @@ -606,11 +650,14 @@ mod test { .build()?; // Should push down limit to sort - let expected = "Limit: skip=5, fetch=10\ - \n Sort: test.a ASC NULLS LAST, fetch=15\ - \n TableScan: test"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=5, fetch=10 + Sort: test.a ASC NULLS LAST, fetch=15 + TableScan: test + " + ) } #[test] @@ -624,12 +671,15 @@ mod test { .build()?; // Limit should use deeper LIMIT 1000, but Limit 10 shouldn't push down aggregation - let expected = "Limit: skip=0, fetch=10\ - \n Aggregate: groupBy=[[test.a]], aggr=[[max(test.b)]]\ - \n Limit: skip=0, fetch=1000\ - \n TableScan: test, fetch=1000"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=0, fetch=10 + Aggregate: groupBy=[[test.a]], aggr=[[max(test.b)]] + Limit: skip=0, fetch=1000 + TableScan: test, fetch=1000 + " + ) } #[test] @@ -641,10 +691,13 @@ mod test { // Should not push any limit down to table provider // When it has a select - let expected = "Limit: skip=10, fetch=None\ - \n TableScan: test"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=10, fetch=None + TableScan: test + " + ) } #[test] @@ -658,11 +711,14 @@ mod test { // Should push the limit down to table provider // When it has a select - let expected = "Projection: test.a\ - \n Limit: skip=10, fetch=1000\ - \n TableScan: test, fetch=1010"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a + Limit: skip=10, fetch=1000 + TableScan: test, fetch=1010 + " + ) } #[test] @@ -675,11 +731,14 @@ mod test { .limit(10, None)? .build()?; - let expected = "Projection: test.a\ - \n Limit: skip=10, fetch=990\ - \n TableScan: test, fetch=1000"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a + Limit: skip=10, fetch=990 + TableScan: test, fetch=1000 + " + ) } #[test] @@ -692,11 +751,14 @@ mod test { .limit(0, Some(1000))? .build()?; - let expected = "Projection: test.a\ - \n Limit: skip=10, fetch=1000\ - \n TableScan: test, fetch=1010"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a + Limit: skip=10, fetch=1000 + TableScan: test, fetch=1010 + " + ) } #[test] @@ -709,10 +771,13 @@ mod test { .limit(0, Some(10))? .build()?; - let expected = "Limit: skip=10, fetch=10\ - \n TableScan: test, fetch=20"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=10, fetch=10 + TableScan: test, fetch=20 + " + ) } #[test] @@ -725,11 +790,14 @@ mod test { .build()?; // Limit should *not* push down aggregate node - let expected = "Limit: skip=10, fetch=1000\ - \n Aggregate: groupBy=[[test.a]], aggr=[[max(test.b)]]\ - \n TableScan: test"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=10, fetch=1000 + Aggregate: groupBy=[[test.a]], aggr=[[max(test.b)]] + TableScan: test + " + ) } #[test] @@ -742,14 +810,17 @@ mod test { .build()?; // Limit should push down through union - let expected = "Limit: skip=10, fetch=1000\ - \n Union\ - \n Limit: skip=0, fetch=1010\ - \n TableScan: test, fetch=1010\ - \n Limit: skip=0, fetch=1010\ - \n TableScan: test, fetch=1010"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=10, fetch=1000 + Union + Limit: skip=0, fetch=1010 + TableScan: test, fetch=1010 + Limit: skip=0, fetch=1010 + TableScan: test, fetch=1010 + " + ) } #[test] @@ -768,12 +839,15 @@ mod test { .build()?; // Limit pushdown Not supported in Join - let expected = "Limit: skip=10, fetch=1000\ - \n Inner Join: test.a = test2.a\ - \n TableScan: test\ - \n TableScan: test2"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=10, fetch=1000 + Inner Join: test.a = test2.a + TableScan: test + TableScan: test2 + " + ) } #[test] @@ -792,12 +866,15 @@ mod test { .build()?; // Limit pushdown Not supported in Join - let expected = "Limit: skip=10, fetch=1000\ - \n Inner Join: test.a = test2.a\ - \n TableScan: test\ - \n TableScan: test2"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=10, fetch=1000 + Inner Join: test.a = test2.a + TableScan: test + TableScan: test2 + " + ) } #[test] @@ -817,16 +894,19 @@ mod test { .build()?; // Limit pushdown Not supported in sub_query - let expected = "Limit: skip=10, fetch=100\ - \n Filter: EXISTS ()\ - \n Subquery:\ - \n Filter: test1.a = test1.a\ - \n Projection: test1.a\ - \n TableScan: test1\ - \n Projection: test2.a\ - \n TableScan: test2"; - - assert_optimized_plan_equal(outer_query, expected) + assert_optimized_plan_equal!( + outer_query, + @r" + Limit: skip=10, fetch=100 + Filter: EXISTS () + Subquery: + Filter: test1.a = test1.a + Projection: test1.a + TableScan: test1 + Projection: test2.a + TableScan: test2 + " + ) } #[test] @@ -846,16 +926,19 @@ mod test { .build()?; // Limit pushdown Not supported in sub_query - let expected = "Limit: skip=10, fetch=100\ - \n Filter: EXISTS ()\ - \n Subquery:\ - \n Filter: test1.a = test1.a\ - \n Projection: test1.a\ - \n TableScan: test1\ - \n Projection: test2.a\ - \n TableScan: test2"; - - assert_optimized_plan_equal(outer_query, expected) + assert_optimized_plan_equal!( + outer_query, + @r" + Limit: skip=10, fetch=100 + Filter: EXISTS () + Subquery: + Filter: test1.a = test1.a + Projection: test1.a + TableScan: test1 + Projection: test2.a + TableScan: test2 + " + ) } #[test] @@ -874,13 +957,16 @@ mod test { .build()?; // Limit pushdown Not supported in Join - let expected = "Limit: skip=10, fetch=1000\ - \n Left Join: test.a = test2.a\ - \n Limit: skip=0, fetch=1010\ - \n TableScan: test, fetch=1010\ - \n TableScan: test2"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=10, fetch=1000 + Left Join: test.a = test2.a + Limit: skip=0, fetch=1010 + TableScan: test, fetch=1010 + TableScan: test2 + " + ) } #[test] @@ -899,13 +985,16 @@ mod test { .build()?; // Limit pushdown Not supported in Join - let expected = "Limit: skip=0, fetch=1000\ - \n Right Join: test.a = test2.a\ - \n TableScan: test\ - \n Limit: skip=0, fetch=1000\ - \n TableScan: test2, fetch=1000"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=0, fetch=1000 + Right Join: test.a = test2.a + TableScan: test + Limit: skip=0, fetch=1000 + TableScan: test2, fetch=1000 + " + ) } #[test] @@ -924,13 +1013,16 @@ mod test { .build()?; // Limit pushdown with offset supported in right outer join - let expected = "Limit: skip=10, fetch=1000\ - \n Right Join: test.a = test2.a\ - \n TableScan: test\ - \n Limit: skip=0, fetch=1010\ - \n TableScan: test2, fetch=1010"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=10, fetch=1000 + Right Join: test.a = test2.a + TableScan: test + Limit: skip=0, fetch=1010 + TableScan: test2, fetch=1010 + " + ) } #[test] @@ -943,14 +1035,17 @@ mod test { .limit(0, Some(1000))? .build()?; - let expected = "Limit: skip=0, fetch=1000\ - \n Cross Join: \ - \n Limit: skip=0, fetch=1000\ - \n TableScan: test, fetch=1000\ - \n Limit: skip=0, fetch=1000\ - \n TableScan: test2, fetch=1000"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=0, fetch=1000 + Cross Join: + Limit: skip=0, fetch=1000 + TableScan: test, fetch=1000 + Limit: skip=0, fetch=1000 + TableScan: test2, fetch=1000 + " + ) } #[test] @@ -963,14 +1058,17 @@ mod test { .limit(1000, Some(1000))? .build()?; - let expected = "Limit: skip=1000, fetch=1000\ - \n Cross Join: \ - \n Limit: skip=0, fetch=2000\ - \n TableScan: test, fetch=2000\ - \n Limit: skip=0, fetch=2000\ - \n TableScan: test2, fetch=2000"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=1000, fetch=1000 + Cross Join: + Limit: skip=0, fetch=2000 + TableScan: test, fetch=2000 + Limit: skip=0, fetch=2000 + TableScan: test2, fetch=2000 + " + ) } #[test] @@ -982,10 +1080,13 @@ mod test { .limit(1000, None)? .build()?; - let expected = "Limit: skip=1000, fetch=0\ - \n TableScan: test, fetch=0"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=1000, fetch=0 + TableScan: test, fetch=0 + " + ) } #[test] @@ -997,10 +1098,13 @@ mod test { .limit(1000, None)? .build()?; - let expected = "Limit: skip=1000, fetch=0\ - \n TableScan: test, fetch=0"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=1000, fetch=0 + TableScan: test, fetch=0 + " + ) } #[test] @@ -1013,10 +1117,13 @@ mod test { .limit(1000, None)? .build()?; - let expected = "SubqueryAlias: a\ - \n Limit: skip=1000, fetch=0\ - \n TableScan: test, fetch=0"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + SubqueryAlias: a + Limit: skip=1000, fetch=0 + TableScan: test, fetch=0 + " + ) } } diff --git a/datafusion/optimizer/src/replace_distinct_aggregate.rs b/datafusion/optimizer/src/replace_distinct_aggregate.rs index 48b2828faf45..2383787fa0e8 100644 --- a/datafusion/optimizer/src/replace_distinct_aggregate.rs +++ b/datafusion/optimizer/src/replace_distinct_aggregate.rs @@ -186,21 +186,29 @@ impl OptimizerRule for ReplaceDistinctWithAggregate { mod tests { use std::sync::Arc; + use crate::assert_optimized_plan_eq_snapshot; use crate::replace_distinct_aggregate::ReplaceDistinctWithAggregate; use crate::test::*; + use crate::OptimizerContext; use datafusion_common::Result; - use datafusion_expr::{ - col, logical_plan::builder::LogicalPlanBuilder, Expr, LogicalPlan, - }; + use datafusion_expr::{col, logical_plan::builder::LogicalPlanBuilder, Expr}; use datafusion_functions_aggregate::sum::sum; - fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { - assert_optimized_plan_eq( - Arc::new(ReplaceDistinctWithAggregate::new()), - plan.clone(), - expected, - ) + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let optimizer_ctx = OptimizerContext::new().with_max_passes(1); + let rules: Vec> = vec![Arc::new(ReplaceDistinctWithAggregate::new())]; + assert_optimized_plan_eq_snapshot!( + optimizer_ctx, + rules, + $plan, + @ $expected, + ) + }}; } #[test] @@ -212,8 +220,11 @@ mod tests { .distinct()? .build()?; - let expected = "Projection: test.c\n Aggregate: groupBy=[[test.c]], aggr=[[]]\n TableScan: test"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal!(plan, @r" + Projection: test.c + Aggregate: groupBy=[[test.c]], aggr=[[]] + TableScan: test + ") } #[test] @@ -225,9 +236,11 @@ mod tests { .distinct()? .build()?; - let expected = - "Projection: test.a, test.b\n Aggregate: groupBy=[[test.a, test.b]], aggr=[[]]\n TableScan: test"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal!(plan, @r" + Projection: test.a, test.b + Aggregate: groupBy=[[test.a, test.b]], aggr=[[]] + TableScan: test + ") } #[test] @@ -238,8 +251,11 @@ mod tests { .distinct()? .build()?; - let expected = "Aggregate: groupBy=[[test.a, test.b]], aggr=[[]]\n Projection: test.a, test.b\n TableScan: test"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal!(plan, @r" + Aggregate: groupBy=[[test.a, test.b]], aggr=[[]] + Projection: test.a, test.b + TableScan: test + ") } #[test] @@ -251,8 +267,11 @@ mod tests { .distinct()? .build()?; - let expected = - "Aggregate: groupBy=[[test.a, test.b]], aggr=[[]]\n Projection: test.a, test.b\n Aggregate: groupBy=[[test.a, test.b, test.c]], aggr=[[sum(test.c)]]\n TableScan: test"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal!(plan, @r" + Aggregate: groupBy=[[test.a, test.b]], aggr=[[]] + Projection: test.a, test.b + Aggregate: groupBy=[[test.a, test.b, test.c]], aggr=[[sum(test.c)]] + TableScan: test + ") } } diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index 5c89bc29a596..2f9a2f6bb9ed 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -335,7 +335,7 @@ fn build_join( .join_on( sub_query_alias, JoinType::Left, - vec![Expr::Literal(ScalarValue::Boolean(Some(true)))], + vec![Expr::Literal(ScalarValue::Boolean(Some(true)), None)], )? .build()? } @@ -365,7 +365,7 @@ fn build_join( ), ( Box::new(Expr::Not(Box::new(filter.clone()))), - Box::new(Expr::Literal(ScalarValue::Null)), + Box::new(Expr::Literal(ScalarValue::Null, None)), ), ], else_expr: Some(Box::new(Expr::Column(Column::new_unqualified( @@ -407,9 +407,24 @@ mod tests { use arrow::datatypes::DataType; use datafusion_expr::test::function_stub::sum; + use crate::assert_optimized_plan_eq_display_indent_snapshot; use datafusion_expr::{col, lit, out_ref_col, scalar_subquery, Between}; use datafusion_functions_aggregate::min_max::{max, min}; + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let rule: Arc = Arc::new(ScalarSubqueryToJoin::new()); + assert_optimized_plan_eq_display_indent_snapshot!( + rule, + $plan, + @ $expected, + ) + }}; + } + /// Test multiple correlated subqueries #[test] fn multiple_subqueries() -> Result<()> { @@ -433,25 +448,24 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: Int32(1) < __scalar_sq_1.max(orders.o_custkey) AND Int32(1) < __scalar_sq_2.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]\ - \n Left Join: Filter: __scalar_sq_2.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]\ - \n Left Join: Filter: __scalar_sq_1.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]\ - \n Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]\ - \n Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n SubqueryAlias: __scalar_sq_2 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]\ - \n Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]\ - \n Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_multi_rules_optimized_plan_eq_display_indent( - vec![Arc::new(ScalarSubqueryToJoin::new())], + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + Filter: Int32(1) < __scalar_sq_1.max(orders.o_custkey) AND Int32(1) < __scalar_sq_2.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] + Left Join: Filter: __scalar_sq_2.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] + Left Join: Filter: __scalar_sq_1.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean] + Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean] + Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + SubqueryAlias: __scalar_sq_2 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean] + Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean] + Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test recursive correlated subqueries @@ -488,26 +502,25 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_acctbal < __scalar_sq_1.sum(orders.o_totalprice) [c_custkey:Int64, c_name:Utf8, sum(orders.o_totalprice):Float64;N, o_custkey:Int64;N, __always_true:Boolean;N]\ - \n Left Join: Filter: __scalar_sq_1.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8, sum(orders.o_totalprice):Float64;N, o_custkey:Int64;N, __always_true:Boolean;N]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [sum(orders.o_totalprice):Float64;N, o_custkey:Int64, __always_true:Boolean]\ - \n Projection: sum(orders.o_totalprice), orders.o_custkey, __always_true [sum(orders.o_totalprice):Float64;N, o_custkey:Int64, __always_true:Boolean]\ - \n Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[sum(orders.o_totalprice)]] [o_custkey:Int64, __always_true:Boolean, sum(orders.o_totalprice):Float64;N]\ - \n Filter: orders.o_totalprice < __scalar_sq_2.sum(lineitem.l_extendedprice) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N, sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64;N, __always_true:Boolean;N]\ - \n Left Join: Filter: __scalar_sq_2.l_orderkey = orders.o_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N, sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64;N, __always_true:Boolean;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n SubqueryAlias: __scalar_sq_2 [sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64, __always_true:Boolean]\ - \n Projection: sum(lineitem.l_extendedprice), lineitem.l_orderkey, __always_true [sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64, __always_true:Boolean]\ - \n Aggregate: groupBy=[[lineitem.l_orderkey, Boolean(true) AS __always_true]], aggr=[[sum(lineitem.l_extendedprice)]] [l_orderkey:Int64, __always_true:Boolean, sum(lineitem.l_extendedprice):Float64;N]\ - \n TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]"; - assert_multi_rules_optimized_plan_eq_display_indent( - vec![Arc::new(ScalarSubqueryToJoin::new())], + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + Filter: customer.c_acctbal < __scalar_sq_1.sum(orders.o_totalprice) [c_custkey:Int64, c_name:Utf8, sum(orders.o_totalprice):Float64;N, o_custkey:Int64;N, __always_true:Boolean;N] + Left Join: Filter: __scalar_sq_1.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8, sum(orders.o_totalprice):Float64;N, o_custkey:Int64;N, __always_true:Boolean;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __scalar_sq_1 [sum(orders.o_totalprice):Float64;N, o_custkey:Int64, __always_true:Boolean] + Projection: sum(orders.o_totalprice), orders.o_custkey, __always_true [sum(orders.o_totalprice):Float64;N, o_custkey:Int64, __always_true:Boolean] + Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[sum(orders.o_totalprice)]] [o_custkey:Int64, __always_true:Boolean, sum(orders.o_totalprice):Float64;N] + Filter: orders.o_totalprice < __scalar_sq_2.sum(lineitem.l_extendedprice) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N, sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64;N, __always_true:Boolean;N] + Left Join: Filter: __scalar_sq_2.l_orderkey = orders.o_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N, sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64;N, __always_true:Boolean;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + SubqueryAlias: __scalar_sq_2 [sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64, __always_true:Boolean] + Projection: sum(lineitem.l_extendedprice), lineitem.l_orderkey, __always_true [sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64, __always_true:Boolean] + Aggregate: groupBy=[[lineitem.l_orderkey, Boolean(true) AS __always_true]], aggr=[[sum(lineitem.l_extendedprice)]] [l_orderkey:Int64, __always_true:Boolean, sum(lineitem.l_extendedprice):Float64;N] + TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64] + " + ) } /// Test for correlated scalar subquery filter with additional subquery filters @@ -530,22 +543,20 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]\ - \n Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]\ - \n Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]\ - \n Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]\ - \n Filter: orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_multi_rules_optimized_plan_eq_display_indent( - vec![Arc::new(ScalarSubqueryToJoin::new())], + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] + Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean] + Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean] + Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N] + Filter: orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test for correlated scalar subquery with no columns in schema @@ -568,20 +579,19 @@ mod tests { .build()?; // it will optimize, but fail for the same reason the unoptimized query would - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]\ - \n Left Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N]\ - \n Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_multi_rules_optimized_plan_eq_display_indent( - vec![Arc::new(ScalarSubqueryToJoin::new())], + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N] + Left Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N] + Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N] + Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test for scalar subquery with both columns in schema @@ -600,22 +610,20 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]\ - \n Left Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N]\ - \n Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]\ - \n Filter: orders.o_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_multi_rules_optimized_plan_eq_display_indent( - vec![Arc::new(ScalarSubqueryToJoin::new())], + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N] + Left Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N] + Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N] + Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N] + Filter: orders.o_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test for correlated scalar subquery not equal @@ -638,21 +646,19 @@ mod tests { .build()?; // Unsupported predicate, subquery should not be decorrelated - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey = () [c_custkey:Int64, c_name:Utf8]\ - \n Subquery: [max(orders.o_custkey):Int64;N]\ - \n Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]\ - \n Filter: outer_ref(customer.c_custkey) != orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]"; - - assert_multi_rules_optimized_plan_eq_display_indent( - vec![Arc::new(ScalarSubqueryToJoin::new())], + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + Filter: customer.c_custkey = () [c_custkey:Int64, c_name:Utf8] + Subquery: [max(orders.o_custkey):Int64;N] + Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N] + Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N] + Filter: outer_ref(customer.c_custkey) != orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + " + ) } /// Test for correlated scalar subquery less than @@ -675,21 +681,19 @@ mod tests { .build()?; // Unsupported predicate, subquery should not be decorrelated - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey = () [c_custkey:Int64, c_name:Utf8]\ - \n Subquery: [max(orders.o_custkey):Int64;N]\ - \n Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]\ - \n Filter: outer_ref(customer.c_custkey) < orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]"; - - assert_multi_rules_optimized_plan_eq_display_indent( - vec![Arc::new(ScalarSubqueryToJoin::new())], + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + Filter: customer.c_custkey = () [c_custkey:Int64, c_name:Utf8] + Subquery: [max(orders.o_custkey):Int64;N] + Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N] + Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N] + Filter: outer_ref(customer.c_custkey) < orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + " + ) } /// Test for correlated scalar subquery filter with subquery disjunction @@ -713,21 +717,19 @@ mod tests { .build()?; // Unsupported predicate, subquery should not be decorrelated - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey = () [c_custkey:Int64, c_name:Utf8]\ - \n Subquery: [max(orders.o_custkey):Int64;N]\ - \n Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]\ - \n Filter: outer_ref(customer.c_custkey) = orders.o_custkey OR orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]"; - - assert_multi_rules_optimized_plan_eq_display_indent( - vec![Arc::new(ScalarSubqueryToJoin::new())], + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + Filter: customer.c_custkey = () [c_custkey:Int64, c_name:Utf8] + Subquery: [max(orders.o_custkey):Int64;N] + Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N] + Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N] + Filter: outer_ref(customer.c_custkey) = orders.o_custkey OR orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + " + ) } /// Test for correlated scalar without projection @@ -768,21 +770,19 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) + Int32(1) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]\ - \n Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64, __always_true:Boolean]\ - \n Projection: max(orders.o_custkey) + Int32(1), orders.o_custkey, __always_true [max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64, __always_true:Boolean]\ - \n Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_multi_rules_optimized_plan_eq_display_indent( - vec![Arc::new(ScalarSubqueryToJoin::new())], + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) + Int32(1) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] + Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64, __always_true:Boolean] + Projection: max(orders.o_custkey) + Int32(1), orders.o_custkey, __always_true [max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64, __always_true:Boolean] + Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test for correlated scalar subquery with non-strong project @@ -812,20 +812,18 @@ mod tests { .project(vec![col("customer.c_custkey"), scalar_subquery(sq)])? .build()?; - let expected = "Projection: customer.c_custkey, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN CASE WHEN CAST(NULL AS Boolean) THEN Utf8(\"a\") ELSE Utf8(\"b\") END ELSE __scalar_sq_1.CASE WHEN max(orders.o_totalprice) THEN Utf8(\"a\") ELSE Utf8(\"b\") END END AS CASE WHEN max(orders.o_totalprice) THEN Utf8(\"a\") ELSE Utf8(\"b\") END [c_custkey:Int64, CASE WHEN max(orders.o_totalprice) THEN Utf8(\"a\") ELSE Utf8(\"b\") END:Utf8;N]\ - \n Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, CASE WHEN max(orders.o_totalprice) THEN Utf8(\"a\") ELSE Utf8(\"b\") END:Utf8;N, o_custkey:Int64;N, __always_true:Boolean;N]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [CASE WHEN max(orders.o_totalprice) THEN Utf8(\"a\") ELSE Utf8(\"b\") END:Utf8, o_custkey:Int64, __always_true:Boolean]\ - \n Projection: CASE WHEN max(orders.o_totalprice) THEN Utf8(\"a\") ELSE Utf8(\"b\") END, orders.o_custkey, __always_true [CASE WHEN max(orders.o_totalprice) THEN Utf8(\"a\") ELSE Utf8(\"b\") END:Utf8, o_custkey:Int64, __always_true:Boolean]\ - \n Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_totalprice)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_totalprice):Float64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_multi_rules_optimized_plan_eq_display_indent( - vec![Arc::new(ScalarSubqueryToJoin::new())], + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r#" + Projection: customer.c_custkey, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN CASE WHEN CAST(NULL AS Boolean) THEN Utf8("a") ELSE Utf8("b") END ELSE __scalar_sq_1.CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END END AS CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END [c_custkey:Int64, CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END:Utf8;N] + Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END:Utf8;N, o_custkey:Int64;N, __always_true:Boolean;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __scalar_sq_1 [CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END:Utf8, o_custkey:Int64, __always_true:Boolean] + Projection: CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END, orders.o_custkey, __always_true [CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END:Utf8, o_custkey:Int64, __always_true:Boolean] + Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_totalprice)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_totalprice):Float64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + "# + ) } /// Test for correlated scalar subquery multiple projected columns @@ -875,21 +873,19 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey >= __scalar_sq_1.max(orders.o_custkey) AND customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]\ - \n Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]\ - \n Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]\ - \n Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_multi_rules_optimized_plan_eq_display_indent( - vec![Arc::new(ScalarSubqueryToJoin::new())], + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + Filter: customer.c_custkey >= __scalar_sq_1.max(orders.o_custkey) AND customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] + Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean] + Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean] + Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } #[test] @@ -914,21 +910,19 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) AND customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]\ - \n Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]\ - \n Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]\ - \n Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_multi_rules_optimized_plan_eq_display_indent( - vec![Arc::new(ScalarSubqueryToJoin::new())], + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) AND customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] + Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean] + Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean] + Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test for correlated scalar subquery filter with disjunctions @@ -954,21 +948,19 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) OR customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]\ - \n Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]\ - \n Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]\ - \n Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_multi_rules_optimized_plan_eq_display_indent( - vec![Arc::new(ScalarSubqueryToJoin::new())], + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) OR customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] + Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean] + Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean] + Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test for correlated scalar subquery filter @@ -987,21 +979,19 @@ mod tests { .project(vec![col("test.c")])? .build()?; - let expected = "Projection: test.c [c:UInt32]\ - \n Filter: test.c < __scalar_sq_1.min(sq.c) [a:UInt32, b:UInt32, c:UInt32, min(sq.c):UInt32;N, a:UInt32;N, __always_true:Boolean;N]\ - \n Left Join: Filter: test.a = __scalar_sq_1.a [a:UInt32, b:UInt32, c:UInt32, min(sq.c):UInt32;N, a:UInt32;N, __always_true:Boolean;N]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __scalar_sq_1 [min(sq.c):UInt32;N, a:UInt32, __always_true:Boolean]\ - \n Projection: min(sq.c), sq.a, __always_true [min(sq.c):UInt32;N, a:UInt32, __always_true:Boolean]\ - \n Aggregate: groupBy=[[sq.a, Boolean(true) AS __always_true]], aggr=[[min(sq.c)]] [a:UInt32, __always_true:Boolean, min(sq.c):UInt32;N]\ - \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - - assert_multi_rules_optimized_plan_eq_display_indent( - vec![Arc::new(ScalarSubqueryToJoin::new())], + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: test.c [c:UInt32] + Filter: test.c < __scalar_sq_1.min(sq.c) [a:UInt32, b:UInt32, c:UInt32, min(sq.c):UInt32;N, a:UInt32;N, __always_true:Boolean;N] + Left Join: Filter: test.a = __scalar_sq_1.a [a:UInt32, b:UInt32, c:UInt32, min(sq.c):UInt32;N, a:UInt32;N, __always_true:Boolean;N] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __scalar_sq_1 [min(sq.c):UInt32;N, a:UInt32, __always_true:Boolean] + Projection: min(sq.c), sq.a, __always_true [min(sq.c):UInt32;N, a:UInt32, __always_true:Boolean] + Aggregate: groupBy=[[sq.a, Boolean(true) AS __always_true]], aggr=[[min(sq.c)]] [a:UInt32, __always_true:Boolean, min(sq.c):UInt32;N] + TableScan: sq [a:UInt32, b:UInt32, c:UInt32] + " + ) } /// Test for non-correlated scalar subquery with no filters @@ -1019,21 +1009,19 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey < __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]\ - \n Left Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N]\ - \n Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_multi_rules_optimized_plan_eq_display_indent( - vec![Arc::new(ScalarSubqueryToJoin::new())], + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + Filter: customer.c_custkey < __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N] + Left Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N] + Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N] + Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } #[test] @@ -1050,21 +1038,19 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]\ - \n Left Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N]\ - \n Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_multi_rules_optimized_plan_eq_display_indent( - vec![Arc::new(ScalarSubqueryToJoin::new())], + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N] + Left Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N] + Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N] + Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } #[test] @@ -1102,26 +1088,24 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey BETWEEN __scalar_sq_1.min(orders.o_custkey) AND __scalar_sq_2.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]\ - \n Left Join: Filter: customer.c_custkey = __scalar_sq_2.o_custkey [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]\ - \n Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [min(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]\ - \n Projection: min(orders.o_custkey), orders.o_custkey, __always_true [min(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]\ - \n Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[min(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, min(orders.o_custkey):Int64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n SubqueryAlias: __scalar_sq_2 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]\ - \n Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]\ - \n Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_multi_rules_optimized_plan_eq_display_indent( - vec![Arc::new(ScalarSubqueryToJoin::new())], + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + Filter: customer.c_custkey BETWEEN __scalar_sq_1.min(orders.o_custkey) AND __scalar_sq_2.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] + Left Join: Filter: customer.c_custkey = __scalar_sq_2.o_custkey [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] + Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __scalar_sq_1 [min(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean] + Projection: min(orders.o_custkey), orders.o_custkey, __always_true [min(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean] + Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[min(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, min(orders.o_custkey):Int64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + SubqueryAlias: __scalar_sq_2 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean] + Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean] + Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } #[test] @@ -1151,25 +1135,23 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey BETWEEN __scalar_sq_1.min(orders.o_custkey) AND __scalar_sq_2.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, max(orders.o_custkey):Int64;N]\ - \n Left Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, max(orders.o_custkey):Int64;N]\ - \n Left Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [min(orders.o_custkey):Int64;N]\ - \n Projection: min(orders.o_custkey) [min(orders.o_custkey):Int64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[min(orders.o_custkey)]] [min(orders.o_custkey):Int64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n SubqueryAlias: __scalar_sq_2 [max(orders.o_custkey):Int64;N]\ - \n Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_multi_rules_optimized_plan_eq_display_indent( - vec![Arc::new(ScalarSubqueryToJoin::new())], + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + Filter: customer.c_custkey BETWEEN __scalar_sq_1.min(orders.o_custkey) AND __scalar_sq_2.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, max(orders.o_custkey):Int64;N] + Left Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, max(orders.o_custkey):Int64;N] + Left Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __scalar_sq_1 [min(orders.o_custkey):Int64;N] + Projection: min(orders.o_custkey) [min(orders.o_custkey):Int64;N] + Aggregate: groupBy=[[]], aggr=[[min(orders.o_custkey)]] [min(orders.o_custkey):Int64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + SubqueryAlias: __scalar_sq_2 [max(orders.o_custkey):Int64;N] + Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N] + Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } } diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 5c13ddb17639..e91aea3305be 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -18,7 +18,7 @@ //! Expression simplification API use std::borrow::Cow; -use std::collections::HashSet; +use std::collections::{BTreeMap, HashSet}; use std::ops::Not; use arrow::{ @@ -34,11 +34,11 @@ use datafusion_common::{ use datafusion_common::{internal_err, DFSchema, DataFusionError, Result, ScalarValue}; use datafusion_expr::{ and, binary::BinaryTypeCoercer, lit, or, BinaryExpr, Case, ColumnarValue, Expr, Like, - Operator, Volatility, WindowFunctionDefinition, + Operator, Volatility, }; use datafusion_expr::{expr::ScalarFunction, interval_arithmetic::NullableInterval}; use datafusion_expr::{ - expr::{InList, InSubquery, WindowFunction}, + expr::{InList, InSubquery}, utils::{iter_conjunction, iter_conjunction_owned}, }; use datafusion_expr::{simplify::ExprSimplifyResult, Cast, TryCast}; @@ -188,7 +188,7 @@ impl ExprSimplifier { /// assert_eq!(expr, b_lt_2); /// ``` pub fn simplify(&self, expr: Expr) -> Result { - Ok(self.simplify_with_cycle_count(expr)?.0) + Ok(self.simplify_with_cycle_count_transformed(expr)?.0.data) } /// Like [Self::simplify], simplifies this [`Expr`] as much as possible, evaluating @@ -198,7 +198,34 @@ impl ExprSimplifier { /// /// See [Self::simplify] for details and usage examples. /// + #[deprecated( + since = "48.0.0", + note = "Use `simplify_with_cycle_count_transformed` instead" + )] + #[allow(unused_mut)] pub fn simplify_with_cycle_count(&self, mut expr: Expr) -> Result<(Expr, u32)> { + let (transformed, cycle_count) = + self.simplify_with_cycle_count_transformed(expr)?; + Ok((transformed.data, cycle_count)) + } + + /// Like [Self::simplify], simplifies this [`Expr`] as much as possible, evaluating + /// constants and applying algebraic simplifications. Additionally returns a `u32` + /// representing the number of simplification cycles performed, which can be useful for testing + /// optimizations. + /// + /// # Returns + /// + /// A tuple containing: + /// - The simplified expression wrapped in a `Transformed` indicating if changes were made + /// - The number of simplification cycles that were performed + /// + /// See [Self::simplify] for details and usage examples. + /// + pub fn simplify_with_cycle_count_transformed( + &self, + mut expr: Expr, + ) -> Result<(Transformed, u32)> { let mut simplifier = Simplifier::new(&self.info); let mut const_evaluator = ConstEvaluator::try_new(self.info.execution_props())?; let mut shorten_in_list_simplifier = ShortenInListSimplifier::new(); @@ -212,6 +239,7 @@ impl ExprSimplifier { // simplifications can enable new constant evaluation // see `Self::with_max_cycles` let mut num_cycles = 0; + let mut has_transformed = false; loop { let Transformed { data, transformed, .. @@ -221,13 +249,18 @@ impl ExprSimplifier { .transform_data(|expr| expr.rewrite(&mut guarantee_rewriter))?; expr = data; num_cycles += 1; + // Track if any transformation occurred + has_transformed = has_transformed || transformed; if !transformed || num_cycles >= self.max_simplifier_cycles { break; } } // shorten inlist should be started after other inlist rules are applied expr = expr.rewrite(&mut shorten_in_list_simplifier).data()?; - Ok((expr, num_cycles)) + Ok(( + Transformed::new_transformed(expr, has_transformed), + num_cycles, + )) } /// Apply type coercion to an [`Expr`] so that it can be @@ -392,15 +425,15 @@ impl ExprSimplifier { /// let expr = col("a").is_not_null(); /// /// // When using default maximum cycles, 2 cycles will be performed. - /// let (simplified_expr, count) = simplifier.simplify_with_cycle_count(expr.clone()).unwrap(); - /// assert_eq!(simplified_expr, lit(true)); + /// let (simplified_expr, count) = simplifier.simplify_with_cycle_count_transformed(expr.clone()).unwrap(); + /// assert_eq!(simplified_expr.data, lit(true)); /// // 2 cycles were executed, but only 1 was needed /// assert_eq!(count, 2); /// /// // Only 1 simplification pass is necessary here, so we can set the maximum cycles to 1. - /// let (simplified_expr, count) = simplifier.with_max_cycles(1).simplify_with_cycle_count(expr.clone()).unwrap(); + /// let (simplified_expr, count) = simplifier.with_max_cycles(1).simplify_with_cycle_count_transformed(expr.clone()).unwrap(); /// // Expression has been rewritten to: (c = a AND b = 1) - /// assert_eq!(simplified_expr, lit(true)); + /// assert_eq!(simplified_expr.data, lit(true)); /// // Only 1 cycle was executed /// assert_eq!(count, 1); /// @@ -444,7 +477,7 @@ impl TreeNodeRewriter for Canonicalizer { }))) } // - (Expr::Literal(_a), Expr::Column(_b), Some(swapped_op)) => { + (Expr::Literal(_a, _), Expr::Column(_b), Some(swapped_op)) => { Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr { left: right, op: swapped_op, @@ -487,11 +520,12 @@ struct ConstEvaluator<'a> { #[allow(dead_code)] /// The simplify result of ConstEvaluator +#[allow(clippy::large_enum_variant)] enum ConstSimplifyResult { // Expr was simplified and contains the new expression - Simplified(ScalarValue), + Simplified(ScalarValue, Option>), // Expr was not simplified and original value is returned - NotSimplified(ScalarValue), + NotSimplified(ScalarValue, Option>), // Evaluation encountered an error, contains the original expression SimplifyRuntimeError(DataFusionError, Expr), } @@ -533,11 +567,11 @@ impl TreeNodeRewriter for ConstEvaluator<'_> { // any error is countered during simplification, return the original // so that normal evaluation can occur Some(true) => match self.evaluate_to_scalar(expr) { - ConstSimplifyResult::Simplified(s) => { - Ok(Transformed::yes(Expr::Literal(s))) + ConstSimplifyResult::Simplified(s, m) => { + Ok(Transformed::yes(Expr::Literal(s, m))) } - ConstSimplifyResult::NotSimplified(s) => { - Ok(Transformed::no(Expr::Literal(s))) + ConstSimplifyResult::NotSimplified(s, m) => { + Ok(Transformed::no(Expr::Literal(s, m))) } ConstSimplifyResult::SimplifyRuntimeError(_, expr) => { Ok(Transformed::yes(expr)) @@ -606,7 +640,7 @@ impl<'a> ConstEvaluator<'a> { Expr::ScalarFunction(ScalarFunction { func, .. }) => { Self::volatility_ok(func.signature().volatility) } - Expr::Literal(_) + Expr::Literal(_, _) | Expr::Alias(..) | Expr::Unnest(_) | Expr::BinaryExpr { .. } @@ -632,8 +666,8 @@ impl<'a> ConstEvaluator<'a> { /// Internal helper to evaluates an Expr pub(crate) fn evaluate_to_scalar(&mut self, expr: Expr) -> ConstSimplifyResult { - if let Expr::Literal(s) = expr { - return ConstSimplifyResult::NotSimplified(s); + if let Expr::Literal(s, m) = expr { + return ConstSimplifyResult::NotSimplified(s, m); } let phys_expr = @@ -641,6 +675,18 @@ impl<'a> ConstEvaluator<'a> { Ok(e) => e, Err(err) => return ConstSimplifyResult::SimplifyRuntimeError(err, expr), }; + let metadata = phys_expr + .return_field(self.input_batch.schema_ref()) + .ok() + .and_then(|f| { + let m = f.metadata(); + match m.is_empty() { + true => None, + false => { + Some(m.iter().map(|(k, v)| (k.clone(), v.clone())).collect()) + } + } + }); let col_val = match phys_expr.evaluate(&self.input_batch) { Ok(v) => v, Err(err) => return ConstSimplifyResult::SimplifyRuntimeError(err, expr), @@ -653,13 +699,15 @@ impl<'a> ConstEvaluator<'a> { expr, ) } else if as_list_array(&a).is_ok() { - ConstSimplifyResult::Simplified(ScalarValue::List( - a.as_list::().to_owned().into(), - )) + ConstSimplifyResult::Simplified( + ScalarValue::List(a.as_list::().to_owned().into()), + metadata, + ) } else if as_large_list_array(&a).is_ok() { - ConstSimplifyResult::Simplified(ScalarValue::LargeList( - a.as_list::().to_owned().into(), - )) + ConstSimplifyResult::Simplified( + ScalarValue::LargeList(a.as_list::().to_owned().into()), + metadata, + ) } else { // Non-ListArray match ScalarValue::try_from_array(&a, 0) { @@ -671,7 +719,7 @@ impl<'a> ConstEvaluator<'a> { expr, ) } else { - ConstSimplifyResult::Simplified(s) + ConstSimplifyResult::Simplified(s, metadata) } } Err(err) => ConstSimplifyResult::SimplifyRuntimeError(err, expr), @@ -689,7 +737,7 @@ impl<'a> ConstEvaluator<'a> { expr, ) } else { - ConstSimplifyResult::Simplified(s) + ConstSimplifyResult::Simplified(s, metadata) } } } @@ -1104,9 +1152,10 @@ impl TreeNodeRewriter for Simplifier<'_, S> { && !info.get_data_type(&left)?.is_floating() && is_one(&right) => { - Transformed::yes(Expr::Literal(ScalarValue::new_zero( - &info.get_data_type(&left)?, - )?)) + Transformed::yes(Expr::Literal( + ScalarValue::new_zero(&info.get_data_type(&left)?)?, + None, + )) } // @@ -1147,9 +1196,10 @@ impl TreeNodeRewriter for Simplifier<'_, S> { op: BitwiseAnd, right, }) if is_negative_of(&left, &right) && !info.nullable(&right)? => { - Transformed::yes(Expr::Literal(ScalarValue::new_zero( - &info.get_data_type(&left)?, - )?)) + Transformed::yes(Expr::Literal( + ScalarValue::new_zero(&info.get_data_type(&left)?)?, + None, + )) } // A & !A -> 0 (if A not nullable) @@ -1158,9 +1208,10 @@ impl TreeNodeRewriter for Simplifier<'_, S> { op: BitwiseAnd, right, }) if is_negative_of(&right, &left) && !info.nullable(&left)? => { - Transformed::yes(Expr::Literal(ScalarValue::new_zero( - &info.get_data_type(&left)?, - )?)) + Transformed::yes(Expr::Literal( + ScalarValue::new_zero(&info.get_data_type(&left)?)?, + None, + )) } // (..A..) & A --> (..A..) @@ -1233,9 +1284,10 @@ impl TreeNodeRewriter for Simplifier<'_, S> { op: BitwiseOr, right, }) if is_negative_of(&left, &right) && !info.nullable(&right)? => { - Transformed::yes(Expr::Literal(ScalarValue::new_negative_one( - &info.get_data_type(&left)?, - )?)) + Transformed::yes(Expr::Literal( + ScalarValue::new_negative_one(&info.get_data_type(&left)?)?, + None, + )) } // A | !A -> -1 (if A not nullable) @@ -1244,9 +1296,10 @@ impl TreeNodeRewriter for Simplifier<'_, S> { op: BitwiseOr, right, }) if is_negative_of(&right, &left) && !info.nullable(&left)? => { - Transformed::yes(Expr::Literal(ScalarValue::new_negative_one( - &info.get_data_type(&left)?, - )?)) + Transformed::yes(Expr::Literal( + ScalarValue::new_negative_one(&info.get_data_type(&left)?)?, + None, + )) } // (..A..) | A --> (..A..) @@ -1319,9 +1372,10 @@ impl TreeNodeRewriter for Simplifier<'_, S> { op: BitwiseXor, right, }) if is_negative_of(&left, &right) && !info.nullable(&right)? => { - Transformed::yes(Expr::Literal(ScalarValue::new_negative_one( - &info.get_data_type(&left)?, - )?)) + Transformed::yes(Expr::Literal( + ScalarValue::new_negative_one(&info.get_data_type(&left)?)?, + None, + )) } // A ^ !A -> -1 (if A not nullable) @@ -1330,9 +1384,10 @@ impl TreeNodeRewriter for Simplifier<'_, S> { op: BitwiseXor, right, }) if is_negative_of(&right, &left) && !info.nullable(&left)? => { - Transformed::yes(Expr::Literal(ScalarValue::new_negative_one( - &info.get_data_type(&left)?, - )?)) + Transformed::yes(Expr::Literal( + ScalarValue::new_negative_one(&info.get_data_type(&left)?)?, + None, + )) } // (..A..) ^ A --> (the expression without A, if number of A is odd, otherwise one A) @@ -1343,7 +1398,10 @@ impl TreeNodeRewriter for Simplifier<'_, S> { }) if expr_contains(&left, &right, BitwiseXor) => { let expr = delete_xor_in_complex_expr(&left, &right, false); Transformed::yes(if expr == *right { - Expr::Literal(ScalarValue::new_zero(&info.get_data_type(&right)?)?) + Expr::Literal( + ScalarValue::new_zero(&info.get_data_type(&right)?)?, + None, + ) } else { expr }) @@ -1357,7 +1415,10 @@ impl TreeNodeRewriter for Simplifier<'_, S> { }) if expr_contains(&right, &left, BitwiseXor) => { let expr = delete_xor_in_complex_expr(&right, &left, true); Transformed::yes(if expr == *left { - Expr::Literal(ScalarValue::new_zero(&info.get_data_type(&left)?)?) + Expr::Literal( + ScalarValue::new_zero(&info.get_data_type(&left)?)?, + None, + ) } else { expr }) @@ -1489,12 +1550,9 @@ impl TreeNodeRewriter for Simplifier<'_, S> { (_, expr) => Transformed::no(expr), }, - Expr::WindowFunction(WindowFunction { - fun: WindowFunctionDefinition::WindowUDF(ref udwf), - .. - }) => match (udwf.simplify(), expr) { + Expr::WindowFunction(ref window_fun) => match (window_fun.simplify(), expr) { (Some(simplify_function), Expr::WindowFunction(wf)) => { - Transformed::yes(simplify_function(wf, info)?) + Transformed::yes(simplify_function(*wf, info)?) } (_, expr) => Transformed::no(expr), }, @@ -1611,7 +1669,7 @@ impl TreeNodeRewriter for Simplifier<'_, S> { expr, list, negated, - }) if list.is_empty() && *expr != Expr::Literal(ScalarValue::Null) => { + }) if list.is_empty() && *expr != Expr::Literal(ScalarValue::Null, None) => { Transformed::yes(lit(negated)) } @@ -1794,7 +1852,7 @@ impl TreeNodeRewriter for Simplifier<'_, S> { info, &left, op, &right, ) && op.supports_propagation() => { - unwrap_cast_in_comparison_for_binary(info, left, right, op)? + unwrap_cast_in_comparison_for_binary(info, *left, *right, op)? } // literal op try_cast/cast(expr as data_type) // --> @@ -1807,8 +1865,8 @@ impl TreeNodeRewriter for Simplifier<'_, S> { { unwrap_cast_in_comparison_for_binary( info, - right, - left, + *right, + *left, op.swap().unwrap(), )? } @@ -1837,7 +1895,7 @@ impl TreeNodeRewriter for Simplifier<'_, S> { .into_iter() .map(|right| { match right { - Expr::Literal(right_lit_value) => { + Expr::Literal(right_lit_value, _) => { // if the right_lit_value can be casted to the type of internal_left_expr // we need to unwrap the cast for cast/try_cast expr, and add cast to the literal let Some(value) = try_cast_literal_to_type(&right_lit_value, &expr_type) else { @@ -1871,18 +1929,18 @@ impl TreeNodeRewriter for Simplifier<'_, S> { fn as_string_scalar(expr: &Expr) -> Option<(DataType, &Option)> { match expr { - Expr::Literal(ScalarValue::Utf8(s)) => Some((DataType::Utf8, s)), - Expr::Literal(ScalarValue::LargeUtf8(s)) => Some((DataType::LargeUtf8, s)), - Expr::Literal(ScalarValue::Utf8View(s)) => Some((DataType::Utf8View, s)), + Expr::Literal(ScalarValue::Utf8(s), _) => Some((DataType::Utf8, s)), + Expr::Literal(ScalarValue::LargeUtf8(s), _) => Some((DataType::LargeUtf8, s)), + Expr::Literal(ScalarValue::Utf8View(s), _) => Some((DataType::Utf8View, s)), _ => None, } } fn to_string_scalar(data_type: DataType, value: Option) -> Expr { match data_type { - DataType::Utf8 => Expr::Literal(ScalarValue::Utf8(value)), - DataType::LargeUtf8 => Expr::Literal(ScalarValue::LargeUtf8(value)), - DataType::Utf8View => Expr::Literal(ScalarValue::Utf8View(value)), + DataType::Utf8 => Expr::Literal(ScalarValue::Utf8(value), None), + DataType::LargeUtf8 => Expr::Literal(ScalarValue::LargeUtf8(value), None), + DataType::Utf8View => Expr::Literal(ScalarValue::Utf8View(value), None), _ => unreachable!(), } } @@ -1928,12 +1986,12 @@ fn as_inlist(expr: &Expr) -> Option> { Expr::InList(inlist) => Some(Cow::Borrowed(inlist)), Expr::BinaryExpr(BinaryExpr { left, op, right }) if *op == Operator::Eq => { match (left.as_ref(), right.as_ref()) { - (Expr::Column(_), Expr::Literal(_)) => Some(Cow::Owned(InList { + (Expr::Column(_), Expr::Literal(_, _)) => Some(Cow::Owned(InList { expr: left.clone(), list: vec![*right.clone()], negated: false, })), - (Expr::Literal(_), Expr::Column(_)) => Some(Cow::Owned(InList { + (Expr::Literal(_, _), Expr::Column(_)) => Some(Cow::Owned(InList { expr: right.clone(), list: vec![*left.clone()], negated: false, @@ -1953,12 +2011,12 @@ fn to_inlist(expr: Expr) -> Option { op: Operator::Eq, right, }) => match (left.as_ref(), right.as_ref()) { - (Expr::Column(_), Expr::Literal(_)) => Some(InList { + (Expr::Column(_), Expr::Literal(_, _)) => Some(InList { expr: left, list: vec![*right], negated: false, }), - (Expr::Literal(_), Expr::Column(_)) => Some(InList { + (Expr::Literal(_, _), Expr::Column(_)) => Some(InList { expr: right, list: vec![*left], negated: false, @@ -2110,10 +2168,13 @@ fn simplify_null_div_other_case( #[cfg(test)] mod tests { + use super::*; use crate::simplify_expressions::SimplifyContext; use crate::test::test_table_scan_with_name; + use arrow::datatypes::FieldRef; use datafusion_common::{assert_contains, DFSchemaRef, ToDFSchema}; use datafusion_expr::{ + expr::WindowFunction, function::{ AccumulatorArgs, AggregateFunctionSimplification, WindowFunctionSimplification, @@ -2129,8 +2190,6 @@ mod tests { sync::Arc, }; - use super::*; - // ------------------------------ // --- ExprSimplifier tests ----- // ------------------------------ @@ -2376,7 +2435,7 @@ mod tests { #[test] fn test_simplify_multiply_by_null() { - let null = Expr::Literal(ScalarValue::Null); + let null = Expr::Literal(ScalarValue::Null, None); // A * null --> null { let expr = col("c2") * null.clone(); @@ -3311,6 +3370,15 @@ mod tests { simplifier.simplify(expr) } + fn coerce(expr: Expr) -> Expr { + let schema = expr_test_schema(); + let execution_props = ExecutionProps::new(); + let simplifier = ExprSimplifier::new( + SimplifyContext::new(&execution_props).with_schema(Arc::clone(&schema)), + ); + simplifier.coerce(expr, schema.as_ref()).unwrap() + } + fn simplify(expr: Expr) -> Expr { try_simplify(expr).unwrap() } @@ -3321,7 +3389,8 @@ mod tests { let simplifier = ExprSimplifier::new( SimplifyContext::new(&execution_props).with_schema(schema), ); - simplifier.simplify_with_cycle_count(expr) + let (expr, count) = simplifier.simplify_with_cycle_count_transformed(expr)?; + Ok((expr.data, count)) } fn simplify_with_cycle_count(expr: Expr) -> (Expr, u32) { @@ -3353,6 +3422,7 @@ mod tests { Field::new("c2_non_null", DataType::Boolean, false), Field::new("c3_non_null", DataType::Int64, false), Field::new("c4_non_null", DataType::UInt32, false), + Field::new("c5", DataType::FixedSizeBinary(3), true), ] .into(), HashMap::new(), @@ -4344,8 +4414,7 @@ mod tests { let udwf = WindowFunctionDefinition::WindowUDF( WindowUDF::new_from_impl(SimplifyMockUdwf::new_with_simplify()).into(), ); - let window_function_expr = - Expr::WindowFunction(WindowFunction::new(udwf, vec![])); + let window_function_expr = Expr::from(WindowFunction::new(udwf, vec![])); let expected = col("result_column"); assert_eq!(simplify(window_function_expr), expected); @@ -4353,8 +4422,7 @@ mod tests { let udwf = WindowFunctionDefinition::WindowUDF( WindowUDF::new_from_impl(SimplifyMockUdwf::new_without_simplify()).into(), ); - let window_function_expr = - Expr::WindowFunction(WindowFunction::new(udwf, vec![])); + let window_function_expr = Expr::from(WindowFunction::new(udwf, vec![])); let expected = window_function_expr.clone(); assert_eq!(simplify(window_function_expr), expected); @@ -4406,7 +4474,7 @@ mod tests { unimplemented!("not needed for tests") } - fn field(&self, _field_args: WindowUDFFieldArgs) -> Result { + fn field(&self, _field_args: WindowUDFFieldArgs) -> Result { unimplemented!("not needed for tests") } } @@ -4481,6 +4549,34 @@ mod tests { } } + #[test] + fn simplify_fixed_size_binary_eq_lit() { + let bytes = [1u8, 2, 3].as_slice(); + + // The expression starts simple. + let expr = col("c5").eq(lit(bytes)); + + // The type coercer introduces a cast. + let coerced = coerce(expr.clone()); + let schema = expr_test_schema(); + assert_eq!( + coerced, + col("c5") + .cast_to(&DataType::Binary, schema.as_ref()) + .unwrap() + .eq(lit(bytes)) + ); + + // The simplifier removes the cast. + assert_eq!( + simplify(coerced), + col("c5").eq(Expr::Literal( + ScalarValue::FixedSizeBinary(3, Some(bytes.to_vec()),), + None + )) + ); + } + fn if_not_null(expr: Expr, then: bool) -> Expr { Expr::Case(Case { expr: Some(expr.is_not_null().into()), diff --git a/datafusion/optimizer/src/simplify_expressions/guarantees.rs b/datafusion/optimizer/src/simplify_expressions/guarantees.rs index 4700ab97b5f3..bbb023cfbad9 100644 --- a/datafusion/optimizer/src/simplify_expressions/guarantees.rs +++ b/datafusion/optimizer/src/simplify_expressions/guarantees.rs @@ -84,7 +84,7 @@ impl TreeNodeRewriter for GuaranteeRewriter<'_> { low, high, }) => { - if let (Some(interval), Expr::Literal(low), Expr::Literal(high)) = ( + if let (Some(interval), Expr::Literal(low, _), Expr::Literal(high, _)) = ( self.guarantees.get(inner.as_ref()), low.as_ref(), high.as_ref(), @@ -115,7 +115,7 @@ impl TreeNodeRewriter for GuaranteeRewriter<'_> { .get(left.as_ref()) .map(|interval| Cow::Borrowed(*interval)) .or_else(|| { - if let Expr::Literal(value) = left.as_ref() { + if let Expr::Literal(value, _) = left.as_ref() { Some(Cow::Owned(value.clone().into())) } else { None @@ -126,7 +126,7 @@ impl TreeNodeRewriter for GuaranteeRewriter<'_> { .get(right.as_ref()) .map(|interval| Cow::Borrowed(*interval)) .or_else(|| { - if let Expr::Literal(value) = right.as_ref() { + if let Expr::Literal(value, _) = right.as_ref() { Some(Cow::Owned(value.clone().into())) } else { None @@ -168,7 +168,7 @@ impl TreeNodeRewriter for GuaranteeRewriter<'_> { let new_list: Vec = list .iter() .filter_map(|expr| { - if let Expr::Literal(item) = expr { + if let Expr::Literal(item, _) = expr { match interval .contains(NullableInterval::from(item.clone())) { @@ -244,8 +244,7 @@ mod tests { let expected = lit(ScalarValue::from(expected_value.clone())); assert_eq!( output, expected, - "{} simplified to {}, but expected {}", - expr, output, expected + "{expr} simplified to {output}, but expected {expected}" ); } } @@ -255,8 +254,7 @@ mod tests { let output = expr.clone().rewrite(rewriter).data().unwrap(); assert_eq!( &output, expr, - "{} was simplified to {}, but expected it to be unchanged", - expr, output + "{expr} was simplified to {output}, but expected it to be unchanged" ); } } @@ -417,7 +415,7 @@ mod tests { let mut rewriter = GuaranteeRewriter::new(guarantees.iter()); let output = col("x").rewrite(&mut rewriter).data().unwrap(); - assert_eq!(output, Expr::Literal(scalar.clone())); + assert_eq!(output, Expr::Literal(scalar.clone(), None)); } } diff --git a/datafusion/optimizer/src/simplify_expressions/regex.rs b/datafusion/optimizer/src/simplify_expressions/regex.rs index 0b47cdee212f..82c5ea3d8d82 100644 --- a/datafusion/optimizer/src/simplify_expressions/regex.rs +++ b/datafusion/optimizer/src/simplify_expressions/regex.rs @@ -46,7 +46,7 @@ pub fn simplify_regex_expr( ) -> Result { let mode = OperatorMode::new(&op); - if let Expr::Literal(ScalarValue::Utf8(Some(pattern))) = right.as_ref() { + if let Expr::Literal(ScalarValue::Utf8(Some(pattern)), _) = right.as_ref() { // Handle the special case for ".*" pattern if pattern == ANY_CHAR_REGEX_PATTERN { let new_expr = if mode.not { @@ -121,7 +121,7 @@ impl OperatorMode { let like = Like { negated: self.not, expr, - pattern: Box::new(Expr::Literal(ScalarValue::from(pattern))), + pattern: Box::new(Expr::Literal(ScalarValue::from(pattern), None)), escape_char: None, case_insensitive: self.i, }; @@ -255,9 +255,9 @@ fn partial_anchored_literal_to_like(v: &[Hir]) -> Option { }; if match_begin { - Some(format!("{}%", lit)) + Some(format!("{lit}%")) } else { - Some(format!("%{}", lit)) + Some(format!("%{lit}")) } } @@ -331,7 +331,7 @@ fn lower_simple(mode: &OperatorMode, left: &Expr, hir: &Hir) -> Option { } HirKind::Concat(inner) => { if let Some(pattern) = partial_anchored_literal_to_like(inner) - .or(collect_concat_to_like_string(inner)) + .or_else(|| collect_concat_to_like_string(inner)) { return Some(mode.expr(Box::new(left.clone()), pattern)); } diff --git a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs index e33869ca2b63..ccf90893e17e 100644 --- a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs +++ b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs @@ -123,10 +123,11 @@ impl SimplifyExpressions { let name_preserver = NamePreserver::new(&plan); let mut rewrite_expr = |expr: Expr| { let name = name_preserver.save(&expr); - let expr = simplifier.simplify(expr)?; - // TODO it would be nice to have a way to know if the expression was simplified - // or not. For now conservatively return Transformed::yes - Ok(Transformed::yes(name.restore(expr))) + let expr = simplifier.simplify_with_cycle_count_transformed(expr)?.0; + Ok(Transformed::new_transformed( + name.restore(expr.data), + expr.transformed, + )) }; plan.map_expressions(|expr| { @@ -154,12 +155,12 @@ mod tests { use arrow::datatypes::{DataType, Field, Schema}; use chrono::{DateTime, Utc}; - use crate::optimizer::Optimizer; use datafusion_expr::logical_plan::builder::table_scan_with_filters; use datafusion_expr::logical_plan::table_scan; use datafusion_expr::*; use datafusion_functions_aggregate::expr_fn::{max, min}; + use crate::assert_optimized_plan_eq_snapshot; use crate::test::{assert_fields_eq, test_table_scan_with_name}; use crate::OptimizerContext; @@ -179,15 +180,20 @@ mod tests { .expect("building plan") } - fn assert_optimized_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> { - // Use Optimizer to do plan traversal - fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} - let optimizer = Optimizer::with_rules(vec![Arc::new(SimplifyExpressions::new())]); - let optimized_plan = - optimizer.optimize(plan, &OptimizerContext::new(), observe)?; - let formatted_plan = format!("{optimized_plan}"); - assert_eq!(formatted_plan, expected); - Ok(()) + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let rules: Vec> = vec![Arc::new(SimplifyExpressions::new())]; + let optimizer_ctx = OptimizerContext::new(); + assert_optimized_plan_eq_snapshot!( + optimizer_ctx, + rules, + $plan, + @ $expected, + ) + }}; } #[test] @@ -210,9 +216,10 @@ mod tests { assert_eq!(1, table_scan.schema().fields().len()); assert_fields_eq(&table_scan, vec!["a"]); - let expected = "TableScan: test projection=[a], full_filters=[Boolean(true)]"; - - assert_optimized_plan_eq(table_scan, expected) + assert_optimized_plan_equal!( + table_scan, + @ r"TableScan: test projection=[a], full_filters=[Boolean(true)]" + ) } #[test] @@ -223,12 +230,13 @@ mod tests { .filter(and(col("b").gt(lit(1)), col("b").gt(lit(1))))? .build()?; - assert_optimized_plan_eq( + assert_optimized_plan_equal!( plan, - "\ - Filter: test.b > Int32(1)\ - \n Projection: test.a\ - \n TableScan: test", + @ r" + Filter: test.b > Int32(1) + Projection: test.a + TableScan: test + " ) } @@ -240,12 +248,13 @@ mod tests { .filter(and(col("b").gt(lit(1)), col("b").gt(lit(1))))? .build()?; - assert_optimized_plan_eq( + assert_optimized_plan_equal!( plan, - "\ - Filter: test.b > Int32(1)\ - \n Projection: test.a\ - \n TableScan: test", + @ r" + Filter: test.b > Int32(1) + Projection: test.a + TableScan: test + " ) } @@ -257,12 +266,13 @@ mod tests { .filter(or(col("b").gt(lit(1)), col("b").gt(lit(1))))? .build()?; - assert_optimized_plan_eq( + assert_optimized_plan_equal!( plan, - "\ - Filter: test.b > Int32(1)\ - \n Projection: test.a\ - \n TableScan: test", + @ r" + Filter: test.b > Int32(1) + Projection: test.a + TableScan: test + " ) } @@ -278,12 +288,13 @@ mod tests { ))? .build()?; - assert_optimized_plan_eq( + assert_optimized_plan_equal!( plan, - "\ - Filter: test.a > Int32(5) AND test.b < Int32(6)\ - \n Projection: test.a, test.b\ - \n TableScan: test", + @ r" + Filter: test.a > Int32(5) AND test.b < Int32(6) + Projection: test.a, test.b + TableScan: test + " ) } @@ -296,13 +307,15 @@ mod tests { .project(vec![col("a")])? .build()?; - let expected = "\ - Projection: test.a\ - \n Filter: NOT test.c\ - \n Filter: test.b\ - \n TableScan: test"; - - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: test.a + Filter: NOT test.c + Filter: test.b + TableScan: test + " + ) } #[test] @@ -315,14 +328,16 @@ mod tests { .project(vec![col("a")])? .build()?; - let expected = "\ - Projection: test.a\ - \n Limit: skip=0, fetch=1\ - \n Filter: test.c\ - \n Filter: NOT test.b\ - \n TableScan: test"; - - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: test.a + Limit: skip=0, fetch=1 + Filter: test.c + Filter: NOT test.b + TableScan: test + " + ) } #[test] @@ -333,12 +348,14 @@ mod tests { .project(vec![col("a")])? .build()?; - let expected = "\ - Projection: test.a\ - \n Filter: NOT test.b AND test.c\ - \n TableScan: test"; - - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: test.a + Filter: NOT test.b AND test.c + TableScan: test + " + ) } #[test] @@ -349,12 +366,14 @@ mod tests { .project(vec![col("a")])? .build()?; - let expected = "\ - Projection: test.a\ - \n Filter: NOT test.b OR NOT test.c\ - \n TableScan: test"; - - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: test.a + Filter: NOT test.b OR NOT test.c + TableScan: test + " + ) } #[test] @@ -365,12 +384,14 @@ mod tests { .project(vec![col("a")])? .build()?; - let expected = "\ - Projection: test.a\ - \n Filter: test.b\ - \n TableScan: test"; - - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: test.a + Filter: test.b + TableScan: test + " + ) } #[test] @@ -380,11 +401,13 @@ mod tests { .project(vec![col("a"), col("d"), col("b").eq(lit(false))])? .build()?; - let expected = "\ - Projection: test.a, test.d, NOT test.b AS test.b = Boolean(false)\ - \n TableScan: test"; - - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: test.a, test.d, NOT test.b AS test.b = Boolean(false) + TableScan: test + " + ) } #[test] @@ -398,12 +421,14 @@ mod tests { )? .build()?; - let expected = "\ - Aggregate: groupBy=[[test.a, test.c]], aggr=[[max(test.b) AS max(test.b = Boolean(true)), min(test.b)]]\ - \n Projection: test.a, test.c, test.b\ - \n TableScan: test"; - - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Aggregate: groupBy=[[test.a, test.c]], aggr=[[max(test.b) AS max(test.b = Boolean(true)), min(test.b)]] + Projection: test.a, test.c, test.b + TableScan: test + " + ) } #[test] @@ -421,10 +446,10 @@ mod tests { let values = vec![vec![expr1, expr2]]; let plan = LogicalPlanBuilder::values(values)?.build()?; - let expected = "\ - Values: (Int32(3) AS Int32(1) + Int32(2), Int32(1) AS Int32(2) - Int32(1))"; - - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ "Values: (Int32(3) AS Int32(1) + Int32(2), Int32(1) AS Int32(2) - Int32(1))" + ) } fn get_optimized_plan_formatted( @@ -481,10 +506,14 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .filter(col("d").gt(lit(10)).not())? .build()?; - let expected = "Filter: test.d <= Int32(10)\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: test.d <= Int32(10) + TableScan: test + " + ) } #[test] @@ -494,10 +523,14 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .filter(col("d").gt(lit(10)).and(col("d").lt(lit(100))).not())? .build()?; - let expected = "Filter: test.d <= Int32(10) OR test.d >= Int32(100)\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: test.d <= Int32(10) OR test.d >= Int32(100) + TableScan: test + " + ) } #[test] @@ -507,10 +540,14 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .filter(col("d").gt(lit(10)).or(col("d").lt(lit(100))).not())? .build()?; - let expected = "Filter: test.d <= Int32(10) AND test.d >= Int32(100)\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: test.d <= Int32(10) AND test.d >= Int32(100) + TableScan: test + " + ) } #[test] @@ -520,10 +557,14 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .filter(col("d").gt(lit(10)).not().not())? .build()?; - let expected = "Filter: test.d > Int32(10)\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: test.d > Int32(10) + TableScan: test + " + ) } #[test] @@ -533,10 +574,14 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .filter(col("e").is_null().not())? .build()?; - let expected = "Filter: test.e IS NOT NULL\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: test.e IS NOT NULL + TableScan: test + " + ) } #[test] @@ -546,10 +591,14 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .filter(col("e").is_not_null().not())? .build()?; - let expected = "Filter: test.e IS NULL\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: test.e IS NULL + TableScan: test + " + ) } #[test] @@ -559,11 +608,14 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .filter(col("d").in_list(vec![lit(1), lit(2), lit(3)], false).not())? .build()?; - let expected = - "Filter: test.d != Int32(1) AND test.d != Int32(2) AND test.d != Int32(3)\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: test.d != Int32(1) AND test.d != Int32(2) AND test.d != Int32(3) + TableScan: test + " + ) } #[test] @@ -573,11 +625,14 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .filter(col("d").in_list(vec![lit(1), lit(2), lit(3)], true).not())? .build()?; - let expected = - "Filter: test.d = Int32(1) OR test.d = Int32(2) OR test.d = Int32(3)\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: test.d = Int32(1) OR test.d = Int32(2) OR test.d = Int32(3) + TableScan: test + " + ) } #[test] @@ -588,10 +643,14 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .filter(qual.not())? .build()?; - let expected = "Filter: test.d < Int32(1) OR test.d > Int32(10)\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: test.d < Int32(1) OR test.d > Int32(10) + TableScan: test + " + ) } #[test] @@ -602,10 +661,14 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .filter(qual.not())? .build()?; - let expected = "Filter: test.d >= Int32(1) AND test.d <= Int32(10)\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: test.d >= Int32(1) AND test.d <= Int32(10) + TableScan: test + " + ) } #[test] @@ -622,10 +685,14 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .filter(col("a").like(col("b")).not())? .build()?; - let expected = "Filter: test.a NOT LIKE test.b\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: test.a NOT LIKE test.b + TableScan: test + " + ) } #[test] @@ -642,10 +709,14 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .filter(col("a").not_like(col("b")).not())? .build()?; - let expected = "Filter: test.a LIKE test.b\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: test.a LIKE test.b + TableScan: test + " + ) } #[test] @@ -662,10 +733,14 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .filter(col("a").ilike(col("b")).not())? .build()?; - let expected = "Filter: test.a NOT ILIKE test.b\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: test.a NOT ILIKE test.b + TableScan: test + " + ) } #[test] @@ -675,10 +750,14 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .filter(binary_expr(col("d"), Operator::IsDistinctFrom, lit(10)).not())? .build()?; - let expected = "Filter: test.d IS NOT DISTINCT FROM Int32(10)\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: test.d IS NOT DISTINCT FROM Int32(10) + TableScan: test + " + ) } #[test] @@ -688,10 +767,14 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .filter(binary_expr(col("d"), Operator::IsNotDistinctFrom, lit(10)).not())? .build()?; - let expected = "Filter: test.d IS DISTINCT FROM Int32(10)\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: test.d IS DISTINCT FROM Int32(10) + TableScan: test + " + ) } #[test] @@ -713,11 +796,14 @@ mod tests { // before simplify: t1.a + CAST(Int64(1), UInt32) = t2.a + CAST(Int64(2), UInt32) // after simplify: t1.a + UInt32(1) = t2.a + UInt32(2) AS t1.a + Int64(1) = t2.a + Int64(2) - let expected = "Inner Join: t1.a + UInt32(1) = t2.a + UInt32(2)\ - \n TableScan: t1\ - \n TableScan: t2"; - - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Inner Join: t1.a + UInt32(1) = t2.a + UInt32(2) + TableScan: t1 + TableScan: t2 + " + ) } #[test] @@ -727,10 +813,14 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .filter(col("d").is_not_null())? .build()?; - let expected = "Filter: Boolean(true)\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: Boolean(true) + TableScan: test + " + ) } #[test] @@ -740,10 +830,14 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .filter(col("d").is_null())? .build()?; - let expected = "Filter: Boolean(false)\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: Boolean(false) + TableScan: test + " + ) } #[test] @@ -760,10 +854,13 @@ mod tests { )? .build()?; - let expected = "Aggregate: groupBy=[[GROUPING SETS ((Int32(43) AS age, test.a), (Boolean(false) AS cond), (test.d AS e, Int32(3) AS Int32(1) + Int32(2)))]], aggr=[[]]\ - \n TableScan: test"; - - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Aggregate: groupBy=[[GROUPING SETS ((Int32(43) AS age, test.a), (Boolean(false) AS cond), (test.d AS e, Int32(3) AS Int32(1) + Int32(2)))]], aggr=[[]] + TableScan: test + " + ) } #[test] @@ -778,19 +875,27 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan.clone()) .filter(binary_expr(col("a"), Operator::RegexMatch, lit(".*")))? .build()?; - let expected = "Filter: test.a IS NOT NULL\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected)?; + assert_optimized_plan_equal!( + plan, + @ r" + Filter: test.a IS NOT NULL + TableScan: test + " + )?; // Test `!= ".*"` transforms to checking if the column is empty let plan = LogicalPlanBuilder::from(table_scan.clone()) .filter(binary_expr(col("a"), Operator::RegexNotMatch, lit(".*")))? .build()?; - let expected = "Filter: test.a = Utf8(\"\")\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected)?; + assert_optimized_plan_equal!( + plan, + @ r#" + Filter: test.a = Utf8("") + TableScan: test + "# + )?; // Test case-insensitive versions @@ -798,18 +903,26 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan.clone()) .filter(binary_expr(col("b"), Operator::RegexIMatch, lit(".*")))? .build()?; - let expected = "Filter: Boolean(true)\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected)?; + assert_optimized_plan_equal!( + plan, + @ r" + Filter: Boolean(true) + TableScan: test + " + )?; // Test `!~ ".*"` (case-insensitive) transforms to checking if the column is empty let plan = LogicalPlanBuilder::from(table_scan.clone()) .filter(binary_expr(col("a"), Operator::RegexNotIMatch, lit(".*")))? .build()?; - let expected = "Filter: test.a = Utf8(\"\")\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r#" + Filter: test.a = Utf8("") + TableScan: test + "# + ) } } diff --git a/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs b/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs index be71a8cd19b0..7c8ff8305e84 100644 --- a/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs +++ b/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs @@ -69,14 +69,14 @@ use datafusion_expr::{simplify::SimplifyInfo, Cast, Expr, Operator, TryCast}; pub(super) fn unwrap_cast_in_comparison_for_binary( info: &S, - cast_expr: Box, - literal: Box, + cast_expr: Expr, + literal: Expr, op: Operator, ) -> Result> { - match (*cast_expr, *literal) { + match (cast_expr, literal) { ( Expr::TryCast(TryCast { expr, .. }) | Expr::Cast(Cast { expr, .. }), - Expr::Literal(lit_value), + Expr::Literal(lit_value, _), ) => { let Ok(expr_type) = info.get_data_type(&expr) else { return internal_err!("Can't get the data type of the expr {:?}", &expr); @@ -126,7 +126,7 @@ pub(super) fn is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary< | Expr::Cast(Cast { expr: left_expr, .. }), - Expr::Literal(lit_val), + Expr::Literal(lit_val, _), ) => { let Ok(expr_type) = info.get_data_type(left_expr) else { return false; @@ -183,7 +183,7 @@ pub(super) fn is_cast_expr_and_support_unwrap_cast_in_comparison_for_inlist< } match right { - Expr::Literal(lit_val) + Expr::Literal(lit_val, _) if try_cast_literal_to_type(lit_val, &expr_type).is_some() => {} _ => return false, } @@ -197,6 +197,7 @@ fn is_supported_type(data_type: &DataType) -> bool { is_supported_numeric_type(data_type) || is_supported_string_type(data_type) || is_supported_dictionary_type(data_type) + || is_supported_binary_type(data_type) } /// Returns true if unwrap_cast_in_comparison support this numeric type @@ -230,6 +231,10 @@ fn is_supported_dictionary_type(data_type: &DataType) -> bool { DataType::Dictionary(_, inner) if is_supported_type(inner)) } +fn is_supported_binary_type(data_type: &DataType) -> bool { + matches!(data_type, DataType::Binary | DataType::FixedSizeBinary(_)) +} + ///// Tries to move a cast from an expression (such as column) to the literal other side of a comparison operator./ /// /// Specifically, rewrites @@ -292,6 +297,7 @@ pub(super) fn try_cast_literal_to_type( try_cast_numeric_literal(lit_value, target_type) .or_else(|| try_cast_string_literal(lit_value, target_type)) .or_else(|| try_cast_dictionary(lit_value, target_type)) + .or_else(|| try_cast_binary(lit_value, target_type)) } /// Convert a numeric value from one numeric data type to another @@ -501,6 +507,20 @@ fn cast_between_timestamp(from: &DataType, to: &DataType, value: i128) -> Option } } +fn try_cast_binary( + lit_value: &ScalarValue, + target_type: &DataType, +) -> Option { + match (lit_value, target_type) { + (ScalarValue::Binary(Some(v)), DataType::FixedSizeBinary(n)) + if v.len() == *n as usize => + { + Some(ScalarValue::FixedSizeBinary(*n, Some(v.clone()))) + } + _ => None, + } +} + #[cfg(test)] mod tests { use super::*; @@ -1450,4 +1470,13 @@ mod tests { ) } } + + #[test] + fn try_cast_to_fixed_size_binary() { + expect_cast( + ScalarValue::Binary(Some(vec![1, 2, 3])), + DataType::FixedSizeBinary(3), + ExpectedCast::Value(ScalarValue::FixedSizeBinary(3, Some(vec![1, 2, 3]))), + ) + } } diff --git a/datafusion/optimizer/src/simplify_expressions/utils.rs b/datafusion/optimizer/src/simplify_expressions/utils.rs index cf182175e48e..4df0e125eb18 100644 --- a/datafusion/optimizer/src/simplify_expressions/utils.rs +++ b/datafusion/optimizer/src/simplify_expressions/utils.rs @@ -139,34 +139,34 @@ pub fn delete_xor_in_complex_expr(expr: &Expr, needle: &Expr, is_left: bool) -> pub fn is_zero(s: &Expr) -> bool { match s { - Expr::Literal(ScalarValue::Int8(Some(0))) - | Expr::Literal(ScalarValue::Int16(Some(0))) - | Expr::Literal(ScalarValue::Int32(Some(0))) - | Expr::Literal(ScalarValue::Int64(Some(0))) - | Expr::Literal(ScalarValue::UInt8(Some(0))) - | Expr::Literal(ScalarValue::UInt16(Some(0))) - | Expr::Literal(ScalarValue::UInt32(Some(0))) - | Expr::Literal(ScalarValue::UInt64(Some(0))) => true, - Expr::Literal(ScalarValue::Float32(Some(v))) if *v == 0. => true, - Expr::Literal(ScalarValue::Float64(Some(v))) if *v == 0. => true, - Expr::Literal(ScalarValue::Decimal128(Some(v), _p, _s)) if *v == 0 => true, + Expr::Literal(ScalarValue::Int8(Some(0)), _) + | Expr::Literal(ScalarValue::Int16(Some(0)), _) + | Expr::Literal(ScalarValue::Int32(Some(0)), _) + | Expr::Literal(ScalarValue::Int64(Some(0)), _) + | Expr::Literal(ScalarValue::UInt8(Some(0)), _) + | Expr::Literal(ScalarValue::UInt16(Some(0)), _) + | Expr::Literal(ScalarValue::UInt32(Some(0)), _) + | Expr::Literal(ScalarValue::UInt64(Some(0)), _) => true, + Expr::Literal(ScalarValue::Float32(Some(v)), _) if *v == 0. => true, + Expr::Literal(ScalarValue::Float64(Some(v)), _) if *v == 0. => true, + Expr::Literal(ScalarValue::Decimal128(Some(v), _p, _s), _) if *v == 0 => true, _ => false, } } pub fn is_one(s: &Expr) -> bool { match s { - Expr::Literal(ScalarValue::Int8(Some(1))) - | Expr::Literal(ScalarValue::Int16(Some(1))) - | Expr::Literal(ScalarValue::Int32(Some(1))) - | Expr::Literal(ScalarValue::Int64(Some(1))) - | Expr::Literal(ScalarValue::UInt8(Some(1))) - | Expr::Literal(ScalarValue::UInt16(Some(1))) - | Expr::Literal(ScalarValue::UInt32(Some(1))) - | Expr::Literal(ScalarValue::UInt64(Some(1))) => true, - Expr::Literal(ScalarValue::Float32(Some(v))) if *v == 1. => true, - Expr::Literal(ScalarValue::Float64(Some(v))) if *v == 1. => true, - Expr::Literal(ScalarValue::Decimal128(Some(v), _p, s)) => { + Expr::Literal(ScalarValue::Int8(Some(1)), _) + | Expr::Literal(ScalarValue::Int16(Some(1)), _) + | Expr::Literal(ScalarValue::Int32(Some(1)), _) + | Expr::Literal(ScalarValue::Int64(Some(1)), _) + | Expr::Literal(ScalarValue::UInt8(Some(1)), _) + | Expr::Literal(ScalarValue::UInt16(Some(1)), _) + | Expr::Literal(ScalarValue::UInt32(Some(1)), _) + | Expr::Literal(ScalarValue::UInt64(Some(1)), _) => true, + Expr::Literal(ScalarValue::Float32(Some(v)), _) if *v == 1. => true, + Expr::Literal(ScalarValue::Float64(Some(v)), _) if *v == 1. => true, + Expr::Literal(ScalarValue::Decimal128(Some(v), _p, s), _) => { *s >= 0 && POWS_OF_TEN .get(*s as usize) @@ -179,7 +179,7 @@ pub fn is_one(s: &Expr) -> bool { pub fn is_true(expr: &Expr) -> bool { match expr { - Expr::Literal(ScalarValue::Boolean(Some(v))) => *v, + Expr::Literal(ScalarValue::Boolean(Some(v)), _) => *v, _ => false, } } @@ -187,24 +187,24 @@ pub fn is_true(expr: &Expr) -> bool { /// returns true if expr is a /// `Expr::Literal(ScalarValue::Boolean(v))` , false otherwise pub fn is_bool_lit(expr: &Expr) -> bool { - matches!(expr, Expr::Literal(ScalarValue::Boolean(_))) + matches!(expr, Expr::Literal(ScalarValue::Boolean(_), _)) } /// Return a literal NULL value of Boolean data type pub fn lit_bool_null() -> Expr { - Expr::Literal(ScalarValue::Boolean(None)) + Expr::Literal(ScalarValue::Boolean(None), None) } pub fn is_null(expr: &Expr) -> bool { match expr { - Expr::Literal(v) => v.is_null(), + Expr::Literal(v, _) => v.is_null(), _ => false, } } pub fn is_false(expr: &Expr) -> bool { match expr { - Expr::Literal(ScalarValue::Boolean(Some(v))) => !(*v), + Expr::Literal(ScalarValue::Boolean(Some(v)), _) => !(*v), _ => false, } } @@ -247,7 +247,7 @@ pub fn is_negative_of(not_expr: &Expr, expr: &Expr) -> bool { /// `Expr::Literal(ScalarValue::Boolean(v))`. pub fn as_bool_lit(expr: &Expr) -> Result> { match expr { - Expr::Literal(ScalarValue::Boolean(v)) => Ok(*v), + Expr::Literal(ScalarValue::Boolean(v), _) => Ok(*v), _ => internal_err!("Expected boolean literal, got {expr:?}"), } } diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index 7337d2ffce5c..50783a214342 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -206,7 +206,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { // if the aggregate function is not distinct, we need to rewrite it like two phase aggregation } else { index += 1; - let alias_str = format!("alias{}", index); + let alias_str = format!("alias{index}"); inner_aggr_exprs.push( Expr::AggregateFunction(AggregateFunction::new_udf( Arc::clone(&func), @@ -280,6 +280,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { #[cfg(test)] mod tests { use super::*; + use crate::assert_optimized_plan_eq_display_indent_snapshot; use crate::test::*; use datafusion_expr::expr::GroupingSet; use datafusion_expr::ExprFunctionExt; @@ -300,13 +301,18 @@ mod tests { )) } - fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { - assert_optimized_plan_eq_display_indent( - Arc::new(SingleDistinctToGroupBy::new()), - plan, - expected, - ); - Ok(()) + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let rule: Arc = Arc::new(SingleDistinctToGroupBy::new()); + assert_optimized_plan_eq_display_indent_snapshot!( + rule, + $plan, + @ $expected, + ) + }}; } #[test] @@ -318,11 +324,13 @@ mod tests { .build()?; // Do nothing - let expected = - "Aggregate: groupBy=[[]], aggr=[[max(test.b)]] [max(test.b):UInt32;N]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Aggregate: groupBy=[[]], aggr=[[max(test.b)]] [max(test.b):UInt32;N] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -334,12 +342,15 @@ mod tests { .build()?; // Should work - let expected = "Projection: count(alias1) AS count(DISTINCT test.b) [count(DISTINCT test.b):Int64]\ - \n Aggregate: groupBy=[[]], aggr=[[count(alias1)]] [count(alias1):Int64]\ - \n Aggregate: groupBy=[[test.b AS alias1]], aggr=[[]] [alias1:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: count(alias1) AS count(DISTINCT test.b) [count(DISTINCT test.b):Int64] + Aggregate: groupBy=[[]], aggr=[[count(alias1)]] [count(alias1):Int64] + Aggregate: groupBy=[[test.b AS alias1]], aggr=[[]] [alias1:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + " + ) } // Currently this optimization is disabled for CUBE/ROLLUP/GROUPING SET @@ -357,10 +368,13 @@ mod tests { .build()?; // Should not be optimized - let expected = "Aggregate: groupBy=[[GROUPING SETS ((test.a), (test.b))]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, __grouping_id:UInt8, count(DISTINCT test.c):Int64]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Aggregate: groupBy=[[GROUPING SETS ((test.a), (test.b))]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, __grouping_id:UInt8, count(DISTINCT test.c):Int64] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + " + ) } // Currently this optimization is disabled for CUBE/ROLLUP/GROUPING SET @@ -375,10 +389,13 @@ mod tests { .build()?; // Should not be optimized - let expected = "Aggregate: groupBy=[[CUBE (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, __grouping_id:UInt8, count(DISTINCT test.c):Int64]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Aggregate: groupBy=[[CUBE (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, __grouping_id:UInt8, count(DISTINCT test.c):Int64] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + " + ) } // Currently this optimization is disabled for CUBE/ROLLUP/GROUPING SET @@ -394,10 +411,13 @@ mod tests { .build()?; // Should not be optimized - let expected = "Aggregate: groupBy=[[ROLLUP (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, __grouping_id:UInt8, count(DISTINCT test.c):Int64]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Aggregate: groupBy=[[ROLLUP (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, __grouping_id:UInt8, count(DISTINCT test.c):Int64] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -408,12 +428,15 @@ mod tests { .aggregate(Vec::::new(), vec![count_distinct(lit(2) * col("b"))])? .build()?; - let expected = "Projection: count(alias1) AS count(DISTINCT Int32(2) * test.b) [count(DISTINCT Int32(2) * test.b):Int64]\ - \n Aggregate: groupBy=[[]], aggr=[[count(alias1)]] [count(alias1):Int64]\ - \n Aggregate: groupBy=[[Int32(2) * test.b AS alias1]], aggr=[[]] [alias1:Int64]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: count(alias1) AS count(DISTINCT Int32(2) * test.b) [count(DISTINCT Int32(2) * test.b):Int64] + Aggregate: groupBy=[[]], aggr=[[count(alias1)]] [count(alias1):Int64] + Aggregate: groupBy=[[Int32(2) * test.b AS alias1]], aggr=[[]] [alias1:Int64] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -425,12 +448,15 @@ mod tests { .build()?; // Should work - let expected = "Projection: test.a, count(alias1) AS count(DISTINCT test.b) [a:UInt32, count(DISTINCT test.b):Int64]\ - \n Aggregate: groupBy=[[test.a]], aggr=[[count(alias1)]] [a:UInt32, count(alias1):Int64]\ - \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a, count(alias1) AS count(DISTINCT test.b) [a:UInt32, count(DISTINCT test.b):Int64] + Aggregate: groupBy=[[test.a]], aggr=[[count(alias1)]] [a:UInt32, count(alias1):Int64] + Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -445,10 +471,13 @@ mod tests { .build()?; // Do nothing - let expected = "Aggregate: groupBy=[[test.a]], aggr=[[count(DISTINCT test.b), count(DISTINCT test.c)]] [a:UInt32, count(DISTINCT test.b):Int64, count(DISTINCT test.c):Int64]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Aggregate: groupBy=[[test.a]], aggr=[[count(DISTINCT test.b), count(DISTINCT test.c)]] [a:UInt32, count(DISTINCT test.b):Int64, count(DISTINCT test.c):Int64] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -461,13 +490,17 @@ mod tests { vec![count_distinct(col("b")), max_distinct(col("b"))], )? .build()?; - // Should work - let expected = "Projection: test.a, count(alias1) AS count(DISTINCT test.b), max(alias1) AS max(DISTINCT test.b) [a:UInt32, count(DISTINCT test.b):Int64, max(DISTINCT test.b):UInt32;N]\ - \n Aggregate: groupBy=[[test.a]], aggr=[[count(alias1), max(alias1)]] [a:UInt32, count(alias1):Int64, max(alias1):UInt32;N]\ - \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(plan, expected) + // Should work + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a, count(alias1) AS count(DISTINCT test.b), max(alias1) AS max(DISTINCT test.b) [a:UInt32, count(DISTINCT test.b):Int64, max(DISTINCT test.b):UInt32;N] + Aggregate: groupBy=[[test.a]], aggr=[[count(alias1), max(alias1)]] [a:UInt32, count(alias1):Int64, max(alias1):UInt32;N] + Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -482,10 +515,13 @@ mod tests { .build()?; // Do nothing - let expected = "Aggregate: groupBy=[[test.a]], aggr=[[count(DISTINCT test.b), count(test.c)]] [a:UInt32, count(DISTINCT test.b):Int64, count(test.c):Int64]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Aggregate: groupBy=[[test.a]], aggr=[[count(DISTINCT test.b), count(test.c)]] [a:UInt32, count(DISTINCT test.b):Int64, count(test.c):Int64] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -497,12 +533,15 @@ mod tests { .build()?; // Should work - let expected = "Projection: group_alias_0 AS test.a + Int32(1), count(alias1) AS count(DISTINCT test.c) [test.a + Int32(1):Int64, count(DISTINCT test.c):Int64]\ - \n Aggregate: groupBy=[[group_alias_0]], aggr=[[count(alias1)]] [group_alias_0:Int64, count(alias1):Int64]\ - \n Aggregate: groupBy=[[test.a + Int32(1) AS group_alias_0, test.c AS alias1]], aggr=[[]] [group_alias_0:Int64, alias1:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: group_alias_0 AS test.a + Int32(1), count(alias1) AS count(DISTINCT test.c) [test.a + Int32(1):Int64, count(DISTINCT test.c):Int64] + Aggregate: groupBy=[[group_alias_0]], aggr=[[count(alias1)]] [group_alias_0:Int64, count(alias1):Int64] + Aggregate: groupBy=[[test.a + Int32(1) AS group_alias_0, test.c AS alias1]], aggr=[[]] [group_alias_0:Int64, alias1:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -519,13 +558,17 @@ mod tests { ], )? .build()?; - // Should work - let expected = "Projection: test.a, sum(alias2) AS sum(test.c), count(alias1) AS count(DISTINCT test.b), max(alias1) AS max(DISTINCT test.b) [a:UInt32, sum(test.c):UInt64;N, count(DISTINCT test.b):Int64, max(DISTINCT test.b):UInt32;N]\ - \n Aggregate: groupBy=[[test.a]], aggr=[[sum(alias2), count(alias1), max(alias1)]] [a:UInt32, sum(alias2):UInt64;N, count(alias1):Int64, max(alias1):UInt32;N]\ - \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[sum(test.c) AS alias2]] [a:UInt32, alias1:UInt32, alias2:UInt64;N]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(plan, expected) + // Should work + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a, sum(alias2) AS sum(test.c), count(alias1) AS count(DISTINCT test.b), max(alias1) AS max(DISTINCT test.b) [a:UInt32, sum(test.c):UInt64;N, count(DISTINCT test.b):Int64, max(DISTINCT test.b):UInt32;N] + Aggregate: groupBy=[[test.a]], aggr=[[sum(alias2), count(alias1), max(alias1)]] [a:UInt32, sum(alias2):UInt64;N, count(alias1):Int64, max(alias1):UInt32;N] + Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[sum(test.c) AS alias2]] [a:UInt32, alias1:UInt32, alias2:UInt64;N] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -538,13 +581,17 @@ mod tests { vec![sum(col("c")), max(col("c")), count_distinct(col("b"))], )? .build()?; - // Should work - let expected = "Projection: test.a, sum(alias2) AS sum(test.c), max(alias3) AS max(test.c), count(alias1) AS count(DISTINCT test.b) [a:UInt32, sum(test.c):UInt64;N, max(test.c):UInt32;N, count(DISTINCT test.b):Int64]\ - \n Aggregate: groupBy=[[test.a]], aggr=[[sum(alias2), max(alias3), count(alias1)]] [a:UInt32, sum(alias2):UInt64;N, max(alias3):UInt32;N, count(alias1):Int64]\ - \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[sum(test.c) AS alias2, max(test.c) AS alias3]] [a:UInt32, alias1:UInt32, alias2:UInt64;N, alias3:UInt32;N]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(plan, expected) + // Should work + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a, sum(alias2) AS sum(test.c), max(alias3) AS max(test.c), count(alias1) AS count(DISTINCT test.b) [a:UInt32, sum(test.c):UInt64;N, max(test.c):UInt32;N, count(DISTINCT test.b):Int64] + Aggregate: groupBy=[[test.a]], aggr=[[sum(alias2), max(alias3), count(alias1)]] [a:UInt32, sum(alias2):UInt64;N, max(alias3):UInt32;N, count(alias1):Int64] + Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[sum(test.c) AS alias2, max(test.c) AS alias3]] [a:UInt32, alias1:UInt32, alias2:UInt64;N, alias3:UInt32;N] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -557,13 +604,17 @@ mod tests { vec![min(col("a")), count_distinct(col("b"))], )? .build()?; - // Should work - let expected = "Projection: test.c, min(alias2) AS min(test.a), count(alias1) AS count(DISTINCT test.b) [c:UInt32, min(test.a):UInt32;N, count(DISTINCT test.b):Int64]\ - \n Aggregate: groupBy=[[test.c]], aggr=[[min(alias2), count(alias1)]] [c:UInt32, min(alias2):UInt32;N, count(alias1):Int64]\ - \n Aggregate: groupBy=[[test.c, test.b AS alias1]], aggr=[[min(test.a) AS alias2]] [c:UInt32, alias1:UInt32, alias2:UInt32;N]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(plan, expected) + // Should work + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.c, min(alias2) AS min(test.a), count(alias1) AS count(DISTINCT test.b) [c:UInt32, min(test.a):UInt32;N, count(DISTINCT test.b):Int64] + Aggregate: groupBy=[[test.c]], aggr=[[min(alias2), count(alias1)]] [c:UInt32, min(alias2):UInt32;N, count(alias1):Int64] + Aggregate: groupBy=[[test.c, test.b AS alias1]], aggr=[[min(test.a) AS alias2]] [c:UInt32, alias1:UInt32, alias2:UInt32;N] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -582,11 +633,15 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![col("c")], vec![expr, count_distinct(col("b"))])? .build()?; - // Do nothing - let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a) FILTER (WHERE test.a > Int32(5)), count(DISTINCT test.b)]] [c:UInt32, sum(test.a) FILTER (WHERE test.a > Int32(5)):UInt64;N, count(DISTINCT test.b):Int64]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(plan, expected) + // Do nothing + assert_optimized_plan_equal!( + plan, + @r" + Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a) FILTER (WHERE test.a > Int32(5)), count(DISTINCT test.b)]] [c:UInt32, sum(test.a) FILTER (WHERE test.a > Int32(5)):UInt64;N, count(DISTINCT test.b):Int64] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -602,11 +657,15 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![col("c")], vec![sum(col("a")), expr])? .build()?; - // Do nothing - let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5))]] [c:UInt32, sum(test.a):UInt64;N, count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)):Int64]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(plan, expected) + // Do nothing + assert_optimized_plan_equal!( + plan, + @r" + Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5))]] [c:UInt32, sum(test.a):UInt64;N, count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)):Int64] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -625,11 +684,15 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![col("c")], vec![expr, count_distinct(col("b"))])? .build()?; - // Do nothing - let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a) ORDER BY [test.a ASC NULLS LAST], count(DISTINCT test.b)]] [c:UInt32, sum(test.a) ORDER BY [test.a ASC NULLS LAST]:UInt64;N, count(DISTINCT test.b):Int64]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(plan, expected) + // Do nothing + assert_optimized_plan_equal!( + plan, + @r" + Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a) ORDER BY [test.a ASC NULLS LAST], count(DISTINCT test.b)]] [c:UInt32, sum(test.a) ORDER BY [test.a ASC NULLS LAST]:UInt64;N, count(DISTINCT test.b):Int64] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -645,11 +708,15 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![col("c")], vec![sum(col("a")), expr])? .build()?; - // Do nothing - let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), count(DISTINCT test.a) ORDER BY [test.a ASC NULLS LAST]]] [c:UInt32, sum(test.a):UInt64;N, count(DISTINCT test.a) ORDER BY [test.a ASC NULLS LAST]:Int64]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(plan, expected) + // Do nothing + assert_optimized_plan_equal!( + plan, + @r" + Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), count(DISTINCT test.a) ORDER BY [test.a ASC NULLS LAST]]] [c:UInt32, sum(test.a):UInt64;N, count(DISTINCT test.a) ORDER BY [test.a ASC NULLS LAST]:Int64] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -666,10 +733,14 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![col("c")], vec![sum(col("a")), expr])? .build()?; - // Do nothing - let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a ASC NULLS LAST]]] [c:UInt32, sum(test.a):UInt64;N, count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a ASC NULLS LAST]:Int64]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(plan, expected) + // Do nothing + assert_optimized_plan_equal!( + plan, + @r" + Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a ASC NULLS LAST]]] [c:UInt32, sum(test.a):UInt64;N, count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a ASC NULLS LAST]:Int64] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + " + ) } } diff --git a/datafusion/optimizer/src/test/mod.rs b/datafusion/optimizer/src/test/mod.rs index 94d07a0791b3..6e0b734bb928 100644 --- a/datafusion/optimizer/src/test/mod.rs +++ b/datafusion/optimizer/src/test/mod.rs @@ -21,7 +21,7 @@ use crate::{OptimizerContext, OptimizerRule}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::config::ConfigOptions; use datafusion_common::{assert_contains, Result}; -use datafusion_expr::{col, logical_plan::table_scan, LogicalPlan, LogicalPlanBuilder}; +use datafusion_expr::{logical_plan::table_scan, LogicalPlan, LogicalPlanBuilder}; use std::sync::Arc; pub mod user_defined; @@ -64,15 +64,6 @@ pub fn assert_fields_eq(plan: &LogicalPlan, expected: Vec<&str>) { assert_eq!(actual, expected); } -pub fn test_subquery_with_name(name: &str) -> Result> { - let table_scan = test_table_scan_with_name(name)?; - Ok(Arc::new( - LogicalPlanBuilder::from(table_scan) - .project(vec![col("c")])? - .build()?, - )) -} - pub fn scan_tpch_table(table: &str) -> LogicalPlan { let schema = Arc::new(get_tpch_table_schema(table)); table_scan(Some(table), &schema, None) @@ -108,43 +99,20 @@ pub fn get_tpch_table_schema(table: &str) -> Schema { } } -pub fn assert_analyzed_plan_eq( - rule: Arc, - plan: LogicalPlan, - expected: &str, -) -> Result<()> { - let options = ConfigOptions::default(); - assert_analyzed_plan_with_config_eq(options, rule, plan, expected)?; - - Ok(()) -} +#[macro_export] +macro_rules! assert_analyzed_plan_with_config_eq_snapshot { + ( + $options:expr, + $rule:expr, + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let analyzed_plan = $crate::Analyzer::with_rules(vec![$rule]).execute_and_check($plan, &$options, |_, _| {})?; -pub fn assert_analyzed_plan_with_config_eq( - options: ConfigOptions, - rule: Arc, - plan: LogicalPlan, - expected: &str, -) -> Result<()> { - let analyzed_plan = - Analyzer::with_rules(vec![rule]).execute_and_check(plan, &options, |_, _| {})?; - let formatted_plan = format!("{analyzed_plan}"); - assert_eq!(formatted_plan, expected); + insta::assert_snapshot!(analyzed_plan, @ $expected); - Ok(()) -} - -pub fn assert_analyzed_plan_eq_display_indent( - rule: Arc, - plan: LogicalPlan, - expected: &str, -) -> Result<()> { - let options = ConfigOptions::default(); - let analyzed_plan = - Analyzer::with_rules(vec![rule]).execute_and_check(plan, &options, |_, _| {})?; - let formatted_plan = analyzed_plan.display_indent_schema().to_string(); - assert_eq!(formatted_plan, expected); - - Ok(()) + Ok::<(), datafusion_common::DataFusionError>(()) + }}; } pub fn assert_analyzer_check_err( @@ -165,27 +133,26 @@ pub fn assert_analyzer_check_err( fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} -pub fn assert_optimized_plan_eq( - rule: Arc, - plan: LogicalPlan, - expected: &str, -) -> Result<()> { - // Apply the rule once - let opt_context = OptimizerContext::new().with_max_passes(1); +#[macro_export] +macro_rules! assert_optimized_plan_eq_snapshot { + ( + $optimizer_context:expr, + $rules:expr, + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let optimizer = $crate::Optimizer::with_rules($rules); + let optimized_plan = optimizer.optimize($plan, &$optimizer_context, |_, _| {})?; + insta::assert_snapshot!(optimized_plan, @ $expected); - let optimizer = Optimizer::with_rules(vec![Arc::clone(&rule)]); - let optimized_plan = optimizer.optimize(plan, &opt_context, observe)?; - let formatted_plan = format!("{optimized_plan}"); - assert_eq!(formatted_plan, expected); - - Ok(()) + Ok::<(), datafusion_common::DataFusionError>(()) + }}; } fn generate_optimized_plan_with_rules( rules: Vec>, plan: LogicalPlan, ) -> LogicalPlan { - fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} let config = &mut OptimizerContext::new() .with_max_passes(1) .with_skip_failing_rules(false); @@ -211,60 +178,20 @@ pub fn assert_optimized_plan_with_rules( Ok(()) } -pub fn assert_optimized_plan_eq_display_indent( - rule: Arc, - plan: LogicalPlan, - expected: &str, -) { - let optimizer = Optimizer::with_rules(vec![rule]); - let optimized_plan = optimizer - .optimize(plan, &OptimizerContext::new(), observe) - .expect("failed to optimize plan"); - let formatted_plan = optimized_plan.display_indent_schema().to_string(); - assert_eq!(formatted_plan, expected); -} - -pub fn assert_multi_rules_optimized_plan_eq_display_indent( - rules: Vec>, - plan: LogicalPlan, - expected: &str, -) { - let optimizer = Optimizer::with_rules(rules); - let optimized_plan = optimizer - .optimize(plan, &OptimizerContext::new(), observe) - .expect("failed to optimize plan"); - let formatted_plan = optimized_plan.display_indent_schema().to_string(); - assert_eq!(formatted_plan, expected); -} - -pub fn assert_optimizer_err( - rule: Arc, - plan: LogicalPlan, - expected: &str, -) { - let optimizer = Optimizer::with_rules(vec![rule]); - let res = optimizer.optimize(plan, &OptimizerContext::new(), observe); - match res { - Ok(plan) => assert_eq!(format!("{}", plan.display_indent()), "An error"), - Err(ref e) => { - let actual = format!("{e}"); - if expected.is_empty() || !actual.contains(expected) { - assert_eq!(actual, expected) - } - } - } -} - -pub fn assert_optimization_skipped( - rule: Arc, - plan: LogicalPlan, -) -> Result<()> { - let optimizer = Optimizer::with_rules(vec![rule]); - let new_plan = optimizer.optimize(plan.clone(), &OptimizerContext::new(), observe)?; - - assert_eq!( - format!("{}", plan.display_indent()), - format!("{}", new_plan.display_indent()) - ); - Ok(()) +#[macro_export] +macro_rules! assert_optimized_plan_eq_display_indent_snapshot { + ( + $rule:expr, + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let optimizer = $crate::Optimizer::with_rules(vec![$rule]); + let optimized_plan = optimizer + .optimize($plan, &$crate::OptimizerContext::new(), |_, _| {}) + .expect("failed to optimize plan"); + let formatted_plan = optimized_plan.display_indent_schema(); + insta::assert_snapshot!(formatted_plan, @ $expected); + + Ok::<(), datafusion_common::DataFusionError>(()) + }}; } diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 41c40ec06d65..0aa0bf3ea430 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -163,7 +163,11 @@ mod tests { (Expr::IsNotNull(Box::new(col("a"))), true), // a = NULL ( - binary_expr(col("a"), Operator::Eq, Expr::Literal(ScalarValue::Null)), + binary_expr( + col("a"), + Operator::Eq, + Expr::Literal(ScalarValue::Null, None), + ), true, ), // a > 8 @@ -226,12 +230,16 @@ mod tests { ), // a IN (NULL) ( - in_list(col("a"), vec![Expr::Literal(ScalarValue::Null)], false), + in_list( + col("a"), + vec![Expr::Literal(ScalarValue::Null, None)], + false, + ), true, ), // a NOT IN (NULL) ( - in_list(col("a"), vec![Expr::Literal(ScalarValue::Null)], true), + in_list(col("a"), vec![Expr::Literal(ScalarValue::Null, None)], true), true, ), ]; @@ -241,7 +249,7 @@ mod tests { let join_cols_of_predicate = std::iter::once(&column_a); let actual = is_restrict_null_predicate(predicate.clone(), join_cols_of_predicate)?; - assert_eq!(actual, expected, "{}", predicate); + assert_eq!(actual, expected, "{predicate}"); } Ok(()) diff --git a/datafusion/optimizer/tests/optimizer_integration.rs b/datafusion/optimizer/tests/optimizer_integration.rs index 941e5bd7b4d7..95a9db6c8abd 100644 --- a/datafusion/optimizer/tests/optimizer_integration.rs +++ b/datafusion/optimizer/tests/optimizer_integration.rs @@ -250,7 +250,7 @@ fn between_date32_plus_interval() -> Result<()> { format!("{plan}"), @r#" Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] - Projection: + Projection: Filter: test.col_date32 >= Date32("1998-03-18") AND test.col_date32 <= Date32("1998-06-16") TableScan: test projection=[col_date32] "# @@ -268,7 +268,7 @@ fn between_date64_plus_interval() -> Result<()> { format!("{plan}"), @r#" Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] - Projection: + Projection: Filter: test.col_date64 >= Date64("1998-03-18") AND test.col_date64 <= Date64("1998-06-16") TableScan: test projection=[col_date64] "# @@ -492,7 +492,32 @@ fn test_sql(sql: &str) -> Result { .with_expr_planners(vec![ Arc::new(AggregateFunctionPlanner), Arc::new(WindowFunctionPlanner), - ]); + ]) + .with_schema( + "test", + Schema::new_with_metadata( + vec![ + Field::new("col_int32", DataType::Int32, true), + Field::new("col_uint32", DataType::UInt32, true), + Field::new("col_utf8", DataType::Utf8, true), + Field::new("col_date32", DataType::Date32, true), + Field::new("col_date64", DataType::Date64, true), + // timestamp with no timezone + Field::new( + "col_ts_nano_none", + DataType::Timestamp(TimeUnit::Nanosecond, None), + true, + ), + // timestamp with UTC timezone + Field::new( + "col_ts_nano_utc", + DataType::Timestamp(TimeUnit::Nanosecond, Some("+00:00".into())), + true, + ), + ], + HashMap::new(), + ), + ); let sql_to_rel = SqlToRel::new(&context_provider); let plan = sql_to_rel.sql_statement_to_plan(statement.clone())?; @@ -510,6 +535,7 @@ fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} #[derive(Default)] struct MyContextProvider { options: ConfigOptions, + tables: HashMap>, udafs: HashMap>, expr_planners: Vec>, } @@ -525,38 +551,23 @@ impl MyContextProvider { self.expr_planners = expr_planners; self } + + fn with_schema(mut self, name: impl Into, schema: Schema) -> Self { + self.tables.insert( + name.into(), + Arc::new(MyTableSource { + schema: Arc::new(schema), + }), + ); + self + } } impl ContextProvider for MyContextProvider { fn get_table_source(&self, name: TableReference) -> Result> { let table_name = name.table(); - if table_name.starts_with("test") { - let schema = Schema::new_with_metadata( - vec![ - Field::new("col_int32", DataType::Int32, true), - Field::new("col_uint32", DataType::UInt32, true), - Field::new("col_utf8", DataType::Utf8, true), - Field::new("col_date32", DataType::Date32, true), - Field::new("col_date64", DataType::Date64, true), - // timestamp with no timezone - Field::new( - "col_ts_nano_none", - DataType::Timestamp(TimeUnit::Nanosecond, None), - true, - ), - // timestamp with UTC timezone - Field::new( - "col_ts_nano_utc", - DataType::Timestamp(TimeUnit::Nanosecond, Some("+00:00".into())), - true, - ), - ], - HashMap::new(), - ); - - Ok(Arc::new(MyTableSource { - schema: Arc::new(schema), - })) + if let Some(table) = self.tables.get(table_name) { + Ok(table.clone()) } else { plan_err!("table does not exist") } diff --git a/datafusion/physical-expr-common/src/physical_expr.rs b/datafusion/physical-expr-common/src/physical_expr.rs index 3bc41d2652d9..7be132fa6123 100644 --- a/datafusion/physical-expr-common/src/physical_expr.rs +++ b/datafusion/physical-expr-common/src/physical_expr.rs @@ -25,7 +25,7 @@ use crate::utils::scatter; use arrow::array::BooleanArray; use arrow::compute::filter_record_batch; -use arrow::datatypes::{DataType, Schema}; +use arrow::datatypes::{DataType, Field, FieldRef, Schema}; use arrow::record_batch::RecordBatch; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{internal_err, not_impl_err, Result, ScalarValue}; @@ -71,11 +71,23 @@ pub trait PhysicalExpr: Send + Sync + Display + Debug + DynEq + DynHash { /// downcast to a specific implementation. fn as_any(&self) -> &dyn Any; /// Get the data type of this expression, given the schema of the input - fn data_type(&self, input_schema: &Schema) -> Result; + fn data_type(&self, input_schema: &Schema) -> Result { + Ok(self.return_field(input_schema)?.data_type().to_owned()) + } /// Determine whether this expression is nullable, given the schema of the input - fn nullable(&self, input_schema: &Schema) -> Result; + fn nullable(&self, input_schema: &Schema) -> Result { + Ok(self.return_field(input_schema)?.is_nullable()) + } /// Evaluate an expression against a RecordBatch fn evaluate(&self, batch: &RecordBatch) -> Result; + /// The output field associated with this expression + fn return_field(&self, input_schema: &Schema) -> Result { + Ok(Arc::new(Field::new( + format!("{self}"), + self.data_type(input_schema)?, + self.nullable(input_schema)?, + ))) + } /// Evaluate an expression against a RecordBatch after first applying a /// validity array fn evaluate_selection( @@ -434,10 +446,10 @@ where let mut iter = self.0.clone(); write!(f, "[")?; if let Some(expr) = iter.next() { - write!(f, "{}", expr)?; + write!(f, "{expr}")?; } for expr in iter { - write!(f, ", {}", expr)?; + write!(f, ", {expr}")?; } write!(f, "]")?; Ok(()) @@ -453,19 +465,21 @@ where /// ``` /// # // The boiler plate needed to create a `PhysicalExpr` for the example /// # use std::any::Any; +/// use std::collections::HashMap; /// # use std::fmt::Formatter; /// # use std::sync::Arc; /// # use arrow::array::RecordBatch; -/// # use arrow::datatypes::{DataType, Schema}; +/// # use arrow::datatypes::{DataType, Field, FieldRef, Schema}; /// # use datafusion_common::Result; /// # use datafusion_expr_common::columnar_value::ColumnarValue; /// # use datafusion_physical_expr_common::physical_expr::{fmt_sql, DynEq, PhysicalExpr}; /// # #[derive(Debug, Hash, PartialOrd, PartialEq)] -/// # struct MyExpr {}; +/// # struct MyExpr {} /// # impl PhysicalExpr for MyExpr {fn as_any(&self) -> &dyn Any { unimplemented!() } /// # fn data_type(&self, input_schema: &Schema) -> Result { unimplemented!() } /// # fn nullable(&self, input_schema: &Schema) -> Result { unimplemented!() } /// # fn evaluate(&self, batch: &RecordBatch) -> Result { unimplemented!() } +/// # fn return_field(&self, input_schema: &Schema) -> Result { unimplemented!() } /// # fn children(&self) -> Vec<&Arc>{ unimplemented!() } /// # fn with_new_children(self: Arc, children: Vec>) -> Result> { unimplemented!() } /// # fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "CASE a > b THEN 1 ELSE 0 END") } diff --git a/datafusion/physical-expr-common/src/sort_expr.rs b/datafusion/physical-expr-common/src/sort_expr.rs index 3a54b5b40399..2572e8679484 100644 --- a/datafusion/physical-expr-common/src/sort_expr.rs +++ b/datafusion/physical-expr-common/src/sort_expr.rs @@ -37,13 +37,14 @@ use itertools::Itertools; /// Example: /// ``` /// # use std::any::Any; +/// # use std::collections::HashMap; /// # use std::fmt::{Display, Formatter}; /// # use std::hash::Hasher; /// # use std::sync::Arc; /// # use arrow::array::RecordBatch; /// # use datafusion_common::Result; /// # use arrow::compute::SortOptions; -/// # use arrow::datatypes::{DataType, Schema}; +/// # use arrow::datatypes::{DataType, Field, FieldRef, Schema}; /// # use datafusion_expr_common::columnar_value::ColumnarValue; /// # use datafusion_physical_expr_common::physical_expr::PhysicalExpr; /// # use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; @@ -56,6 +57,7 @@ use itertools::Itertools; /// # fn data_type(&self, input_schema: &Schema) -> Result {todo!()} /// # fn nullable(&self, input_schema: &Schema) -> Result {todo!() } /// # fn evaluate(&self, batch: &RecordBatch) -> Result {todo!() } +/// # fn return_field(&self, input_schema: &Schema) -> Result { unimplemented!() } /// # fn children(&self) -> Vec<&Arc> {todo!()} /// # fn with_new_children(self: Arc, children: Vec>) -> Result> {todo!()} /// # fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result { todo!() } @@ -265,10 +267,10 @@ pub fn format_physical_sort_requirement_list( let mut iter = self.0.iter(); write!(f, "[")?; if let Some(expr) = iter.next() { - write!(f, "{}", expr)?; + write!(f, "{expr}")?; } for expr in iter { - write!(f, ", {}", expr)?; + write!(f, ", {expr}")?; } write!(f, "]")?; Ok(()) @@ -508,7 +510,7 @@ impl Display for LexOrdering { } else { write!(f, ", ")?; } - write!(f, "{}", sort_expr)?; + write!(f, "{sort_expr}")?; } Ok(()) } diff --git a/datafusion/physical-expr/Cargo.toml b/datafusion/physical-expr/Cargo.toml index 47e3291e5cb4..2cce585b7f15 100644 --- a/datafusion/physical-expr/Cargo.toml +++ b/datafusion/physical-expr/Cargo.toml @@ -51,7 +51,7 @@ indexmap = { workspace = true } itertools = { workspace = true, features = ["use_std"] } log = { workspace = true } paste = "^1.0" -petgraph = "0.7.1" +petgraph = "0.8.1" [dev-dependencies] arrow = { workspace = true, features = ["test_utils"] } diff --git a/datafusion/physical-expr/benches/binary_op.rs b/datafusion/physical-expr/benches/binary_op.rs index 59a602df053c..5b0f700fdb8a 100644 --- a/datafusion/physical-expr/benches/binary_op.rs +++ b/datafusion/physical-expr/benches/binary_op.rs @@ -126,14 +126,25 @@ fn generate_boolean_cases( )); } + // Scenario 7: Test all true or all false in AND/OR + // This situation won't cause a short circuit, but it can skip the bool calculation + if TEST_ALL_FALSE { + let all_true = vec![true; len]; + cases.push(("all_true_in_and".to_string(), BooleanArray::from(all_true))); + } else { + let all_false = vec![false; len]; + cases.push(("all_false_in_or".to_string(), BooleanArray::from(all_false))); + } + cases } /// Benchmarks AND/OR operator short-circuiting by evaluating complex regex conditions. /// -/// Creates 6 test scenarios per operator: +/// Creates 7 test scenarios per operator: /// 1. All values enable short-circuit (all_true/all_false) /// 2. 2-6 Single true/false value at different positions to measure early exit +/// 3. Test all true or all false in AND/OR /// /// You can run this benchmark with: /// ```sh @@ -203,16 +214,16 @@ fn benchmark_binary_op_in_short_circuit(c: &mut Criterion) { // Each scenario when the test operator is `and` { - for (name, batch) in batches_and { - c.bench_function(&format!("short_circuit/and/{}", name), |b| { + for (name, batch) in batches_and.into_iter() { + c.bench_function(&format!("short_circuit/and/{name}"), |b| { b.iter(|| expr_and.evaluate(black_box(&batch)).unwrap()) }); } } // Each scenario when the test operator is `or` { - for (name, batch) in batches_or { - c.bench_function(&format!("short_circuit/or/{}", name), |b| { + for (name, batch) in batches_or.into_iter() { + c.bench_function(&format!("short_circuit/or/{name}"), |b| { b.iter(|| expr_or.evaluate(black_box(&batch)).unwrap()) }); } diff --git a/datafusion/physical-expr/benches/in_list.rs b/datafusion/physical-expr/benches/in_list.rs index 90bfc5efb61e..e91e8d1f137c 100644 --- a/datafusion/physical-expr/benches/in_list.rs +++ b/datafusion/physical-expr/benches/in_list.rs @@ -21,7 +21,7 @@ use arrow::record_batch::RecordBatch; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_common::ScalarValue; use datafusion_physical_expr::expressions::{col, in_list, lit}; -use rand::distributions::Alphanumeric; +use rand::distr::Alphanumeric; use rand::prelude::*; use std::sync::Arc; @@ -51,7 +51,7 @@ fn do_benches( for string_length in [5, 10, 20] { let values: StringArray = (0..array_length) .map(|_| { - rng.gen_bool(null_percent) + rng.random_bool(null_percent) .then(|| random_string(&mut rng, string_length)) }) .collect(); @@ -71,11 +71,11 @@ fn do_benches( } let values: Float32Array = (0..array_length) - .map(|_| rng.gen_bool(null_percent).then(|| rng.gen())) + .map(|_| rng.random_bool(null_percent).then(|| rng.random())) .collect(); let in_list: Vec<_> = (0..in_list_length) - .map(|_| ScalarValue::Float32(Some(rng.gen()))) + .map(|_| ScalarValue::Float32(Some(rng.random()))) .collect(); do_bench( @@ -86,11 +86,11 @@ fn do_benches( ); let values: Int32Array = (0..array_length) - .map(|_| rng.gen_bool(null_percent).then(|| rng.gen())) + .map(|_| rng.random_bool(null_percent).then(|| rng.random())) .collect(); let in_list: Vec<_> = (0..in_list_length) - .map(|_| ScalarValue::Int32(Some(rng.gen()))) + .map(|_| ScalarValue::Int32(Some(rng.random()))) .collect(); do_bench( diff --git a/datafusion/physical-expr/src/aggregate.rs b/datafusion/physical-expr/src/aggregate.rs index 49912954ac81..be04b9c6b8ea 100644 --- a/datafusion/physical-expr/src/aggregate.rs +++ b/datafusion/physical-expr/src/aggregate.rs @@ -41,7 +41,7 @@ use std::sync::Arc; use crate::expressions::Column; use arrow::compute::SortOptions; -use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use arrow::datatypes::{DataType, FieldRef, Schema, SchemaRef}; use datafusion_common::{internal_err, not_impl_err, Result, ScalarValue}; use datafusion_expr::{AggregateUDF, ReversedUDAF, SetMonotonicity}; use datafusion_expr_common::accumulator::Accumulator; @@ -106,7 +106,7 @@ impl AggregateExprBuilder { /// ``` /// # use std::any::Any; /// # use std::sync::Arc; - /// # use arrow::datatypes::DataType; + /// # use arrow::datatypes::{DataType, FieldRef}; /// # use datafusion_common::{Result, ScalarValue}; /// # use datafusion_expr::{col, ColumnarValue, Documentation, Signature, Volatility, Expr}; /// # use datafusion_expr::{AggregateUDFImpl, AggregateUDF, Accumulator, function::{AccumulatorArgs, StateFieldsArgs}}; @@ -143,7 +143,7 @@ impl AggregateExprBuilder { /// # unimplemented!() /// # } /// # - /// # fn state_fields(&self, args: StateFieldsArgs) -> Result> { + /// # fn state_fields(&self, args: StateFieldsArgs) -> Result> { /// # unimplemented!() /// # } /// # @@ -213,18 +213,18 @@ impl AggregateExprBuilder { utils::ordering_fields(ordering_req.as_ref(), &ordering_types); } - let input_exprs_types = args + let input_exprs_fields = args .iter() - .map(|arg| arg.data_type(&schema)) + .map(|arg| arg.return_field(&schema)) .collect::>>()?; check_arg_count( fun.name(), - &input_exprs_types, + &input_exprs_fields, &fun.signature().type_signature, )?; - let data_type = fun.return_type(&input_exprs_types)?; + let return_field = fun.return_field(&input_exprs_fields)?; let is_nullable = fun.is_nullable(); let name = match alias { None => { @@ -238,7 +238,7 @@ impl AggregateExprBuilder { Ok(AggregateFunctionExpr { fun: Arc::unwrap_or_clone(fun), args, - data_type, + return_field, name, human_display, schema: Arc::unwrap_or_clone(schema), @@ -246,7 +246,7 @@ impl AggregateExprBuilder { ignore_nulls, ordering_fields, is_distinct, - input_types: input_exprs_types, + input_fields: input_exprs_fields, is_reversed, is_nullable, }) @@ -310,8 +310,8 @@ impl AggregateExprBuilder { pub struct AggregateFunctionExpr { fun: AggregateUDF, args: Vec>, - /// Output / return type of this aggregate - data_type: DataType, + /// Output / return field of this aggregate + return_field: FieldRef, /// Output column name that this expression creates name: String, /// Simplified name for `tree` explain. @@ -322,10 +322,10 @@ pub struct AggregateFunctionExpr { // Whether to ignore null values ignore_nulls: bool, // fields used for order sensitive aggregation functions - ordering_fields: Vec, + ordering_fields: Vec, is_distinct: bool, is_reversed: bool, - input_types: Vec, + input_fields: Vec, is_nullable: bool, } @@ -372,8 +372,12 @@ impl AggregateFunctionExpr { } /// the field of the final result of this aggregation. - pub fn field(&self) -> Field { - Field::new(&self.name, self.data_type.clone(), self.is_nullable) + pub fn field(&self) -> FieldRef { + self.return_field + .as_ref() + .clone() + .with_name(&self.name) + .into() } /// the accumulator used to accumulate values from the expressions. @@ -381,7 +385,7 @@ impl AggregateFunctionExpr { /// return states with the same description as `state_fields` pub fn create_accumulator(&self) -> Result> { let acc_args = AccumulatorArgs { - return_type: &self.data_type, + return_field: Arc::clone(&self.return_field), schema: &self.schema, ignore_nulls: self.ignore_nulls, ordering_req: self.ordering_req.as_ref(), @@ -395,11 +399,11 @@ impl AggregateFunctionExpr { } /// the field of the final result of this aggregation. - pub fn state_fields(&self) -> Result> { + pub fn state_fields(&self) -> Result> { let args = StateFieldsArgs { name: &self.name, - input_types: &self.input_types, - return_type: &self.data_type, + input_fields: &self.input_fields, + return_field: Arc::clone(&self.return_field), ordering_fields: &self.ordering_fields, is_distinct: self.is_distinct, }; @@ -472,7 +476,7 @@ impl AggregateFunctionExpr { /// Creates accumulator implementation that supports retract pub fn create_sliding_accumulator(&self) -> Result> { let args = AccumulatorArgs { - return_type: &self.data_type, + return_field: Arc::clone(&self.return_field), schema: &self.schema, ignore_nulls: self.ignore_nulls, ordering_req: self.ordering_req.as_ref(), @@ -541,7 +545,7 @@ impl AggregateFunctionExpr { /// `[Self::create_groups_accumulator`] will be called. pub fn groups_accumulator_supported(&self) -> bool { let args = AccumulatorArgs { - return_type: &self.data_type, + return_field: Arc::clone(&self.return_field), schema: &self.schema, ignore_nulls: self.ignore_nulls, ordering_req: self.ordering_req.as_ref(), @@ -560,7 +564,7 @@ impl AggregateFunctionExpr { /// implemented in addition to [`Accumulator`]. pub fn create_groups_accumulator(&self) -> Result> { let args = AccumulatorArgs { - return_type: &self.data_type, + return_field: Arc::clone(&self.return_field), schema: &self.schema, ignore_nulls: self.ignore_nulls, ordering_req: self.ordering_req.as_ref(), @@ -640,7 +644,7 @@ impl AggregateFunctionExpr { /// output_field is the name of the column produced by this aggregate /// /// Note: this is used to use special aggregate implementations in certain conditions - pub fn get_minmax_desc(&self) -> Option<(Field, bool)> { + pub fn get_minmax_desc(&self) -> Option<(FieldRef, bool)> { self.fun.is_descending().map(|flag| (self.field(), flag)) } @@ -685,7 +689,7 @@ pub struct AggregatePhysicalExpressions { impl PartialEq for AggregateFunctionExpr { fn eq(&self, other: &Self) -> bool { self.name == other.name - && self.data_type == other.data_type + && self.return_field == other.return_field && self.fun == other.fun && self.args.len() == other.args.len() && self diff --git a/datafusion/physical-expr/src/analysis.rs b/datafusion/physical-expr/src/analysis.rs index 5abd50f6d1b4..1d59dab8fd6d 100644 --- a/datafusion/physical-expr/src/analysis.rs +++ b/datafusion/physical-expr/src/analysis.rs @@ -100,7 +100,7 @@ impl ExprBoundaries { ) -> Result { let field = schema.fields().get(col_index).ok_or_else(|| { internal_datafusion_err!( - "Could not create `ExprBoundaries`: in `try_from_column` `col_index` + "Could not create `ExprBoundaries`: in `try_from_column` `col_index` has gone out of bounds with a value of {col_index}, the schema has {} columns.", schema.fields.len() ) @@ -112,7 +112,7 @@ impl ExprBoundaries { .min_value .get_value() .cloned() - .unwrap_or(empty_field.clone()), + .unwrap_or_else(|| empty_field.clone()), col_stats .max_value .get_value() @@ -425,7 +425,7 @@ mod tests { fn test_analyze_invalid_boundary_exprs() { let schema = Arc::new(Schema::new(vec![make_field("a", DataType::Int32)])); let expr = col("a").lt(lit(10)).or(col("a").gt(lit(20))); - let expected_error = "Interval arithmetic does not support the operator OR"; + let expected_error = "OR operator cannot yet propagate true intervals"; let boundaries = ExprBoundaries::try_new_unbounded(&schema).unwrap(); let df_schema = DFSchema::try_from(Arc::clone(&schema)).unwrap(); let physical_expr = diff --git a/datafusion/physical-expr/src/equivalence/class.rs b/datafusion/physical-expr/src/equivalence/class.rs index 13a3c79a47a2..98b1299a2ec6 100644 --- a/datafusion/physical-expr/src/equivalence/class.rs +++ b/datafusion/physical-expr/src/equivalence/class.rs @@ -166,7 +166,7 @@ impl ConstExpr { } else { write!(f, ",")?; } - write!(f, "{}", const_expr)?; + write!(f, "{const_expr}")?; } Ok(()) } @@ -184,7 +184,7 @@ impl Display for ConstExpr { } AcrossPartitions::Uniform(value) => { if let Some(val) = value { - write!(f, "(uniform: {})", val)?; + write!(f, "(uniform: {val})")?; } else { write!(f, "(uniform: unknown)")?; } @@ -747,10 +747,10 @@ impl Display for EquivalenceGroup { write!(f, "[")?; let mut iter = self.iter(); if let Some(cls) = iter.next() { - write!(f, "{}", cls)?; + write!(f, "{cls}")?; } for cls in iter { - write!(f, ", {}", cls)?; + write!(f, ", {cls}")?; } write!(f, "]") } @@ -798,12 +798,11 @@ mod tests { eq_groups.bridge_classes(); let eq_groups = eq_groups.classes; let err_msg = format!( - "error in test entries: {:?}, expected: {:?}, actual:{:?}", - entries, expected, eq_groups + "error in test entries: {entries:?}, expected: {expected:?}, actual:{eq_groups:?}" ); - assert_eq!(eq_groups.len(), expected.len(), "{}", err_msg); + assert_eq!(eq_groups.len(), expected.len(), "{err_msg}"); for idx in 0..eq_groups.len() { - assert_eq!(&eq_groups[idx], &expected[idx], "{}", err_msg); + assert_eq!(&eq_groups[idx], &expected[idx], "{err_msg}"); } } Ok(()) @@ -1040,8 +1039,7 @@ mod tests { let actual = eq_group.exprs_equal(&left, &right); assert_eq!( actual, expected, - "{}: Failed comparing {:?} and {:?}, expected {}, got {}", - description, left, right, expected, actual + "{description}: Failed comparing {left:?} and {right:?}, expected {expected}, got {actual}" ); } diff --git a/datafusion/physical-expr/src/equivalence/mod.rs b/datafusion/physical-expr/src/equivalence/mod.rs index e94d2bad5712..ef98b4812265 100644 --- a/datafusion/physical-expr/src/equivalence/mod.rs +++ b/datafusion/physical-expr/src/equivalence/mod.rs @@ -97,8 +97,7 @@ mod tests { "ASC" => sort_expr.asc(), "DESC" => sort_expr.desc(), _ => panic!( - "unknown sort options. Expected 'ASC' or 'DESC', got {}", - options + "unknown sort options. Expected 'ASC' or 'DESC', got {options}" ), } } diff --git a/datafusion/physical-expr/src/equivalence/ordering.rs b/datafusion/physical-expr/src/equivalence/ordering.rs index 0efd46ad912e..819f8905bda5 100644 --- a/datafusion/physical-expr/src/equivalence/ordering.rs +++ b/datafusion/physical-expr/src/equivalence/ordering.rs @@ -329,10 +329,10 @@ impl Display for OrderingEquivalenceClass { write!(f, "[")?; let mut iter = self.orderings.iter(); if let Some(ordering) = iter.next() { - write!(f, "[{}]", ordering)?; + write!(f, "[{ordering}]")?; } for ordering in iter { - write!(f, ", [{}]", ordering)?; + write!(f, ", [{ordering}]")?; } write!(f, "]")?; Ok(()) @@ -684,8 +684,7 @@ mod tests { assert_eq!( eq_properties.ordering_satisfy(reqs.as_ref()), expected, - "{}", - err_msg + "{err_msg}" ); } @@ -739,13 +738,12 @@ mod tests { for (reqs, expected) in test_cases { let err_msg = - format!("error in test reqs: {:?}, expected: {:?}", reqs, expected,); + format!("error in test reqs: {reqs:?}, expected: {expected:?}",); let reqs = convert_to_sort_exprs(&reqs); assert_eq!( eq_properties.ordering_satisfy(reqs.as_ref()), expected, - "{}", - err_msg + "{err_msg}" ); } @@ -978,10 +976,9 @@ mod tests { let actual = OrderingEquivalenceClass::new(orderings.clone()); let actual = actual.orderings; let err_msg = format!( - "orderings: {:?}, expected: {:?}, actual :{:?}", - orderings, expected, actual + "orderings: {orderings:?}, expected: {expected:?}, actual :{actual:?}" ); - assert_eq!(actual.len(), expected.len(), "{}", err_msg); + assert_eq!(actual.len(), expected.len(), "{err_msg}"); for elem in actual { assert!(expected.contains(&elem), "{}", err_msg); } diff --git a/datafusion/physical-expr/src/equivalence/projection.rs b/datafusion/physical-expr/src/equivalence/projection.rs index a33339091c85..efbb50bc40e1 100644 --- a/datafusion/physical-expr/src/equivalence/projection.rs +++ b/datafusion/physical-expr/src/equivalence/projection.rs @@ -634,11 +634,10 @@ mod tests { let orderings = projected_eq.oeq_class(); let err_msg = format!( - "test_idx: {:?}, actual: {:?}, expected: {:?}, projection_mapping: {:?}", - idx, orderings, expected, projection_mapping + "test_idx: {idx:?}, actual: {orderings:?}, expected: {expected:?}, projection_mapping: {projection_mapping:?}" ); - assert_eq!(orderings.len(), expected.len(), "{}", err_msg); + assert_eq!(orderings.len(), expected.len(), "{err_msg}"); for expected_ordering in &expected { assert!(orderings.contains(expected_ordering), "{}", err_msg) } @@ -822,11 +821,10 @@ mod tests { let orderings = projected_eq.oeq_class(); let err_msg = format!( - "test idx: {:?}, actual: {:?}, expected: {:?}, projection_mapping: {:?}", - idx, orderings, expected, projection_mapping + "test idx: {idx:?}, actual: {orderings:?}, expected: {expected:?}, projection_mapping: {projection_mapping:?}" ); - assert_eq!(orderings.len(), expected.len(), "{}", err_msg); + assert_eq!(orderings.len(), expected.len(), "{err_msg}"); for expected_ordering in &expected { assert!(orderings.contains(expected_ordering), "{}", err_msg) } @@ -968,11 +966,10 @@ mod tests { let orderings = projected_eq.oeq_class(); let err_msg = format!( - "actual: {:?}, expected: {:?}, projection_mapping: {:?}", - orderings, expected, projection_mapping + "actual: {orderings:?}, expected: {expected:?}, projection_mapping: {projection_mapping:?}" ); - assert_eq!(orderings.len(), expected.len(), "{}", err_msg); + assert_eq!(orderings.len(), expected.len(), "{err_msg}"); for expected_ordering in &expected { assert!(orderings.contains(expected_ordering), "{}", err_msg) } diff --git a/datafusion/physical-expr/src/equivalence/properties/dependency.rs b/datafusion/physical-expr/src/equivalence/properties/dependency.rs index 9eba295e562e..fa52ae8686f7 100644 --- a/datafusion/physical-expr/src/equivalence/properties/dependency.rs +++ b/datafusion/physical-expr/src/equivalence/properties/dependency.rs @@ -39,10 +39,10 @@ impl Display for Dependencies { write!(f, "[")?; let mut iter = self.inner.iter(); if let Some(dep) = iter.next() { - write!(f, "{}", dep)?; + write!(f, "{dep}")?; } for dep in iter { - write!(f, ", {}", dep)?; + write!(f, ", {dep}")?; } write!(f, "]") } @@ -279,7 +279,7 @@ impl DependencyNode { impl Display for DependencyNode { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { if let Some(target) = &self.target_sort_expr { - write!(f, "(target: {}, ", target)?; + write!(f, "(target: {target}, ")?; } else { write!(f, "(")?; } @@ -764,7 +764,7 @@ mod tests { "expr:{:?}, expected: {:?}, actual: {:?}, leading_orderings: {leading_orderings:?}", expr, expected, expr_props.sort_properties ); - assert_eq!(expr_props.sort_properties, expected, "{}", err_msg); + assert_eq!(expr_props.sort_properties, expected, "{err_msg}"); } Ok(()) @@ -1224,7 +1224,7 @@ mod tests { "concat", concat(), vec![Arc::clone(&col_a), Arc::clone(&col_b)], - DataType::Utf8, + Field::new("f", DataType::Utf8, true).into(), )); // Assume existing ordering is [c ASC, a ASC, b ASC] @@ -1315,7 +1315,7 @@ mod tests { "concat", concat(), vec![Arc::clone(&col_a), Arc::clone(&col_b)], - DataType::Utf8, + Field::new("f", DataType::Utf8, true).into(), )); // Assume existing ordering is [concat(a, b) ASC, a ASC, b ASC] @@ -1735,9 +1735,7 @@ mod tests { for ordering in &satisfied_orderings { assert!( !eq_properties.ordering_satisfy(ordering), - "{}: ordering {:?} should not be satisfied before adding constraints", - name, - ordering + "{name}: ordering {ordering:?} should not be satisfied before adding constraints" ); } @@ -1752,9 +1750,7 @@ mod tests { for ordering in &satisfied_orderings { assert!( eq_properties.ordering_satisfy(ordering), - "{}: ordering {:?} should be satisfied after adding constraints", - name, - ordering + "{name}: ordering {ordering:?} should be satisfied after adding constraints" ); } @@ -1762,9 +1758,7 @@ mod tests { for ordering in &unsatisfied_orderings { assert!( !eq_properties.ordering_satisfy(ordering), - "{}: ordering {:?} should not be satisfied after adding constraints", - name, - ordering + "{name}: ordering {ordering:?} should not be satisfied after adding constraints" ); } } diff --git a/datafusion/physical-expr/src/equivalence/properties/joins.rs b/datafusion/physical-expr/src/equivalence/properties/joins.rs index 7944e89d0305..344cf54a57a8 100644 --- a/datafusion/physical-expr/src/equivalence/properties/joins.rs +++ b/datafusion/physical-expr/src/equivalence/properties/joins.rs @@ -218,13 +218,11 @@ mod tests { ); let err_msg = format!("expected: {:?}, actual:{:?}", expected, &join_eq.oeq_class); - assert_eq!(join_eq.oeq_class.len(), expected.len(), "{}", err_msg); + assert_eq!(join_eq.oeq_class.len(), expected.len(), "{err_msg}"); for ordering in join_eq.oeq_class { assert!( expected.contains(&ordering), - "{}, ordering: {:?}", - err_msg, - ordering + "{err_msg}, ordering: {ordering:?}" ); } } diff --git a/datafusion/physical-expr/src/equivalence/properties/mod.rs b/datafusion/physical-expr/src/equivalence/properties/mod.rs index 5b34a02a9142..8f6391bc0b5e 100644 --- a/datafusion/physical-expr/src/equivalence/properties/mod.rs +++ b/datafusion/physical-expr/src/equivalence/properties/mod.rs @@ -1364,7 +1364,7 @@ impl EquivalenceProperties { .transform_up(|expr| update_properties(expr, self)) .data() .map(|node| node.data) - .unwrap_or(ExprProperties::new_unknown()) + .unwrap_or_else(|_| ExprProperties::new_unknown()) } /// Transforms this `EquivalenceProperties` into a new `EquivalenceProperties` @@ -1600,7 +1600,7 @@ fn get_expr_properties( } else if let Some(literal) = expr.as_any().downcast_ref::() { Ok(ExprProperties { sort_properties: SortProperties::Singleton, - range: Interval::try_new(literal.value().clone(), literal.value().clone())?, + range: literal.value().into(), preserves_lex_ordering: true, }) } else { diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index 84374f4a2970..798e68a459ce 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -17,19 +17,20 @@ mod kernels; -use std::hash::Hash; -use std::{any::Any, sync::Arc}; - use crate::expressions::binary::kernels::concat_elements_utf8view; use crate::intervals::cp_solver::{propagate_arithmetic, propagate_comparison}; use crate::PhysicalExpr; +use std::hash::Hash; +use std::{any::Any, sync::Arc}; use arrow::array::*; use arrow::compute::kernels::boolean::{and_kleene, not, or_kleene}; use arrow::compute::kernels::cmp::*; use arrow::compute::kernels::comparison::{regexp_is_match, regexp_is_match_scalar}; use arrow::compute::kernels::concat_elements::concat_elements_utf8; -use arrow::compute::{cast, ilike, like, nilike, nlike}; +use arrow::compute::{ + cast, filter_record_batch, ilike, like, nilike, nlike, SlicesIterator, +}; use arrow::datatypes::*; use arrow::error::ArrowError; use datafusion_common::cast::as_boolean_array; @@ -358,11 +359,24 @@ impl PhysicalExpr for BinaryExpr { fn evaluate(&self, batch: &RecordBatch) -> Result { use arrow::compute::kernels::numeric::*; + // Evaluate left-hand side expression. let lhs = self.left.evaluate(batch)?; - // Optimize for short-circuiting `Operator::And` or `Operator::Or` operations and return early. - if check_short_circuit(&lhs, &self.op) { - return Ok(lhs); + // Check if we can apply short-circuit evaluation. + match check_short_circuit(&lhs, &self.op) { + ShortCircuitStrategy::None => {} + ShortCircuitStrategy::ReturnLeft => return Ok(lhs), + ShortCircuitStrategy::ReturnRight => { + let rhs = self.right.evaluate(batch)?; + return Ok(rhs); + } + ShortCircuitStrategy::PreSelection(selection) => { + // The function `evaluate_selection` was not called for filtering and calculation, + // as it takes into account cases where the selection contains null values. + let batch = filter_record_batch(batch, selection)?; + let right_ret = self.right.evaluate(&batch)?; + return pre_selection_scatter(selection, right_ret); + } } let rhs = self.right.evaluate(batch)?; @@ -405,23 +419,19 @@ impl PhysicalExpr for BinaryExpr { let result_type = self.data_type(input_schema)?; - // Attempt to use special kernels if one input is scalar and the other is an array - let scalar_result = match (&lhs, &rhs) { - (ColumnarValue::Array(array), ColumnarValue::Scalar(scalar)) => { - // if left is array and right is literal(not NULL) - use scalar operations - if scalar.is_null() { - None - } else { - self.evaluate_array_scalar(array, scalar.clone())?.map(|r| { - r.and_then(|a| to_result_type_array(&self.op, a, &result_type)) - }) + // If the left-hand side is an array and the right-hand side is a non-null scalar, try the optimized kernel. + if let (ColumnarValue::Array(array), ColumnarValue::Scalar(ref scalar)) = + (&lhs, &rhs) + { + if !scalar.is_null() { + if let Some(result_array) = + self.evaluate_array_scalar(array, scalar.clone())? + { + let final_array = result_array + .and_then(|a| to_result_type_array(&self.op, a, &result_type)); + return final_array.map(ColumnarValue::Array); } } - (_, _) => None, // default to array implementation - }; - - if let Some(result) = scalar_result { - return result.map(ColumnarValue::Array); } // if both arrays or both literals - extract arrays and continue execution @@ -506,7 +516,7 @@ impl PhysicalExpr for BinaryExpr { } } else if self.op.eq(&Operator::Or) { if interval.eq(&Interval::CERTAINLY_FALSE) { - // A certainly false logical conjunction can only derive from certainly + // A certainly false logical disjunction can only derive from certainly // false operands. Otherwise, we prove infeasibility. Ok((!left_interval.eq(&Interval::CERTAINLY_TRUE) && !right_interval.eq(&Interval::CERTAINLY_TRUE)) @@ -811,58 +821,199 @@ impl BinaryExpr { } } +enum ShortCircuitStrategy<'a> { + None, + ReturnLeft, + ReturnRight, + PreSelection(&'a BooleanArray), +} + +/// Based on the results calculated from the left side of the short-circuit operation, +/// if the proportion of `true` is less than 0.2 and the current operation is an `and`, +/// the `RecordBatch` will be filtered in advance. +const PRE_SELECTION_THRESHOLD: f32 = 0.2; + /// Checks if a logical operator (`AND`/`OR`) can short-circuit evaluation based on the left-hand side (lhs) result. /// -/// Short-circuiting occurs when evaluating the right-hand side (rhs) becomes unnecessary: -/// - For `AND`: if ALL values in `lhs` are `false`, the expression must be `false` regardless of rhs. -/// - For `OR`: if ALL values in `lhs` are `true`, the expression must be `true` regardless of rhs. -/// -/// Returns `true` if short-circuiting is possible, `false` otherwise. -/// +/// Short-circuiting occurs under these circumstances: +/// - For `AND`: +/// - if LHS is all false => short-circuit → return LHS +/// - if LHS is all true => short-circuit → return RHS +/// - if LHS is mixed and true_count/sum_count <= [`PRE_SELECTION_THRESHOLD`] -> pre-selection +/// - For `OR`: +/// - if LHS is all true => short-circuit → return LHS +/// - if LHS is all false => short-circuit → return RHS /// # Arguments -/// * `arg` - The left-hand side (lhs) columnar value (array or scalar) +/// * `lhs` - The left-hand side (lhs) columnar value (array or scalar) +/// * `lhs` - The left-hand side (lhs) columnar value (array or scalar) /// * `op` - The logical operator (`AND` or `OR`) /// /// # Implementation Notes /// 1. Only works with Boolean-typed arguments (other types automatically return `false`) /// 2. Handles both scalar values and array values -/// 3. For arrays, uses optimized `true_count()`/`false_count()` methods from arrow-rs. -/// `bool_or`/`bool_and` maybe a better choice too,for detailed discussion,see:[link](https://github.com/apache/datafusion/pull/15462#discussion_r2020558418) -fn check_short_circuit(arg: &ColumnarValue, op: &Operator) -> bool { - let data_type = arg.data_type(); - match (data_type, op) { - (DataType::Boolean, Operator::And) => { - match arg { - ColumnarValue::Array(array) => { - if let Ok(array) = as_boolean_array(&array) { - return array.false_count() == array.len(); - } +/// 3. For arrays, uses optimized bit counting techniques for boolean arrays +fn check_short_circuit<'a>( + lhs: &'a ColumnarValue, + op: &Operator, +) -> ShortCircuitStrategy<'a> { + // Quick reject for non-logical operators,and quick judgment when op is and + let is_and = match op { + Operator::And => true, + Operator::Or => false, + _ => return ShortCircuitStrategy::None, + }; + + // Non-boolean types can't be short-circuited + if lhs.data_type() != DataType::Boolean { + return ShortCircuitStrategy::None; + } + + match lhs { + ColumnarValue::Array(array) => { + // Fast path for arrays - try to downcast to boolean array + if let Ok(bool_array) = as_boolean_array(array) { + // Arrays with nulls can't be short-circuited + if bool_array.null_count() > 0 { + return ShortCircuitStrategy::None; } - ColumnarValue::Scalar(scalar) => { - if let ScalarValue::Boolean(Some(value)) = scalar { - return !value; + + let len = bool_array.len(); + if len == 0 { + return ShortCircuitStrategy::None; + } + + let true_count = bool_array.values().count_set_bits(); + if is_and { + // For AND, prioritize checking for all-false (short circuit case) + // Uses optimized false_count() method provided by Arrow + + // Short circuit if all values are false + if true_count == 0 { + return ShortCircuitStrategy::ReturnLeft; + } + + // If no false values, then all must be true + if true_count == len { + return ShortCircuitStrategy::ReturnRight; + } + + // determine if we can pre-selection + if true_count as f32 / len as f32 <= PRE_SELECTION_THRESHOLD { + return ShortCircuitStrategy::PreSelection(bool_array); + } + } else { + // For OR, prioritize checking for all-true (short circuit case) + // Uses optimized true_count() method provided by Arrow + + // Short circuit if all values are true + if true_count == len { + return ShortCircuitStrategy::ReturnLeft; + } + + // If no true values, then all must be false + if true_count == 0 { + return ShortCircuitStrategy::ReturnRight; } } } - false } - (DataType::Boolean, Operator::Or) => { - match arg { - ColumnarValue::Array(array) => { - if let Ok(array) = as_boolean_array(&array) { - return array.true_count() == array.len(); - } - } - ColumnarValue::Scalar(scalar) => { - if let ScalarValue::Boolean(Some(value)) = scalar { - return *value; - } + ColumnarValue::Scalar(scalar) => { + // Fast path for scalar values + if let ScalarValue::Boolean(Some(is_true)) = scalar { + // Return Left for: + // - AND with false value + // - OR with true value + if (is_and && !is_true) || (!is_and && *is_true) { + return ShortCircuitStrategy::ReturnLeft; + } else { + return ShortCircuitStrategy::ReturnRight; } } - false } - _ => false, } + + // If we can't short-circuit, indicate that normal evaluation should continue + ShortCircuitStrategy::None +} + +/// Creates a new boolean array based on the evaluation of the right expression, +/// but only for positions where the left_result is true. +/// +/// This function is used for short-circuit evaluation optimization of logical AND operations: +/// - When left_result has few true values, we only evaluate the right expression for those positions +/// - Values are copied from right_array where left_result is true +/// - All other positions are filled with false values +/// +/// # Parameters +/// - `left_result` Boolean array with selection mask (typically from left side of AND) +/// - `right_result` Result of evaluating right side of expression (only for selected positions) +/// +/// # Returns +/// A combined ColumnarValue with values from right_result where left_result is true +/// +/// # Example +/// Initial Data: { 1, 2, 3, 4, 5 } +/// Left Evaluation +/// (Condition: Equal to 2 or 3) +/// ↓ +/// Filtered Data: {2, 3} +/// Left Bitmap: { 0, 1, 1, 0, 0 } +/// ↓ +/// Right Evaluation +/// (Condition: Even numbers) +/// ↓ +/// Right Data: { 2 } +/// Right Bitmap: { 1, 0 } +/// ↓ +/// Combine Results +/// Final Bitmap: { 0, 1, 0, 0, 0 } +/// +/// # Note +/// Perhaps it would be better to modify `left_result` directly without creating a copy? +/// In practice, `left_result` should have only one owner, so making changes should be safe. +/// However, this is difficult to achieve under the immutable constraints of [`Arc`] and [`BooleanArray`]. +fn pre_selection_scatter( + left_result: &BooleanArray, + right_result: ColumnarValue, +) -> Result { + let right_boolean_array = match &right_result { + ColumnarValue::Array(array) => array.as_boolean(), + ColumnarValue::Scalar(_) => return Ok(right_result), + }; + + let result_len = left_result.len(); + + let mut result_array_builder = BooleanArray::builder(result_len); + + // keep track of current position we have in right boolean array + let mut right_array_pos = 0; + + // keep track of how much is filled + let mut last_end = 0; + SlicesIterator::new(left_result).for_each(|(start, end)| { + // the gap needs to be filled with false + if start > last_end { + result_array_builder.append_n(start - last_end, false); + } + + // copy values from right array for this slice + let len = end - start; + right_boolean_array + .slice(right_array_pos, len) + .iter() + .for_each(|v| result_array_builder.append_option(v)); + + right_array_pos += len; + last_end = end; + }); + + // Fill any remaining positions with false + if last_end < result_len { + result_array_builder.append_n(result_len - last_end, false); + } + let boolean_result = result_array_builder.finish(); + + Ok(ColumnarValue::Array(Arc::new(boolean_result))) } fn concat_elements(left: Arc, right: Arc) -> Result { @@ -919,10 +1070,14 @@ pub fn similar_to( mod tests { use super::*; use crate::expressions::{col, lit, try_cast, Column, Literal}; + use datafusion_expr::lit as expr_lit; use datafusion_common::plan_datafusion_err; use datafusion_physical_expr_common::physical_expr::fmt_sql; + use crate::planner::logical2physical; + use arrow::array::BooleanArray; + use datafusion_expr::col as logical_col; /// Performs a binary operation, applying any type coercion necessary fn binary_op( left: Arc, @@ -4895,9 +5050,7 @@ mod tests { #[test] fn test_check_short_circuit() { - use crate::planner::logical2physical; - use datafusion_expr::col as logical_col; - use datafusion_expr::lit; + // Test with non-nullable arrays let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Int32, false), Field::new("b", DataType::Int32, false), @@ -4911,20 +5064,339 @@ mod tests { .unwrap(); // op: AND left: all false - let left_expr = logical2physical(&logical_col("a").eq(lit(2)), &schema); + let left_expr = logical2physical(&logical_col("a").eq(expr_lit(2)), &schema); let left_value = left_expr.evaluate(&batch).unwrap(); - assert!(check_short_circuit(&left_value, &Operator::And)); + assert!(matches!( + check_short_circuit(&left_value, &Operator::And), + ShortCircuitStrategy::ReturnLeft + )); + // op: AND left: not all false - let left_expr = logical2physical(&logical_col("a").eq(lit(3)), &schema); + let left_expr = logical2physical(&logical_col("a").eq(expr_lit(3)), &schema); let left_value = left_expr.evaluate(&batch).unwrap(); - assert!(!check_short_circuit(&left_value, &Operator::And)); + let ColumnarValue::Array(array) = &left_value else { + panic!("Expected ColumnarValue::Array"); + }; + let ShortCircuitStrategy::PreSelection(value) = + check_short_circuit(&left_value, &Operator::And) + else { + panic!("Expected ShortCircuitStrategy::PreSelection"); + }; + let expected_boolean_arr: Vec<_> = + as_boolean_array(array).unwrap().iter().collect(); + let boolean_arr: Vec<_> = value.iter().collect(); + assert_eq!(expected_boolean_arr, boolean_arr); + // op: OR left: all true - let left_expr = logical2physical(&logical_col("a").gt(lit(0)), &schema); + let left_expr = logical2physical(&logical_col("a").gt(expr_lit(0)), &schema); let left_value = left_expr.evaluate(&batch).unwrap(); - assert!(check_short_circuit(&left_value, &Operator::Or)); + assert!(matches!( + check_short_circuit(&left_value, &Operator::Or), + ShortCircuitStrategy::ReturnLeft + )); + // op: OR left: not all true - let left_expr = logical2physical(&logical_col("a").gt(lit(2)), &schema); + let left_expr: Arc = + logical2physical(&logical_col("a").gt(expr_lit(2)), &schema); let left_value = left_expr.evaluate(&batch).unwrap(); - assert!(!check_short_circuit(&left_value, &Operator::Or)); + assert!(matches!( + check_short_circuit(&left_value, &Operator::Or), + ShortCircuitStrategy::None + )); + + // Test with nullable arrays and null values + let schema_nullable = Arc::new(Schema::new(vec![ + Field::new("c", DataType::Boolean, true), + Field::new("d", DataType::Boolean, true), + ])); + + // Create arrays with null values + let c_array = Arc::new(BooleanArray::from(vec![ + Some(true), + Some(false), + None, + Some(true), + None, + ])) as ArrayRef; + let d_array = Arc::new(BooleanArray::from(vec![ + Some(false), + Some(true), + Some(false), + None, + Some(true), + ])) as ArrayRef; + + let batch_nullable = RecordBatch::try_new( + Arc::clone(&schema_nullable), + vec![Arc::clone(&c_array), Arc::clone(&d_array)], + ) + .unwrap(); + + // Case: Mixed values with nulls - shouldn't short-circuit for AND + let mixed_nulls = logical2physical(&logical_col("c"), &schema_nullable); + let mixed_nulls_value = mixed_nulls.evaluate(&batch_nullable).unwrap(); + assert!(matches!( + check_short_circuit(&mixed_nulls_value, &Operator::And), + ShortCircuitStrategy::None + )); + + // Case: Mixed values with nulls - shouldn't short-circuit for OR + assert!(matches!( + check_short_circuit(&mixed_nulls_value, &Operator::Or), + ShortCircuitStrategy::None + )); + + // Test with all nulls + let all_nulls = Arc::new(BooleanArray::from(vec![None, None, None])) as ArrayRef; + let null_batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("e", DataType::Boolean, true)])), + vec![all_nulls], + ) + .unwrap(); + + let null_expr = logical2physical(&logical_col("e"), &null_batch.schema()); + let null_value = null_expr.evaluate(&null_batch).unwrap(); + + // All nulls shouldn't short-circuit for AND or OR + assert!(matches!( + check_short_circuit(&null_value, &Operator::And), + ShortCircuitStrategy::None + )); + assert!(matches!( + check_short_circuit(&null_value, &Operator::Or), + ShortCircuitStrategy::None + )); + + // Test with scalar values + // Scalar true + let scalar_true = ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))); + assert!(matches!( + check_short_circuit(&scalar_true, &Operator::Or), + ShortCircuitStrategy::ReturnLeft + )); // Should short-circuit OR + assert!(matches!( + check_short_circuit(&scalar_true, &Operator::And), + ShortCircuitStrategy::ReturnRight + )); // Should return the RHS for AND + + // Scalar false + let scalar_false = ColumnarValue::Scalar(ScalarValue::Boolean(Some(false))); + assert!(matches!( + check_short_circuit(&scalar_false, &Operator::And), + ShortCircuitStrategy::ReturnLeft + )); // Should short-circuit AND + assert!(matches!( + check_short_circuit(&scalar_false, &Operator::Or), + ShortCircuitStrategy::ReturnRight + )); // Should return the RHS for OR + + // Scalar null + let scalar_null = ColumnarValue::Scalar(ScalarValue::Boolean(None)); + assert!(matches!( + check_short_circuit(&scalar_null, &Operator::And), + ShortCircuitStrategy::None + )); + assert!(matches!( + check_short_circuit(&scalar_null, &Operator::Or), + ShortCircuitStrategy::None + )); + } + + /// Test for [pre_selection_scatter] + /// Since [check_short_circuit] ensures that the left side does not contain null and is neither all_true nor all_false, as well as not being empty, + /// the following tests have been designed: + /// 1. Test sparse left with interleaved true/false + /// 2. Test multiple consecutive true blocks + /// 3. Test multiple consecutive true blocks + /// 4. Test single true at first position + /// 5. Test single true at last position + /// 6. Test nulls in right array + /// 7. Test scalar right handling + #[test] + fn test_pre_selection_scatter() { + fn create_bool_array(bools: Vec) -> BooleanArray { + BooleanArray::from(bools.into_iter().map(Some).collect::>()) + } + // Test sparse left with interleaved true/false + { + // Left: [T, F, T, F, T] + // Right: [F, T, F] (values for 3 true positions) + let left = create_bool_array(vec![true, false, true, false, true]); + let right = ColumnarValue::Array(Arc::new(create_bool_array(vec![ + false, true, false, + ]))); + + let result = pre_selection_scatter(&left, right).unwrap(); + let result_arr = result.into_array(left.len()).unwrap(); + + let expected = create_bool_array(vec![false, false, true, false, false]); + assert_eq!(&expected, result_arr.as_boolean()); + } + // Test multiple consecutive true blocks + { + // Left: [F, T, T, F, T, T, T] + // Right: [T, F, F, T, F] + let left = + create_bool_array(vec![false, true, true, false, true, true, true]); + let right = ColumnarValue::Array(Arc::new(create_bool_array(vec![ + true, false, false, true, false, + ]))); + + let result = pre_selection_scatter(&left, right).unwrap(); + let result_arr = result.into_array(left.len()).unwrap(); + + let expected = + create_bool_array(vec![false, true, false, false, false, true, false]); + assert_eq!(&expected, result_arr.as_boolean()); + } + // Test single true at first position + { + // Left: [T, F, F] + // Right: [F] + let left = create_bool_array(vec![true, false, false]); + let right = ColumnarValue::Array(Arc::new(create_bool_array(vec![false]))); + + let result = pre_selection_scatter(&left, right).unwrap(); + let result_arr = result.into_array(left.len()).unwrap(); + + let expected = create_bool_array(vec![false, false, false]); + assert_eq!(&expected, result_arr.as_boolean()); + } + // Test single true at last position + { + // Left: [F, F, T] + // Right: [F] + let left = create_bool_array(vec![false, false, true]); + let right = ColumnarValue::Array(Arc::new(create_bool_array(vec![false]))); + + let result = pre_selection_scatter(&left, right).unwrap(); + let result_arr = result.into_array(left.len()).unwrap(); + + let expected = create_bool_array(vec![false, false, false]); + assert_eq!(&expected, result_arr.as_boolean()); + } + // Test nulls in right array + { + // Left: [F, T, F, T] + // Right: [None, Some(false)] (with null at first position) + let left = create_bool_array(vec![false, true, false, true]); + let right_arr = BooleanArray::from(vec![None, Some(false)]); + let right = ColumnarValue::Array(Arc::new(right_arr)); + + let result = pre_selection_scatter(&left, right).unwrap(); + let result_arr = result.into_array(left.len()).unwrap(); + + let expected = BooleanArray::from(vec![ + Some(false), + None, // null from right + Some(false), + Some(false), + ]); + assert_eq!(&expected, result_arr.as_boolean()); + } + // Test scalar right handling + { + // Left: [T, F, T] + // Right: Scalar true + let left = create_bool_array(vec![true, false, true]); + let right = ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))); + + let result = pre_selection_scatter(&left, right).unwrap(); + assert!(matches!(result, ColumnarValue::Scalar(_))); + } + } + + #[test] + fn test_evaluate_bounds_int32() { + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ]); + + let a = Arc::new(Column::new("a", 0)) as _; + let b = Arc::new(Column::new("b", 1)) as _; + + // Test addition bounds + let add_expr = + binary_expr(Arc::clone(&a), Operator::Plus, Arc::clone(&b), &schema).unwrap(); + let add_bounds = add_expr + .evaluate_bounds(&[ + &Interval::make(Some(1), Some(10)).unwrap(), + &Interval::make(Some(5), Some(15)).unwrap(), + ]) + .unwrap(); + assert_eq!(add_bounds, Interval::make(Some(6), Some(25)).unwrap()); + + // Test subtraction bounds + let sub_expr = + binary_expr(Arc::clone(&a), Operator::Minus, Arc::clone(&b), &schema) + .unwrap(); + let sub_bounds = sub_expr + .evaluate_bounds(&[ + &Interval::make(Some(1), Some(10)).unwrap(), + &Interval::make(Some(5), Some(15)).unwrap(), + ]) + .unwrap(); + assert_eq!(sub_bounds, Interval::make(Some(-14), Some(5)).unwrap()); + + // Test multiplication bounds + let mul_expr = + binary_expr(Arc::clone(&a), Operator::Multiply, Arc::clone(&b), &schema) + .unwrap(); + let mul_bounds = mul_expr + .evaluate_bounds(&[ + &Interval::make(Some(1), Some(10)).unwrap(), + &Interval::make(Some(5), Some(15)).unwrap(), + ]) + .unwrap(); + assert_eq!(mul_bounds, Interval::make(Some(5), Some(150)).unwrap()); + + // Test division bounds + let div_expr = + binary_expr(Arc::clone(&a), Operator::Divide, Arc::clone(&b), &schema) + .unwrap(); + let div_bounds = div_expr + .evaluate_bounds(&[ + &Interval::make(Some(10), Some(20)).unwrap(), + &Interval::make(Some(2), Some(5)).unwrap(), + ]) + .unwrap(); + assert_eq!(div_bounds, Interval::make(Some(2), Some(10)).unwrap()); + } + + #[test] + fn test_evaluate_bounds_bool() { + let schema = Schema::new(vec![ + Field::new("a", DataType::Boolean, false), + Field::new("b", DataType::Boolean, false), + ]); + + let a = Arc::new(Column::new("a", 0)) as _; + let b = Arc::new(Column::new("b", 1)) as _; + + // Test OR bounds + let or_expr = + binary_expr(Arc::clone(&a), Operator::Or, Arc::clone(&b), &schema).unwrap(); + let or_bounds = or_expr + .evaluate_bounds(&[ + &Interval::make(Some(true), Some(true)).unwrap(), + &Interval::make(Some(false), Some(false)).unwrap(), + ]) + .unwrap(); + assert_eq!(or_bounds, Interval::make(Some(true), Some(true)).unwrap()); + + // Test AND bounds + let and_expr = + binary_expr(Arc::clone(&a), Operator::And, Arc::clone(&b), &schema).unwrap(); + let and_bounds = and_expr + .evaluate_bounds(&[ + &Interval::make(Some(true), Some(true)).unwrap(), + &Interval::make(Some(false), Some(false)).unwrap(), + ]) + .unwrap(); + assert_eq!( + and_bounds, + Interval::make(Some(false), Some(false)).unwrap() + ); } } diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 854c715eb0a2..1a74e78f1075 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -15,13 +15,12 @@ // specific language governing permissions and limitations // under the License. +use crate::expressions::try_cast; +use crate::PhysicalExpr; use std::borrow::Cow; use std::hash::Hash; use std::{any::Any, sync::Arc}; -use crate::expressions::try_cast; -use crate::PhysicalExpr; - use arrow::array::*; use arrow::compute::kernels::zip::zip; use arrow::compute::{and, and_not, is_null, not, nullif, or, prep_null_mask_filter}; @@ -603,7 +602,7 @@ mod tests { use crate::expressions::{binary, cast, col, lit, BinaryExpr}; use arrow::buffer::Buffer; use arrow::datatypes::DataType::Float64; - use arrow::datatypes::*; + use arrow::datatypes::Field; use datafusion_common::cast::{as_float64_array, as_int32_array}; use datafusion_common::plan_err; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; diff --git a/datafusion/physical-expr/src/expressions/cast.rs b/datafusion/physical-expr/src/expressions/cast.rs index a6766687a881..7e345e60271f 100644 --- a/datafusion/physical-expr/src/expressions/cast.rs +++ b/datafusion/physical-expr/src/expressions/cast.rs @@ -23,7 +23,7 @@ use std::sync::Arc; use crate::physical_expr::PhysicalExpr; use arrow::compute::{can_cast_types, CastOptions}; -use arrow::datatypes::{DataType, DataType::*, Schema}; +use arrow::datatypes::{DataType, DataType::*, FieldRef, Schema}; use arrow::record_batch::RecordBatch; use datafusion_common::format::DEFAULT_FORMAT_OPTIONS; use datafusion_common::{not_impl_err, Result}; @@ -144,6 +144,16 @@ impl PhysicalExpr for CastExpr { value.cast_to(&self.cast_type, Some(&self.cast_options)) } + fn return_field(&self, input_schema: &Schema) -> Result { + Ok(self + .expr + .return_field(input_schema)? + .as_ref() + .clone() + .with_data_type(self.cast_type.clone()) + .into()) + } + fn children(&self) -> Vec<&Arc> { vec![&self.expr] } diff --git a/datafusion/physical-expr/src/expressions/column.rs b/datafusion/physical-expr/src/expressions/column.rs index ab5b35984753..5a11783a87e9 100644 --- a/datafusion/physical-expr/src/expressions/column.rs +++ b/datafusion/physical-expr/src/expressions/column.rs @@ -22,6 +22,7 @@ use std::hash::Hash; use std::sync::Arc; use crate::physical_expr::PhysicalExpr; +use arrow::datatypes::FieldRef; use arrow::{ datatypes::{DataType, Schema, SchemaRef}, record_batch::RecordBatch, @@ -127,6 +128,10 @@ impl PhysicalExpr for Column { Ok(ColumnarValue::Array(Arc::clone(batch.column(self.index)))) } + fn return_field(&self, input_schema: &Schema) -> Result { + Ok(input_schema.field(self.index).clone().into()) + } + fn children(&self) -> Vec<&Arc> { vec![] } diff --git a/datafusion/physical-expr/src/expressions/dynamic_filters.rs b/datafusion/physical-expr/src/expressions/dynamic_filters.rs index c0a3285f0e78..756fb638af2b 100644 --- a/datafusion/physical-expr/src/expressions/dynamic_filters.rs +++ b/datafusion/physical-expr/src/expressions/dynamic_filters.rs @@ -75,7 +75,7 @@ impl Eq for DynamicFilterPhysicalExpr {} impl Display for DynamicFilterPhysicalExpr { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let inner = self.current().expect("Failed to get current expression"); - write!(f, "DynamicFilterPhysicalExpr [ {} ]", inner) + write!(f, "DynamicFilterPhysicalExpr [ {inner} ]") } } @@ -342,7 +342,7 @@ mod test { ) .unwrap(); let snap = dynamic_filter_1.snapshot().unwrap().unwrap(); - insta::assert_snapshot!(format!("{snap:?}"), @r#"BinaryExpr { left: Column { name: "a", index: 0 }, op: Eq, right: Literal { value: Int32(42) }, fail_on_overflow: false }"#); + insta::assert_snapshot!(format!("{snap:?}"), @r#"BinaryExpr { left: Column { name: "a", index: 0 }, op: Eq, right: Literal { value: Int32(42), field: Field { name: "42", data_type: Int32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, fail_on_overflow: false }"#); let dynamic_filter_2 = reassign_predicate_columns( Arc::clone(&dynamic_filter) as Arc, &filter_schema_2, @@ -350,7 +350,7 @@ mod test { ) .unwrap(); let snap = dynamic_filter_2.snapshot().unwrap().unwrap(); - insta::assert_snapshot!(format!("{snap:?}"), @r#"BinaryExpr { left: Column { name: "a", index: 1 }, op: Eq, right: Literal { value: Int32(42) }, fail_on_overflow: false }"#); + insta::assert_snapshot!(format!("{snap:?}"), @r#"BinaryExpr { left: Column { name: "a", index: 1 }, op: Eq, right: Literal { value: Int32(42), field: Field { name: "42", data_type: Int32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, fail_on_overflow: false }"#); // Both filters allow evaluating the same expression let batch_1 = RecordBatch::try_new( Arc::clone(&filter_schema_1), diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index 469f7bbee317..a1a14b2f30ff 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -1451,7 +1451,7 @@ mod tests { let sql_string = fmt_sql(expr.as_ref()).to_string(); let display_string = expr.to_string(); assert_eq!(sql_string, "a IN (a, b)"); - assert_eq!(display_string, "Use a@0 IN (SET) ([Literal { value: Utf8(\"a\") }, Literal { value: Utf8(\"b\") }])"); + assert_eq!(display_string, "Use a@0 IN (SET) ([Literal { value: Utf8(\"a\"), field: Field { name: \"a\", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8(\"b\"), field: Field { name: \"b\", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }])"); // Test: a NOT IN ('a', 'b') let list = vec![lit("a"), lit("b")]; @@ -1459,7 +1459,7 @@ mod tests { let sql_string = fmt_sql(expr.as_ref()).to_string(); let display_string = expr.to_string(); assert_eq!(sql_string, "a NOT IN (a, b)"); - assert_eq!(display_string, "a@0 NOT IN (SET) ([Literal { value: Utf8(\"a\") }, Literal { value: Utf8(\"b\") }])"); + assert_eq!(display_string, "a@0 NOT IN (SET) ([Literal { value: Utf8(\"a\"), field: Field { name: \"a\", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8(\"b\"), field: Field { name: \"b\", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }])"); // Test: a IN ('a', 'b', NULL) let list = vec![lit("a"), lit("b"), lit(ScalarValue::Utf8(None))]; @@ -1467,7 +1467,7 @@ mod tests { let sql_string = fmt_sql(expr.as_ref()).to_string(); let display_string = expr.to_string(); assert_eq!(sql_string, "a IN (a, b, NULL)"); - assert_eq!(display_string, "Use a@0 IN (SET) ([Literal { value: Utf8(\"a\") }, Literal { value: Utf8(\"b\") }, Literal { value: Utf8(NULL) }])"); + assert_eq!(display_string, "Use a@0 IN (SET) ([Literal { value: Utf8(\"a\"), field: Field { name: \"a\", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8(\"b\"), field: Field { name: \"b\", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8(NULL), field: Field { name: \"NULL\", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} } }])"); // Test: a NOT IN ('a', 'b', NULL) let list = vec![lit("a"), lit("b"), lit(ScalarValue::Utf8(None))]; @@ -1475,7 +1475,7 @@ mod tests { let sql_string = fmt_sql(expr.as_ref()).to_string(); let display_string = expr.to_string(); assert_eq!(sql_string, "a NOT IN (a, b, NULL)"); - assert_eq!(display_string, "a@0 NOT IN (SET) ([Literal { value: Utf8(\"a\") }, Literal { value: Utf8(\"b\") }, Literal { value: Utf8(NULL) }])"); + assert_eq!(display_string, "a@0 NOT IN (SET) ([Literal { value: Utf8(\"a\"), field: Field { name: \"a\", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8(\"b\"), field: Field { name: \"b\", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8(NULL), field: Field { name: \"NULL\", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} } }])"); Ok(()) } diff --git a/datafusion/physical-expr/src/expressions/is_not_null.rs b/datafusion/physical-expr/src/expressions/is_not_null.rs index 0619e7248858..ff05dab40126 100644 --- a/datafusion/physical-expr/src/expressions/is_not_null.rs +++ b/datafusion/physical-expr/src/expressions/is_not_null.rs @@ -17,10 +17,8 @@ //! IS NOT NULL expression -use std::hash::Hash; -use std::{any::Any, sync::Arc}; - use crate::PhysicalExpr; +use arrow::datatypes::FieldRef; use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, @@ -28,6 +26,8 @@ use arrow::{ use datafusion_common::Result; use datafusion_common::ScalarValue; use datafusion_expr::ColumnarValue; +use std::hash::Hash; +use std::{any::Any, sync::Arc}; /// IS NOT NULL expression #[derive(Debug, Eq)] @@ -94,6 +94,10 @@ impl PhysicalExpr for IsNotNullExpr { } } + fn return_field(&self, input_schema: &Schema) -> Result { + self.arg.return_field(input_schema) + } + fn children(&self) -> Vec<&Arc> { vec![&self.arg] } diff --git a/datafusion/physical-expr/src/expressions/is_null.rs b/datafusion/physical-expr/src/expressions/is_null.rs index 4c6081f35cad..15c7c645bda0 100644 --- a/datafusion/physical-expr/src/expressions/is_null.rs +++ b/datafusion/physical-expr/src/expressions/is_null.rs @@ -17,10 +17,8 @@ //! IS NULL expression -use std::hash::Hash; -use std::{any::Any, sync::Arc}; - use crate::PhysicalExpr; +use arrow::datatypes::FieldRef; use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, @@ -28,6 +26,8 @@ use arrow::{ use datafusion_common::Result; use datafusion_common::ScalarValue; use datafusion_expr::ColumnarValue; +use std::hash::Hash; +use std::{any::Any, sync::Arc}; /// IS NULL expression #[derive(Debug, Eq)] @@ -93,6 +93,10 @@ impl PhysicalExpr for IsNullExpr { } } + fn return_field(&self, input_schema: &Schema) -> Result { + self.arg.return_field(input_schema) + } + fn children(&self) -> Vec<&Arc> { vec![&self.arg] } diff --git a/datafusion/physical-expr/src/expressions/like.rs b/datafusion/physical-expr/src/expressions/like.rs index ebf9882665ba..e86c778d5161 100644 --- a/datafusion/physical-expr/src/expressions/like.rs +++ b/datafusion/physical-expr/src/expressions/like.rs @@ -15,15 +15,14 @@ // specific language governing permissions and limitations // under the License. -use std::hash::Hash; -use std::{any::Any, sync::Arc}; - use crate::PhysicalExpr; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; use datafusion_common::{internal_err, Result}; use datafusion_expr::ColumnarValue; use datafusion_physical_expr_common::datum::apply_cmp; +use std::hash::Hash; +use std::{any::Any, sync::Arc}; // Like expression #[derive(Debug, Eq)] diff --git a/datafusion/physical-expr/src/expressions/literal.rs b/datafusion/physical-expr/src/expressions/literal.rs index 0d0c0ecc62c7..0d4d62ef4719 100644 --- a/datafusion/physical-expr/src/expressions/literal.rs +++ b/datafusion/physical-expr/src/expressions/literal.rs @@ -18,11 +18,13 @@ //! Literal expressions for physical operations use std::any::Any; +use std::collections::HashMap; use std::hash::Hash; use std::sync::Arc; use crate::physical_expr::PhysicalExpr; +use arrow::datatypes::{Field, FieldRef}; use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, @@ -34,15 +36,48 @@ use datafusion_expr_common::interval_arithmetic::Interval; use datafusion_expr_common::sort_properties::{ExprProperties, SortProperties}; /// Represents a literal value -#[derive(Debug, PartialEq, Eq, Hash)] +#[derive(Debug, PartialEq, Eq)] pub struct Literal { value: ScalarValue, + field: FieldRef, +} + +impl Hash for Literal { + fn hash(&self, state: &mut H) { + self.value.hash(state); + let metadata = self.field.metadata(); + let mut keys = metadata.keys().collect::>(); + keys.sort(); + for key in keys { + key.hash(state); + metadata.get(key).unwrap().hash(state); + } + } } impl Literal { /// Create a literal value expression pub fn new(value: ScalarValue) -> Self { - Self { value } + Self::new_with_metadata(value, None) + } + + /// Create a literal value expression + pub fn new_with_metadata( + value: ScalarValue, + metadata: impl Into>>, + ) -> Self { + let metadata = metadata.into(); + let mut field = + Field::new(format!("{value}"), value.data_type(), value.is_null()); + + if let Some(metadata) = metadata { + field = field.with_metadata(metadata); + } + + Self { + value, + field: field.into(), + } } /// Get the scalar value @@ -71,6 +106,10 @@ impl PhysicalExpr for Literal { Ok(self.value.is_null()) } + fn return_field(&self, _input_schema: &Schema) -> Result { + Ok(Arc::clone(&self.field)) + } + fn evaluate(&self, _batch: &RecordBatch) -> Result { Ok(ColumnarValue::Scalar(self.value.clone())) } @@ -102,7 +141,7 @@ impl PhysicalExpr for Literal { /// Create a literal expression pub fn lit(value: T) -> Arc { match value.lit() { - Expr::Literal(v) => Arc::new(Literal::new(v)), + Expr::Literal(v, _) => Arc::new(Literal::new(v)), _ => unreachable!(), } } @@ -112,7 +151,7 @@ mod tests { use super::*; use arrow::array::Int32Array; - use arrow::datatypes::*; + use arrow::datatypes::Field; use datafusion_common::cast::as_int32_array; use datafusion_physical_expr_common::physical_expr::fmt_sql; diff --git a/datafusion/physical-expr/src/expressions/negative.rs b/datafusion/physical-expr/src/expressions/negative.rs index 33a1bae14d42..fa7224768a77 100644 --- a/datafusion/physical-expr/src/expressions/negative.rs +++ b/datafusion/physical-expr/src/expressions/negative.rs @@ -23,6 +23,7 @@ use std::sync::Arc; use crate::PhysicalExpr; +use arrow::datatypes::FieldRef; use arrow::{ compute::kernels::numeric::neg_wrapping, datatypes::{DataType, Schema}, @@ -103,6 +104,10 @@ impl PhysicalExpr for NegativeExpr { } } + fn return_field(&self, input_schema: &Schema) -> Result { + self.arg.return_field(input_schema) + } + fn children(&self) -> Vec<&Arc> { vec![&self.arg] } diff --git a/datafusion/physical-expr/src/expressions/no_op.rs b/datafusion/physical-expr/src/expressions/no_op.rs index 24d2f4d9e074..94610996c6b0 100644 --- a/datafusion/physical-expr/src/expressions/no_op.rs +++ b/datafusion/physical-expr/src/expressions/no_op.rs @@ -21,12 +21,11 @@ use std::any::Any; use std::hash::Hash; use std::sync::Arc; +use crate::PhysicalExpr; use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, }; - -use crate::PhysicalExpr; use datafusion_common::{internal_err, Result}; use datafusion_expr::ColumnarValue; diff --git a/datafusion/physical-expr/src/expressions/not.rs b/datafusion/physical-expr/src/expressions/not.rs index 8a3348b43d20..8184ef601e54 100644 --- a/datafusion/physical-expr/src/expressions/not.rs +++ b/datafusion/physical-expr/src/expressions/not.rs @@ -24,7 +24,7 @@ use std::sync::Arc; use crate::PhysicalExpr; -use arrow::datatypes::{DataType, Schema}; +use arrow::datatypes::{DataType, FieldRef, Schema}; use arrow::record_batch::RecordBatch; use datafusion_common::{cast::as_boolean_array, internal_err, Result, ScalarValue}; use datafusion_expr::interval_arithmetic::Interval; @@ -101,6 +101,10 @@ impl PhysicalExpr for NotExpr { } } + fn return_field(&self, input_schema: &Schema) -> Result { + self.arg.return_field(input_schema) + } + fn children(&self) -> Vec<&Arc> { vec![&self.arg] } diff --git a/datafusion/physical-expr/src/expressions/try_cast.rs b/datafusion/physical-expr/src/expressions/try_cast.rs index e49815cd8b64..b593dfe83209 100644 --- a/datafusion/physical-expr/src/expressions/try_cast.rs +++ b/datafusion/physical-expr/src/expressions/try_cast.rs @@ -23,7 +23,7 @@ use std::sync::Arc; use crate::PhysicalExpr; use arrow::compute; use arrow::compute::{cast_with_options, CastOptions}; -use arrow::datatypes::{DataType, Schema}; +use arrow::datatypes::{DataType, FieldRef, Schema}; use arrow::record_batch::RecordBatch; use compute::can_cast_types; use datafusion_common::format::DEFAULT_FORMAT_OPTIONS; @@ -110,6 +110,13 @@ impl PhysicalExpr for TryCastExpr { } } + fn return_field(&self, input_schema: &Schema) -> Result { + self.expr + .return_field(input_schema) + .map(|f| f.as_ref().clone().with_data_type(self.cast_type.clone())) + .map(Arc::new) + } + fn children(&self) -> Vec<&Arc> { vec![&self.expr] } diff --git a/datafusion/physical-expr/src/intervals/cp_solver.rs b/datafusion/physical-expr/src/intervals/cp_solver.rs index a53814c3ad2b..28f76bbfd1c8 100644 --- a/datafusion/physical-expr/src/intervals/cp_solver.rs +++ b/datafusion/physical-expr/src/intervals/cp_solver.rs @@ -148,12 +148,12 @@ use std::sync::Arc; use super::utils::{ convert_duration_type_to_interval, convert_interval_type_to_duration, get_inverse_op, }; -use crate::expressions::Literal; +use crate::expressions::{BinaryExpr, Literal}; use crate::utils::{build_dag, ExprTreeNode}; use crate::PhysicalExpr; use arrow::datatypes::{DataType, Schema}; -use datafusion_common::{internal_err, Result}; +use datafusion_common::{internal_err, not_impl_err, Result}; use datafusion_expr::interval_arithmetic::{apply_operator, satisfy_greater, Interval}; use datafusion_expr::Operator; @@ -645,6 +645,17 @@ impl ExprIntervalGraph { .map(|child| self.graph[*child].interval()) .collect::>(); let node_interval = self.graph[node].interval(); + // Special case: true OR could in principle be propagated by 3 interval sets, + // (i.e. left true, or right true, or both true) however we do not support this yet. + if node_interval == &Interval::CERTAINLY_TRUE + && self.graph[node] + .expr + .as_any() + .downcast_ref::() + .is_some_and(|expr| expr.op() == &Operator::Or) + { + return not_impl_err!("OR operator cannot yet propagate true intervals"); + } let propagated_intervals = self.graph[node] .expr .propagate_constraints(node_interval, &children_intervals)?; @@ -857,8 +868,8 @@ mod tests { let mut r = StdRng::seed_from_u64(seed); let (left_given, right_given, left_expected, right_expected) = if ASC { - let left = r.gen_range((0 as $TYPE)..(1000 as $TYPE)); - let right = r.gen_range((0 as $TYPE)..(1000 as $TYPE)); + let left = r.random_range((0 as $TYPE)..(1000 as $TYPE)); + let right = r.random_range((0 as $TYPE)..(1000 as $TYPE)); ( (Some(left), None), (Some(right), None), @@ -866,8 +877,8 @@ mod tests { (Some(<$TYPE>::max(right, left + expr_right)), None), ) } else { - let left = r.gen_range((0 as $TYPE)..(1000 as $TYPE)); - let right = r.gen_range((0 as $TYPE)..(1000 as $TYPE)); + let left = r.random_range((0 as $TYPE)..(1000 as $TYPE)); + let right = r.random_range((0 as $TYPE)..(1000 as $TYPE)); ( (None, Some(left)), (None, Some(right)), diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 8660bff796d5..6f1417ec23bf 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use std::collections::HashMap; use std::sync::Arc; use crate::ScalarFunctionExpr; @@ -111,14 +112,42 @@ pub fn create_physical_expr( let input_schema: &Schema = &input_dfschema.into(); match e { - Expr::Alias(Alias { expr, .. }) => { - Ok(create_physical_expr(expr, input_dfschema, execution_props)?) + Expr::Alias(Alias { expr, metadata, .. }) => { + if let Expr::Literal(v, prior_metadata) = expr.as_ref() { + let mut new_metadata = prior_metadata + .as_ref() + .map(|m| { + m.iter() + .map(|(k, v)| (k.clone(), v.clone())) + .collect::>() + }) + .unwrap_or_default(); + if let Some(metadata) = metadata { + new_metadata.extend(metadata.clone()); + } + let new_metadata = match new_metadata.is_empty() { + true => None, + false => Some(new_metadata), + }; + + Ok(Arc::new(Literal::new_with_metadata( + v.clone(), + new_metadata, + ))) + } else { + Ok(create_physical_expr(expr, input_dfschema, execution_props)?) + } } Expr::Column(c) => { let idx = input_dfschema.index_of_column(c)?; Ok(Arc::new(Column::new(&c.name, idx))) } - Expr::Literal(value) => Ok(Arc::new(Literal::new(value.clone()))), + Expr::Literal(value, metadata) => Ok(Arc::new(Literal::new_with_metadata( + value.clone(), + metadata + .as_ref() + .map(|m| m.iter().map(|(k, v)| (k.clone(), v.clone())).collect()), + ))), Expr::ScalarVariable(_, variable_names) => { if is_system_variables(variable_names) { match execution_props.get_var_provider(VarType::System) { @@ -168,7 +197,7 @@ pub fn create_physical_expr( let binary_op = binary_expr( expr.as_ref().clone(), Operator::IsNotDistinctFrom, - Expr::Literal(ScalarValue::Boolean(None)), + Expr::Literal(ScalarValue::Boolean(None), None), ); create_physical_expr(&binary_op, input_dfschema, execution_props) } @@ -176,7 +205,7 @@ pub fn create_physical_expr( let binary_op = binary_expr( expr.as_ref().clone(), Operator::IsDistinctFrom, - Expr::Literal(ScalarValue::Boolean(None)), + Expr::Literal(ScalarValue::Boolean(None), None), ); create_physical_expr(&binary_op, input_dfschema, execution_props) } @@ -347,7 +376,7 @@ pub fn create_physical_expr( list, negated, }) => match expr.as_ref() { - Expr::Literal(ScalarValue::Utf8(None)) => { + Expr::Literal(ScalarValue::Utf8(None), _) => { Ok(expressions::lit(ScalarValue::Boolean(None))) } _ => { diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index 44bbcc4928c6..d014bbb74caa 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -38,13 +38,13 @@ use crate::expressions::Literal; use crate::PhysicalExpr; use arrow::array::{Array, RecordBatch}; -use arrow::datatypes::{DataType, Schema}; +use arrow::datatypes::{DataType, FieldRef, Schema}; use datafusion_common::{internal_err, Result, ScalarValue}; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::ExprProperties; use datafusion_expr::type_coercion::functions::data_types_with_scalar_udf; use datafusion_expr::{ - expr_vec_fmt, ColumnarValue, ReturnTypeArgs, ScalarFunctionArgs, ScalarUDF, + expr_vec_fmt, ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, }; /// Physical expression of a scalar function @@ -53,8 +53,7 @@ pub struct ScalarFunctionExpr { fun: Arc, name: String, args: Vec>, - return_type: DataType, - nullable: bool, + return_field: FieldRef, } impl Debug for ScalarFunctionExpr { @@ -63,7 +62,7 @@ impl Debug for ScalarFunctionExpr { .field("fun", &"") .field("name", &self.name) .field("args", &self.args) - .field("return_type", &self.return_type) + .field("return_field", &self.return_field) .finish() } } @@ -74,14 +73,13 @@ impl ScalarFunctionExpr { name: &str, fun: Arc, args: Vec>, - return_type: DataType, + return_field: FieldRef, ) -> Self { Self { fun, name: name.to_owned(), args, - return_type, - nullable: true, + return_field, } } @@ -92,18 +90,17 @@ impl ScalarFunctionExpr { schema: &Schema, ) -> Result { let name = fun.name().to_string(); - let arg_types = args + let arg_fields = args .iter() - .map(|e| e.data_type(schema)) + .map(|e| e.return_field(schema)) .collect::>>()?; // verify that input data types is consistent with function's `TypeSignature` - data_types_with_scalar_udf(&arg_types, &fun)?; - - let nullables = args + let arg_types = arg_fields .iter() - .map(|e| e.nullable(schema)) - .collect::>>()?; + .map(|f| f.data_type().clone()) + .collect::>(); + data_types_with_scalar_udf(&arg_types, &fun)?; let arguments = args .iter() @@ -113,18 +110,16 @@ impl ScalarFunctionExpr { .map(|literal| literal.value()) }) .collect::>(); - let ret_args = ReturnTypeArgs { - arg_types: &arg_types, + let ret_args = ReturnFieldArgs { + arg_fields: &arg_fields, scalar_arguments: &arguments, - nullables: &nullables, }; - let (return_type, nullable) = fun.return_type_from_args(ret_args)?.into_parts(); + let return_field = fun.return_field_from_args(ret_args)?; Ok(Self { fun, name, args, - return_type, - nullable, + return_field, }) } @@ -145,16 +140,21 @@ impl ScalarFunctionExpr { /// Data type produced by this expression pub fn return_type(&self) -> &DataType { - &self.return_type + self.return_field.data_type() } pub fn with_nullable(mut self, nullable: bool) -> Self { - self.nullable = nullable; + self.return_field = self + .return_field + .as_ref() + .clone() + .with_nullable(nullable) + .into(); self } pub fn nullable(&self) -> bool { - self.nullable + self.return_field.is_nullable() } } @@ -171,11 +171,11 @@ impl PhysicalExpr for ScalarFunctionExpr { } fn data_type(&self, _input_schema: &Schema) -> Result { - Ok(self.return_type.clone()) + Ok(self.return_field.data_type().clone()) } fn nullable(&self, _input_schema: &Schema) -> Result { - Ok(self.nullable) + Ok(self.return_field.is_nullable()) } fn evaluate(&self, batch: &RecordBatch) -> Result { @@ -185,6 +185,12 @@ impl PhysicalExpr for ScalarFunctionExpr { .map(|e| e.evaluate(batch)) .collect::>>()?; + let arg_fields = self + .args + .iter() + .map(|e| e.return_field(batch.schema_ref())) + .collect::>>()?; + let input_empty = args.is_empty(); let input_all_scalar = args .iter() @@ -193,8 +199,9 @@ impl PhysicalExpr for ScalarFunctionExpr { // evaluate the function let output = self.fun.invoke_with_args(ScalarFunctionArgs { args, + arg_fields, number_rows: batch.num_rows(), - return_type: &self.return_type, + return_field: Arc::clone(&self.return_field), })?; if let ColumnarValue::Array(array) = &output { @@ -214,6 +221,10 @@ impl PhysicalExpr for ScalarFunctionExpr { Ok(output) } + fn return_field(&self, _input_schema: &Schema) -> Result { + Ok(Arc::clone(&self.return_field)) + } + fn children(&self) -> Vec<&Arc> { self.args.iter().collect() } @@ -222,15 +233,12 @@ impl PhysicalExpr for ScalarFunctionExpr { self: Arc, children: Vec>, ) -> Result> { - Ok(Arc::new( - ScalarFunctionExpr::new( - &self.name, - Arc::clone(&self.fun), - children, - self.return_type().clone(), - ) - .with_nullable(self.nullable), - )) + Ok(Arc::new(ScalarFunctionExpr::new( + &self.name, + Arc::clone(&self.fun), + children, + Arc::clone(&self.return_field), + ))) } fn evaluate_bounds(&self, children: &[&Interval]) -> Result { diff --git a/datafusion/physical-expr/src/window/aggregate.rs b/datafusion/physical-expr/src/window/aggregate.rs index a94d5b1212f5..9b959796136a 100644 --- a/datafusion/physical-expr/src/window/aggregate.rs +++ b/datafusion/physical-expr/src/window/aggregate.rs @@ -30,8 +30,9 @@ use crate::window::{ use crate::{reverse_order_bys, EquivalenceProperties, PhysicalExpr}; use arrow::array::Array; +use arrow::array::ArrayRef; +use arrow::datatypes::FieldRef; use arrow::record_batch::RecordBatch; -use arrow::{array::ArrayRef, datatypes::Field}; use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_expr::{Accumulator, WindowFrame}; use datafusion_physical_expr_common::sort_expr::LexOrdering; @@ -95,7 +96,7 @@ impl WindowExpr for PlainAggregateWindowExpr { self } - fn field(&self) -> Result { + fn field(&self) -> Result { Ok(self.aggregate.field()) } diff --git a/datafusion/physical-expr/src/window/sliding_aggregate.rs b/datafusion/physical-expr/src/window/sliding_aggregate.rs index 23967e78f07a..2b22299f9386 100644 --- a/datafusion/physical-expr/src/window/sliding_aggregate.rs +++ b/datafusion/physical-expr/src/window/sliding_aggregate.rs @@ -29,7 +29,7 @@ use crate::window::{ use crate::{expressions::PhysicalSortExpr, reverse_order_bys, PhysicalExpr}; use arrow::array::{Array, ArrayRef}; -use arrow::datatypes::Field; +use arrow::datatypes::FieldRef; use arrow::record_batch::RecordBatch; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{Accumulator, WindowFrame}; @@ -80,7 +80,7 @@ impl WindowExpr for SlidingAggregateWindowExpr { self } - fn field(&self) -> Result { + fn field(&self) -> Result { Ok(self.aggregate.field()) } diff --git a/datafusion/physical-expr/src/window/standard.rs b/datafusion/physical-expr/src/window/standard.rs index 22e8aea83fe7..73f47b0b6863 100644 --- a/datafusion/physical-expr/src/window/standard.rs +++ b/datafusion/physical-expr/src/window/standard.rs @@ -27,7 +27,7 @@ use crate::window::{PartitionBatches, PartitionWindowAggStates, WindowState}; use crate::{reverse_order_bys, EquivalenceProperties, PhysicalExpr}; use arrow::array::{new_empty_array, ArrayRef}; use arrow::compute::SortOptions; -use arrow::datatypes::Field; +use arrow::datatypes::FieldRef; use arrow::record_batch::RecordBatch; use datafusion_common::utils::evaluate_partition_ranges; use datafusion_common::{Result, ScalarValue}; @@ -92,7 +92,7 @@ impl WindowExpr for StandardWindowExpr { self.expr.name() } - fn field(&self) -> Result { + fn field(&self) -> Result { self.expr.field() } diff --git a/datafusion/physical-expr/src/window/standard_window_function_expr.rs b/datafusion/physical-expr/src/window/standard_window_function_expr.rs index 624b747d93f9..871f735e9a96 100644 --- a/datafusion/physical-expr/src/window/standard_window_function_expr.rs +++ b/datafusion/physical-expr/src/window/standard_window_function_expr.rs @@ -18,7 +18,7 @@ use crate::{PhysicalExpr, PhysicalSortExpr}; use arrow::array::ArrayRef; -use arrow::datatypes::{Field, SchemaRef}; +use arrow::datatypes::{FieldRef, SchemaRef}; use arrow::record_batch::RecordBatch; use datafusion_common::Result; use datafusion_expr::PartitionEvaluator; @@ -41,7 +41,7 @@ pub trait StandardWindowFunctionExpr: Send + Sync + std::fmt::Debug { fn as_any(&self) -> &dyn Any; /// The field of the final result of evaluating this window function. - fn field(&self) -> Result; + fn field(&self) -> Result; /// Expressions that are passed to the [`PartitionEvaluator`]. fn expressions(&self) -> Vec>; diff --git a/datafusion/physical-expr/src/window/window_expr.rs b/datafusion/physical-expr/src/window/window_expr.rs index 793f2e5ee586..8d72604a6af5 100644 --- a/datafusion/physical-expr/src/window/window_expr.rs +++ b/datafusion/physical-expr/src/window/window_expr.rs @@ -25,7 +25,7 @@ use crate::{LexOrdering, PhysicalExpr}; use arrow::array::{new_empty_array, Array, ArrayRef}; use arrow::compute::kernels::sort::SortColumn; use arrow::compute::SortOptions; -use arrow::datatypes::Field; +use arrow::datatypes::FieldRef; use arrow::record_batch::RecordBatch; use datafusion_common::utils::compare_rows; use datafusion_common::{internal_err, DataFusionError, Result, ScalarValue}; @@ -67,7 +67,7 @@ pub trait WindowExpr: Send + Sync + Debug { fn as_any(&self) -> &dyn Any; /// The field of the final result of this window function. - fn field(&self) -> Result; + fn field(&self) -> Result; /// Human readable name such as `"MIN(c2)"` or `"RANK()"`. The default /// implementation returns placeholder text. diff --git a/datafusion/physical-optimizer/src/aggregate_statistics.rs b/datafusion/physical-optimizer/src/aggregate_statistics.rs index 28ee10eb650a..6c44c8fe86c5 100644 --- a/datafusion/physical-optimizer/src/aggregate_statistics.rs +++ b/datafusion/physical-optimizer/src/aggregate_statistics.rs @@ -53,7 +53,7 @@ impl PhysicalOptimizerRule for AggregateStatistics { .as_any() .downcast_ref::() .expect("take_optimizable() ensures that this is a AggregateExec"); - let stats = partial_agg_exec.input().statistics()?; + let stats = partial_agg_exec.input().partition_statistics(None)?; let mut projections = vec![]; for expr in partial_agg_exec.aggr_expr() { let field = expr.field(); diff --git a/datafusion/physical-optimizer/src/enforce_distribution.rs b/datafusion/physical-optimizer/src/enforce_distribution.rs index b314b43c6a14..700b00c19dd5 100644 --- a/datafusion/physical-optimizer/src/enforce_distribution.rs +++ b/datafusion/physical-optimizer/src/enforce_distribution.rs @@ -42,7 +42,6 @@ use datafusion_physical_expr::utils::map_columns_before_projection; use datafusion_physical_expr::{ physical_exprs_equal, EquivalenceProperties, PhysicalExpr, PhysicalExprRef, }; -use datafusion_physical_expr_common::sort_expr::LexOrdering; use datafusion_physical_plan::aggregates::{ AggregateExec, AggregateMode, PhysicalGroupBy, }; @@ -950,11 +949,7 @@ fn add_spm_on_top(input: DistributionContext) -> DistributionContext { let new_plan = if should_preserve_ordering { Arc::new(SortPreservingMergeExec::new( - input - .plan - .output_ordering() - .unwrap_or(&LexOrdering::default()) - .clone(), + input.plan.output_ordering().cloned().unwrap_or_default(), Arc::clone(&input.plan), )) as _ } else { @@ -1018,7 +1013,7 @@ fn remove_dist_changing_operators( /// " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2", /// " DataSourceExec: file_groups={2 groups: \[\[x], \[y]]}, projection=\[a, b, c, d, e], output_ordering=\[a@0 ASC], file_type=parquet", /// ``` -fn replace_order_preserving_variants( +pub fn replace_order_preserving_variants( mut context: DistributionContext, ) -> Result { context.children = context @@ -1035,7 +1030,9 @@ fn replace_order_preserving_variants( if is_sort_preserving_merge(&context.plan) { let child_plan = Arc::clone(&context.children[0].plan); - context.plan = Arc::new(CoalescePartitionsExec::new(child_plan)); + context.plan = Arc::new( + CoalescePartitionsExec::new(child_plan).with_fetch(context.plan.fetch()), + ); return Ok(context); } else if let Some(repartition) = context.plan.as_any().downcast_ref::() @@ -1112,7 +1109,8 @@ fn get_repartition_requirement_status( { // Decide whether adding a round robin is beneficial depending on // the statistical information we have on the number of rows: - let roundrobin_beneficial_stats = match child.statistics()?.num_rows { + let roundrobin_beneficial_stats = match child.partition_statistics(None)?.num_rows + { Precision::Exact(n_rows) => n_rows > batch_size, Precision::Inexact(n_rows) => !should_use_estimates || (n_rows > batch_size), Precision::Absent => true, @@ -1155,6 +1153,10 @@ fn get_repartition_requirement_status( /// operators to satisfy distribution requirements. Since this function /// takes care of such requirements, we should avoid manually adding data /// exchange operators in other places. +/// +/// This function is intended to be used in a bottom up traversal, as it +/// can first repartition (or newly partition) at the datasources -- these +/// source partitions may be later repartitioned with additional data exchange operators. pub fn ensure_distribution( dist_context: DistributionContext, config: &ConfigOptions, @@ -1244,6 +1246,10 @@ pub fn ensure_distribution( // When `repartition_file_scans` is set, attempt to increase // parallelism at the source. + // + // If repartitioning is not possible (a.k.a. None is returned from `ExecutionPlan::repartitioned`) + // then no repartitioning will have occurred. As the default implementation returns None, it is only + // specific physical plan nodes, such as certain datasources, which are repartitioned. if repartition_file_scans && roundrobin_beneficial_stats { if let Some(new_child) = child.plan.repartitioned(target_partitions, config)? diff --git a/datafusion/physical-optimizer/src/enforce_sorting/mod.rs b/datafusion/physical-optimizer/src/enforce_sorting/mod.rs index 20733b65692f..37fec2eab3f9 100644 --- a/datafusion/physical-optimizer/src/enforce_sorting/mod.rs +++ b/datafusion/physical-optimizer/src/enforce_sorting/mod.rs @@ -400,6 +400,7 @@ pub fn parallelize_sorts( ), )) } else if is_coalesce_partitions(&requirements.plan) { + let fetch = requirements.plan.fetch(); // There is an unnecessary `CoalescePartitionsExec` in the plan. // This will handle the recursive `CoalescePartitionsExec` plans. requirements = remove_bottleneck_in_subplan(requirements)?; @@ -408,7 +409,10 @@ pub fn parallelize_sorts( Ok(Transformed::yes( PlanWithCorrespondingCoalescePartitions::new( - Arc::new(CoalescePartitionsExec::new(Arc::clone(&requirements.plan))), + Arc::new( + CoalescePartitionsExec::new(Arc::clone(&requirements.plan)) + .with_fetch(fetch), + ), false, vec![requirements], ), @@ -501,7 +505,7 @@ fn analyze_immediate_sort_removal( sort_exec .properties() .output_ordering() - .unwrap_or(LexOrdering::empty()), + .unwrap_or_else(|| LexOrdering::empty()), ) { node.plan = if !sort_exec.preserve_partitioning() && sort_input.output_partitioning().partition_count() > 1 diff --git a/datafusion/physical-optimizer/src/enforce_sorting/replace_with_order_preserving_variants.rs b/datafusion/physical-optimizer/src/enforce_sorting/replace_with_order_preserving_variants.rs index 2c5c0d4d510e..9769e2e0366f 100644 --- a/datafusion/physical-optimizer/src/enforce_sorting/replace_with_order_preserving_variants.rs +++ b/datafusion/physical-optimizer/src/enforce_sorting/replace_with_order_preserving_variants.rs @@ -27,8 +27,10 @@ use crate::utils::{ use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::Transformed; +use datafusion_physical_expr::LexOrdering; +use datafusion_physical_plan::internal_err; + use datafusion_common::Result; -use datafusion_physical_expr_common::sort_expr::LexOrdering; use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion_physical_plan::execution_plan::EmissionType; use datafusion_physical_plan::repartition::RepartitionExec; @@ -93,7 +95,7 @@ pub fn update_order_preservation_ctx_children_data(opc: &mut OrderPreservationCo /// inside `sort_input` with their order-preserving variants. This will /// generate an alternative plan, which will be accepted or rejected later on /// depending on whether it helps us remove a `SortExec`. -fn plan_with_order_preserving_variants( +pub fn plan_with_order_preserving_variants( mut sort_input: OrderPreservationContext, // Flag indicating that it is desirable to replace `RepartitionExec`s with // `SortPreservingRepartitionExec`s: @@ -138,6 +140,19 @@ fn plan_with_order_preserving_variants( } else if is_coalesce_partitions(&sort_input.plan) && is_spm_better { let child = &sort_input.children[0].plan; if let Some(ordering) = child.output_ordering() { + let mut fetch = fetch; + if let Some(coalesce_fetch) = sort_input.plan.fetch() { + if let Some(sort_fetch) = fetch { + if coalesce_fetch < sort_fetch { + return internal_err!( + "CoalescePartitionsExec fetch [{:?}] should be greater than or equal to SortExec fetch [{:?}]", coalesce_fetch, sort_fetch + ); + } + } else { + // If the sort node does not have a fetch, we need to keep the coalesce node's fetch. + fetch = Some(coalesce_fetch); + } + }; // When the input of a `CoalescePartitionsExec` has an ordering, // replace it with a `SortPreservingMergeExec` if appropriate: let spm = SortPreservingMergeExec::new(ordering.clone(), Arc::clone(child)) @@ -154,7 +169,7 @@ fn plan_with_order_preserving_variants( /// Calculates the updated plan by replacing operators that preserve ordering /// inside `sort_input` with their order-breaking variants. This will restore /// the original plan modified by [`plan_with_order_preserving_variants`]. -fn plan_with_order_breaking_variants( +pub fn plan_with_order_breaking_variants( mut sort_input: OrderPreservationContext, ) -> Result { let plan = &sort_input.plan; @@ -189,10 +204,12 @@ fn plan_with_order_breaking_variants( let partitioning = plan.output_partitioning().clone(); sort_input.plan = Arc::new(RepartitionExec::try_new(child, partitioning)?) as _; } else if is_sort_preserving_merge(plan) { - // Replace `SortPreservingMergeExec` with a `CoalescePartitionsExec`: + // Replace `SortPreservingMergeExec` with a `CoalescePartitionsExec` + // SPM may have `fetch`, so pass it to the `CoalescePartitionsExec` let child = Arc::clone(&sort_input.children[0].plan); - let coalesce = CoalescePartitionsExec::new(child); - sort_input.plan = Arc::new(coalesce) as _; + let coalesce = + Arc::new(CoalescePartitionsExec::new(child).with_fetch(plan.fetch())); + sort_input.plan = coalesce; } else { return sort_input.update_plan_from_children(); } @@ -271,7 +288,7 @@ pub fn replace_with_order_preserving_variants( requirements .plan .output_ordering() - .unwrap_or(LexOrdering::empty()), + .unwrap_or_else(|| LexOrdering::empty()), ) { for child in alternate_plan.children.iter_mut() { diff --git a/datafusion/physical-optimizer/src/enforce_sorting/sort_pushdown.rs b/datafusion/physical-optimizer/src/enforce_sorting/sort_pushdown.rs index 2e20608d0e9e..6d2c014f9e7c 100644 --- a/datafusion/physical-optimizer/src/enforce_sorting/sort_pushdown.rs +++ b/datafusion/physical-optimizer/src/enforce_sorting/sort_pushdown.rs @@ -233,7 +233,7 @@ fn pushdown_requirement_to_children( .properties() .output_ordering() .cloned() - .unwrap_or(LexOrdering::default()), + .unwrap_or_else(LexOrdering::default), ); if sort_exec .properties() @@ -258,7 +258,7 @@ fn pushdown_requirement_to_children( plan.properties() .output_ordering() .cloned() - .unwrap_or(LexOrdering::default()), + .unwrap_or_else(LexOrdering::default), ); // Push down through operator with fetch when: // - requirement is aligned with output ordering diff --git a/datafusion/physical-optimizer/src/filter_pushdown.rs b/datafusion/physical-optimizer/src/filter_pushdown.rs new file mode 100644 index 000000000000..5b2d47106b8d --- /dev/null +++ b/datafusion/physical-optimizer/src/filter_pushdown.rs @@ -0,0 +1,541 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use crate::PhysicalOptimizerRule; + +use datafusion_common::{config::ConfigOptions, Result}; +use datafusion_physical_expr::PhysicalExpr; +use datafusion_physical_plan::filter_pushdown::{ + ChildPushdownResult, FilterPushdownPropagation, PredicateSupport, PredicateSupports, +}; +use datafusion_physical_plan::{with_new_children_if_necessary, ExecutionPlan}; + +use itertools::izip; + +/// Attempts to recursively push given filters from the top of the tree into leafs. +/// +/// # Default Implementation +/// +/// The default implementation in [`ExecutionPlan::gather_filters_for_pushdown`] +/// and [`ExecutionPlan::handle_child_pushdown_result`] assumes that: +/// +/// * Parent filters can't be passed onto children (determined by [`ExecutionPlan::gather_filters_for_pushdown`]) +/// * This node has no filters to contribute (determined by [`ExecutionPlan::gather_filters_for_pushdown`]). +/// * Any filters that could not be pushed down to the children are marked as unsupported (determined by [`ExecutionPlan::handle_child_pushdown_result`]). +/// +/// # Example: Push filter into a `DataSourceExec` +/// +/// For example, consider the following plan: +/// +/// ```text +/// ┌──────────────────────┐ +/// │ CoalesceBatchesExec │ +/// └──────────────────────┘ +/// │ +/// ▼ +/// ┌──────────────────────┐ +/// │ FilterExec │ +/// │ filters = [ id=1] │ +/// └──────────────────────┘ +/// │ +/// ▼ +/// ┌──────────────────────┐ +/// │ DataSourceExec │ +/// │ projection = * │ +/// └──────────────────────┘ +/// ``` +/// +/// Our goal is to move the `id = 1` filter from the [`FilterExec`] node to the `DataSourceExec` node. +/// +/// If this filter is selective pushing it into the scan can avoid massive +/// amounts of data being read from the source (the projection is `*` so all +/// matching columns are read). +/// +/// The new plan looks like: +/// +/// ```text +/// ┌──────────────────────┐ +/// │ CoalesceBatchesExec │ +/// └──────────────────────┘ +/// │ +/// ▼ +/// ┌──────────────────────┐ +/// │ DataSourceExec │ +/// │ projection = * │ +/// │ filters = [ id=1] │ +/// └──────────────────────┘ +/// ``` +/// +/// # Example: Push filters with `ProjectionExec` +/// +/// Let's consider a more complex example involving a [`ProjectionExec`] +/// node in between the [`FilterExec`] and `DataSourceExec` nodes that +/// creates a new column that the filter depends on. +/// +/// ```text +/// ┌──────────────────────┐ +/// │ CoalesceBatchesExec │ +/// └──────────────────────┘ +/// │ +/// ▼ +/// ┌──────────────────────┐ +/// │ FilterExec │ +/// │ filters = │ +/// │ [cost>50,id=1] │ +/// └──────────────────────┘ +/// │ +/// ▼ +/// ┌──────────────────────┐ +/// │ ProjectionExec │ +/// │ cost = price * 1.2 │ +/// └──────────────────────┘ +/// │ +/// ▼ +/// ┌──────────────────────┐ +/// │ DataSourceExec │ +/// │ projection = * │ +/// └──────────────────────┘ +/// ``` +/// +/// We want to push down the filters `[id=1]` to the `DataSourceExec` node, +/// but can't push down `cost>50` because it requires the [`ProjectionExec`] +/// node to be executed first. A simple thing to do would be to split up the +/// filter into two separate filters and push down the first one: +/// +/// ```text +/// ┌──────────────────────┐ +/// │ CoalesceBatchesExec │ +/// └──────────────────────┘ +/// │ +/// ▼ +/// ┌──────────────────────┐ +/// │ FilterExec │ +/// │ filters = │ +/// │ [cost>50] │ +/// └──────────────────────┘ +/// │ +/// ▼ +/// ┌──────────────────────┐ +/// │ ProjectionExec │ +/// │ cost = price * 1.2 │ +/// └──────────────────────┘ +/// │ +/// ▼ +/// ┌──────────────────────┐ +/// │ DataSourceExec │ +/// │ projection = * │ +/// │ filters = [ id=1] │ +/// └──────────────────────┘ +/// ``` +/// +/// We can actually however do better by pushing down `price * 1.2 > 50` +/// instead of `cost > 50`: +/// +/// ```text +/// ┌──────────────────────┐ +/// │ CoalesceBatchesExec │ +/// └──────────────────────┘ +/// │ +/// ▼ +/// ┌──────────────────────┐ +/// │ ProjectionExec │ +/// │ cost = price * 1.2 │ +/// └──────────────────────┘ +/// │ +/// ▼ +/// ┌──────────────────────┐ +/// │ DataSourceExec │ +/// │ projection = * │ +/// │ filters = [id=1, │ +/// │ price * 1.2 > 50] │ +/// └──────────────────────┘ +/// ``` +/// +/// # Example: Push filters within a subtree +/// +/// There are also cases where we may be able to push down filters within a +/// subtree but not the entire tree. A good example of this is aggregation +/// nodes: +/// +/// ```text +/// ┌──────────────────────┐ +/// │ ProjectionExec │ +/// │ projection = * │ +/// └──────────────────────┘ +/// │ +/// ▼ +/// ┌──────────────────────┐ +/// │ FilterExec │ +/// │ filters = [sum > 10] │ +/// └──────────────────────┘ +/// │ +/// ▼ +/// ┌───────────────────────┐ +/// │ AggregateExec │ +/// │ group by = [id] │ +/// │ aggregate = │ +/// │ [sum(price)] │ +/// └───────────────────────┘ +/// │ +/// ▼ +/// ┌──────────────────────┐ +/// │ FilterExec │ +/// │ filters = [id=1] │ +/// └──────────────────────┘ +/// │ +/// ▼ +/// ┌──────────────────────┐ +/// │ DataSourceExec │ +/// │ projection = * │ +/// └──────────────────────┘ +/// ``` +/// +/// The transformation here is to push down the `id=1` filter to the +/// `DataSourceExec` node: +/// +/// ```text +/// ┌──────────────────────┐ +/// │ ProjectionExec │ +/// │ projection = * │ +/// └──────────────────────┘ +/// │ +/// ▼ +/// ┌──────────────────────┐ +/// │ FilterExec │ +/// │ filters = [sum > 10] │ +/// └──────────────────────┘ +/// │ +/// ▼ +/// ┌───────────────────────┐ +/// │ AggregateExec │ +/// │ group by = [id] │ +/// │ aggregate = │ +/// │ [sum(price)] │ +/// └───────────────────────┘ +/// │ +/// ▼ +/// ┌──────────────────────┐ +/// │ DataSourceExec │ +/// │ projection = * │ +/// │ filters = [id=1] │ +/// └──────────────────────┘ +/// ``` +/// +/// The point here is that: +/// 1. We cannot push down `sum > 10` through the [`AggregateExec`] node into the `DataSourceExec` node. +/// Any filters above the [`AggregateExec`] node are not pushed down. +/// This is determined by calling [`ExecutionPlan::gather_filters_for_pushdown`] on the [`AggregateExec`] node. +/// 2. We need to keep recursing into the tree so that we can discover the other [`FilterExec`] node and push +/// down the `id=1` filter. +/// +/// # Example: Push filters through Joins +/// +/// It is also possible to push down filters through joins and filters that +/// originate from joins. For example, a hash join where we build a hash +/// table of the left side and probe the right side (ignoring why we would +/// choose this order, typically it depends on the size of each table, +/// etc.). +/// +/// ```text +/// ┌─────────────────────┐ +/// │ FilterExec │ +/// │ filters = │ +/// │ [d.size > 100] │ +/// └─────────────────────┘ +/// │ +/// │ +/// ┌──────────▼──────────┐ +/// │ │ +/// │ HashJoinExec │ +/// │ [u.dept@hash(d.id)] │ +/// │ │ +/// └─────────────────────┘ +/// │ +/// ┌────────────┴────────────┐ +/// ┌──────────▼──────────┐ ┌──────────▼──────────┐ +/// │ DataSourceExec │ │ DataSourceExec │ +/// │ alias [users as u] │ │ alias [dept as d] │ +/// │ │ │ │ +/// └─────────────────────┘ └─────────────────────┘ +/// ``` +/// +/// There are two pushdowns we can do here: +/// 1. Push down the `d.size > 100` filter through the `HashJoinExec` node to the `DataSourceExec` +/// node for the `departments` table. +/// 2. Push down the hash table state from the `HashJoinExec` node to the `DataSourceExec` node to avoid reading +/// rows from the `users` table that will be eliminated by the join. +/// This can be done via a bloom filter or similar and is not (yet) supported +/// in DataFusion. See . +/// +/// ```text +/// ┌─────────────────────┐ +/// │ │ +/// │ HashJoinExec │ +/// │ [u.dept@hash(d.id)] │ +/// │ │ +/// └─────────────────────┘ +/// │ +/// ┌────────────┴────────────┐ +/// ┌──────────▼──────────┐ ┌──────────▼──────────┐ +/// │ DataSourceExec │ │ DataSourceExec │ +/// │ alias [users as u] │ │ alias [dept as d] │ +/// │ filters = │ │ filters = │ +/// │ [depg@hash(d.id)] │ │ [ d.size > 100] │ +/// └─────────────────────┘ └─────────────────────┘ +/// ``` +/// +/// You may notice in this case that the filter is *dynamic*: the hash table +/// is built _after_ the `departments` table is read and at runtime. We +/// don't have a concrete `InList` filter or similar to push down at +/// optimization time. These sorts of dynamic filters are handled by +/// building a specialized [`PhysicalExpr`] that can be evaluated at runtime +/// and internally maintains a reference to the hash table or other state. +/// +/// To make working with these sorts of dynamic filters more tractable we have the method [`PhysicalExpr::snapshot`] +/// which attempts to simplify a dynamic filter into a "basic" non-dynamic filter. +/// For a join this could mean converting it to an `InList` filter or a min/max filter for example. +/// See `datafusion/physical-plan/src/dynamic_filters.rs` for more details. +/// +/// # Example: Push TopK filters into Scans +/// +/// Another form of dynamic filter is pushing down the state of a `TopK` +/// operator for queries like `SELECT * FROM t ORDER BY id LIMIT 10`: +/// +/// ```text +/// ┌──────────────────────┐ +/// │ TopK │ +/// │ limit = 10 │ +/// │ order by = [id] │ +/// └──────────────────────┘ +/// │ +/// ▼ +/// ┌──────────────────────┐ +/// │ DataSourceExec │ +/// │ projection = * │ +/// └──────────────────────┘ +/// ``` +/// +/// We can avoid large amounts of data processing by transforming this into: +/// +/// ```text +/// ┌──────────────────────┐ +/// │ TopK │ +/// │ limit = 10 │ +/// │ order by = [id] │ +/// └──────────────────────┘ +/// │ +/// ▼ +/// ┌──────────────────────┐ +/// │ DataSourceExec │ +/// │ projection = * │ +/// │ filters = │ +/// │ [id < @ TopKHeap] │ +/// └──────────────────────┘ +/// ``` +/// +/// Now as we fill our `TopK` heap we can push down the state of the heap to +/// the `DataSourceExec` node to avoid reading files / row groups / pages / +/// rows that could not possibly be in the top 10. +/// +/// This is not yet implemented in DataFusion. See +/// +/// +/// [`PhysicalExpr`]: datafusion_physical_plan::PhysicalExpr +/// [`PhysicalExpr::snapshot`]: datafusion_physical_plan::PhysicalExpr::snapshot +/// [`FilterExec`]: datafusion_physical_plan::filter::FilterExec +/// [`ProjectionExec`]: datafusion_physical_plan::projection::ProjectionExec +/// [`AggregateExec`]: datafusion_physical_plan::aggregates::AggregateExec +#[derive(Debug)] +pub struct FilterPushdown {} + +impl FilterPushdown { + pub fn new() -> Self { + Self {} + } +} + +impl Default for FilterPushdown { + fn default() -> Self { + Self::new() + } +} + +impl PhysicalOptimizerRule for FilterPushdown { + fn optimize( + &self, + plan: Arc, + config: &ConfigOptions, + ) -> Result> { + Ok(push_down_filters(Arc::clone(&plan), vec![], config)? + .updated_node + .unwrap_or(plan)) + } + + fn name(&self) -> &str { + "FilterPushdown" + } + + fn schema_check(&self) -> bool { + true // Filter pushdown does not change the schema of the plan + } +} + +/// Support state of each predicate for the children of the node. +/// These predicates are coming from the parent node. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum ParentPredicateStates { + NoChildren, + Unsupported, + Supported, +} + +fn push_down_filters( + node: Arc, + parent_predicates: Vec>, + config: &ConfigOptions, +) -> Result>> { + // If the node has any child, these will be rewritten as supported or unsupported + let mut parent_predicates_pushdown_states = + vec![ParentPredicateStates::NoChildren; parent_predicates.len()]; + let mut self_filters_pushdown_supports = vec![]; + let mut new_children = Vec::with_capacity(node.children().len()); + + let children = node.children(); + let filter_description = + node.gather_filters_for_pushdown(parent_predicates.clone(), config)?; + + for (child, parent_filters, self_filters) in izip!( + children, + filter_description.parent_filters(), + filter_description.self_filters() + ) { + // Here, `parent_filters` are the predicates which are provided by the parent node of + // the current node, and tried to be pushed down over the child which the loop points + // currently. `self_filters` are the predicates which are provided by the current node, + // and tried to be pushed down over the child similarly. + + let num_self_filters = self_filters.len(); + let mut parent_supported_predicate_indices = vec![]; + let mut all_predicates = self_filters; + + // Iterate over each predicate coming from the parent + for (idx, filter) in parent_filters.into_iter().enumerate() { + // Check if we can push this filter down to our child. + // These supports are defined in `gather_filters_for_pushdown()` + match filter { + PredicateSupport::Supported(predicate) => { + // Queue this filter up for pushdown to this child + all_predicates.push(predicate); + parent_supported_predicate_indices.push(idx); + // Mark this filter as supported by our children if no child has marked it as unsupported + if parent_predicates_pushdown_states[idx] + != ParentPredicateStates::Unsupported + { + parent_predicates_pushdown_states[idx] = + ParentPredicateStates::Supported; + } + } + PredicateSupport::Unsupported(_) => { + // Mark as unsupported by our children + parent_predicates_pushdown_states[idx] = + ParentPredicateStates::Unsupported; + } + } + } + + // Any filters that could not be pushed down to a child are marked as not-supported to our parents + let result = push_down_filters(Arc::clone(child), all_predicates, config)?; + + if let Some(new_child) = result.updated_node { + // If we have a filter pushdown result, we need to update our children + new_children.push(new_child); + } else { + // If we don't have a filter pushdown result, we need to update our children + new_children.push(Arc::clone(child)); + } + + // Our child doesn't know the difference between filters that were passed down + // from our parents and filters that the current node injected. We need to de-entangle + // this since we do need to distinguish between them. + let mut all_filters = result.filters.into_inner(); + let parent_predicates = all_filters.split_off(num_self_filters); + let self_predicates = all_filters; + self_filters_pushdown_supports.push(PredicateSupports::new(self_predicates)); + + for (idx, result) in parent_supported_predicate_indices + .iter() + .zip(parent_predicates) + { + let current_node_state = match result { + PredicateSupport::Supported(_) => ParentPredicateStates::Supported, + PredicateSupport::Unsupported(_) => ParentPredicateStates::Unsupported, + }; + match (current_node_state, parent_predicates_pushdown_states[*idx]) { + (r, ParentPredicateStates::NoChildren) => { + // If we have no result, use the current state from this child + parent_predicates_pushdown_states[*idx] = r; + } + (ParentPredicateStates::Supported, ParentPredicateStates::Supported) => { + // If the current child and all previous children are supported, + // the filter continues to support it + parent_predicates_pushdown_states[*idx] = + ParentPredicateStates::Supported; + } + _ => { + // Either the current child or a previous child marked this filter as unsupported + parent_predicates_pushdown_states[*idx] = + ParentPredicateStates::Unsupported; + } + } + } + } + // Re-create this node with new children + let updated_node = with_new_children_if_necessary(Arc::clone(&node), new_children)?; + // Remap the result onto the parent filters as they were given to us. + // Any filters that were not pushed down to any children are marked as unsupported. + let parent_pushdown_result = PredicateSupports::new( + parent_predicates_pushdown_states + .into_iter() + .zip(parent_predicates) + .map(|(state, filter)| match state { + ParentPredicateStates::NoChildren => { + PredicateSupport::Unsupported(filter) + } + ParentPredicateStates::Unsupported => { + PredicateSupport::Unsupported(filter) + } + ParentPredicateStates::Supported => PredicateSupport::Supported(filter), + }) + .collect(), + ); + // Check what the current node wants to do given the result of pushdown to it's children + let mut res = updated_node.handle_child_pushdown_result( + ChildPushdownResult { + parent_filters: parent_pushdown_result, + self_filters: self_filters_pushdown_supports, + }, + config, + )?; + // Compare pointers for new_node and node, if they are different we must replace + // ourselves because of changes in our children. + if res.updated_node.is_none() && !Arc::ptr_eq(&updated_node, &node) { + res.updated_node = Some(updated_node) + } + Ok(res) +} diff --git a/datafusion/physical-optimizer/src/join_selection.rs b/datafusion/physical-optimizer/src/join_selection.rs index 5a772ccdd249..05758e5dfdf1 100644 --- a/datafusion/physical-optimizer/src/join_selection.rs +++ b/datafusion/physical-optimizer/src/join_selection.rs @@ -65,8 +65,8 @@ pub(crate) fn should_swap_join_order( // Get the left and right table's total bytes // If both the left and right tables contain total_byte_size statistics, // use `total_byte_size` to determine `should_swap_join_order`, else use `num_rows` - let left_stats = left.statistics()?; - let right_stats = right.statistics()?; + let left_stats = left.partition_statistics(None)?; + let right_stats = right.partition_statistics(None)?; // First compare `total_byte_size` of left and right side, // if information in this field is insufficient fallback to the `num_rows` match ( @@ -91,7 +91,7 @@ fn supports_collect_by_thresholds( ) -> bool { // Currently we do not trust the 0 value from stats, due to stats collection might have bug // TODO check the logic in datasource::get_statistics_with_limit() - let Ok(stats) = plan.statistics() else { + let Ok(stats) = plan.partition_statistics(None) else { return false; }; diff --git a/datafusion/physical-optimizer/src/lib.rs b/datafusion/physical-optimizer/src/lib.rs index 35503f3b0b5f..5a43d7118d63 100644 --- a/datafusion/physical-optimizer/src/lib.rs +++ b/datafusion/physical-optimizer/src/lib.rs @@ -29,6 +29,7 @@ pub mod coalesce_batches; pub mod combine_partial_final_agg; pub mod enforce_distribution; pub mod enforce_sorting; +pub mod filter_pushdown; pub mod join_selection; pub mod limit_pushdown; pub mod limited_distinct_aggregation; diff --git a/datafusion/physical-optimizer/src/limit_pushdown.rs b/datafusion/physical-optimizer/src/limit_pushdown.rs index 5887cb51a727..7469c3af9344 100644 --- a/datafusion/physical-optimizer/src/limit_pushdown.rs +++ b/datafusion/physical-optimizer/src/limit_pushdown.rs @@ -246,16 +246,7 @@ pub fn pushdown_limit_helper( Ok((Transformed::no(pushdown_plan), global_state)) } } else { - // Add fetch or a `LimitExec`: - // If the plan's children have limit and the child's limit < parent's limit, we shouldn't change the global state to true, - // because the children limit will be overridden if the global state is changed. - if !pushdown_plan - .children() - .iter() - .any(|&child| extract_limit(child).is_some()) - { - global_state.satisfied = true; - } + global_state.satisfied = true; pushdown_plan = if let Some(plan_with_fetch) = maybe_fetchable { if global_skip > 0 { add_global_limit(plan_with_fetch, global_skip, Some(global_fetch)) diff --git a/datafusion/physical-optimizer/src/optimizer.rs b/datafusion/physical-optimizer/src/optimizer.rs index bab31150e250..432ac35ebc23 100644 --- a/datafusion/physical-optimizer/src/optimizer.rs +++ b/datafusion/physical-optimizer/src/optimizer.rs @@ -25,6 +25,7 @@ use crate::coalesce_batches::CoalesceBatches; use crate::combine_partial_final_agg::CombinePartialFinalAggregate; use crate::enforce_distribution::EnforceDistribution; use crate::enforce_sorting::EnforceSorting; +use crate::filter_pushdown::FilterPushdown; use crate::join_selection::JoinSelection; use crate::limit_pushdown::LimitPushdown; use crate::limited_distinct_aggregation::LimitedDistinctAggregation; @@ -94,6 +95,10 @@ impl PhysicalOptimizer { // as that rule may inject other operations in between the different AggregateExecs. // Applying the rule early means only directly-connected AggregateExecs must be examined. Arc::new(LimitedDistinctAggregation::new()), + // The FilterPushdown rule tries to push down filters as far as it can. + // For example, it will push down filtering from a `FilterExec` to + // a `DataSourceExec`, or from a `TopK`'s current state to a `DataSourceExec`. + Arc::new(FilterPushdown::new()), // The EnforceDistribution rule is for adding essential repartitioning to satisfy distribution // requirements. Please make sure that the whole plan tree is determined before this rule. // This rule increases parallelism if doing so is beneficial to the physical plan; i.e. at diff --git a/datafusion/physical-optimizer/src/output_requirements.rs b/datafusion/physical-optimizer/src/output_requirements.rs index 3ca0547aa11d..0488b3fd49a8 100644 --- a/datafusion/physical-optimizer/src/output_requirements.rs +++ b/datafusion/physical-optimizer/src/output_requirements.rs @@ -200,7 +200,11 @@ impl ExecutionPlan for OutputRequirementExec { } fn statistics(&self) -> Result { - self.input.statistics() + self.input.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + self.input.partition_statistics(partition) } fn try_swapping_with_projection( diff --git a/datafusion/physical-optimizer/src/pruning.rs b/datafusion/physical-optimizer/src/pruning.rs index 1dd168f18167..e2378b5f42df 100644 --- a/datafusion/physical-optimizer/src/pruning.rs +++ b/datafusion/physical-optimizer/src/pruning.rs @@ -28,7 +28,9 @@ use arrow::{ datatypes::{DataType, Field, Schema, SchemaRef}, record_batch::{RecordBatch, RecordBatchOptions}, }; -use log::trace; +// pub use for backwards compatibility +pub use datafusion_common::pruning::PruningStatistics; +use log::{debug, trace}; use datafusion_common::error::{DataFusionError, Result}; use datafusion_common::tree_node::TransformedResult; @@ -44,106 +46,6 @@ use datafusion_physical_expr::{expressions as phys_expr, PhysicalExprRef}; use datafusion_physical_expr_common::physical_expr::snapshot_physical_expr; use datafusion_physical_plan::{ColumnarValue, PhysicalExpr}; -/// A source of runtime statistical information to [`PruningPredicate`]s. -/// -/// # Supported Information -/// -/// 1. Minimum and maximum values for columns -/// -/// 2. Null counts and row counts for columns -/// -/// 3. Whether the values in a column are contained in a set of literals -/// -/// # Vectorized Interface -/// -/// Information for containers / files are returned as Arrow [`ArrayRef`], so -/// the evaluation happens once on a single `RecordBatch`, which amortizes the -/// overhead of evaluating the predicate. This is important when pruning 1000s -/// of containers which often happens in analytic systems that have 1000s of -/// potential files to consider. -/// -/// For example, for the following three files with a single column `a`: -/// ```text -/// file1: column a: min=5, max=10 -/// file2: column a: No stats -/// file2: column a: min=20, max=30 -/// ``` -/// -/// PruningStatistics would return: -/// -/// ```text -/// min_values("a") -> Some([5, Null, 20]) -/// max_values("a") -> Some([10, Null, 30]) -/// min_values("X") -> None -/// ``` -pub trait PruningStatistics { - /// Return the minimum values for the named column, if known. - /// - /// If the minimum value for a particular container is not known, the - /// returned array should have `null` in that row. If the minimum value is - /// not known for any row, return `None`. - /// - /// Note: the returned array must contain [`Self::num_containers`] rows - fn min_values(&self, column: &Column) -> Option; - - /// Return the maximum values for the named column, if known. - /// - /// See [`Self::min_values`] for when to return `None` and null values. - /// - /// Note: the returned array must contain [`Self::num_containers`] rows - fn max_values(&self, column: &Column) -> Option; - - /// Return the number of containers (e.g. Row Groups) being pruned with - /// these statistics. - /// - /// This value corresponds to the size of the [`ArrayRef`] returned by - /// [`Self::min_values`], [`Self::max_values`], [`Self::null_counts`], - /// and [`Self::row_counts`]. - fn num_containers(&self) -> usize; - - /// Return the number of null values for the named column as an - /// [`UInt64Array`] - /// - /// See [`Self::min_values`] for when to return `None` and null values. - /// - /// Note: the returned array must contain [`Self::num_containers`] rows - /// - /// [`UInt64Array`]: arrow::array::UInt64Array - fn null_counts(&self, column: &Column) -> Option; - - /// Return the number of rows for the named column in each container - /// as an [`UInt64Array`]. - /// - /// See [`Self::min_values`] for when to return `None` and null values. - /// - /// Note: the returned array must contain [`Self::num_containers`] rows - /// - /// [`UInt64Array`]: arrow::array::UInt64Array - fn row_counts(&self, column: &Column) -> Option; - - /// Returns [`BooleanArray`] where each row represents information known - /// about specific literal `values` in a column. - /// - /// For example, Parquet Bloom Filters implement this API to communicate - /// that `values` are known not to be present in a Row Group. - /// - /// The returned array has one row for each container, with the following - /// meanings: - /// * `true` if the values in `column` ONLY contain values from `values` - /// * `false` if the values in `column` are NOT ANY of `values` - /// * `null` if the neither of the above holds or is unknown. - /// - /// If these statistics can not determine column membership for any - /// container, return `None` (the default). - /// - /// Note: the returned array must contain [`Self::num_containers`] rows - fn contained( - &self, - column: &Column, - values: &HashSet, - ) -> Option; -} - /// Used to prove that arbitrary predicates (boolean expression) can not /// possibly evaluate to `true` given information about a column provided by /// [`PruningStatistics`]. @@ -751,6 +653,13 @@ fn is_always_true(expr: &Arc) -> bool { .unwrap_or_default() } +fn is_always_false(expr: &Arc) -> bool { + expr.as_any() + .downcast_ref::() + .map(|l| matches!(l.value(), ScalarValue::Boolean(Some(false)))) + .unwrap_or_default() +} + /// Describes which columns statistics are necessary to evaluate a /// [`PruningPredicate`]. /// @@ -984,11 +893,7 @@ fn build_statistics_record_batch( let mut options = RecordBatchOptions::default(); options.row_count = Some(statistics.num_containers()); - trace!( - "Creating statistics batch for {:#?} with {:#?}", - required_columns, - arrays - ); + trace!("Creating statistics batch for {required_columns:#?} with {arrays:#?}"); RecordBatch::try_new_with_options(schema, arrays, &options).map_err(|err| { plan_datafusion_err!("Can not create statistics record batch: {err}") @@ -1210,23 +1115,35 @@ fn is_compare_op(op: Operator) -> bool { ) } +fn is_string_type(data_type: &DataType) -> bool { + matches!( + data_type, + DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View + ) +} + // The pruning logic is based on the comparing the min/max bounds. // Must make sure the two type has order. // For example, casts from string to numbers is not correct. // Because the "13" is less than "3" with UTF8 comparison order. fn verify_support_type_for_prune(from_type: &DataType, to_type: &DataType) -> Result<()> { - // TODO: support other data type for prunable cast or try cast - if matches!( - from_type, - DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::Decimal128(_, _) - ) && matches!( - to_type, - DataType::Int8 | DataType::Int32 | DataType::Int64 | DataType::Decimal128(_, _) - ) { + // Dictionary casts are always supported as long as the value types are supported + let from_type = match from_type { + DataType::Dictionary(_, t) => { + return verify_support_type_for_prune(t.as_ref(), to_type) + } + _ => from_type, + }; + let to_type = match to_type { + DataType::Dictionary(_, t) => { + return verify_support_type_for_prune(from_type, t.as_ref()) + } + _ => to_type, + }; + // If both types are strings or both are not strings (number, timestamp, etc) + // then we can compare them. + // PruningPredicate does not support casting of strings to numbers and such. + if is_string_type(from_type) == is_string_type(to_type) { Ok(()) } else { plan_err!( @@ -1427,6 +1344,11 @@ fn build_predicate_expression( required_columns: &mut RequiredColumns, unhandled_hook: &Arc, ) -> Arc { + if is_always_false(expr) { + // Shouldn't return `unhandled_hook.handle(expr)` + // Because it will transfer false to true. + return Arc::clone(expr); + } // predicate expression can only be a binary expression let expr_any = expr.as_any(); if let Some(is_null) = expr_any.downcast_ref::() { @@ -1526,6 +1448,11 @@ fn build_predicate_expression( build_predicate_expression(&right, schema, required_columns, unhandled_hook); // simplify boolean expression if applicable let expr = match (&left_expr, op, &right_expr) { + (left, Operator::And, right) + if is_always_false(left) || is_always_false(right) => + { + Arc::new(phys_expr::Literal::new(ScalarValue::Boolean(Some(false)))) + } (left, Operator::And, _) if is_always_true(left) => right_expr, (_, Operator::And, right) if is_always_true(right) => left_expr, (left, Operator::Or, right) @@ -1533,6 +1460,9 @@ fn build_predicate_expression( { Arc::new(phys_expr::Literal::new(ScalarValue::Boolean(Some(true)))) } + (left, Operator::Or, _) if is_always_false(left) => right_expr, + (_, Operator::Or, right) if is_always_false(right) => left_expr, + _ => Arc::new(phys_expr::BinaryExpr::new(left_expr, op, right_expr)), }; return expr; @@ -1544,7 +1474,10 @@ fn build_predicate_expression( Ok(builder) => builder, // allow partial failure in predicate expression generation // this can still produce a useful predicate when multiple conditions are joined using AND - Err(_) => return unhandled_hook.handle(expr), + Err(e) => { + debug!("Error building pruning expression: {e}"); + return unhandled_hook.handle(expr); + } }; build_statistics_expr(&mut expr_builder) @@ -1889,7 +1822,7 @@ mod tests { use super::*; use datafusion_common::test_util::batches_to_string; - use datafusion_expr::{col, lit}; + use datafusion_expr::{and, col, lit, or}; use insta::assert_snapshot; use arrow::array::Decimal128Array; @@ -2305,8 +2238,7 @@ mod tests { let was_new = fields.insert(field); if !was_new { panic!( - "Duplicate field in required schema: {:?}. Previous fields:\n{:#?}", - field, fields + "Duplicate field in required schema: {field:?}. Previous fields:\n{fields:#?}" ); } } @@ -2811,8 +2743,8 @@ mod tests { let predicate_expr = test_build_predicate_expression(&expr, &schema, &mut required_columns); assert_eq!(predicate_expr.to_string(), expected_expr); - println!("required_columns: {:#?}", required_columns); // for debugging assertions below - // c1 < 1 should add c1_min + println!("required_columns: {required_columns:#?}"); // for debugging assertions below + // c1 < 1 should add c1_min let c1_min_field = Field::new("c1_min", DataType::Int32, false); assert_eq!( required_columns.columns[0], @@ -3006,7 +2938,7 @@ mod tests { } #[test] - fn row_group_predicate_cast() -> Result<()> { + fn row_group_predicate_cast_int_int() -> Result<()> { let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); let expected_expr = "c1_null_count@2 != row_count@3 AND CAST(c1_min@0 AS Int64) <= 1 AND 1 <= CAST(c1_max@1 AS Int64)"; @@ -3043,6 +2975,291 @@ mod tests { Ok(()) } + #[test] + fn row_group_predicate_cast_string_string() -> Result<()> { + let schema = Schema::new(vec![Field::new("c1", DataType::Utf8View, false)]); + let expected_expr = "c1_null_count@2 != row_count@3 AND CAST(c1_min@0 AS Utf8) <= 1 AND 1 <= CAST(c1_max@1 AS Utf8)"; + + // test column on the left + let expr = cast(col("c1"), DataType::Utf8) + .eq(lit(ScalarValue::Utf8(Some("1".to_string())))); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + // test column on the right + let expr = lit(ScalarValue::Utf8(Some("1".to_string()))) + .eq(cast(col("c1"), DataType::Utf8)); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + Ok(()) + } + + #[test] + fn row_group_predicate_cast_string_int() -> Result<()> { + let schema = Schema::new(vec![Field::new("c1", DataType::Utf8View, false)]); + let expected_expr = "true"; + + // test column on the left + let expr = cast(col("c1"), DataType::Int32).eq(lit(ScalarValue::Int32(Some(1)))); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + // test column on the right + let expr = lit(ScalarValue::Int32(Some(1))).eq(cast(col("c1"), DataType::Int32)); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + Ok(()) + } + + #[test] + fn row_group_predicate_cast_int_string() -> Result<()> { + let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); + let expected_expr = "true"; + + // test column on the left + let expr = cast(col("c1"), DataType::Utf8) + .eq(lit(ScalarValue::Utf8(Some("1".to_string())))); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + // test column on the right + let expr = lit(ScalarValue::Utf8(Some("1".to_string()))) + .eq(cast(col("c1"), DataType::Utf8)); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + Ok(()) + } + + #[test] + fn row_group_predicate_date_date() -> Result<()> { + let schema = Schema::new(vec![Field::new("c1", DataType::Date32, false)]); + let expected_expr = "c1_null_count@2 != row_count@3 AND CAST(c1_min@0 AS Date64) <= 1970-01-01 AND 1970-01-01 <= CAST(c1_max@1 AS Date64)"; + + // test column on the left + let expr = + cast(col("c1"), DataType::Date64).eq(lit(ScalarValue::Date64(Some(123)))); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + // test column on the right + let expr = + lit(ScalarValue::Date64(Some(123))).eq(cast(col("c1"), DataType::Date64)); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + Ok(()) + } + + #[test] + fn row_group_predicate_dict_string_date() -> Result<()> { + // Test with Dictionary for the literal + let schema = Schema::new(vec![Field::new("c1", DataType::Date32, false)]); + let expected_expr = "true"; + + // test column on the left + let expr = cast( + col("c1"), + DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Utf8)), + ) + .eq(lit(ScalarValue::Utf8(Some("2024-01-01".to_string())))); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + // test column on the right + let expr = lit(ScalarValue::Utf8(Some("2024-01-01".to_string()))).eq(cast( + col("c1"), + DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Utf8)), + )); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + Ok(()) + } + + #[test] + fn row_group_predicate_date_dict_string() -> Result<()> { + // Test with Dictionary for the column + let schema = Schema::new(vec![Field::new( + "c1", + DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Utf8)), + false, + )]); + let expected_expr = "true"; + + // test column on the left + let expr = + cast(col("c1"), DataType::Date32).eq(lit(ScalarValue::Date32(Some(123)))); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + // test column on the right + let expr = + lit(ScalarValue::Date32(Some(123))).eq(cast(col("c1"), DataType::Date32)); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + Ok(()) + } + + #[test] + fn row_group_predicate_dict_dict_same_value_type() -> Result<()> { + // Test with Dictionary types that have the same value type but different key types + let schema = Schema::new(vec![Field::new( + "c1", + DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Utf8)), + false, + )]); + + // Direct comparison with no cast + let expr = col("c1").eq(lit(ScalarValue::Utf8(Some("test".to_string())))); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + let expected_expr = + "c1_null_count@2 != row_count@3 AND c1_min@0 <= test AND test <= c1_max@1"; + assert_eq!(predicate_expr.to_string(), expected_expr); + + // Test with column cast to a dictionary with different key type + let expr = cast( + col("c1"), + DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Utf8)), + ) + .eq(lit(ScalarValue::Utf8(Some("test".to_string())))); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + let expected_expr = "c1_null_count@2 != row_count@3 AND CAST(c1_min@0 AS Dictionary(UInt16, Utf8)) <= test AND test <= CAST(c1_max@1 AS Dictionary(UInt16, Utf8))"; + assert_eq!(predicate_expr.to_string(), expected_expr); + + Ok(()) + } + + #[test] + fn row_group_predicate_dict_dict_different_value_type() -> Result<()> { + // Test with Dictionary types that have different value types + let schema = Schema::new(vec![Field::new( + "c1", + DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Int32)), + false, + )]); + let expected_expr = "c1_null_count@2 != row_count@3 AND CAST(c1_min@0 AS Int64) <= 123 AND 123 <= CAST(c1_max@1 AS Int64)"; + + // Test with literal of a different type + let expr = + cast(col("c1"), DataType::Int64).eq(lit(ScalarValue::Int64(Some(123)))); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + Ok(()) + } + + #[test] + fn row_group_predicate_nested_dict() -> Result<()> { + // Test with nested Dictionary types + let schema = Schema::new(vec![Field::new( + "c1", + DataType::Dictionary( + Box::new(DataType::UInt8), + Box::new(DataType::Dictionary( + Box::new(DataType::UInt16), + Box::new(DataType::Utf8), + )), + ), + false, + )]); + let expected_expr = + "c1_null_count@2 != row_count@3 AND c1_min@0 <= test AND test <= c1_max@1"; + + // Test with a simple literal + let expr = col("c1").eq(lit(ScalarValue::Utf8(Some("test".to_string())))); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + Ok(()) + } + + #[test] + fn row_group_predicate_dict_date_dict_date() -> Result<()> { + // Test with dictionary-wrapped date types for both sides + let schema = Schema::new(vec![Field::new( + "c1", + DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Date32)), + false, + )]); + let expected_expr = "c1_null_count@2 != row_count@3 AND CAST(c1_min@0 AS Dictionary(UInt16, Date64)) <= 1970-01-01 AND 1970-01-01 <= CAST(c1_max@1 AS Dictionary(UInt16, Date64))"; + + // Test with a cast to a different date type + let expr = cast( + col("c1"), + DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Date64)), + ) + .eq(lit(ScalarValue::Date64(Some(123)))); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + Ok(()) + } + + #[test] + fn row_group_predicate_date_string() -> Result<()> { + let schema = Schema::new(vec![Field::new("c1", DataType::Utf8, false)]); + let expected_expr = "true"; + + // test column on the left + let expr = + cast(col("c1"), DataType::Date32).eq(lit(ScalarValue::Date32(Some(123)))); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + // test column on the right + let expr = + lit(ScalarValue::Date32(Some(123))).eq(cast(col("c1"), DataType::Date32)); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + Ok(()) + } + + #[test] + fn row_group_predicate_string_date() -> Result<()> { + let schema = Schema::new(vec![Field::new("c1", DataType::Date32, false)]); + let expected_expr = "true"; + + // test column on the left + let expr = cast(col("c1"), DataType::Utf8) + .eq(lit(ScalarValue::Utf8(Some("2024-01-01".to_string())))); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + // test column on the right + let expr = lit(ScalarValue::Utf8(Some("2024-01-01".to_string()))) + .eq(cast(col("c1"), DataType::Utf8)); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + Ok(()) + } + #[test] fn row_group_predicate_cast_list() -> Result<()> { let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); @@ -3285,12 +3502,10 @@ mod tests { prune_with_expr( // false - // constant literals that do NOT refer to any columns are currently not evaluated at all, hence the result is - // "all true" lit(false), &schema, &statistics, - &[true, true, true, true, true], + &[false, false, false, false, false], ); } @@ -4855,7 +5070,7 @@ mod tests { statistics: &TestStatistics, expected: &[bool], ) { - println!("Pruning with expr: {}", expr); + println!("Pruning with expr: {expr}"); let expr = logical2physical(&expr, schema); let p = PruningPredicate::try_new(expr, Arc::::clone(schema)).unwrap(); let result = p.prune(statistics).unwrap(); @@ -4871,4 +5086,42 @@ mod tests { let unhandled_hook = Arc::new(ConstantUnhandledPredicateHook::default()) as _; build_predicate_expression(&expr, schema, required_columns, &unhandled_hook) } + + #[test] + fn test_build_predicate_expression_with_false() { + let expr = lit(ScalarValue::Boolean(Some(false))); + let schema = Schema::empty(); + let res = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + let expected = logical2physical(&expr, &schema); + assert_eq!(&res, &expected); + } + + #[test] + fn test_build_predicate_expression_with_and_false() { + let schema = Schema::new(vec![Field::new("c1", DataType::Utf8View, false)]); + let expr = and( + col("c1").eq(lit("a")), + lit(ScalarValue::Boolean(Some(false))), + ); + let res = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + let expected = logical2physical(&lit(ScalarValue::Boolean(Some(false))), &schema); + assert_eq!(&res, &expected); + } + + #[test] + fn test_build_predicate_expression_with_or_false() { + let schema = Schema::new(vec![Field::new("c1", DataType::Utf8View, false)]); + let left_expr = col("c1").eq(lit("a")); + let right_expr = lit(ScalarValue::Boolean(Some(false))); + let res = test_build_predicate_expression( + &or(left_expr.clone(), right_expr.clone()), + &schema, + &mut RequiredColumns::new(), + ); + let expected = + "c1_null_count@2 != row_count@3 AND c1_min@0 <= a AND a <= c1_max@1"; + assert_eq!(res.to_string(), expected); + } } diff --git a/datafusion/physical-optimizer/src/update_aggr_exprs.rs b/datafusion/physical-optimizer/src/update_aggr_exprs.rs index 6228ed10ec34..ae1a38230d04 100644 --- a/datafusion/physical-optimizer/src/update_aggr_exprs.rs +++ b/datafusion/physical-optimizer/src/update_aggr_exprs.rs @@ -24,10 +24,11 @@ use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{plan_datafusion_err, Result}; use datafusion_physical_expr::aggregate::AggregateFunctionExpr; +use datafusion_physical_expr::LexRequirement; use datafusion_physical_expr::{ reverse_order_bys, EquivalenceProperties, PhysicalSortRequirement, }; -use datafusion_physical_expr::{LexOrdering, LexRequirement}; +use datafusion_physical_expr_common::sort_expr::LexOrdering; use datafusion_physical_plan::aggregates::concat_slices; use datafusion_physical_plan::windows::get_ordered_partition_by_indices; use datafusion_physical_plan::{ @@ -159,7 +160,9 @@ fn try_convert_aggregate_if_better( aggr_exprs .into_iter() .map(|aggr_expr| { - let aggr_sort_exprs = aggr_expr.order_bys().unwrap_or(LexOrdering::empty()); + let aggr_sort_exprs = aggr_expr + .order_bys() + .unwrap_or_else(|| LexOrdering::empty()); let reverse_aggr_sort_exprs = reverse_order_bys(aggr_sort_exprs); let aggr_sort_reqs = LexRequirement::from(aggr_sort_exprs.clone()); let reverse_aggr_req = LexRequirement::from(reverse_aggr_sort_exprs); diff --git a/datafusion/physical-plan/Cargo.toml b/datafusion/physical-plan/Cargo.toml index 5210ee26755c..4f58b575f3a0 100644 --- a/datafusion/physical-plan/Cargo.toml +++ b/datafusion/physical-plan/Cargo.toml @@ -36,6 +36,7 @@ workspace = true [features] force_hash_collisions = [] +bench = [] [lib] name = "datafusion_physical_plan" @@ -86,3 +87,8 @@ name = "partial_ordering" [[bench]] harness = false name = "spill_io" + +[[bench]] +harness = false +name = "sort_preserving_merge" +required-features = ["bench"] diff --git a/datafusion/physical-plan/benches/partial_ordering.rs b/datafusion/physical-plan/benches/partial_ordering.rs index 422826abcc8b..22d18dd24891 100644 --- a/datafusion/physical-plan/benches/partial_ordering.rs +++ b/datafusion/physical-plan/benches/partial_ordering.rs @@ -40,7 +40,7 @@ fn bench_new_groups(c: &mut Criterion) { // Test with 1, 2, 4, and 8 order indices for num_columns in [1, 2, 4, 8] { let fields: Vec = (0..num_columns) - .map(|i| Field::new(format!("col{}", i), DataType::Int32, false)) + .map(|i| Field::new(format!("col{i}"), DataType::Int32, false)) .collect(); let schema = Schema::new(fields); @@ -49,14 +49,14 @@ fn bench_new_groups(c: &mut Criterion) { (0..num_columns) .map(|i| { PhysicalSortExpr::new( - col(&format!("col{}", i), &schema).unwrap(), + col(&format!("col{i}"), &schema).unwrap(), SortOptions::default(), ) }) .collect(), ); - group.bench_function(format!("order_indices_{}", num_columns), |b| { + group.bench_function(format!("order_indices_{num_columns}"), |b| { let batch_group_values = create_test_arrays(num_columns); let group_indices: Vec = (0..BATCH_SIZE).collect(); diff --git a/datafusion/physical-plan/benches/sort_preserving_merge.rs b/datafusion/physical-plan/benches/sort_preserving_merge.rs new file mode 100644 index 000000000000..9586dbf94727 --- /dev/null +++ b/datafusion/physical-plan/benches/sort_preserving_merge.rs @@ -0,0 +1,202 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::{ + array::{ArrayRef, StringArray, UInt64Array}, + record_batch::RecordBatch, +}; +use arrow_schema::{SchemaRef, SortOptions}; +use criterion::{criterion_group, criterion_main, BatchSize, Criterion}; +use datafusion_execution::TaskContext; +use datafusion_physical_expr::{expressions::col, LexOrdering, PhysicalSortExpr}; +use datafusion_physical_plan::test::TestMemoryExec; +use datafusion_physical_plan::{ + collect, sorts::sort_preserving_merge::SortPreservingMergeExec, +}; + +use std::sync::Arc; + +const BENCH_ROWS: usize = 1_000_000; // 1 million rows + +fn get_large_string(idx: usize) -> String { + let base_content = [ + concat!( + "# Advanced Topics in Computer Science\n\n", + "## Summary\nThis article explores complex system design patterns and...\n\n", + "```rust\nfn process_data(data: &mut [i32]) {\n // Parallel processing example\n data.par_iter_mut().for_each(|x| *x *= 2);\n}\n```\n\n", + "## Performance Considerations\nWhen implementing concurrent systems...\n" + ), + concat!( + "## API Documentation\n\n", + "```json\n{\n \"endpoint\": \"/api/v2/users\",\n \"methods\": [\"GET\", \"POST\"],\n \"parameters\": {\n \"page\": \"number\"\n }\n}\n```\n\n", + "# Authentication Guide\nSecure your API access using OAuth 2.0...\n" + ), + concat!( + "# Data Processing Pipeline\n\n", + "```python\nfrom multiprocessing import Pool\n\ndef main():\n with Pool(8) as p:\n results = p.map(process_item, data)\n```\n\n", + "## Summary of Optimizations\n1. Batch processing\n2. Memory pooling\n3. Concurrent I/O operations\n" + ), + concat!( + "# System Architecture Overview\n\n", + "## Components\n- Load Balancer\n- Database Cluster\n- Cache Service\n\n", + "```go\nfunc main() {\n router := gin.Default()\n router.GET(\"/api/health\", healthCheck)\n router.Run(\":8080\")\n}\n```\n" + ), + concat!( + "## Configuration Reference\n\n", + "```yaml\nserver:\n port: 8080\n max_threads: 32\n\ndatabase:\n url: postgres://user@prod-db:5432/main\n```\n\n", + "# Deployment Strategies\nBlue-green deployment patterns with...\n" + ), + ]; + base_content[idx % base_content.len()].to_string() +} + +fn generate_sorted_string_column(rows: usize) -> ArrayRef { + let mut values = Vec::with_capacity(rows); + for i in 0..rows { + values.push(get_large_string(i)); + } + values.sort(); + Arc::new(StringArray::from(values)) +} + +fn generate_sorted_u64_column(rows: usize) -> ArrayRef { + Arc::new(UInt64Array::from((0_u64..rows as u64).collect::>())) +} + +fn create_partitions( + num_partitions: usize, + num_columns: usize, + num_rows: usize, +) -> Vec> { + (0..num_partitions) + .map(|_| { + let rows = (0..num_columns) + .map(|i| { + ( + format!("col-{i}"), + if IS_LARGE_COLUMN_TYPE { + generate_sorted_string_column(num_rows) + } else { + generate_sorted_u64_column(num_rows) + }, + ) + }) + .collect::>(); + + let batch = RecordBatch::try_from_iter(rows).unwrap(); + vec![batch] + }) + .collect() +} + +struct BenchData { + bench_name: String, + partitions: Vec>, + schema: SchemaRef, + sort_order: LexOrdering, +} + +fn get_bench_data() -> Vec { + let mut ret = Vec::new(); + let mut push_bench_data = |bench_name: &str, partitions: Vec>| { + let schema = partitions[0][0].schema(); + // Define sort order (col1 ASC, col2 ASC, col3 ASC) + let sort_order = LexOrdering::new( + schema + .fields() + .iter() + .map(|field| { + PhysicalSortExpr::new( + col(field.name(), &schema).unwrap(), + SortOptions::default(), + ) + }) + .collect(), + ); + ret.push(BenchData { + bench_name: bench_name.to_string(), + partitions, + schema, + sort_order, + }); + }; + // 1. single large string column + { + let partitions = create_partitions::(3, 1, BENCH_ROWS); + push_bench_data("single_large_string_column_with_1m_rows", partitions); + } + // 2. single u64 column + { + let partitions = create_partitions::(3, 1, BENCH_ROWS); + push_bench_data("single_u64_column_with_1m_rows", partitions); + } + // 3. multiple large string columns + { + let partitions = create_partitions::(3, 3, BENCH_ROWS); + push_bench_data("multiple_large_string_columns_with_1m_rows", partitions); + } + // 4. multiple u64 columns + { + let partitions = create_partitions::(3, 3, BENCH_ROWS); + push_bench_data("multiple_u64_columns_with_1m_rows", partitions); + } + ret +} + +/// Add a benchmark to test the optimization effect of reusing Rows. +/// Run this benchmark with: +/// ```sh +/// cargo bench --features="bench" --bench sort_preserving_merge -- --sample-size=10 +/// ``` +fn bench_merge_sorted_preserving(c: &mut Criterion) { + let task_ctx = Arc::new(TaskContext::default()); + let bench_data = get_bench_data(); + for data in bench_data.into_iter() { + let BenchData { + bench_name, + partitions, + schema, + sort_order, + } = data; + c.bench_function( + &format!("bench_merge_sorted_preserving/{}", bench_name), + |b| { + b.iter_batched( + || { + let exec = TestMemoryExec::try_new_exec( + &partitions, + schema.clone(), + None, + ) + .unwrap(); + Arc::new(SortPreservingMergeExec::new(sort_order.clone(), exec)) + }, + |merge_exec| { + let rt = tokio::runtime::Runtime::new().unwrap(); + rt.block_on(async { + collect(merge_exec, task_ctx.clone()).await.unwrap(); + }); + }, + BatchSize::LargeInput, + ) + }, + ); + } +} + +criterion_group!(benches, bench_merge_sorted_preserving); +criterion_main!(benches); diff --git a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/bytes.rs b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/bytes.rs index c4525256dbae..be1f68ea453f 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/bytes.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/bytes.rs @@ -24,6 +24,7 @@ use arrow::array::{ use arrow::buffer::{OffsetBuffer, ScalarBuffer}; use arrow::datatypes::{ByteArrayType, DataType, GenericBinaryType}; use datafusion_common::utils::proxy::VecAllocExt; +use datafusion_common::{DataFusionError, Result}; use datafusion_physical_expr_common::binary_map::{OutputType, INITIAL_BUFFER_CAPACITY}; use itertools::izip; use std::mem::size_of; @@ -50,6 +51,8 @@ where offsets: Vec, /// Nulls nulls: MaybeNullBufferBuilder, + /// The maximum size of the buffer for `0` + max_buffer_size: usize, } impl ByteGroupValueBuilder @@ -62,6 +65,11 @@ where buffer: BufferBuilder::new(INITIAL_BUFFER_CAPACITY), offsets: vec![O::default()], nulls: MaybeNullBufferBuilder::new(), + max_buffer_size: if O::IS_LARGE { + i64::MAX as usize + } else { + i32::MAX as usize + }, } } @@ -73,7 +81,7 @@ where self.do_equal_to_inner(lhs_row, array, rhs_row) } - fn append_val_inner(&mut self, array: &ArrayRef, row: usize) + fn append_val_inner(&mut self, array: &ArrayRef, row: usize) -> Result<()> where B: ByteArrayType, { @@ -85,8 +93,10 @@ where self.offsets.push(O::usize_as(offset)); } else { self.nulls.append(false); - self.do_append_val_inner(arr, row); + self.do_append_val_inner(arr, row)?; } + + Ok(()) } fn vectorized_equal_to_inner( @@ -116,7 +126,11 @@ where } } - fn vectorized_append_inner(&mut self, array: &ArrayRef, rows: &[usize]) + fn vectorized_append_inner( + &mut self, + array: &ArrayRef, + rows: &[usize], + ) -> Result<()> where B: ByteArrayType, { @@ -134,22 +148,14 @@ where match all_null_or_non_null { None => { for &row in rows { - if arr.is_null(row) { - self.nulls.append(true); - // nulls need a zero length in the offset buffer - let offset = self.buffer.len(); - self.offsets.push(O::usize_as(offset)); - } else { - self.nulls.append(false); - self.do_append_val_inner(arr, row); - } + self.append_val_inner::(array, row)? } } Some(true) => { self.nulls.append_n(rows.len(), false); for &row in rows { - self.do_append_val_inner(arr, row); + self.do_append_val_inner(arr, row)?; } } @@ -161,6 +167,8 @@ where self.offsets.resize(new_len, O::usize_as(offset)); } } + + Ok(()) } fn do_equal_to_inner( @@ -181,13 +189,26 @@ where self.value(lhs_row) == (array.value(rhs_row).as_ref() as &[u8]) } - fn do_append_val_inner(&mut self, array: &GenericByteArray, row: usize) + fn do_append_val_inner( + &mut self, + array: &GenericByteArray, + row: usize, + ) -> Result<()> where B: ByteArrayType, { let value: &[u8] = array.value(row).as_ref(); self.buffer.append_slice(value); + + if self.buffer.len() > self.max_buffer_size { + return Err(DataFusionError::Execution(format!( + "offset overflow, buffer size > {}", + self.max_buffer_size + ))); + } + self.offsets.push(O::usize_as(self.buffer.len())); + Ok(()) } /// return the current value of the specified row irrespective of null @@ -224,7 +245,7 @@ where } } - fn append_val(&mut self, column: &ArrayRef, row: usize) { + fn append_val(&mut self, column: &ArrayRef, row: usize) -> Result<()> { // Sanity array type match self.output_type { OutputType::Binary => { @@ -232,17 +253,19 @@ where column.data_type(), DataType::Binary | DataType::LargeBinary )); - self.append_val_inner::>(column, row) + self.append_val_inner::>(column, row)? } OutputType::Utf8 => { debug_assert!(matches!( column.data_type(), DataType::Utf8 | DataType::LargeUtf8 )); - self.append_val_inner::>(column, row) + self.append_val_inner::>(column, row)? } _ => unreachable!("View types should use `ArrowBytesViewMap`"), }; + + Ok(()) } fn vectorized_equal_to( @@ -282,24 +305,26 @@ where } } - fn vectorized_append(&mut self, column: &ArrayRef, rows: &[usize]) { + fn vectorized_append(&mut self, column: &ArrayRef, rows: &[usize]) -> Result<()> { match self.output_type { OutputType::Binary => { debug_assert!(matches!( column.data_type(), DataType::Binary | DataType::LargeBinary )); - self.vectorized_append_inner::>(column, rows) + self.vectorized_append_inner::>(column, rows)? } OutputType::Utf8 => { debug_assert!(matches!( column.data_type(), DataType::Utf8 | DataType::LargeUtf8 )); - self.vectorized_append_inner::>(column, rows) + self.vectorized_append_inner::>(column, rows)? } _ => unreachable!("View types should use `ArrowBytesViewMap`"), }; + + Ok(()) } fn len(&self) -> usize { @@ -318,6 +343,7 @@ where mut buffer, offsets, nulls, + .. } = *self; let null_buffer = nulls.build(); @@ -406,27 +432,50 @@ mod tests { use crate::aggregates::group_values::multi_group_by::bytes::ByteGroupValueBuilder; use arrow::array::{ArrayRef, NullBufferBuilder, StringArray}; + use datafusion_common::DataFusionError; use datafusion_physical_expr::binary_map::OutputType; use super::GroupColumn; + #[test] + fn test_byte_group_value_builder_overflow() { + let mut builder = ByteGroupValueBuilder::::new(OutputType::Utf8); + + let large_string = "a".repeat(1024 * 1024); + + let array = + Arc::new(StringArray::from(vec![Some(large_string.as_str())])) as ArrayRef; + + // Append items until our buffer length is i32::MAX as usize + for _ in 0..2047 { + builder.append_val(&array, 0).unwrap(); + } + + assert!(matches!( + builder.append_val(&array, 0), + Err(DataFusionError::Execution(e)) if e.contains("offset overflow") + )); + + assert_eq!(builder.value(2046), large_string.as_bytes()); + } + #[test] fn test_byte_take_n() { let mut builder = ByteGroupValueBuilder::::new(OutputType::Utf8); let array = Arc::new(StringArray::from(vec![Some("a"), None])) as ArrayRef; // a, null, null - builder.append_val(&array, 0); - builder.append_val(&array, 1); - builder.append_val(&array, 1); + builder.append_val(&array, 0).unwrap(); + builder.append_val(&array, 1).unwrap(); + builder.append_val(&array, 1).unwrap(); // (a, null) remaining: null let output = builder.take_n(2); assert_eq!(&output, &array); // null, a, null, a - builder.append_val(&array, 0); - builder.append_val(&array, 1); - builder.append_val(&array, 0); + builder.append_val(&array, 0).unwrap(); + builder.append_val(&array, 1).unwrap(); + builder.append_val(&array, 0).unwrap(); // (null, a) remaining: (null, a) let output = builder.take_n(2); @@ -440,9 +489,9 @@ mod tests { ])) as ArrayRef; // null, a, longstringfortest, null, null - builder.append_val(&array, 2); - builder.append_val(&array, 1); - builder.append_val(&array, 1); + builder.append_val(&array, 2).unwrap(); + builder.append_val(&array, 1).unwrap(); + builder.append_val(&array, 1).unwrap(); // (null, a, longstringfortest, null) remaining: (null) let output = builder.take_n(4); @@ -461,7 +510,7 @@ mod tests { builder_array: &ArrayRef, append_rows: &[usize]| { for &index in append_rows { - builder.append_val(builder_array, index); + builder.append_val(builder_array, index).unwrap(); } }; @@ -484,7 +533,9 @@ mod tests { let append = |builder: &mut ByteGroupValueBuilder, builder_array: &ArrayRef, append_rows: &[usize]| { - builder.vectorized_append(builder_array, append_rows); + builder + .vectorized_append(builder_array, append_rows) + .unwrap(); }; let equal_to = |builder: &ByteGroupValueBuilder, @@ -518,7 +569,9 @@ mod tests { None, None, ])) as _; - builder.vectorized_append(&all_nulls_input_array, &[0, 1, 2, 3, 4]); + builder + .vectorized_append(&all_nulls_input_array, &[0, 1, 2, 3, 4]) + .unwrap(); let mut equal_to_results = vec![true; all_nulls_input_array.len()]; builder.vectorized_equal_to( @@ -542,7 +595,9 @@ mod tests { Some("string4"), Some("string5"), ])) as _; - builder.vectorized_append(&all_not_nulls_input_array, &[0, 1, 2, 3, 4]); + builder + .vectorized_append(&all_not_nulls_input_array, &[0, 1, 2, 3, 4]) + .unwrap(); let mut equal_to_results = vec![true; all_not_nulls_input_array.len()]; builder.vectorized_equal_to( diff --git a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/bytes_view.rs b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/bytes_view.rs index b6d97b5d788d..63018874a1e4 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/bytes_view.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/bytes_view.rs @@ -20,6 +20,7 @@ use crate::aggregates::group_values::null_builder::MaybeNullBufferBuilder; use arrow::array::{make_view, Array, ArrayRef, AsArray, ByteView, GenericByteViewArray}; use arrow::buffer::{Buffer, ScalarBuffer}; use arrow::datatypes::ByteViewType; +use datafusion_common::Result; use itertools::izip; use std::marker::PhantomData; use std::mem::{replace, size_of}; @@ -148,14 +149,7 @@ impl ByteViewGroupValueBuilder { match all_null_or_non_null { None => { for &row in rows { - // Null row case, set and return - if arr.is_valid(row) { - self.nulls.append(false); - self.do_append_val_inner(arr, row); - } else { - self.nulls.append(true); - self.views.push(0); - } + self.append_val_inner(array, row); } } @@ -493,8 +487,9 @@ impl GroupColumn for ByteViewGroupValueBuilder { self.equal_to_inner(lhs_row, array, rhs_row) } - fn append_val(&mut self, array: &ArrayRef, row: usize) { - self.append_val_inner(array, row) + fn append_val(&mut self, array: &ArrayRef, row: usize) -> Result<()> { + self.append_val_inner(array, row); + Ok(()) } fn vectorized_equal_to( @@ -507,8 +502,9 @@ impl GroupColumn for ByteViewGroupValueBuilder { self.vectorized_equal_to_inner(group_indices, array, rows, equal_to_results); } - fn vectorized_append(&mut self, array: &ArrayRef, rows: &[usize]) { + fn vectorized_append(&mut self, array: &ArrayRef, rows: &[usize]) -> Result<()> { self.vectorized_append_inner(array, rows); + Ok(()) } fn len(&self) -> usize { @@ -563,7 +559,7 @@ mod tests { ]); let builder_array: ArrayRef = Arc::new(builder_array); for row in 0..builder_array.len() { - builder.append_val(&builder_array, row); + builder.append_val(&builder_array, row).unwrap(); } let output = Box::new(builder).build(); @@ -578,7 +574,7 @@ mod tests { builder_array: &ArrayRef, append_rows: &[usize]| { for &index in append_rows { - builder.append_val(builder_array, index); + builder.append_val(builder_array, index).unwrap(); } }; @@ -601,7 +597,9 @@ mod tests { let append = |builder: &mut ByteViewGroupValueBuilder, builder_array: &ArrayRef, append_rows: &[usize]| { - builder.vectorized_append(builder_array, append_rows); + builder + .vectorized_append(builder_array, append_rows) + .unwrap(); }; let equal_to = |builder: &ByteViewGroupValueBuilder, @@ -636,7 +634,9 @@ mod tests { None, None, ])) as _; - builder.vectorized_append(&all_nulls_input_array, &[0, 1, 2, 3, 4]); + builder + .vectorized_append(&all_nulls_input_array, &[0, 1, 2, 3, 4]) + .unwrap(); let mut equal_to_results = vec![true; all_nulls_input_array.len()]; builder.vectorized_equal_to( @@ -660,7 +660,9 @@ mod tests { Some("stringview4"), Some("stringview5"), ])) as _; - builder.vectorized_append(&all_not_nulls_input_array, &[0, 1, 2, 3, 4]); + builder + .vectorized_append(&all_not_nulls_input_array, &[0, 1, 2, 3, 4]) + .unwrap(); let mut equal_to_results = vec![true; all_not_nulls_input_array.len()]; builder.vectorized_equal_to( @@ -841,7 +843,7 @@ mod tests { // ####### Test situation 1~5 ####### for row in 0..first_ones_to_append { - builder.append_val(&input_array, row); + builder.append_val(&input_array, row).unwrap(); } assert_eq!(builder.completed.len(), 2); @@ -879,7 +881,7 @@ mod tests { assert!(builder.views.is_empty()); for row in first_ones_to_append..first_ones_to_append + second_ones_to_append { - builder.append_val(&input_array, row); + builder.append_val(&input_array, row).unwrap(); } assert!(builder.completed.is_empty()); @@ -894,7 +896,7 @@ mod tests { ByteViewGroupValueBuilder::::new().with_max_block_size(60); for row in 0..final_ones_to_append { - builder.append_val(&input_array, row); + builder.append_val(&input_array, row).unwrap(); } assert_eq!(builder.completed.len(), 3); diff --git a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs index ac96a98edfe1..2ac0389454de 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs @@ -65,7 +65,7 @@ pub trait GroupColumn: Send + Sync { fn equal_to(&self, lhs_row: usize, array: &ArrayRef, rhs_row: usize) -> bool; /// Appends the row at `row` in `array` to this builder - fn append_val(&mut self, array: &ArrayRef, row: usize); + fn append_val(&mut self, array: &ArrayRef, row: usize) -> Result<()>; /// The vectorized version equal to /// @@ -86,7 +86,7 @@ pub trait GroupColumn: Send + Sync { ); /// The vectorized version `append_val` - fn vectorized_append(&mut self, array: &ArrayRef, rows: &[usize]); + fn vectorized_append(&mut self, array: &ArrayRef, rows: &[usize]) -> Result<()>; /// Returns the number of rows stored in this builder fn len(&self) -> usize; @@ -270,7 +270,7 @@ impl GroupValuesColumn { map_size: 0, group_values: vec![], hashes_buffer: Default::default(), - random_state: Default::default(), + random_state: crate::aggregates::AGGREGATION_HASH_SEED, }) } @@ -384,7 +384,7 @@ impl GroupValuesColumn { let mut checklen = 0; let group_idx = self.group_values[0].len(); for (i, group_value) in self.group_values.iter_mut().enumerate() { - group_value.append_val(&cols[i], row); + group_value.append_val(&cols[i], row)?; let len = group_value.len(); if i == 0 { checklen = len; @@ -460,14 +460,14 @@ impl GroupValuesColumn { self.collect_vectorized_process_context(&batch_hashes, groups); // 2. Perform `vectorized_append` - self.vectorized_append(cols); + self.vectorized_append(cols)?; // 3. Perform `vectorized_equal_to` self.vectorized_equal_to(cols, groups); // 4. Perform scalarized inter for remaining rows // (about remaining rows, can see comments for `remaining_row_indices`) - self.scalarized_intern_remaining(cols, &batch_hashes, groups); + self.scalarized_intern_remaining(cols, &batch_hashes, groups)?; self.hashes_buffer = batch_hashes; @@ -563,13 +563,13 @@ impl GroupValuesColumn { } /// Perform `vectorized_append`` for `rows` in `vectorized_append_row_indices` - fn vectorized_append(&mut self, cols: &[ArrayRef]) { + fn vectorized_append(&mut self, cols: &[ArrayRef]) -> Result<()> { if self .vectorized_operation_buffers .append_row_indices .is_empty() { - return; + return Ok(()); } let iter = self.group_values.iter_mut().zip(cols.iter()); @@ -577,8 +577,10 @@ impl GroupValuesColumn { group_column.vectorized_append( col, &self.vectorized_operation_buffers.append_row_indices, - ); + )?; } + + Ok(()) } /// Perform `vectorized_equal_to` @@ -719,13 +721,13 @@ impl GroupValuesColumn { cols: &[ArrayRef], batch_hashes: &[u64], groups: &mut [usize], - ) { + ) -> Result<()> { if self .vectorized_operation_buffers .remaining_row_indices .is_empty() { - return; + return Ok(()); } let mut map = mem::take(&mut self.map); @@ -758,7 +760,7 @@ impl GroupValuesColumn { let group_idx = self.group_values[0].len(); let mut checklen = 0; for (i, group_value) in self.group_values.iter_mut().enumerate() { - group_value.append_val(&cols[i], row); + group_value.append_val(&cols[i], row)?; let len = group_value.len(); if i == 0 { checklen = len; @@ -795,6 +797,7 @@ impl GroupValuesColumn { } self.map = map; + Ok(()) } fn scalarized_equal_to_remaining( @@ -1756,11 +1759,9 @@ mod tests { (i, actual_line), (i, expected_line), "Inconsistent result\n\n\ - Actual batch:\n{}\n\ - Expected batch:\n{}\n\ + Actual batch:\n{formatted_actual_batch}\n\ + Expected batch:\n{formatted_expected_batch}\n\ ", - formatted_actual_batch, - formatted_expected_batch, ); } } diff --git a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/primitive.rs b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/primitive.rs index e9c3c42e632b..afec25fd3d66 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/primitive.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/primitive.rs @@ -17,9 +17,11 @@ use crate::aggregates::group_values::multi_group_by::{nulls_equal_to, GroupColumn}; use crate::aggregates::group_values::null_builder::MaybeNullBufferBuilder; +use arrow::array::ArrowNativeTypeOp; use arrow::array::{cast::AsArray, Array, ArrayRef, ArrowPrimitiveType, PrimitiveArray}; use arrow::buffer::ScalarBuffer; use arrow::datatypes::DataType; +use datafusion_common::Result; use datafusion_execution::memory_pool::proxy::VecAllocExt; use itertools::izip; use std::iter; @@ -71,7 +73,7 @@ impl GroupColumn self.group_values[lhs_row] == array.as_primitive::().value(rhs_row) } - fn append_val(&mut self, array: &ArrayRef, row: usize) { + fn append_val(&mut self, array: &ArrayRef, row: usize) -> Result<()> { // Perf: skip null check if input can't have nulls if NULLABLE { if array.is_null(row) { @@ -84,6 +86,8 @@ impl GroupColumn } else { self.group_values.push(array.as_primitive::().value(row)); } + + Ok(()) } fn vectorized_equal_to( @@ -118,11 +122,11 @@ impl GroupColumn // Otherwise, we need to check their values } - *equal_to_result = self.group_values[lhs_row] == array.value(rhs_row); + *equal_to_result = self.group_values[lhs_row].is_eq(array.value(rhs_row)); } } - fn vectorized_append(&mut self, array: &ArrayRef, rows: &[usize]) { + fn vectorized_append(&mut self, array: &ArrayRef, rows: &[usize]) -> Result<()> { let arr = array.as_primitive::(); let null_count = array.null_count(); @@ -167,6 +171,8 @@ impl GroupColumn } } } + + Ok(()) } fn len(&self) -> usize { @@ -222,7 +228,7 @@ mod tests { builder_array: &ArrayRef, append_rows: &[usize]| { for &index in append_rows { - builder.append_val(builder_array, index); + builder.append_val(builder_array, index).unwrap(); } }; @@ -245,7 +251,9 @@ mod tests { let append = |builder: &mut PrimitiveGroupValueBuilder, builder_array: &ArrayRef, append_rows: &[usize]| { - builder.vectorized_append(builder_array, append_rows); + builder + .vectorized_append(builder_array, append_rows) + .unwrap(); }; let equal_to = |builder: &PrimitiveGroupValueBuilder, @@ -335,7 +343,7 @@ mod tests { builder_array: &ArrayRef, append_rows: &[usize]| { for &index in append_rows { - builder.append_val(builder_array, index); + builder.append_val(builder_array, index).unwrap(); } }; @@ -358,7 +366,9 @@ mod tests { let append = |builder: &mut PrimitiveGroupValueBuilder, builder_array: &ArrayRef, append_rows: &[usize]| { - builder.vectorized_append(builder_array, append_rows); + builder + .vectorized_append(builder_array, append_rows) + .unwrap(); }; let equal_to = |builder: &PrimitiveGroupValueBuilder, @@ -432,7 +442,9 @@ mod tests { None, None, ])) as _; - builder.vectorized_append(&all_nulls_input_array, &[0, 1, 2, 3, 4]); + builder + .vectorized_append(&all_nulls_input_array, &[0, 1, 2, 3, 4]) + .unwrap(); let mut equal_to_results = vec![true; all_nulls_input_array.len()]; builder.vectorized_equal_to( @@ -456,7 +468,9 @@ mod tests { Some(4), Some(5), ])) as _; - builder.vectorized_append(&all_not_nulls_input_array, &[0, 1, 2, 3, 4]); + builder + .vectorized_append(&all_not_nulls_input_array, &[0, 1, 2, 3, 4]) + .unwrap(); let mut equal_to_results = vec![true; all_not_nulls_input_array.len()]; builder.vectorized_equal_to( diff --git a/datafusion/physical-plan/src/aggregates/group_values/row.rs b/datafusion/physical-plan/src/aggregates/group_values/row.rs index 63751d470313..34893fcc4ed9 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/row.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/row.rs @@ -82,7 +82,7 @@ impl GroupValuesRows { pub fn try_new(schema: SchemaRef) -> Result { // Print a debugging message, so it is clear when the (slower) fallback // GroupValuesRows is used. - debug!("Creating GroupValuesRows for schema: {}", schema); + debug!("Creating GroupValuesRows for schema: {schema}"); let row_converter = RowConverter::new( schema .fields() @@ -106,7 +106,7 @@ impl GroupValuesRows { group_values: None, hashes_buffer: Default::default(), rows_buffer, - random_state: Default::default(), + random_state: crate::aggregates::AGGREGATION_HASH_SEED, }) } } @@ -202,6 +202,7 @@ impl GroupValues for GroupValuesRows { EmitTo::All => { let output = self.row_converter.convert_rows(&group_values)?; group_values.clear(); + self.map.clear(); output } EmitTo::First(n) => { diff --git a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs index d945d3ddcbf5..8b1905e54041 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs @@ -81,11 +81,14 @@ hash_float!(f16, f32, f64); pub struct GroupValuesPrimitive { /// The data type of the output array data_type: DataType, - /// Stores the group index based on the hash of its value + /// Stores the `(group_index, hash)` based on the hash of its value /// - /// We don't store the hashes as hashing fixed width primitives - /// is fast enough for this not to benefit performance - map: HashTable, + /// We also store `hash` is for reducing cost of rehashing. Such cost + /// is obvious in high cardinality group by situation. + /// More details can see: + /// + /// + map: HashTable<(usize, u64)>, /// The group index of the null value if any null_group: Option, /// The values for each group index @@ -102,7 +105,7 @@ impl GroupValuesPrimitive { map: HashTable::with_capacity(128), values: Vec::with_capacity(128), null_group: None, - random_state: Default::default(), + random_state: crate::aggregates::AGGREGATION_HASH_SEED, } } } @@ -127,15 +130,15 @@ where let hash = key.hash(state); let insert = self.map.entry( hash, - |g| unsafe { self.values.get_unchecked(*g).is_eq(key) }, - |g| unsafe { self.values.get_unchecked(*g).hash(state) }, + |&(g, _)| unsafe { self.values.get_unchecked(g).is_eq(key) }, + |&(_, h)| h, ); match insert { - hashbrown::hash_table::Entry::Occupied(o) => *o.get(), + hashbrown::hash_table::Entry::Occupied(o) => o.get().0, hashbrown::hash_table::Entry::Vacant(v) => { let g = self.values.len(); - v.insert(g); + v.insert((g, hash)); self.values.push(key); g } @@ -148,7 +151,7 @@ where } fn size(&self) -> usize { - self.map.capacity() * size_of::() + self.values.allocated_size() + self.map.capacity() * size_of::<(usize, u64)>() + self.values.allocated_size() } fn is_empty(&self) -> bool { @@ -181,12 +184,13 @@ where build_primitive(std::mem::take(&mut self.values), self.null_group.take()) } EmitTo::First(n) => { - self.map.retain(|group_idx| { + self.map.retain(|entry| { // Decrement group index by n + let group_idx = entry.0; match group_idx.checked_sub(n) { // Group index was >= n, shift value down Some(sub) => { - *group_idx = sub; + entry.0 = sub; true } // Group index was < n, so remove from table diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 8906468f68db..656c9a2cd5cb 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -27,7 +27,6 @@ use crate::aggregates::{ }; use crate::execution_plan::{CardinalityEffect, EmissionType}; use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet}; -use crate::projection::get_field_metadata; use crate::windows::get_ordered_partition_by_indices; use crate::{ DisplayFormatType, Distribution, ExecutionPlan, InputOrderMode, @@ -37,6 +36,7 @@ use crate::{ use arrow::array::{ArrayRef, UInt16Array, UInt32Array, UInt64Array, UInt8Array}; use arrow::datatypes::{Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; +use arrow_schema::FieldRef; use datafusion_common::stats::Precision; use datafusion_common::{internal_err, not_impl_err, Constraint, Constraints, Result}; use datafusion_execution::TaskContext; @@ -58,6 +58,10 @@ mod row_hash; mod topk; mod topk_stream; +/// Hard-coded seed for aggregations to ensure hash values differ from `RepartitionExec`, avoiding collisions. +const AGGREGATION_HASH_SEED: ahash::RandomState = + ahash::RandomState::with_seeds('A' as u64, 'G' as u64, 'G' as u64, 'R' as u64); + /// Aggregation modes /// /// See [`Accumulator::state`] for background information on multi-phase @@ -274,7 +278,7 @@ impl PhysicalGroupBy { } /// Returns the fields that are used as the grouping keys. - fn group_fields(&self, input_schema: &Schema) -> Result> { + fn group_fields(&self, input_schema: &Schema) -> Result> { let mut fields = Vec::with_capacity(self.num_group_exprs()); for ((expr, name), group_expr_nullable) in self.expr.iter().zip(self.exprs_nullable().into_iter()) @@ -285,17 +289,19 @@ impl PhysicalGroupBy { expr.data_type(input_schema)?, group_expr_nullable || expr.nullable(input_schema)?, ) - .with_metadata( - get_field_metadata(expr, input_schema).unwrap_or_default(), - ), + .with_metadata(expr.return_field(input_schema)?.metadata().clone()) + .into(), ); } if !self.is_single() { - fields.push(Field::new( - Aggregate::INTERNAL_GROUPING_ID, - Aggregate::grouping_id_type(self.expr.len()), - false, - )); + fields.push( + Field::new( + Aggregate::INTERNAL_GROUPING_ID, + Aggregate::grouping_id_type(self.expr.len()), + false, + ) + .into(), + ); } Ok(fields) } @@ -304,7 +310,7 @@ impl PhysicalGroupBy { /// /// This might be different from the `group_fields` that might contain internal expressions that /// should not be part of the output schema. - fn output_fields(&self, input_schema: &Schema) -> Result> { + fn output_fields(&self, input_schema: &Schema) -> Result> { let mut fields = self.group_fields(input_schema)?; fields.truncate(self.num_output_exprs()); Ok(fields) @@ -349,6 +355,7 @@ impl PartialEq for PhysicalGroupBy { } } +#[allow(clippy::large_enum_variant)] enum StreamType { AggregateStream(AggregateStream), GroupedHash(GroupedHashAggregateStream), @@ -623,7 +630,7 @@ impl AggregateExec { } /// Finds the DataType and SortDirection for this Aggregate, if there is one - pub fn get_minmax_desc(&self) -> Option<(Field, bool)> { + pub fn get_minmax_desc(&self) -> Option<(FieldRef, bool)> { let agg_expr = self.aggr_expr.iter().exactly_one().ok()?; agg_expr.get_minmax_desc() } @@ -735,6 +742,69 @@ impl AggregateExec { pub fn input_order_mode(&self) -> &InputOrderMode { &self.input_order_mode } + + fn statistics_inner(&self, child_statistics: Statistics) -> Result { + // TODO stats: group expressions: + // - once expressions will be able to compute their own stats, use it here + // - case where we group by on a column for which with have the `distinct` stat + // TODO stats: aggr expression: + // - aggregations sometimes also preserve invariants such as min, max... + + let column_statistics = { + // self.schema: [, ] + let mut column_statistics = Statistics::unknown_column(&self.schema()); + + for (idx, (expr, _)) in self.group_by.expr.iter().enumerate() { + if let Some(col) = expr.as_any().downcast_ref::() { + column_statistics[idx].max_value = child_statistics.column_statistics + [col.index()] + .max_value + .clone(); + + column_statistics[idx].min_value = child_statistics.column_statistics + [col.index()] + .min_value + .clone(); + } + } + + column_statistics + }; + match self.mode { + AggregateMode::Final | AggregateMode::FinalPartitioned + if self.group_by.expr.is_empty() => + { + Ok(Statistics { + num_rows: Precision::Exact(1), + column_statistics, + total_byte_size: Precision::Absent, + }) + } + _ => { + // When the input row count is 1, we can adopt that statistic keeping its reliability. + // When it is larger than 1, we degrade the precision since it may decrease after aggregation. + let num_rows = if let Some(value) = child_statistics.num_rows.get_value() + { + if *value > 1 { + child_statistics.num_rows.to_inexact() + } else if *value == 0 { + child_statistics.num_rows + } else { + // num_rows = 1 case + let grouping_set_num = self.group_by.groups.len(); + child_statistics.num_rows.map(|x| x * grouping_set_num) + } + } else { + Precision::Absent + }; + Ok(Statistics { + num_rows, + column_statistics, + total_byte_size: Precision::Absent, + }) + } + } + } } impl DisplayAs for AggregateExec { @@ -941,50 +1011,11 @@ impl ExecutionPlan for AggregateExec { } fn statistics(&self) -> Result { - // TODO stats: group expressions: - // - once expressions will be able to compute their own stats, use it here - // - case where we group by on a column for which with have the `distinct` stat - // TODO stats: aggr expression: - // - aggregations sometimes also preserve invariants such as min, max... - let column_statistics = Statistics::unknown_column(&self.schema()); - match self.mode { - AggregateMode::Final | AggregateMode::FinalPartitioned - if self.group_by.expr.is_empty() => - { - Ok(Statistics { - num_rows: Precision::Exact(1), - column_statistics, - total_byte_size: Precision::Absent, - }) - } - _ => { - // When the input row count is 0 or 1, we can adopt that statistic keeping its reliability. - // When it is larger than 1, we degrade the precision since it may decrease after aggregation. - let num_rows = if let Some(value) = - self.input().statistics()?.num_rows.get_value() - { - if *value > 1 { - self.input().statistics()?.num_rows.to_inexact() - } else if *value == 0 { - // Aggregation on an empty table creates a null row. - self.input() - .statistics()? - .num_rows - .add(&Precision::Exact(1)) - } else { - // num_rows = 1 case - self.input().statistics()?.num_rows - } - } else { - Precision::Absent - }; - Ok(Statistics { - num_rows, - column_statistics, - total_byte_size: Precision::Absent, - }) - } - } + self.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + self.statistics_inner(self.input().partition_statistics(partition)?) } fn cardinality_effect(&self) -> CardinalityEffect { @@ -1924,6 +1955,13 @@ mod tests { } fn statistics(&self) -> Result { + self.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + if partition.is_some() { + return Ok(Statistics::new_unknown(self.schema().as_ref())); + } let (_, batches) = some_data(); Ok(common::compute_record_batch_statistics( &[batches], diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 232565a04466..62f541443068 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -529,7 +529,12 @@ impl GroupedHashAggregateStream { }) .collect(); - let name = format!("GroupedHashAggregateStream[{partition}]"); + let agg_fn_names = aggregate_exprs + .iter() + .map(|expr| expr.human_display()) + .collect::>() + .join(", "); + let name = format!("GroupedHashAggregateStream[{partition}] ({agg_fn_names})"); let reservation = MemoryConsumer::new(name) .with_can_spill(true) .register(context.memory_pool()); diff --git a/datafusion/physical-plan/src/aggregates/topk/hash_table.rs b/datafusion/physical-plan/src/aggregates/topk/hash_table.rs index ae44eb35e6d0..47052fd52511 100644 --- a/datafusion/physical-plan/src/aggregates/topk/hash_table.rs +++ b/datafusion/physical-plan/src/aggregates/topk/hash_table.rs @@ -461,7 +461,7 @@ mod tests { let (_heap_idxs, map_idxs): (Vec<_>, Vec<_>) = heap_to_map.into_iter().unzip(); let ids = unsafe { map.take_all(map_idxs) }; assert_eq!( - format!("{:?}", ids), + format!("{ids:?}"), r#"[Some("1"), Some("2"), Some("3"), Some("4"), Some("5")]"# ); assert_eq!(map.len(), 0, "Map should have been cleared!"); diff --git a/datafusion/physical-plan/src/aggregates/topk/heap.rs b/datafusion/physical-plan/src/aggregates/topk/heap.rs index 8b4b07d211a0..ce47504daf03 100644 --- a/datafusion/physical-plan/src/aggregates/topk/heap.rs +++ b/datafusion/physical-plan/src/aggregates/topk/heap.rs @@ -348,7 +348,7 @@ impl TopKHeap { prefix, connector, hi.val, idx, hi.map_idx )); let new_prefix = if is_tail { "" } else { "│ " }; - let child_prefix = format!("{}{}", prefix, new_prefix); + let child_prefix = format!("{prefix}{new_prefix}"); let left_idx = idx * 2 + 1; let right_idx = idx * 2 + 2; @@ -372,7 +372,7 @@ impl Display for TopKHeap { if !self.heap.is_empty() { self._tree_print(0, String::new(), true, &mut output); } - write!(f, "{}", output) + write!(f, "{output}") } } diff --git a/datafusion/physical-plan/src/coalesce_batches.rs b/datafusion/physical-plan/src/coalesce_batches.rs index 5244038b9ae2..f35231fb6a99 100644 --- a/datafusion/physical-plan/src/coalesce_batches.rs +++ b/datafusion/physical-plan/src/coalesce_batches.rs @@ -32,9 +32,14 @@ use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use datafusion_common::Result; use datafusion_execution::TaskContext; +use datafusion_physical_expr::PhysicalExpr; use crate::coalesce::{BatchCoalescer, CoalescerState}; use crate::execution_plan::CardinalityEffect; +use crate::filter_pushdown::{ + ChildPushdownResult, FilterDescription, FilterPushdownPropagation, +}; +use datafusion_common::config::ConfigOptions; use futures::ready; use futures::stream::{Stream, StreamExt}; @@ -192,7 +197,16 @@ impl ExecutionPlan for CoalesceBatchesExec { } fn statistics(&self) -> Result { - Statistics::with_fetch(self.input.statistics()?, self.schema(), self.fetch, 0, 1) + self.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + self.input.partition_statistics(partition)?.with_fetch( + self.schema(), + self.fetch, + 0, + 1, + ) } fn with_fetch(&self, limit: Option) -> Option> { @@ -212,6 +226,25 @@ impl ExecutionPlan for CoalesceBatchesExec { fn cardinality_effect(&self) -> CardinalityEffect { CardinalityEffect::Equal } + + fn gather_filters_for_pushdown( + &self, + parent_filters: Vec>, + _config: &ConfigOptions, + ) -> Result { + Ok(FilterDescription::new_with_child_count(1) + .all_parent_filters_supported(parent_filters)) + } + + fn handle_child_pushdown_result( + &self, + child_pushdown_result: ChildPushdownResult, + _config: &ConfigOptions, + ) -> Result>> { + Ok(FilterPushdownPropagation::transparent( + child_pushdown_result, + )) + } } /// Stream for [`CoalesceBatchesExec`]. See [`CoalesceBatchesExec`] for more details. @@ -321,6 +354,7 @@ impl CoalesceBatchesStream { } } CoalesceBatchesStreamState::ReturnBuffer => { + let _timer = cloned_time.timer(); // Combine buffered batches into one batch and return it. let batch = self.coalescer.finish_batch()?; // Set to pull state for the next iteration. @@ -333,6 +367,7 @@ impl CoalesceBatchesStream { // If buffer is empty, return None indicating the stream is fully consumed. Poll::Ready(None) } else { + let _timer = cloned_time.timer(); // If the buffer still contains batches, prepare to return them. let batch = self.coalescer.finish_batch()?; Poll::Ready(Some(Ok(batch))) diff --git a/datafusion/physical-plan/src/coalesce_partitions.rs b/datafusion/physical-plan/src/coalesce_partitions.rs index 95a0c8f6ce83..114f830688c9 100644 --- a/datafusion/physical-plan/src/coalesce_partitions.rs +++ b/datafusion/physical-plan/src/coalesce_partitions.rs @@ -59,6 +59,12 @@ impl CoalescePartitionsExec { } } + /// Update fetch with the argument + pub fn with_fetch(mut self, fetch: Option) -> Self { + self.fetch = fetch; + self + } + /// Input execution plan pub fn input(&self) -> &Arc { &self.input @@ -190,7 +196,13 @@ impl ExecutionPlan for CoalescePartitionsExec { } fn statistics(&self) -> Result { - Statistics::with_fetch(self.input.statistics()?, self.schema(), self.fetch, 0, 1) + self.partition_statistics(None) + } + + fn partition_statistics(&self, _partition: Option) -> Result { + self.input + .partition_statistics(None)? + .with_fetch(self.schema(), self.fetch, 0, 1) } fn supports_limit_pushdown(&self) -> bool { diff --git a/datafusion/physical-plan/src/common.rs b/datafusion/physical-plan/src/common.rs index a8d4a3ddf3d1..35f3e8d16e22 100644 --- a/datafusion/physical-plan/src/common.rs +++ b/datafusion/physical-plan/src/common.rs @@ -18,8 +18,7 @@ //! Defines common code used in execution plans use std::fs; -use std::fs::{metadata, File}; -use std::path::{Path, PathBuf}; +use std::fs::metadata; use std::sync::Arc; use super::SendableRecordBatchStream; @@ -28,10 +27,9 @@ use crate::{ColumnStatistics, Statistics}; use arrow::array::Array; use arrow::datatypes::Schema; -use arrow::ipc::writer::{FileWriter, IpcWriteOptions}; use arrow::record_batch::RecordBatch; use datafusion_common::stats::Precision; -use datafusion_common::{plan_err, DataFusionError, Result}; +use datafusion_common::{plan_err, Result}; use datafusion_execution::memory_pool::MemoryReservation; use futures::{StreamExt, TryStreamExt}; @@ -180,77 +178,6 @@ pub fn compute_record_batch_statistics( } } -/// Write in Arrow IPC File format. -pub struct IPCWriter { - /// Path - pub path: PathBuf, - /// Inner writer - pub writer: FileWriter, - /// Batches written - pub num_batches: usize, - /// Rows written - pub num_rows: usize, - /// Bytes written - pub num_bytes: usize, -} - -impl IPCWriter { - /// Create new writer - pub fn new(path: &Path, schema: &Schema) -> Result { - let file = File::create(path).map_err(|e| { - DataFusionError::Execution(format!( - "Failed to create partition file at {path:?}: {e:?}" - )) - })?; - Ok(Self { - num_batches: 0, - num_rows: 0, - num_bytes: 0, - path: path.into(), - writer: FileWriter::try_new(file, schema)?, - }) - } - - /// Create new writer with IPC write options - pub fn new_with_options( - path: &Path, - schema: &Schema, - write_options: IpcWriteOptions, - ) -> Result { - let file = File::create(path).map_err(|e| { - DataFusionError::Execution(format!( - "Failed to create partition file at {path:?}: {e:?}" - )) - })?; - Ok(Self { - num_batches: 0, - num_rows: 0, - num_bytes: 0, - path: path.into(), - writer: FileWriter::try_new_with_options(file, schema, write_options)?, - }) - } - /// Write one single batch - pub fn write(&mut self, batch: &RecordBatch) -> Result<()> { - self.writer.write(batch)?; - self.num_batches += 1; - self.num_rows += batch.num_rows(); - let num_bytes: usize = batch.get_array_memory_size(); - self.num_bytes += num_bytes; - Ok(()) - } - - /// Finish the writer - pub fn finish(&mut self) -> Result<()> { - self.writer.finish().map_err(Into::into) - } - - /// Path write to - pub fn path(&self) -> &Path { - &self.path - } -} - /// Checks if the given projection is valid for the given schema. pub fn can_project( schema: &arrow::datatypes::SchemaRef, diff --git a/datafusion/physical-plan/src/display.rs b/datafusion/physical-plan/src/display.rs index e247f5ad9d19..f555755dd20a 100644 --- a/datafusion/physical-plan/src/display.rs +++ b/datafusion/physical-plan/src/display.rs @@ -394,8 +394,8 @@ impl ExecutionPlanVisitor for IndentVisitor<'_, '_> { } } if self.show_statistics { - let stats = plan.statistics().map_err(|_e| fmt::Error)?; - write!(self.f, ", statistics=[{}]", stats)?; + let stats = plan.partition_statistics(None).map_err(|_e| fmt::Error)?; + write!(self.f, ", statistics=[{stats}]")?; } if self.show_schema { write!( @@ -479,8 +479,8 @@ impl ExecutionPlanVisitor for GraphvizVisitor<'_, '_> { }; let statistics = if self.show_statistics { - let stats = plan.statistics().map_err(|_e| fmt::Error)?; - format!("statistics=[{}]", stats) + let stats = plan.partition_statistics(None).map_err(|_e| fmt::Error)?; + format!("statistics=[{stats}]") } else { "".to_string() }; @@ -495,7 +495,7 @@ impl ExecutionPlanVisitor for GraphvizVisitor<'_, '_> { self.f, id, &label, - Some(&format!("{}{}{}", metrics, delimiter, statistics)), + Some(&format!("{metrics}{delimiter}{statistics}")), )?; if let Some(parent_node_id) = self.parents.last() { @@ -686,7 +686,7 @@ impl TreeRenderVisitor<'_, '_> { &render_text, Self::NODE_RENDER_WIDTH - 2, ); - write!(self.f, "{}", render_text)?; + write!(self.f, "{render_text}")?; if render_y == halfway_point && node.child_positions.len() > 1 { write!(self.f, "{}", Self::LMIDDLE)?; @@ -856,10 +856,10 @@ impl TreeRenderVisitor<'_, '_> { if str.is_empty() { str = key.to_string(); } else if !is_multiline && total_size < available_width { - str = format!("{}: {}", key, str); + str = format!("{key}: {str}"); is_inlined = true; } else { - str = format!("{}:\n{}", key, str); + str = format!("{key}:\n{str}"); } if is_inlined && was_inlined { @@ -902,7 +902,7 @@ impl TreeRenderVisitor<'_, '_> { let render_width = source.chars().count(); if render_width > max_render_width { let truncated = &source[..max_render_width - 3]; - format!("{}...", truncated) + format!("{truncated}...") } else { let total_spaces = max_render_width - render_width; let half_spaces = total_spaces / 2; @@ -1041,17 +1041,17 @@ pub fn display_orderings(f: &mut Formatter, orderings: &[LexOrdering]) -> fmt::R } else { ", output_orderings=[" }; - write!(f, "{}", start)?; + write!(f, "{start}")?; for (idx, ordering) in orderings.iter().enumerate().filter(|(_, o)| !o.is_empty()) { match idx { - 0 => write!(f, "[{}]", ordering)?, - _ => write!(f, ", [{}]", ordering)?, + 0 => write!(f, "[{ordering}]")?, + _ => write!(f, ", [{ordering}]")?, } } let end = if orderings.len() == 1 { "" } else { "]" }; - write!(f, "{}", end)?; + write!(f, "{end}")?; } } @@ -1120,6 +1120,13 @@ mod tests { } fn statistics(&self) -> Result { + self.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + if partition.is_some() { + return Ok(Statistics::new_unknown(self.schema().as_ref())); + } match self { Self::Panic => panic!("expected panic"), Self::Error => { diff --git a/datafusion/physical-plan/src/empty.rs b/datafusion/physical-plan/src/empty.rs index 3fdde39df6f1..36634fbe6d7e 100644 --- a/datafusion/physical-plan/src/empty.rs +++ b/datafusion/physical-plan/src/empty.rs @@ -150,6 +150,20 @@ impl ExecutionPlan for EmptyExec { } fn statistics(&self) -> Result { + self.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + if let Some(partition) = partition { + if partition >= self.partitions { + return internal_err!( + "EmptyExec invalid partition {} (expected less than {})", + partition, + self.partitions + ); + } + } + let batch = self .data() .expect("Create empty RecordBatch should not fail"); diff --git a/datafusion/physical-plan/src/execution_plan.rs b/datafusion/physical-plan/src/execution_plan.rs index 2bc5706ee0e1..b81b3c8beeac 100644 --- a/datafusion/physical-plan/src/execution_plan.rs +++ b/datafusion/physical-plan/src/execution_plan.rs @@ -16,6 +16,9 @@ // under the License. pub use crate::display::{DefaultDisplay, DisplayAs, DisplayFormatType, VerboseDisplay}; +use crate::filter_pushdown::{ + ChildPushdownResult, FilterDescription, FilterPushdownPropagation, +}; pub use crate::metrics::Metric; pub use crate::ordering::InputOrderMode; pub use crate::stream::EmptyRecordBatchStream; @@ -423,10 +426,30 @@ pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync { /// /// For TableScan executors, which supports filter pushdown, special attention /// needs to be paid to whether the stats returned by this method are exact or not + #[deprecated(since = "48.0.0", note = "Use `partition_statistics` method instead")] fn statistics(&self) -> Result { Ok(Statistics::new_unknown(&self.schema())) } + /// Returns statistics for a specific partition of this `ExecutionPlan` node. + /// If statistics are not available, should return [`Statistics::new_unknown`] + /// (the default), not an error. + /// If `partition` is `None`, it returns statistics for the entire plan. + fn partition_statistics(&self, partition: Option) -> Result { + if let Some(idx) = partition { + // Validate partition index + let partition_count = self.properties().partitioning.partition_count(); + if idx >= partition_count { + return internal_err!( + "Invalid partition index: {}, the partition count is {}", + idx, + partition_count + ); + } + } + Ok(Statistics::new_unknown(&self.schema())) + } + /// Returns `true` if a limit can be safely pushed down through this /// `ExecutionPlan` node. /// @@ -467,6 +490,62 @@ pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync { ) -> Result>> { Ok(None) } + + /// Collect filters that this node can push down to its children. + /// Filters that are being pushed down from parents are passed in, + /// and the node may generate additional filters to push down. + /// For example, given the plan FilterExec -> HashJoinExec -> DataSourceExec, + /// what will happen is that we recurse down the plan calling `ExecutionPlan::gather_filters_for_pushdown`: + /// 1. `FilterExec::gather_filters_for_pushdown` is called with no parent + /// filters so it only returns that `FilterExec` wants to push down its own predicate. + /// 2. `HashJoinExec::gather_filters_for_pushdown` is called with the filter from + /// `FilterExec`, which it only allows to push down to one side of the join (unless it's on the join key) + /// but it also adds its own filters (e.g. pushing down a bloom filter of the hash table to the scan side of the join). + /// 3. `DataSourceExec::gather_filters_for_pushdown` is called with both filters from `HashJoinExec` + /// and `FilterExec`, however `DataSourceExec::gather_filters_for_pushdown` doesn't actually do anything + /// since it has no children and no additional filters to push down. + /// It's only once [`ExecutionPlan::handle_child_pushdown_result`] is called on `DataSourceExec` as we recurse + /// up the plan that `DataSourceExec` can actually bind the filters. + /// + /// The default implementation bars all parent filters from being pushed down and adds no new filters. + /// This is the safest option, making filter pushdown opt-in on a per-node pasis. + fn gather_filters_for_pushdown( + &self, + parent_filters: Vec>, + _config: &ConfigOptions, + ) -> Result { + Ok( + FilterDescription::new_with_child_count(self.children().len()) + .all_parent_filters_unsupported(parent_filters), + ) + } + + /// Handle the result of a child pushdown. + /// This is called as we recurse back up the plan tree after recursing down and calling [`ExecutionPlan::gather_filters_for_pushdown`]. + /// Once we know what the result of pushing down filters into children is we ask the current node what it wants to do with that result. + /// For a `DataSourceExec` that may be absorbing the filters to apply them during the scan phase + /// (also known as late materialization). + /// A `FilterExec` may absorb any filters its children could not absorb, or if there are no filters left it + /// may remove itself from the plan altogether. + /// It combines both [`ChildPushdownResult::parent_filters`] and [`ChildPushdownResult::self_filters`] into a single + /// predicate and replaces it's own predicate. + /// Then it passes [`PredicateSupport::Supported`] for each parent predicate to the parent. + /// A `HashJoinExec` may ignore the pushdown result since it needs to apply the filters as part of the join anyhow. + /// It passes [`ChildPushdownResult::parent_filters`] back up to it's parents wrapped in [`FilterPushdownPropagation::transparent`] + /// and [`ChildPushdownResult::self_filters`] is discarded. + /// + /// The default implementation is a no-op that passes the result of pushdown from the children to its parent. + /// + /// [`PredicateSupport::Supported`]: crate::filter_pushdown::PredicateSupport::Supported + fn handle_child_pushdown_result( + &self, + child_pushdown_result: ChildPushdownResult, + _config: &ConfigOptions, + ) -> Result>> { + Ok(FilterPushdownPropagation::transparent( + child_pushdown_result, + )) + } } /// [`ExecutionPlan`] Invariant Level @@ -519,13 +598,15 @@ pub trait ExecutionPlanProperties { /// If this ExecutionPlan makes no changes to the schema of the rows flowing /// through it or how columns within each row relate to each other, it /// should return the equivalence properties of its input. For - /// example, since `FilterExec` may remove rows from its input, but does not + /// example, since [`FilterExec`] may remove rows from its input, but does not /// otherwise modify them, it preserves its input equivalence properties. /// However, since `ProjectionExec` may calculate derived expressions, it /// needs special handling. /// /// See also [`ExecutionPlan::maintains_input_order`] and [`Self::output_ordering`] /// for related concepts. + /// + /// [`FilterExec`]: crate::filter::FilterExec fn equivalence_properties(&self) -> &EquivalenceProperties; } @@ -1137,6 +1218,10 @@ mod tests { fn statistics(&self) -> Result { unimplemented!() } + + fn partition_statistics(&self, _partition: Option) -> Result { + unimplemented!() + } } #[derive(Debug)] @@ -1200,6 +1285,10 @@ mod tests { fn statistics(&self) -> Result { unimplemented!() } + + fn partition_statistics(&self, _partition: Option) -> Result { + unimplemented!() + } } #[test] diff --git a/datafusion/physical-plan/src/filter.rs b/datafusion/physical-plan/src/filter.rs index a8a9973ea043..13129e382dec 100644 --- a/datafusion/physical-plan/src/filter.rs +++ b/datafusion/physical-plan/src/filter.rs @@ -16,6 +16,7 @@ // under the License. use std::any::Any; +use std::collections::HashMap; use std::pin::Pin; use std::sync::Arc; use std::task::{ready, Context, Poll}; @@ -26,6 +27,9 @@ use super::{ }; use crate::common::can_project; use crate::execution_plan::CardinalityEffect; +use crate::filter_pushdown::{ + ChildPushdownResult, FilterDescription, FilterPushdownPropagation, +}; use crate::projection::{ make_with_child, try_embed_projection, update_expr, EmbeddedProjection, ProjectionExec, @@ -39,25 +43,32 @@ use arrow::compute::filter_record_batch; use arrow::datatypes::{DataType, SchemaRef}; use arrow::record_batch::RecordBatch; use datafusion_common::cast::as_boolean_array; +use datafusion_common::config::ConfigOptions; use datafusion_common::stats::Precision; +use datafusion_common::tree_node::{ + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, +}; use datafusion_common::{ internal_err, plan_err, project_schema, DataFusionError, Result, ScalarValue, }; use datafusion_execution::TaskContext; use datafusion_expr::Operator; use datafusion_physical_expr::equivalence::ProjectionMapping; -use datafusion_physical_expr::expressions::BinaryExpr; +use datafusion_physical_expr::expressions::{lit, BinaryExpr, Column}; use datafusion_physical_expr::intervals::utils::check_support; use datafusion_physical_expr::utils::collect_columns; use datafusion_physical_expr::{ - analyze, split_conjunction, AcrossPartitions, AnalysisContext, ConstExpr, - ExprBoundaries, PhysicalExpr, + analyze, conjunction, split_conjunction, AcrossPartitions, AnalysisContext, + ConstExpr, ExprBoundaries, PhysicalExpr, }; use datafusion_physical_expr_common::physical_expr::fmt_sql; use futures::stream::{Stream, StreamExt}; +use itertools::Itertools; use log::trace; +const FILTER_EXEC_DEFAULT_SELECTIVITY: u8 = 20; + /// FilterExec evaluates a boolean predicate against all input batches to determine which rows to /// include in its output batches. #[derive(Debug, Clone)] @@ -84,7 +95,7 @@ impl FilterExec { ) -> Result { match predicate.data_type(input.schema().as_ref())? { DataType::Boolean => { - let default_selectivity = 20; + let default_selectivity = FILTER_EXEC_DEFAULT_SELECTIVITY; let cache = Self::compute_properties( &input, &predicate, @@ -170,12 +181,11 @@ impl FilterExec { /// Calculates `Statistics` for `FilterExec`, by applying selectivity (either default, or estimated) to input statistics. fn statistics_helper( - input: &Arc, + schema: SchemaRef, + input_stats: Statistics, predicate: &Arc, default_selectivity: u8, ) -> Result { - let input_stats = input.statistics()?; - let schema = input.schema(); if !check_support(predicate, &schema) { let selectivity = default_selectivity as f64 / 100.0; let mut stats = input_stats.to_inexact(); @@ -189,7 +199,7 @@ impl FilterExec { let num_rows = input_stats.num_rows; let total_byte_size = input_stats.total_byte_size; let input_analysis_ctx = AnalysisContext::try_from_statistics( - &input.schema(), + &schema, &input_stats.column_statistics, )?; @@ -256,7 +266,12 @@ impl FilterExec { ) -> Result { // Combine the equal predicates with the input equivalence properties // to construct the equivalence properties: - let stats = Self::statistics_helper(input, predicate, default_selectivity)?; + let stats = Self::statistics_helper( + input.schema(), + input.partition_statistics(None)?, + predicate, + default_selectivity, + )?; let mut eq_properties = input.equivalence_properties().clone(); let (equal_pairs, _) = collect_columns_from_predicate(predicate); for (lhs, rhs) in equal_pairs { @@ -396,8 +411,14 @@ impl ExecutionPlan for FilterExec { /// The output statistics of a filtering operation can be estimated if the /// predicate's selectivity value can be determined for the incoming data. fn statistics(&self) -> Result { + self.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + let input_stats = self.input.partition_statistics(partition)?; let stats = Self::statistics_helper( - &self.input, + self.schema(), + input_stats, self.predicate(), self.default_selectivity, )?; @@ -433,6 +454,130 @@ impl ExecutionPlan for FilterExec { } try_embed_projection(projection, self) } + + fn gather_filters_for_pushdown( + &self, + parent_filters: Vec>, + _config: &ConfigOptions, + ) -> Result { + let self_filter = split_conjunction(&self.predicate) + .into_iter() + .cloned() + .collect_vec(); + + let parent_filters = if let Some(projection_indices) = self.projection.as_ref() { + // We need to invert the projection on any referenced columns in the filter + // Create a mapping from the output columns to the input columns (the inverse of the projection) + let inverse_projection = projection_indices + .iter() + .enumerate() + .map(|(i, &p)| (p, i)) + .collect::>(); + parent_filters + .into_iter() + .map(|f| { + f.transform_up(|expr| { + let mut res = + if let Some(col) = expr.as_any().downcast_ref::() { + let index = col.index(); + let index_in_input_schema = + inverse_projection.get(&index).ok_or_else(|| { + DataFusionError::Internal(format!( + "Column {index} not found in projection" + )) + })?; + Transformed::yes(Arc::new(Column::new( + col.name(), + *index_in_input_schema, + )) as _) + } else { + Transformed::no(expr) + }; + // Columns can only exist in the leaves, no need to try all nodes + res.tnr = TreeNodeRecursion::Jump; + Ok(res) + }) + .data() + }) + .collect::>>()? + } else { + parent_filters + }; + + Ok(FilterDescription::new_with_child_count(1) + .all_parent_filters_supported(parent_filters) + .with_self_filters_for_children(vec![self_filter])) + } + + fn handle_child_pushdown_result( + &self, + child_pushdown_result: ChildPushdownResult, + _config: &ConfigOptions, + ) -> Result>> { + // We absorb any parent filters that were not handled by our children + let mut unhandled_filters = + child_pushdown_result.parent_filters.collect_unsupported(); + assert_eq!( + child_pushdown_result.self_filters.len(), + 1, + "FilterExec should only have one child" + ); + let unsupported_self_filters = + child_pushdown_result.self_filters[0].collect_unsupported(); + unhandled_filters.extend(unsupported_self_filters); + + // If we have unhandled filters, we need to create a new FilterExec + let filter_input = Arc::clone(self.input()); + let new_predicate = conjunction(unhandled_filters); + let updated_node = if new_predicate.eq(&lit(true)) { + // FilterExec is no longer needed, but we may need to leave a projection in place + match self.projection() { + Some(projection_indices) => { + let filter_child_schema = filter_input.schema(); + let proj_exprs = projection_indices + .iter() + .map(|p| { + let field = filter_child_schema.field(*p).clone(); + ( + Arc::new(Column::new(field.name(), *p)) + as Arc, + field.name().to_string(), + ) + }) + .collect::>(); + Some(Arc::new(ProjectionExec::try_new(proj_exprs, filter_input)?) + as Arc) + } + None => { + // No projection needed, just return the input + Some(filter_input) + } + } + } else if new_predicate.eq(&self.predicate) { + // The new predicate is the same as our current predicate + None + } else { + // Create a new FilterExec with the new predicate + let new = FilterExec { + predicate: Arc::clone(&new_predicate), + input: Arc::clone(&filter_input), + metrics: self.metrics.clone(), + default_selectivity: self.default_selectivity, + cache: Self::compute_properties( + &filter_input, + &new_predicate, + self.default_selectivity, + self.projection.as_ref(), + )?, + projection: None, + }; + Some(Arc::new(new) as _) + }; + Ok(FilterPushdownPropagation { + filters: child_pushdown_result.parent_filters.make_supported(), + updated_node, + }) + } } impl EmbeddedProjection for FilterExec { @@ -703,7 +848,7 @@ mod tests { let filter: Arc = Arc::new(FilterExec::try_new(predicate, input)?); - let statistics = filter.statistics()?; + let statistics = filter.partition_statistics(None)?; assert_eq!(statistics.num_rows, Precision::Inexact(25)); assert_eq!( statistics.total_byte_size, @@ -753,7 +898,7 @@ mod tests { sub_filter, )?); - let statistics = filter.statistics()?; + let statistics = filter.partition_statistics(None)?; assert_eq!(statistics.num_rows, Precision::Inexact(16)); assert_eq!( statistics.column_statistics, @@ -813,7 +958,7 @@ mod tests { binary(col("a", &schema)?, Operator::GtEq, lit(10i32), &schema)?, b_gt_5, )?); - let statistics = filter.statistics()?; + let statistics = filter.partition_statistics(None)?; // On a uniform distribution, only fifteen rows will satisfy the // filter that 'a' proposed (a >= 10 AND a <= 25) (15/100) and only // 5 rows will satisfy the filter that 'b' proposed (b > 45) (5/50). @@ -858,7 +1003,7 @@ mod tests { let filter: Arc = Arc::new(FilterExec::try_new(predicate, input)?); - let statistics = filter.statistics()?; + let statistics = filter.partition_statistics(None)?; assert_eq!(statistics.num_rows, Precision::Absent); Ok(()) @@ -931,7 +1076,7 @@ mod tests { )); let filter: Arc = Arc::new(FilterExec::try_new(predicate, input)?); - let statistics = filter.statistics()?; + let statistics = filter.partition_statistics(None)?; // 0.5 (from a) * 0.333333... (from b) * 0.798387... (from c) ≈ 0.1330... // num_rows after ceil => 133.0... => 134 // total_byte_size after ceil => 532.0... => 533 @@ -1027,10 +1172,10 @@ mod tests { )), )); // Since filter predicate passes all entries, statistics after filter shouldn't change. - let expected = input.statistics()?.column_statistics; + let expected = input.partition_statistics(None)?.column_statistics; let filter: Arc = Arc::new(FilterExec::try_new(predicate, input)?); - let statistics = filter.statistics()?; + let statistics = filter.partition_statistics(None)?; assert_eq!(statistics.num_rows, Precision::Inexact(1000)); assert_eq!(statistics.total_byte_size, Precision::Inexact(4000)); @@ -1083,7 +1228,7 @@ mod tests { )); let filter: Arc = Arc::new(FilterExec::try_new(predicate, input)?); - let statistics = filter.statistics()?; + let statistics = filter.partition_statistics(None)?; assert_eq!(statistics.num_rows, Precision::Inexact(0)); assert_eq!(statistics.total_byte_size, Precision::Inexact(0)); @@ -1143,7 +1288,7 @@ mod tests { )); let filter: Arc = Arc::new(FilterExec::try_new(predicate, input)?); - let statistics = filter.statistics()?; + let statistics = filter.partition_statistics(None)?; assert_eq!(statistics.num_rows, Precision::Inexact(490)); assert_eq!(statistics.total_byte_size, Precision::Inexact(1960)); @@ -1193,7 +1338,7 @@ mod tests { )); let filter: Arc = Arc::new(FilterExec::try_new(predicate, input)?); - let filter_statistics = filter.statistics()?; + let filter_statistics = filter.partition_statistics(None)?; let expected_filter_statistics = Statistics { num_rows: Precision::Absent, @@ -1227,7 +1372,7 @@ mod tests { )); let filter: Arc = Arc::new(FilterExec::try_new(predicate, input)?); - let filter_statistics = filter.statistics()?; + let filter_statistics = filter.partition_statistics(None)?; // First column is "a", and it is a column with only one value after the filter. assert!(filter_statistics.column_statistics[0].is_singleton()); @@ -1274,11 +1419,11 @@ mod tests { Arc::new(Literal::new(ScalarValue::Decimal128(Some(10), 10, 10))), )); let filter = FilterExec::try_new(predicate, input)?; - let statistics = filter.statistics()?; + let statistics = filter.partition_statistics(None)?; assert_eq!(statistics.num_rows, Precision::Inexact(200)); assert_eq!(statistics.total_byte_size, Precision::Inexact(800)); let filter = filter.with_default_selectivity(40)?; - let statistics = filter.statistics()?; + let statistics = filter.partition_statistics(None)?; assert_eq!(statistics.num_rows, Precision::Inexact(400)); assert_eq!(statistics.total_byte_size, Precision::Inexact(1600)); Ok(()) @@ -1312,7 +1457,7 @@ mod tests { Arc::new(EmptyExec::new(Arc::clone(&schema))), )?; - exec.statistics().unwrap(); + exec.partition_statistics(None).unwrap(); Ok(()) } diff --git a/datafusion/physical-plan/src/filter_pushdown.rs b/datafusion/physical-plan/src/filter_pushdown.rs new file mode 100644 index 000000000000..4e84fe36f98f --- /dev/null +++ b/datafusion/physical-plan/src/filter_pushdown.rs @@ -0,0 +1,340 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; +use std::vec::IntoIter; + +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; + +/// The result of a plan for pushing down a filter into a child node. +/// This contains references to filters so that nodes can mutate a filter +/// before pushing it down to a child node (e.g. to adjust a projection) +/// or can directly take ownership of `Unsupported` filters that their children +/// could not handle. +#[derive(Debug, Clone)] +pub enum PredicateSupport { + Supported(Arc), + Unsupported(Arc), +} + +/// A thin wrapper around [`PredicateSupport`]s that allows for easy collection of +/// supported and unsupported filters. Inner vector stores each predicate for one node. +#[derive(Debug, Clone)] +pub struct PredicateSupports(Vec); + +impl PredicateSupports { + /// Create a new FilterPushdowns with the given filters and their pushdown status. + pub fn new(pushdowns: Vec) -> Self { + Self(pushdowns) + } + + /// Create a new [`PredicateSupport`] with all filters as supported. + pub fn all_supported(filters: Vec>) -> Self { + let pushdowns = filters + .into_iter() + .map(PredicateSupport::Supported) + .collect(); + Self::new(pushdowns) + } + + /// Create a new [`PredicateSupport`] with all filters as unsupported. + pub fn all_unsupported(filters: Vec>) -> Self { + let pushdowns = filters + .into_iter() + .map(PredicateSupport::Unsupported) + .collect(); + Self::new(pushdowns) + } + + /// Transform all filters to supported, returning a new [`PredicateSupports`] + /// with all filters as [`PredicateSupport::Supported`]. + /// This does not modify the original [`PredicateSupport`]. + pub fn make_supported(self) -> Self { + let pushdowns = self + .0 + .into_iter() + .map(|f| match f { + PredicateSupport::Supported(expr) => PredicateSupport::Supported(expr), + PredicateSupport::Unsupported(expr) => PredicateSupport::Supported(expr), + }) + .collect(); + Self::new(pushdowns) + } + + /// Transform all filters to unsupported, returning a new [`PredicateSupports`] + /// with all filters as [`PredicateSupport::Supported`]. + /// This does not modify the original [`PredicateSupport`]. + pub fn make_unsupported(self) -> Self { + let pushdowns = self + .0 + .into_iter() + .map(|f| match f { + PredicateSupport::Supported(expr) => PredicateSupport::Unsupported(expr), + u @ PredicateSupport::Unsupported(_) => u, + }) + .collect(); + Self::new(pushdowns) + } + + /// Collect unsupported filters into a Vec, without removing them from the original + /// [`PredicateSupport`]. + pub fn collect_unsupported(&self) -> Vec> { + self.0 + .iter() + .filter_map(|f| match f { + PredicateSupport::Unsupported(expr) => Some(Arc::clone(expr)), + PredicateSupport::Supported(_) => None, + }) + .collect() + } + + /// Collect all filters into a Vec, without removing them from the original + /// FilterPushdowns. + pub fn collect_all(self) -> Vec> { + self.0 + .into_iter() + .map(|f| match f { + PredicateSupport::Supported(expr) + | PredicateSupport::Unsupported(expr) => expr, + }) + .collect() + } + + pub fn into_inner(self) -> Vec { + self.0 + } + + /// Return an iterator over the inner `Vec`. + pub fn iter(&self) -> impl Iterator { + self.0.iter() + } + + /// Return the number of filters in the inner `Vec`. + pub fn len(&self) -> usize { + self.0.len() + } + + /// Check if the inner `Vec` is empty. + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } +} + +impl IntoIterator for PredicateSupports { + type Item = PredicateSupport; + type IntoIter = IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} + +/// The result of pushing down filters into a child node. +/// This is the result provided to nodes in [`ExecutionPlan::handle_child_pushdown_result`]. +/// Nodes process this result and convert it into a [`FilterPushdownPropagation`] +/// that is returned to their parent. +/// +/// [`ExecutionPlan::handle_child_pushdown_result`]: crate::ExecutionPlan::handle_child_pushdown_result +#[derive(Debug, Clone)] +pub struct ChildPushdownResult { + /// The combined result of pushing down each parent filter into each child. + /// For example, given the fitlers `[a, b]` and children `[1, 2, 3]` the matrix of responses: + /// + // | filter | child 1 | child 2 | child 3 | result | + // |--------|-------------|-----------|-----------|-------------| + // | a | Supported | Supported | Supported | Supported | + // | b | Unsupported | Supported | Supported | Unsupported | + /// + /// That is: if any child marks a filter as unsupported or if the filter was not pushed + /// down into any child then the result is unsupported. + /// If at least one children and all children that received the filter mark it as supported + /// then the result is supported. + pub parent_filters: PredicateSupports, + /// The result of pushing down each filter this node provided into each of it's children. + /// This is not combined with the parent filters so that nodes can treat each child independently. + pub self_filters: Vec, +} + +/// The result of pushing down filters into a node that it returns to its parent. +/// This is what nodes return from [`ExecutionPlan::handle_child_pushdown_result`] to communicate +/// to the optimizer: +/// +/// 1. What to do with any parent filters that were not completely handled by the children. +/// 2. If the node needs to be replaced in the execution plan with a new node or not. +/// +/// [`ExecutionPlan::handle_child_pushdown_result`]: crate::ExecutionPlan::handle_child_pushdown_result +#[derive(Debug, Clone)] +pub struct FilterPushdownPropagation { + pub filters: PredicateSupports, + pub updated_node: Option, +} + +impl FilterPushdownPropagation { + /// Create a new [`FilterPushdownPropagation`] that tells the parent node + /// that echoes back up to the parent the result of pushing down the filters + /// into the children. + pub fn transparent(child_pushdown_result: ChildPushdownResult) -> Self { + Self { + filters: child_pushdown_result.parent_filters, + updated_node: None, + } + } + + /// Create a new [`FilterPushdownPropagation`] that tells the parent node + /// that none of the parent filters were not pushed down. + pub fn unsupported(parent_filters: Vec>) -> Self { + let unsupported = PredicateSupports::all_unsupported(parent_filters); + Self { + filters: unsupported, + updated_node: None, + } + } + + /// Create a new [`FilterPushdownPropagation`] with the specified filter support. + pub fn with_filters(filters: PredicateSupports) -> Self { + Self { + filters, + updated_node: None, + } + } + + /// Bind an updated node to the [`FilterPushdownPropagation`]. + pub fn with_updated_node(mut self, updated_node: T) -> Self { + self.updated_node = Some(updated_node); + self + } +} + +#[derive(Debug, Clone)] +struct ChildFilterDescription { + /// Description of which parent filters can be pushed down into this node. + /// Since we need to transmit filter pushdown results back to this node's parent + /// we need to track each parent filter for each child, even those that are unsupported / won't be pushed down. + /// We do this using a [`PredicateSupport`] which simplifies manipulating supported/unsupported filters. + parent_filters: PredicateSupports, + /// Description of which filters this node is pushing down to its children. + /// Since this is not transmitted back to the parents we can have variable sized inner arrays + /// instead of having to track supported/unsupported. + self_filters: Vec>, +} + +impl ChildFilterDescription { + fn new() -> Self { + Self { + parent_filters: PredicateSupports::new(vec![]), + self_filters: vec![], + } + } +} + +#[derive(Debug, Clone)] +pub struct FilterDescription { + /// A filter description for each child. + /// This includes which parent filters and which self filters (from the node in question) + /// will get pushed down to each child. + child_filter_descriptions: Vec, +} + +impl FilterDescription { + pub fn new_with_child_count(num_children: usize) -> Self { + Self { + child_filter_descriptions: vec![ChildFilterDescription::new(); num_children], + } + } + + pub fn parent_filters(&self) -> Vec { + self.child_filter_descriptions + .iter() + .map(|d| &d.parent_filters) + .cloned() + .collect() + } + + pub fn self_filters(&self) -> Vec>> { + self.child_filter_descriptions + .iter() + .map(|d| &d.self_filters) + .cloned() + .collect() + } + + /// Mark all parent filters as supported for all children. + /// This is the case if the node allows filters to be pushed down through it + /// without any modification. + /// This broadcasts the parent filters to all children. + /// If handling of parent filters is different for each child then you should set the + /// field direclty. + /// For example, nodes like [`RepartitionExec`] that let filters pass through it transparently + /// use this to mark all parent filters as supported. + /// + /// [`RepartitionExec`]: crate::repartition::RepartitionExec + pub fn all_parent_filters_supported( + mut self, + parent_filters: Vec>, + ) -> Self { + let supported = PredicateSupports::all_supported(parent_filters); + for child in &mut self.child_filter_descriptions { + child.parent_filters = supported.clone(); + } + self + } + + /// Mark all parent filters as unsupported for all children. + /// This is the case if the node does not allow filters to be pushed down through it. + /// This broadcasts the parent filters to all children. + /// If handling of parent filters is different for each child then you should set the + /// field direclty. + /// For example, the default implementation of filter pushdwon in [`ExecutionPlan`] + /// assumes that filters cannot be pushed down to children. + /// + /// [`ExecutionPlan`]: crate::ExecutionPlan + pub fn all_parent_filters_unsupported( + mut self, + parent_filters: Vec>, + ) -> Self { + let unsupported = PredicateSupports::all_unsupported(parent_filters); + for child in &mut self.child_filter_descriptions { + child.parent_filters = unsupported.clone(); + } + self + } + + /// Add a filter generated / owned by the current node to be pushed down to all children. + /// This assumes that there is a single filter that that gets pushed down to all children + /// equally. + /// If there are multiple filters or pushdown to children is not homogeneous then + /// you should set the field directly. + /// For example: + /// - `TopK` uses this to push down a single filter to all children, it can use this method. + /// - `HashJoinExec` pushes down a filter only to the probe side, it cannot use this method. + pub fn with_self_filter(mut self, predicate: Arc) -> Self { + for child in &mut self.child_filter_descriptions { + child.self_filters = vec![Arc::clone(&predicate)]; + } + self + } + + pub fn with_self_filters_for_children( + mut self, + filters: Vec>>, + ) -> Self { + for (child, filters) in self.child_filter_descriptions.iter_mut().zip(filters) { + child.self_filters = filters; + } + self + } +} diff --git a/datafusion/physical-plan/src/joins/cross_join.rs b/datafusion/physical-plan/src/joins/cross_join.rs index 8dd1addff15c..4d8c48c659ef 100644 --- a/datafusion/physical-plan/src/joins/cross_join.rs +++ b/datafusion/physical-plan/src/joins/cross_join.rs @@ -337,10 +337,15 @@ impl ExecutionPlan for CrossJoinExec { } fn statistics(&self) -> Result { - Ok(stats_cartesian_product( - self.left.statistics()?, - self.right.statistics()?, - )) + self.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + // Get the all partitions statistics of the left + let left_stats = self.left.partition_statistics(None)?; + let right_stats = self.right.partition_statistics(partition)?; + + Ok(stats_cartesian_product(left_stats, right_stats)) } /// Tries to swap the projection with its input [`CrossJoinExec`]. If it can be done, @@ -869,7 +874,7 @@ mod tests { assert_contains!( err.to_string(), - "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: CrossJoinExec" + "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as:\n CrossJoinExec" ); Ok(()) diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index e8904db0f3ea..398c2fed7cdf 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -86,6 +86,10 @@ use datafusion_physical_expr_common::physical_expr::fmt_sql; use futures::{ready, Stream, StreamExt, TryStreamExt}; use parking_lot::Mutex; +/// Hard-coded seed to ensure hash values from the hash join differ from `RepartitionExec`, avoiding collisions. +const HASH_JOIN_SEED: RandomState = + RandomState::with_seeds('J' as u64, 'O' as u64, 'I' as u64, 'N' as u64); + /// HashTable and input data for the left (build side) of a join struct JoinLeftData { /// The hash table with indices into `batch` @@ -385,7 +389,7 @@ impl HashJoinExec { let (join_schema, column_indices) = build_join_schema(&left_schema, &right_schema, join_type); - let random_state = RandomState::with_seeds(0, 0, 0, 0); + let random_state = HASH_JOIN_SEED; let join_schema = Arc::new(join_schema); @@ -660,7 +664,7 @@ impl DisplayAs for HashJoinExec { let on = self .on .iter() - .map(|(c1, c2)| format!("({}, {})", c1, c2)) + .map(|(c1, c2)| format!("({c1}, {c2})")) .collect::>() .join(", "); write!( @@ -682,7 +686,7 @@ impl DisplayAs for HashJoinExec { if *self.join_type() != JoinType::Inner { writeln!(f, "join_type={:?}", self.join_type)?; } - writeln!(f, "on={}", on) + writeln!(f, "on={on}") } } } @@ -879,12 +883,19 @@ impl ExecutionPlan for HashJoinExec { } fn statistics(&self) -> Result { + self.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + if partition.is_some() { + return Ok(Statistics::new_unknown(&self.schema())); + } // TODO stats: it is not possible in general to know the output size of joins // There are some special cases though, for example: // - `A LEFT JOIN B ON A.col=B.col` with `COUNT_DISTINCT(B.col)=COUNT(B.col)` let stats = estimate_join_statistics( - Arc::clone(&self.left), - Arc::clone(&self.right), + self.left.partition_statistics(None)?, + self.right.partition_statistics(None)?, self.on.clone(), &self.join_type, &self.join_schema, @@ -1296,8 +1307,8 @@ fn lookup_join_hashmap( limit: usize, offset: JoinHashMapOffset, ) -> Result<(UInt64Array, UInt32Array, Option)> { - let (probe_indices, build_indices, next_offset) = build_hashmap - .get_matched_indices_with_limit_offset(hashes_buffer, None, limit, offset); + let (probe_indices, build_indices, next_offset) = + build_hashmap.get_matched_indices_with_limit_offset(hashes_buffer, limit, offset); let build_indices: UInt64Array = build_indices.into(); let probe_indices: UInt32Array = probe_indices.into(); @@ -3322,7 +3333,7 @@ mod tests { #[test] fn join_with_hash_collision() -> Result<()> { - let mut hashmap_left = HashTable::with_capacity(2); + let mut hashmap_left = HashTable::with_capacity(4); let left = build_table_i32( ("a", &vec![10, 20]), ("x", &vec![100, 200]), @@ -3337,9 +3348,15 @@ mod tests { hashes_buff, )?; - // Create hash collisions (same hashes) + // Maps both values to both indices (1 and 2, representing input 0 and 1) + // 0 -> (0, 1) + // 1 -> (0, 2) + // The equality check will make sure only hashes[0] maps to 0 and hashes[1] maps to 1 hashmap_left.insert_unique(hashes[0], (hashes[0], 1), |(h, _)| *h); + hashmap_left.insert_unique(hashes[0], (hashes[0], 2), |(h, _)| *h); + hashmap_left.insert_unique(hashes[1], (hashes[1], 1), |(h, _)| *h); + hashmap_left.insert_unique(hashes[1], (hashes[1], 2), |(h, _)| *h); let next = vec![2, 0]; @@ -3990,10 +4007,7 @@ mod tests { assert_eq!( batches.len(), expected_batch_count, - "expected {} output batches for {} join with batch_size = {}", - expected_batch_count, - join_type, - batch_size + "expected {expected_batch_count} output batches for {join_type} join with batch_size = {batch_size}" ); let expected = match join_type { @@ -4058,12 +4072,12 @@ mod tests { // Asserting that operator-level reservation attempting to overallocate assert_contains!( err.to_string(), - "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: HashJoinInput" + "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as:\n HashJoinInput" ); assert_contains!( err.to_string(), - "Failed to allocate additional 120 bytes for HashJoinInput" + "Failed to allocate additional 120.0 B for HashJoinInput" ); } @@ -4139,13 +4153,13 @@ mod tests { // Asserting that stream-level reservation attempting to overallocate assert_contains!( err.to_string(), - "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: HashJoinInput[1]" + "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as:\n HashJoinInput[1]" ); assert_contains!( err.to_string(), - "Failed to allocate additional 120 bytes for HashJoinInput[1]" + "Failed to allocate additional 120.0 B for HashJoinInput[1]" ); } diff --git a/datafusion/physical-plan/src/joins/join_hash_map.rs b/datafusion/physical-plan/src/joins/join_hash_map.rs index 7af0aeca0fd6..521e19d7bf44 100644 --- a/datafusion/physical-plan/src/joins/join_hash_map.rs +++ b/datafusion/physical-plan/src/joins/join_hash_map.rs @@ -116,20 +116,10 @@ pub(crate) type JoinHashMapOffset = (usize, Option); macro_rules! chain_traverse { ( $input_indices:ident, $match_indices:ident, $hash_values:ident, $next_chain:ident, - $input_idx:ident, $chain_idx:ident, $deleted_offset:ident, $remaining_output:ident + $input_idx:ident, $chain_idx:ident, $remaining_output:ident ) => { - let mut i = $chain_idx - 1; + let mut match_row_idx = $chain_idx - 1; loop { - let match_row_idx = if let Some(offset) = $deleted_offset { - // This arguments means that we prune the next index way before here. - if i < offset as u64 { - // End of the list due to pruning - break; - } - i - offset as u64 - } else { - i - }; $match_indices.push(match_row_idx); $input_indices.push($input_idx as u32); $remaining_output -= 1; @@ -150,7 +140,7 @@ macro_rules! chain_traverse { // end of list break; } - i = next - 1; + match_row_idx = next - 1; } }; } @@ -168,6 +158,11 @@ pub trait JoinHashMapType { /// Returns a reference to the next. fn get_list(&self) -> &Self::NextType; + // Whether values in the hashmap are distinct (no duplicate keys) + fn is_distinct(&self) -> bool { + false + } + /// Updates hashmap from iterator of row indices & row hashes pairs. fn update_from_iter<'a>( &mut self, @@ -257,17 +252,35 @@ pub trait JoinHashMapType { fn get_matched_indices_with_limit_offset( &self, hash_values: &[u64], - deleted_offset: Option, limit: usize, offset: JoinHashMapOffset, ) -> (Vec, Vec, Option) { - let mut input_indices = vec![]; - let mut match_indices = vec![]; - - let mut remaining_output = limit; + let mut input_indices = Vec::with_capacity(limit); + let mut match_indices = Vec::with_capacity(limit); let hash_map: &HashTable<(u64, u64)> = self.get_map(); let next_chain = self.get_list(); + // Check if hashmap consists of unique values + // If so, we can skip the chain traversal + if self.is_distinct() { + let start = offset.0; + let end = (start + limit).min(hash_values.len()); + for (row_idx, &hash_value) in hash_values[start..end].iter().enumerate() { + if let Some((_, index)) = + hash_map.find(hash_value, |(hash, _)| hash_value == *hash) + { + input_indices.push(start as u32 + row_idx as u32); + match_indices.push(*index - 1); + } + } + if end == hash_values.len() { + // No more values to process + return (input_indices, match_indices, None); + } + return (input_indices, match_indices, Some((end, None))); + } + + let mut remaining_output = limit; // Calculate initial `hash_values` index before iterating let to_skip = match offset { @@ -286,7 +299,6 @@ pub trait JoinHashMapType { next_chain, initial_idx, initial_next_idx, - deleted_offset, remaining_output ); @@ -295,6 +307,7 @@ pub trait JoinHashMapType { }; let mut row_idx = to_skip; + for hash_value in &hash_values[to_skip..] { if let Some((_, index)) = hash_map.find(*hash_value, |(hash, _)| *hash_value == *hash) @@ -306,7 +319,6 @@ pub trait JoinHashMapType { next_chain, row_idx, index, - deleted_offset, remaining_output ); } @@ -338,6 +350,11 @@ impl JoinHashMapType for JoinHashMap { fn get_list(&self) -> &Self::NextType { &self.next } + + /// Check if the values in the hashmap are distinct. + fn is_distinct(&self) -> bool { + self.map.len() == self.next.len() + } } impl Debug for JoinHashMap { diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs b/datafusion/physical-plan/src/joins/nested_loop_join.rs index b90279595096..f87cf3d8864c 100644 --- a/datafusion/physical-plan/src/joins/nested_loop_join.rs +++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs @@ -567,9 +567,16 @@ impl ExecutionPlan for NestedLoopJoinExec { } fn statistics(&self) -> Result { + self.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + if partition.is_some() { + return Ok(Statistics::new_unknown(&self.schema())); + } estimate_join_statistics( - Arc::clone(&self.left), - Arc::clone(&self.right), + self.left.partition_statistics(None)?, + self.right.partition_statistics(None)?, vec![], &self.join_type, &self.join_schema, @@ -1506,7 +1513,7 @@ pub(crate) mod tests { assert_contains!( err.to_string(), - "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: NestedLoopJoinLoad[0]" + "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as:\n NestedLoopJoinLoad[0]" ); } @@ -1663,11 +1670,7 @@ pub(crate) mod tests { .into_iter() .zip(prev_values) .all(|(current, prev)| current >= prev), - "batch_index: {} row: {} current: {:?}, prev: {:?}", - batch_index, - row, - current_values, - prev_values + "batch_index: {batch_index} row: {row} current: {current_values:?}, prev: {prev_values:?}" ); prev_values = current_values; } diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index 89f2e3c911f8..cadd2b53ab11 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -16,9 +16,8 @@ // under the License. //! Defines the Sort-Merge join execution plan. -//! A Sort-Merge join plan consumes two sorted children plan and produces +//! A Sort-Merge join plan consumes two sorted children plans and produces //! joined output by given join type and other options. -//! Sort-Merge join feature is currently experimental. use std::any::Any; use std::cmp::Ordering; @@ -170,12 +169,6 @@ impl SortMergeJoinExec { let left_schema = left.schema(); let right_schema = right.schema(); - if join_type == JoinType::RightSemi { - return not_impl_err!( - "SortMergeJoinExec does not support JoinType::RightSemi" - ); - } - check_join_is_valid(&left_schema, &right_schema, &on)?; if sort_options.len() != on.len() { return plan_err!( @@ -358,7 +351,7 @@ impl DisplayAs for SortMergeJoinExec { let on = self .on .iter() - .map(|(c1, c2)| format!("({}, {})", c1, c2)) + .map(|(c1, c2)| format!("({c1}, {c2})")) .collect::>() .join(", "); write!( @@ -385,7 +378,7 @@ impl DisplayAs for SortMergeJoinExec { if self.join_type() != JoinType::Inner { writeln!(f, "join_type={:?}", self.join_type)?; } - writeln!(f, "on={}", on) + writeln!(f, "on={on}") } } } @@ -514,12 +507,19 @@ impl ExecutionPlan for SortMergeJoinExec { } fn statistics(&self) -> Result { + self.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + if partition.is_some() { + return Ok(Statistics::new_unknown(&self.schema())); + } // TODO stats: it is not possible in general to know the output size of joins // There are some special cases though, for example: // - `A LEFT JOIN B ON A.col=B.col` with `COUNT_DISTINCT(B.col)=COUNT(B.col)` estimate_join_statistics( - Arc::clone(&self.left), - Arc::clone(&self.right), + self.left.partition_statistics(None)?, + self.right.partition_statistics(None)?, self.on.clone(), &self.join_type, &self.schema, @@ -916,7 +916,7 @@ struct JoinedRecordBatches { pub batches: Vec, /// Filter match mask for each row(matched/non-matched) pub filter_mask: BooleanBuilder, - /// Row indices to glue together rows in `batches` and `filter_mask` + /// Left row indices to glue together rows in `batches` and `filter_mask` pub row_indices: UInt64Builder, /// Which unique batch id the row belongs to /// It is necessary to differentiate rows that are distributed the way when they point to the same @@ -1016,7 +1016,7 @@ fn get_corrected_filter_mask( corrected_mask.append_n(expected_size - corrected_mask.len(), false); Some(corrected_mask.finish()) } - JoinType::LeftSemi => { + JoinType::LeftSemi | JoinType::RightSemi => { for i in 0..row_indices_length { let last_index = last_index_for_row(i, row_indices, batch_ids, row_indices_length); @@ -1145,6 +1145,7 @@ impl Stream for SortMergeJoinStream { | JoinType::LeftSemi | JoinType::LeftMark | JoinType::Right + | JoinType::RightSemi | JoinType::LeftAnti | JoinType::RightAnti | JoinType::Full @@ -1250,6 +1251,7 @@ impl Stream for SortMergeJoinStream { JoinType::Left | JoinType::LeftSemi | JoinType::Right + | JoinType::RightSemi | JoinType::LeftAnti | JoinType::RightAnti | JoinType::LeftMark @@ -1275,6 +1277,7 @@ impl Stream for SortMergeJoinStream { JoinType::Left | JoinType::LeftSemi | JoinType::Right + | JoinType::RightSemi | JoinType::LeftAnti | JoinType::RightAnti | JoinType::Full @@ -1597,7 +1600,6 @@ impl SortMergeJoinStream { self.join_type, JoinType::Left | JoinType::Right - | JoinType::RightSemi | JoinType::Full | JoinType::LeftAnti | JoinType::RightAnti @@ -1607,7 +1609,10 @@ impl SortMergeJoinStream { } } Ordering::Equal => { - if matches!(self.join_type, JoinType::LeftSemi | JoinType::LeftMark) { + if matches!( + self.join_type, + JoinType::LeftSemi | JoinType::LeftMark | JoinType::RightSemi + ) { mark_row_as_match = matches!(self.join_type, JoinType::LeftMark); // if the join filter is specified then its needed to output the streamed index // only if it has not been emitted before @@ -1827,7 +1832,10 @@ impl SortMergeJoinStream { vec![Arc::new(is_not_null(&right_indices)?) as ArrayRef] } else if matches!( self.join_type, - JoinType::LeftSemi | JoinType::LeftAnti | JoinType::RightAnti + JoinType::LeftSemi + | JoinType::LeftAnti + | JoinType::RightAnti + | JoinType::RightSemi ) { vec![] } else if let Some(buffered_idx) = chunk.buffered_batch_idx { @@ -1861,7 +1869,10 @@ impl SortMergeJoinStream { )?; get_filter_column(&self.filter, &left_columns, &right_cols) - } else if matches!(self.join_type, JoinType::RightAnti) { + } else if matches!( + self.join_type, + JoinType::RightAnti | JoinType::RightSemi + ) { let right_cols = fetch_right_columns_by_idxs( &self.buffered_data, chunk.buffered_batch_idx.unwrap(), @@ -1922,6 +1933,7 @@ impl SortMergeJoinStream { JoinType::Left | JoinType::LeftSemi | JoinType::Right + | JoinType::RightSemi | JoinType::LeftAnti | JoinType::RightAnti | JoinType::LeftMark @@ -2019,6 +2031,7 @@ impl SortMergeJoinStream { JoinType::Left | JoinType::LeftSemi | JoinType::Right + | JoinType::RightSemi | JoinType::LeftAnti | JoinType::RightAnti | JoinType::LeftMark @@ -2128,7 +2141,7 @@ impl SortMergeJoinStream { let output_column_indices = (0..left_columns_length).collect::>(); filtered_record_batch = filtered_record_batch.project(&output_column_indices)?; - } else if matches!(self.join_type, JoinType::RightAnti) { + } else if matches!(self.join_type, JoinType::RightAnti | JoinType::RightSemi) { let output_column_indices = (0..right_columns_length).collect::>(); filtered_record_batch = filtered_record_batch.project(&output_column_indices)?; @@ -2574,7 +2587,7 @@ mod tests { JoinSide, }; use datafusion_execution::config::SessionConfig; - use datafusion_execution::disk_manager::DiskManagerConfig; + use datafusion_execution::disk_manager::{DiskManagerBuilder, DiskManagerMode}; use datafusion_execution::runtime_env::RuntimeEnvBuilder; use datafusion_execution::TaskContext; use datafusion_expr::Operator; @@ -3467,7 +3480,7 @@ mod tests { } #[tokio::test] - async fn join_semi() -> Result<()> { + async fn join_left_semi() -> Result<()> { let left = build_table( ("a1", &vec![1, 2, 2, 3]), ("b1", &vec![4, 5, 5, 7]), // 7 does not exist on the right @@ -3497,6 +3510,255 @@ mod tests { Ok(()) } + #[tokio::test] + async fn join_right_semi_one() -> Result<()> { + let left = build_table( + ("a1", &vec![10, 20, 30, 40]), + ("b1", &vec![4, 5, 5, 6]), + ("c1", &vec![70, 80, 90, 100]), + ); + let right = build_table( + ("a2", &vec![1, 2, 2, 3]), + ("b1", &vec![4, 5, 5, 7]), + ("c2", &vec![7, 8, 8, 9]), + ); + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + )]; + + let (_, batches) = join_collect(left, right, on, RightSemi).await?; + let expected = [ + "+----+----+----+", + "| a2 | b1 | c2 |", + "+----+----+----+", + "| 1 | 4 | 7 |", + "| 2 | 5 | 8 |", + "| 2 | 5 | 8 |", + "+----+----+----+", + ]; + assert_batches_eq!(expected, &batches); + Ok(()) + } + + #[tokio::test] + async fn join_right_semi_two() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 2, 3]), + ("b1", &vec![4, 5, 5, 6]), + ("c1", &vec![70, 80, 90, 100]), + ); + let right = build_table( + ("a1", &vec![1, 2, 2, 3]), + ("b1", &vec![4, 5, 5, 7]), + ("c2", &vec![7, 8, 8, 9]), + ); + let on = vec![ + ( + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("a1", &right.schema())?) as _, + ), + ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ), + ]; + + let (_, batches) = join_collect(left, right, on, RightSemi).await?; + let expected = [ + "+----+----+----+", + "| a1 | b1 | c2 |", + "+----+----+----+", + "| 1 | 4 | 7 |", + "| 2 | 5 | 8 |", + "| 2 | 5 | 8 |", + "+----+----+----+", + ]; + assert_batches_eq!(expected, &batches); + Ok(()) + } + + #[tokio::test] + async fn join_right_semi_two_with_filter() -> Result<()> { + let left = build_table(("a1", &vec![1]), ("b1", &vec![10]), ("c1", &vec![30])); + let right = build_table(("a1", &vec![1]), ("b1", &vec![10]), ("c2", &vec![20])); + let on = vec![ + ( + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("a1", &right.schema())?) as _, + ), + ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ), + ]; + let filter = JoinFilter::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("c2", 1)), + Operator::Lt, + Arc::new(Column::new("c1", 0)), + )), + vec![ + ColumnIndex { + index: 2, + side: JoinSide::Left, + }, + ColumnIndex { + index: 2, + side: JoinSide::Right, + }, + ], + Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Int32, true), + Field::new("c2", DataType::Int32, true), + ])), + ); + let (_, batches) = + join_collect_with_filter(left, right, on, filter, RightSemi).await?; + let expected = [ + "+----+----+----+", + "| a1 | b1 | c2 |", + "+----+----+----+", + "| 1 | 10 | 20 |", + "+----+----+----+", + ]; + assert_batches_eq!(expected, &batches); + Ok(()) + } + + #[tokio::test] + async fn join_right_semi_with_nulls() -> Result<()> { + let left = build_table_i32_nullable( + ("a1", &vec![Some(0), Some(1), Some(2), Some(2), Some(3)]), + ("b1", &vec![Some(3), Some(4), Some(5), None, Some(6)]), + ("c2", &vec![Some(60), None, Some(80), Some(85), Some(90)]), + ); + let right = build_table_i32_nullable( + ("a1", &vec![Some(1), Some(2), Some(2), Some(3)]), + ("b1", &vec![Some(4), Some(5), None, Some(6)]), // null in key field + ("c2", &vec![Some(7), Some(8), Some(8), None]), // null in non-key field + ); + let on = vec![ + ( + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("a1", &right.schema())?) as _, + ), + ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ), + ]; + + let (_, batches) = join_collect(left, right, on, RightSemi).await?; + let expected = [ + "+----+----+----+", + "| a1 | b1 | c2 |", + "+----+----+----+", + "| 1 | 4 | 7 |", + "| 2 | 5 | 8 |", + "| 3 | 6 | |", + "+----+----+----+", + ]; + // The output order is important as SMJ preserves sortedness + assert_batches_eq!(expected, &batches); + Ok(()) + } + + #[tokio::test] + async fn join_right_semi_with_nulls_with_options() -> Result<()> { + let left = build_table_i32_nullable( + ("a1", &vec![Some(3), Some(2), Some(1), Some(0), Some(2)]), + ("b1", &vec![None, Some(5), Some(4), None, Some(5)]), + ("c2", &vec![Some(90), Some(80), Some(70), Some(60), None]), + ); + let right = build_table_i32_nullable( + ("a1", &vec![Some(3), Some(2), Some(2), Some(1)]), + ("b1", &vec![None, Some(5), Some(5), Some(4)]), // null in key field + ("c2", &vec![Some(9), None, Some(8), Some(7)]), // null in non-key field + ); + let on = vec![ + ( + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("a1", &right.schema())?) as _, + ), + ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ), + ]; + + let (_, batches) = join_collect_with_options( + left, + right, + on, + RightSemi, + vec![ + SortOptions { + descending: true, + nulls_first: false, + }; + 2 + ], + true, + ) + .await?; + + let expected = [ + "+----+----+----+", + "| a1 | b1 | c2 |", + "+----+----+----+", + "| 3 | | 9 |", + "| 2 | 5 | |", + "| 2 | 5 | 8 |", + "| 1 | 4 | 7 |", + "+----+----+----+", + ]; + // The output order is important as SMJ preserves sortedness + assert_batches_eq!(expected, &batches); + Ok(()) + } + + #[tokio::test] + async fn join_right_semi_output_two_batches() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 2, 3]), + ("b1", &vec![4, 5, 5, 6]), + ("c1", &vec![70, 80, 90, 100]), + ); + let right = build_table( + ("a1", &vec![1, 2, 2, 3]), + ("b1", &vec![4, 5, 5, 7]), + ("c2", &vec![7, 8, 8, 9]), + ); + let on = vec![ + ( + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("a1", &right.schema())?) as _, + ), + ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ), + ]; + + let (_, batches) = + join_collect_batch_size_equals_two(left, right, on, RightSemi).await?; + let expected = [ + "+----+----+----+", + "| a1 | b1 | c2 |", + "+----+----+----+", + "| 1 | 4 | 7 |", + "| 2 | 5 | 8 |", + "| 2 | 5 | 8 |", + "+----+----+----+", + ]; + assert_eq!(batches.len(), 2); + assert_eq!(batches[0].num_rows(), 2); + assert_eq!(batches[1].num_rows(), 1); + assert_batches_eq!(expected, &batches); + Ok(()) + } + #[tokio::test] async fn join_left_mark() -> Result<()> { let left = build_table( @@ -3856,12 +4118,16 @@ mod tests { )]; let sort_options = vec![SortOptions::default(); on.len()]; - let join_types = vec![Inner, Left, Right, Full, LeftSemi, LeftAnti, LeftMark]; + let join_types = vec![ + Inner, Left, Right, RightSemi, Full, LeftSemi, LeftAnti, LeftMark, + ]; // Disable DiskManager to prevent spilling let runtime = RuntimeEnvBuilder::new() .with_memory_limit(100, 1.0) - .with_disk_manager(DiskManagerConfig::Disabled) + .with_disk_manager_builder( + DiskManagerBuilder::default().with_mode(DiskManagerMode::Disabled), + ) .build_arc()?; let session_config = SessionConfig::default().with_batch_size(50); @@ -3934,12 +4200,16 @@ mod tests { )]; let sort_options = vec![SortOptions::default(); on.len()]; - let join_types = vec![Inner, Left, Right, Full, LeftSemi, LeftAnti, LeftMark]; + let join_types = vec![ + Inner, Left, Right, RightSemi, Full, LeftSemi, LeftAnti, LeftMark, + ]; // Disable DiskManager to prevent spilling let runtime = RuntimeEnvBuilder::new() .with_memory_limit(100, 1.0) - .with_disk_manager(DiskManagerConfig::Disabled) + .with_disk_manager_builder( + DiskManagerBuilder::default().with_mode(DiskManagerMode::Disabled), + ) .build_arc()?; let session_config = SessionConfig::default().with_batch_size(50); @@ -3990,12 +4260,16 @@ mod tests { )]; let sort_options = vec![SortOptions::default(); on.len()]; - let join_types = [Inner, Left, Right, Full, LeftSemi, LeftAnti, LeftMark]; + let join_types = [ + Inner, Left, Right, RightSemi, Full, LeftSemi, LeftAnti, LeftMark, + ]; // Enable DiskManager to allow spilling let runtime = RuntimeEnvBuilder::new() .with_memory_limit(100, 1.0) - .with_disk_manager(DiskManagerConfig::NewOs) + .with_disk_manager_builder( + DiskManagerBuilder::default().with_mode(DiskManagerMode::OsTmpDirectory), + ) .build_arc()?; for batch_size in [1, 50] { @@ -4091,12 +4365,16 @@ mod tests { )]; let sort_options = vec![SortOptions::default(); on.len()]; - let join_types = [Inner, Left, Right, Full, LeftSemi, LeftAnti, LeftMark]; + let join_types = [ + Inner, Left, Right, RightSemi, Full, LeftSemi, LeftAnti, LeftMark, + ]; // Enable DiskManager to allow spilling let runtime = RuntimeEnvBuilder::new() .with_memory_limit(500, 1.0) - .with_disk_manager(DiskManagerConfig::NewOs) + .with_disk_manager_builder( + DiskManagerBuilder::default().with_mode(DiskManagerMode::OsTmpDirectory), + ) .build_arc()?; for batch_size in [1, 50] { @@ -4475,169 +4753,177 @@ mod tests { } #[tokio::test] - async fn test_left_semi_join_filtered_mask() -> Result<()> { - let mut joined_batches = build_joined_record_batches()?; - let schema = joined_batches.batches.first().unwrap().schema(); + async fn test_semi_join_filtered_mask() -> Result<()> { + for join_type in [LeftSemi, RightSemi] { + let mut joined_batches = build_joined_record_batches()?; + let schema = joined_batches.batches.first().unwrap().schema(); - let output = concat_batches(&schema, &joined_batches.batches)?; - let out_mask = joined_batches.filter_mask.finish(); - let out_indices = joined_batches.row_indices.finish(); + let output = concat_batches(&schema, &joined_batches.batches)?; + let out_mask = joined_batches.filter_mask.finish(); + let out_indices = joined_batches.row_indices.finish(); - assert_eq!( - get_corrected_filter_mask( - LeftSemi, - &UInt64Array::from(vec![0]), - &[0usize], - &BooleanArray::from(vec![true]), - output.num_rows() - ) - .unwrap(), - BooleanArray::from(vec![true]) - ); + assert_eq!( + get_corrected_filter_mask( + join_type, + &UInt64Array::from(vec![0]), + &[0usize], + &BooleanArray::from(vec![true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![true]) + ); - assert_eq!( - get_corrected_filter_mask( - LeftSemi, - &UInt64Array::from(vec![0]), - &[0usize], - &BooleanArray::from(vec![false]), - output.num_rows() - ) - .unwrap(), - BooleanArray::from(vec![None]) - ); + assert_eq!( + get_corrected_filter_mask( + join_type, + &UInt64Array::from(vec![0]), + &[0usize], + &BooleanArray::from(vec![false]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![None]) + ); - assert_eq!( - get_corrected_filter_mask( - LeftSemi, - &UInt64Array::from(vec![0, 0]), - &[0usize; 2], - &BooleanArray::from(vec![true, true]), - output.num_rows() - ) - .unwrap(), - BooleanArray::from(vec![Some(true), None]) - ); + assert_eq!( + get_corrected_filter_mask( + join_type, + &UInt64Array::from(vec![0, 0]), + &[0usize; 2], + &BooleanArray::from(vec![true, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![Some(true), None]) + ); - assert_eq!( - get_corrected_filter_mask( - LeftSemi, - &UInt64Array::from(vec![0, 0, 0]), - &[0usize; 3], - &BooleanArray::from(vec![true, true, true]), - output.num_rows() - ) - .unwrap(), - BooleanArray::from(vec![Some(true), None, None]) - ); + assert_eq!( + get_corrected_filter_mask( + join_type, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![true, true, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![Some(true), None, None]) + ); - assert_eq!( - get_corrected_filter_mask( - LeftSemi, - &UInt64Array::from(vec![0, 0, 0]), - &[0usize; 3], - &BooleanArray::from(vec![true, false, true]), - output.num_rows() - ) - .unwrap(), - BooleanArray::from(vec![Some(true), None, None]) - ); + assert_eq!( + get_corrected_filter_mask( + join_type, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![true, false, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![Some(true), None, None]) + ); - assert_eq!( - get_corrected_filter_mask( - LeftSemi, - &UInt64Array::from(vec![0, 0, 0]), - &[0usize; 3], - &BooleanArray::from(vec![false, false, true]), - output.num_rows() - ) - .unwrap(), - BooleanArray::from(vec![None, None, Some(true),]) - ); + assert_eq!( + get_corrected_filter_mask( + join_type, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![false, false, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![None, None, Some(true),]) + ); - assert_eq!( - get_corrected_filter_mask( - LeftSemi, - &UInt64Array::from(vec![0, 0, 0]), - &[0usize; 3], - &BooleanArray::from(vec![false, true, true]), - output.num_rows() - ) - .unwrap(), - BooleanArray::from(vec![None, Some(true), None]) - ); + assert_eq!( + get_corrected_filter_mask( + join_type, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![false, true, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![None, Some(true), None]) + ); - assert_eq!( - get_corrected_filter_mask( - LeftSemi, - &UInt64Array::from(vec![0, 0, 0]), - &[0usize; 3], - &BooleanArray::from(vec![false, false, false]), - output.num_rows() - ) - .unwrap(), - BooleanArray::from(vec![None, None, None]) - ); + assert_eq!( + get_corrected_filter_mask( + join_type, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![false, false, false]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![None, None, None]) + ); - let corrected_mask = get_corrected_filter_mask( - LeftSemi, - &out_indices, - &joined_batches.batch_ids, - &out_mask, - output.num_rows(), - ) - .unwrap(); + let corrected_mask = get_corrected_filter_mask( + join_type, + &out_indices, + &joined_batches.batch_ids, + &out_mask, + output.num_rows(), + ) + .unwrap(); - assert_eq!( - corrected_mask, - BooleanArray::from(vec![ - Some(true), - None, - Some(true), - None, - Some(true), - None, - None, - None - ]) - ); + assert_eq!( + corrected_mask, + BooleanArray::from(vec![ + Some(true), + None, + Some(true), + None, + Some(true), + None, + None, + None + ]) + ); - let filtered_rb = filter_record_batch(&output, &corrected_mask)?; + let filtered_rb = filter_record_batch(&output, &corrected_mask)?; - assert_snapshot!(batches_to_string(&[filtered_rb]), @r#" - +---+----+---+----+ - | a | b | x | y | - +---+----+---+----+ - | 1 | 10 | 1 | 11 | - | 1 | 11 | 1 | 12 | - | 1 | 12 | 1 | 13 | - +---+----+---+----+ - "#); + assert_batches_eq!( + &[ + "+---+----+---+----+", + "| a | b | x | y |", + "+---+----+---+----+", + "| 1 | 10 | 1 | 11 |", + "| 1 | 11 | 1 | 12 |", + "| 1 | 12 | 1 | 13 |", + "+---+----+---+----+", + ], + &[filtered_rb] + ); - // output null rows - let null_mask = arrow::compute::not(&corrected_mask)?; - assert_eq!( - null_mask, - BooleanArray::from(vec![ - Some(false), - None, - Some(false), - None, - Some(false), - None, - None, - None - ]) - ); + // output null rows + let null_mask = arrow::compute::not(&corrected_mask)?; + assert_eq!( + null_mask, + BooleanArray::from(vec![ + Some(false), + None, + Some(false), + None, + Some(false), + None, + None, + None + ]) + ); - let null_joined_batch = filter_record_batch(&output, &null_mask)?; + let null_joined_batch = filter_record_batch(&output, &null_mask)?; - assert_snapshot!(batches_to_string(&[null_joined_batch]), @r#" - +---+---+---+---+ - | a | b | x | y | - +---+---+---+---+ - +---+---+---+---+ - "#); + assert_batches_eq!( + &[ + "+---+---+---+---+", + "| a | b | x | y |", + "+---+---+---+---+", + "+---+---+---+---+", + ], + &[null_joined_batch] + ); + } Ok(()) } diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs index 0dcb42169e00..819a3302b062 100644 --- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs +++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs @@ -372,7 +372,7 @@ impl DisplayAs for SymmetricHashJoinExec { let on = self .on .iter() - .map(|(c1, c2)| format!("({}, {})", c1, c2)) + .map(|(c1, c2)| format!("({c1}, {c2})")) .collect::>() .join(", "); write!( @@ -395,7 +395,7 @@ impl DisplayAs for SymmetricHashJoinExec { if *self.join_type() != JoinType::Inner { writeln!(f, "join_type={:?}", self.join_type)?; } - writeln!(f, "on={}", on) + writeln!(f, "on={on}") } } } diff --git a/datafusion/physical-plan/src/joins/test_utils.rs b/datafusion/physical-plan/src/joins/test_utils.rs index d38637dae028..81f56c865f04 100644 --- a/datafusion/physical-plan/src/joins/test_utils.rs +++ b/datafusion/physical-plan/src/joins/test_utils.rs @@ -195,7 +195,7 @@ struct AscendingRandomFloatIterator { impl AscendingRandomFloatIterator { fn new(min: f64, max: f64) -> Self { let mut rng = StdRng::seed_from_u64(42); - let initial = rng.gen_range(min..max); + let initial = rng.random_range(min..max); AscendingRandomFloatIterator { prev: initial, max, @@ -208,7 +208,7 @@ impl Iterator for AscendingRandomFloatIterator { type Item = f64; fn next(&mut self) -> Option { - let value = self.rng.gen_range(self.prev..self.max); + let value = self.rng.random_range(self.prev..self.max); self.prev = value; Some(value) } diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index 5516f172d510..3abeff6621d2 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -314,12 +314,18 @@ pub fn build_join_schema( JoinType::RightSemi | JoinType::RightAnti => right_fields().unzip(), }; - let metadata = left + let (schema1, schema2) = match join_type { + JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => (left, right), + _ => (right, left), + }; + + let metadata = schema1 .metadata() .clone() .into_iter() - .chain(right.metadata().clone()) + .chain(schema2.metadata().clone()) .collect(); + (fields.finish().with_metadata(metadata), column_indices) } @@ -403,15 +409,12 @@ struct PartialJoinStatistics { /// Estimate the statistics for the given join's output. pub(crate) fn estimate_join_statistics( - left: Arc, - right: Arc, + left_stats: Statistics, + right_stats: Statistics, on: JoinOn, join_type: &JoinType, schema: &Schema, ) -> Result { - let left_stats = left.statistics()?; - let right_stats = right.statistics()?; - let join_stats = estimate_join_cardinality(join_type, left_stats, right_stats, &on); let (num_rows, column_statistics) = match join_stats { Some(stats) => (Precision::Inexact(stats.num_rows), stats.column_statistics), @@ -1501,6 +1504,7 @@ pub(super) fn swap_join_projection( #[cfg(test)] mod tests { use super::*; + use std::collections::HashMap; use std::pin::Pin; use arrow::array::Int32Array; @@ -2245,8 +2249,7 @@ mod tests { assert_eq!( output_cardinality, expected, - "failure for join_type: {}", - join_type + "failure for join_type: {join_type}" ); } @@ -2499,4 +2502,28 @@ mod tests { assert_eq!(col.name(), name); assert_eq!(col.index(), index); } + + #[test] + fn test_join_metadata() -> Result<()> { + let left_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]) + .with_metadata(HashMap::from([("key".to_string(), "left".to_string())])); + + let right_schema = Schema::new(vec![Field::new("b", DataType::Int32, false)]) + .with_metadata(HashMap::from([("key".to_string(), "right".to_string())])); + + let (join_schema, _) = + build_join_schema(&left_schema, &right_schema, &JoinType::Left); + assert_eq!( + join_schema.metadata(), + &HashMap::from([("key".to_string(), "left".to_string())]) + ); + let (join_schema, _) = + build_join_schema(&left_schema, &right_schema, &JoinType::Right); + assert_eq!( + join_schema.metadata(), + &HashMap::from([("key".to_string(), "right".to_string())]) + ); + + Ok(()) + } } diff --git a/datafusion/physical-plan/src/lib.rs b/datafusion/physical-plan/src/lib.rs index b256e615b232..ba423f958c78 100644 --- a/datafusion/physical-plan/src/lib.rs +++ b/datafusion/physical-plan/src/lib.rs @@ -67,6 +67,7 @@ pub mod empty; pub mod execution_plan; pub mod explain; pub mod filter; +pub mod filter_pushdown; pub mod joins; pub mod limit; pub mod memory; @@ -91,5 +92,5 @@ pub mod udaf { } pub mod coalesce; -#[cfg(test)] +#[cfg(any(test, feature = "bench"))] pub mod test; diff --git a/datafusion/physical-plan/src/limit.rs b/datafusion/physical-plan/src/limit.rs index 89cf47a6d650..2224f85cc122 100644 --- a/datafusion/physical-plan/src/limit.rs +++ b/datafusion/physical-plan/src/limit.rs @@ -110,7 +110,7 @@ impl DisplayAs for GlobalLimitExec { } DisplayFormatType::TreeRender => { if let Some(fetch) = self.fetch { - writeln!(f, "limit={}", fetch)?; + writeln!(f, "limit={fetch}")?; } write!(f, "skip={}", self.skip) } @@ -164,10 +164,7 @@ impl ExecutionPlan for GlobalLimitExec { partition: usize, context: Arc, ) -> Result { - trace!( - "Start GlobalLimitExec::execute for partition: {}", - partition - ); + trace!("Start GlobalLimitExec::execute for partition: {partition}"); // GlobalLimitExec has a single output partition if 0 != partition { return internal_err!("GlobalLimitExec invalid partition {partition}"); @@ -193,8 +190,11 @@ impl ExecutionPlan for GlobalLimitExec { } fn statistics(&self) -> Result { - Statistics::with_fetch( - self.input.statistics()?, + self.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + self.input.partition_statistics(partition)?.with_fetch( self.schema(), self.fetch, self.skip, @@ -334,8 +334,11 @@ impl ExecutionPlan for LocalLimitExec { } fn statistics(&self) -> Result { - Statistics::with_fetch( - self.input.statistics()?, + self.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + self.input.partition_statistics(partition)?.with_fetch( self.schema(), Some(self.fetch), 0, @@ -765,7 +768,7 @@ mod tests { let offset = GlobalLimitExec::new(Arc::new(CoalescePartitionsExec::new(csv)), skip, fetch); - Ok(offset.statistics()?.num_rows) + Ok(offset.partition_statistics(None)?.num_rows) } pub fn build_group_by( @@ -805,7 +808,7 @@ mod tests { fetch, ); - Ok(offset.statistics()?.num_rows) + Ok(offset.partition_statistics(None)?.num_rows) } async fn row_number_statistics_for_local_limit( @@ -818,7 +821,7 @@ mod tests { let offset = LocalLimitExec::new(csv, fetch); - Ok(offset.statistics()?.num_rows) + Ok(offset.partition_statistics(None)?.num_rows) } /// Return a RecordBatch with a single array with row_count sz diff --git a/datafusion/physical-plan/src/memory.rs b/datafusion/physical-plan/src/memory.rs index 1bc872a56e76..c232970b2188 100644 --- a/datafusion/physical-plan/src/memory.rs +++ b/datafusion/physical-plan/src/memory.rs @@ -23,6 +23,7 @@ use std::sync::Arc; use std::task::{Context, Poll}; use crate::execution_plan::{Boundedness, EmissionType}; +use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; use crate::{ DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics, @@ -146,6 +147,8 @@ pub struct LazyMemoryExec { batch_generators: Vec>>, /// Plan properties cache storing equivalence properties, partitioning, and execution mode cache: PlanProperties, + /// Execution metrics + metrics: ExecutionPlanMetricsSet, } impl LazyMemoryExec { @@ -164,6 +167,7 @@ impl LazyMemoryExec { schema, batch_generators: generators, cache, + metrics: ExecutionPlanMetricsSet::new(), }) } } @@ -254,12 +258,18 @@ impl ExecutionPlan for LazyMemoryExec { ); } + let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); Ok(Box::pin(LazyMemoryStream { schema: Arc::clone(&self.schema), generator: Arc::clone(&self.batch_generators[partition]), + baseline_metrics, })) } + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } + fn statistics(&self) -> Result { Ok(Statistics::new_unknown(&self.schema)) } @@ -276,6 +286,8 @@ pub struct LazyMemoryStream { /// parallel execution. /// Sharing generators between streams should be used with caution. generator: Arc>, + /// Execution metrics + baseline_metrics: BaselineMetrics, } impl Stream for LazyMemoryStream { @@ -285,13 +297,16 @@ impl Stream for LazyMemoryStream { self: std::pin::Pin<&mut Self>, _: &mut Context<'_>, ) -> Poll> { + let _timer_guard = self.baseline_metrics.elapsed_compute().timer(); let batch = self.generator.write().generate_next_batch(); - match batch { + let poll = match batch { Ok(Some(batch)) => Poll::Ready(Some(Ok(batch))), Ok(None) => Poll::Ready(None), Err(e) => Poll::Ready(Some(Err(e))), - } + }; + + self.baseline_metrics.record_poll(poll) } } @@ -304,6 +319,7 @@ impl RecordBatchStream for LazyMemoryStream { #[cfg(test)] mod lazy_memory_tests { use super::*; + use crate::common::collect; use arrow::array::Int64Array; use arrow::datatypes::{DataType, Field, Schema}; use futures::StreamExt; @@ -419,4 +435,45 @@ mod lazy_memory_tests { Ok(()) } + + #[tokio::test] + async fn test_generate_series_metrics_integration() -> Result<()> { + // Test LazyMemoryExec metrics with different configurations + let test_cases = vec![ + (10, 2, 10), // 10 rows, batch size 2, expected 10 rows + (100, 10, 100), // 100 rows, batch size 10, expected 100 rows + (5, 1, 5), // 5 rows, batch size 1, expected 5 rows + ]; + + for (total_rows, batch_size, expected_rows) in test_cases { + let schema = + Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)])); + let generator = TestGenerator { + counter: 0, + max_batches: (total_rows + batch_size - 1) / batch_size, // ceiling division + batch_size: batch_size as usize, + schema: Arc::clone(&schema), + }; + + let exec = + LazyMemoryExec::try_new(schema, vec![Arc::new(RwLock::new(generator))])?; + let task_ctx = Arc::new(TaskContext::default()); + + let stream = exec.execute(0, task_ctx)?; + let batches = collect(stream).await?; + + // Verify metrics exist with actual expected numbers + let metrics = exec.metrics().unwrap(); + + // Count actual rows returned + let actual_rows: usize = batches.iter().map(|b| b.num_rows()).sum(); + assert_eq!(actual_rows, expected_rows); + + // Verify metrics match actual output + assert_eq!(metrics.output_rows().unwrap(), expected_rows); + assert!(metrics.elapsed_compute().unwrap() > 0); + } + + Ok(()) + } } diff --git a/datafusion/physical-plan/src/metrics/baseline.rs b/datafusion/physical-plan/src/metrics/baseline.rs index a4a83b84b655..de436d0e4f5c 100644 --- a/datafusion/physical-plan/src/metrics/baseline.rs +++ b/datafusion/physical-plan/src/metrics/baseline.rs @@ -117,9 +117,10 @@ impl BaselineMetrics { } } - /// Process a poll result of a stream producing output for an - /// operator, recording the output rows and stream done time and - /// returning the same poll result + /// Process a poll result of a stream producing output for an operator. + /// + /// Note: this method only updates `output_rows` and `end_time` metrics. + /// Remember to update `elapsed_compute` and other metrics manually. pub fn record_poll( &self, poll: Poll>>, diff --git a/datafusion/physical-plan/src/metrics/value.rs b/datafusion/physical-plan/src/metrics/value.rs index decf77369db4..249cd5edb133 100644 --- a/datafusion/physical-plan/src/metrics/value.rs +++ b/datafusion/physical-plan/src/metrics/value.rs @@ -29,6 +29,7 @@ use std::{ use chrono::{DateTime, Utc}; use datafusion_common::instant::Instant; +use datafusion_execution::memory_pool::human_readable_size; use parking_lot::Mutex; /// A counter to record things such as number of input or output rows @@ -554,11 +555,14 @@ impl Display for MetricValue { match self { Self::OutputRows(count) | Self::SpillCount(count) - | Self::SpilledBytes(count) | Self::SpilledRows(count) | Self::Count { count, .. } => { write!(f, "{count}") } + Self::SpilledBytes(count) => { + let readable_count = human_readable_size(count.value()); + write!(f, "{readable_count}") + } Self::CurrentMemoryUsage(gauge) | Self::Gauge { gauge, .. } => { write!(f, "{gauge}") } @@ -581,6 +585,7 @@ impl Display for MetricValue { #[cfg(test)] mod tests { use chrono::TimeZone; + use datafusion_execution::memory_pool::units::MB; use super::*; @@ -605,6 +610,20 @@ mod tests { } } + #[test] + fn test_display_spilled_bytes() { + let count = Count::new(); + let spilled_byte = MetricValue::SpilledBytes(count.clone()); + + assert_eq!("0.0 B", spilled_byte.to_string()); + + count.add((100 * MB) as usize); + assert_eq!("100.0 MB", spilled_byte.to_string()); + + count.add((0.5 * MB as f64) as usize); + assert_eq!("100.5 MB", spilled_byte.to_string()); + } + #[test] fn test_display_time() { let time = Time::new(); diff --git a/datafusion/physical-plan/src/placeholder_row.rs b/datafusion/physical-plan/src/placeholder_row.rs index eecd980d09f8..46847b2413c0 100644 --- a/datafusion/physical-plan/src/placeholder_row.rs +++ b/datafusion/physical-plan/src/placeholder_row.rs @@ -166,6 +166,13 @@ impl ExecutionPlan for PlaceholderRowExec { } fn statistics(&self) -> Result { + self.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + if partition.is_some() { + return Ok(Statistics::new_unknown(&self.schema())); + } let batch = self .data() .expect("Create single row placeholder RecordBatch should not fail"); diff --git a/datafusion/physical-plan/src/projection.rs b/datafusion/physical-plan/src/projection.rs index 72934c74446e..f1621acd0deb 100644 --- a/datafusion/physical-plan/src/projection.rs +++ b/datafusion/physical-plan/src/projection.rs @@ -26,7 +26,7 @@ use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use super::expressions::{CastExpr, Column, Literal}; +use super::expressions::{Column, Literal}; use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; use super::{ DisplayAs, ExecutionPlanProperties, PlanProperties, RecordBatchStream, @@ -79,14 +79,14 @@ impl ProjectionExec { let fields: Result> = expr .iter() .map(|(e, name)| { - let mut field = Field::new( + let metadata = e.return_field(&input_schema)?.metadata().clone(); + + let field = Field::new( name, e.data_type(&input_schema)?, e.nullable(&input_schema)?, - ); - field.set_metadata( - get_field_metadata(e, &input_schema).unwrap_or_default(), - ); + ) + .with_metadata(metadata); Ok(field) }) @@ -198,23 +198,11 @@ impl ExecutionPlan for ProjectionExec { &self.cache } - fn children(&self) -> Vec<&Arc> { - vec![&self.input] - } - fn maintains_input_order(&self) -> Vec { // Tell optimizer this operator doesn't reorder its input vec![true] } - fn with_new_children( - self: Arc, - mut children: Vec>, - ) -> Result> { - ProjectionExec::try_new(self.expr.clone(), children.swap_remove(0)) - .map(|p| Arc::new(p) as _) - } - fn benefits_from_input_partitioning(&self) -> Vec { let all_simple_exprs = self .expr @@ -225,6 +213,18 @@ impl ExecutionPlan for ProjectionExec { vec![!all_simple_exprs] } + fn children(&self) -> Vec<&Arc> { + vec![&self.input] + } + + fn with_new_children( + self: Arc, + mut children: Vec>, + ) -> Result> { + ProjectionExec::try_new(self.expr.clone(), children.swap_remove(0)) + .map(|p| Arc::new(p) as _) + } + fn execute( &self, partition: usize, @@ -244,8 +244,13 @@ impl ExecutionPlan for ProjectionExec { } fn statistics(&self) -> Result { + self.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + let input_stats = self.input.partition_statistics(partition)?; Ok(stats_projection( - self.input.statistics()?, + input_stats, self.expr.iter().map(|(e, _)| Arc::clone(e)), Arc::clone(&self.schema), )) @@ -273,24 +278,6 @@ impl ExecutionPlan for ProjectionExec { } } -/// If 'e' is a direct column reference, returns the field level -/// metadata for that field, if any. Otherwise returns None -pub(crate) fn get_field_metadata( - e: &Arc, - input_schema: &Schema, -) -> Option> { - if let Some(cast) = e.as_any().downcast_ref::() { - return get_field_metadata(cast.expr(), input_schema); - } - - // Look up field by index in schema (not NAME as there can be more than one - // column with the same name) - e.as_any() - .downcast_ref::() - .map(|column| input_schema.field(column.index()).metadata()) - .cloned() -} - fn stats_projection( mut stats: Statistics, exprs: impl Iterator>, @@ -538,7 +525,7 @@ pub fn remove_unnecessary_projections( } else { return Ok(Transformed::no(plan)); }; - Ok(maybe_modified.map_or(Transformed::no(plan), Transformed::yes)) + Ok(maybe_modified.map_or_else(|| Transformed::no(plan), Transformed::yes)) } /// Compare the inputs and outputs of the projection. All expressions must be @@ -1093,13 +1080,11 @@ mod tests { let task_ctx = Arc::new(TaskContext::default()); let exec = test::scan_partitioned(1); - let expected = collect(exec.execute(0, Arc::clone(&task_ctx))?) - .await - .unwrap(); + let expected = collect(exec.execute(0, Arc::clone(&task_ctx))?).await?; let projection = ProjectionExec::try_new(vec![], exec)?; let stream = projection.execute(0, Arc::clone(&task_ctx))?; - let output = collect(stream).await.unwrap(); + let output = collect(stream).await?; assert_eq!(output.len(), expected.len()); Ok(()) diff --git a/datafusion/physical-plan/src/recursive_query.rs b/datafusion/physical-plan/src/recursive_query.rs index 7268735ea457..210db90c3c7f 100644 --- a/datafusion/physical-plan/src/recursive_query.rs +++ b/datafusion/physical-plan/src/recursive_query.rs @@ -184,8 +184,7 @@ impl ExecutionPlan for RecursiveQueryExec { // TODO: we might be able to handle multiple partitions in the future. if partition != 0 { return Err(DataFusionError::Internal(format!( - "RecursiveQueryExec got an invalid partition {} (expected 0)", - partition + "RecursiveQueryExec got an invalid partition {partition} (expected 0)" ))); } diff --git a/datafusion/physical-plan/src/repartition/mod.rs b/datafusion/physical-plan/src/repartition/mod.rs index 71479ffa960d..d0ad50666416 100644 --- a/datafusion/physical-plan/src/repartition/mod.rs +++ b/datafusion/physical-plan/src/repartition/mod.rs @@ -19,6 +19,7 @@ //! partitions to M output partitions based on a partitioning scheme, optionally //! maintaining the order of the input rows in the output. +use std::fmt::{Debug, Formatter}; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; @@ -43,8 +44,9 @@ use crate::{DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, Stat use arrow::array::{PrimitiveArray, RecordBatch, RecordBatchOptions}; use arrow::compute::take_arrays; use arrow::datatypes::{SchemaRef, UInt32Type}; +use datafusion_common::config::ConfigOptions; use datafusion_common::utils::transpose; -use datafusion_common::HashMap; +use datafusion_common::{internal_err, HashMap}; use datafusion_common::{not_impl_err, DataFusionError, Result}; use datafusion_common_runtime::SpawnedTask; use datafusion_execution::memory_pool::MemoryConsumer; @@ -52,6 +54,9 @@ use datafusion_execution::TaskContext; use datafusion_physical_expr::{EquivalenceProperties, PhysicalExpr}; use datafusion_physical_expr_common::sort_expr::LexOrdering; +use crate::filter_pushdown::{ + ChildPushdownResult, FilterDescription, FilterPushdownPropagation, +}; use futures::stream::Stream; use futures::{FutureExt, StreamExt, TryStreamExt}; use log::trace; @@ -63,9 +68,8 @@ type MaybeBatch = Option>; type InputPartitionsToCurrentPartitionSender = Vec>; type InputPartitionsToCurrentPartitionReceiver = Vec>; -/// Inner state of [`RepartitionExec`]. #[derive(Debug)] -struct RepartitionExecState { +struct ConsumingInputStreamsState { /// Channels for sending batches from input partitions to output partitions. /// Key is the partition number. channels: HashMap< @@ -81,16 +85,97 @@ struct RepartitionExecState { abort_helper: Arc>>, } +/// Inner state of [`RepartitionExec`]. +enum RepartitionExecState { + /// Not initialized yet. This is the default state stored in the RepartitionExec node + /// upon instantiation. + NotInitialized, + /// Input streams are initialized, but they are still not being consumed. The node + /// transitions to this state when the arrow's RecordBatch stream is created in + /// RepartitionExec::execute(), but before any message is polled. + InputStreamsInitialized(Vec<(SendableRecordBatchStream, RepartitionMetrics)>), + /// The input streams are being consumed. The node transitions to this state when + /// the first message in the arrow's RecordBatch stream is consumed. + ConsumingInputStreams(ConsumingInputStreamsState), +} + +impl Default for RepartitionExecState { + fn default() -> Self { + Self::NotInitialized + } +} + +impl Debug for RepartitionExecState { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + RepartitionExecState::NotInitialized => write!(f, "NotInitialized"), + RepartitionExecState::InputStreamsInitialized(v) => { + write!(f, "InputStreamsInitialized({:?})", v.len()) + } + RepartitionExecState::ConsumingInputStreams(v) => { + write!(f, "ConsumingInputStreams({v:?})") + } + } + } +} + impl RepartitionExecState { - fn new( + fn ensure_input_streams_initialized( + &mut self, input: Arc, - partitioning: Partitioning, metrics: ExecutionPlanMetricsSet, + output_partitions: usize, + ctx: Arc, + ) -> Result<()> { + if !matches!(self, RepartitionExecState::NotInitialized) { + return Ok(()); + } + + let num_input_partitions = input.output_partitioning().partition_count(); + let mut streams_and_metrics = Vec::with_capacity(num_input_partitions); + + for i in 0..num_input_partitions { + let metrics = RepartitionMetrics::new(i, output_partitions, &metrics); + + let timer = metrics.fetch_time.timer(); + let stream = input.execute(i, Arc::clone(&ctx))?; + timer.done(); + + streams_and_metrics.push((stream, metrics)); + } + *self = RepartitionExecState::InputStreamsInitialized(streams_and_metrics); + Ok(()) + } + + fn consume_input_streams( + &mut self, + input: Arc, + metrics: ExecutionPlanMetricsSet, + partitioning: Partitioning, preserve_order: bool, name: String, context: Arc, - ) -> Self { - let num_input_partitions = input.output_partitioning().partition_count(); + ) -> Result<&mut ConsumingInputStreamsState> { + let streams_and_metrics = match self { + RepartitionExecState::NotInitialized => { + self.ensure_input_streams_initialized( + input, + metrics, + partitioning.partition_count(), + Arc::clone(&context), + )?; + let RepartitionExecState::InputStreamsInitialized(value) = self else { + // This cannot happen, as ensure_input_streams_initialized() was just called, + // but the compiler does not know. + return internal_err!("Programming error: RepartitionExecState must be in the InputStreamsInitialized state after calling RepartitionExecState::ensure_input_streams_initialized"); + }; + value + } + RepartitionExecState::ConsumingInputStreams(value) => return Ok(value), + RepartitionExecState::InputStreamsInitialized(value) => value, + }; + + let num_input_partitions = streams_and_metrics.len(); let num_output_partitions = partitioning.partition_count(); let (txs, rxs) = if preserve_order { @@ -117,7 +202,7 @@ impl RepartitionExecState { let mut channels = HashMap::with_capacity(txs.len()); for (partition, (tx, rx)) in txs.into_iter().zip(rxs).enumerate() { let reservation = Arc::new(Mutex::new( - MemoryConsumer::new(format!("{}[{partition}]", name)) + MemoryConsumer::new(format!("{name}[{partition}]")) .register(context.memory_pool()), )); channels.insert(partition, (tx, rx, reservation)); @@ -125,7 +210,9 @@ impl RepartitionExecState { // launch one async task per *input* partition let mut spawned_tasks = Vec::with_capacity(num_input_partitions); - for i in 0..num_input_partitions { + for (i, (stream, metrics)) in + std::mem::take(streams_and_metrics).into_iter().enumerate() + { let txs: HashMap<_, _> = channels .iter() .map(|(partition, (tx, _rx, reservation))| { @@ -133,15 +220,11 @@ impl RepartitionExecState { }) .collect(); - let r_metrics = RepartitionMetrics::new(i, num_output_partitions, &metrics); - let input_task = SpawnedTask::spawn(RepartitionExec::pull_from_input( - Arc::clone(&input), - i, + stream, txs.clone(), partitioning.clone(), - r_metrics, - Arc::clone(&context), + metrics, )); // In a separate task, wait for each input to be done @@ -154,28 +237,17 @@ impl RepartitionExecState { )); spawned_tasks.push(wait_for_task); } - - Self { + *self = Self::ConsumingInputStreams(ConsumingInputStreamsState { channels, abort_helper: Arc::new(spawned_tasks), + }); + match self { + RepartitionExecState::ConsumingInputStreams(value) => Ok(value), + _ => unreachable!(), } } } -/// Lazily initialized state -/// -/// Note that the state is initialized ONCE for all partitions by a single task(thread). -/// This may take a short while. It is also like that multiple threads -/// call execute at the same time, because we have just started "target partitions" tasks -/// which is commonly set to the number of CPU cores and all call execute at the same time. -/// -/// Thus, use a **tokio** `OnceCell` for this initialization so as not to waste CPU cycles -/// in a mutex lock but instead allow other threads to do something useful. -/// -/// Uses a parking_lot `Mutex` to control other accesses as they are very short duration -/// (e.g. removing channels on completion) where the overhead of `await` is not warranted. -type LazyState = Arc>>; - /// A utility that can be used to partition batches based on [`Partitioning`] pub struct BatchPartitioner { state: BatchPartitionerState, @@ -402,8 +474,9 @@ impl BatchPartitioner { pub struct RepartitionExec { /// Input execution plan input: Arc, - /// Inner state that is initialized when the first output stream is created. - state: LazyState, + /// Inner state that is initialized when the parent calls .execute() on this node + /// and consumed as soon as the parent starts consuming this node. + state: Arc>, /// Execution metrics metrics: ExecutionPlanMetricsSet, /// Boolean flag to decide whether to preserve ordering. If true means @@ -482,11 +555,7 @@ impl RepartitionExec { } impl DisplayAs for RepartitionExec { - fn fmt_as( - &self, - t: DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { + fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { write!( @@ -513,11 +582,10 @@ impl DisplayAs for RepartitionExec { self.input.output_partitioning().partition_count(); let output_partition_count = self.partitioning().partition_count(); let input_to_output_partition_str = - format!("{} -> {}", input_partition_count, output_partition_count); + format!("{input_partition_count} -> {output_partition_count}"); writeln!( f, - "partition_count(in->out)={}", - input_to_output_partition_str + "partition_count(in->out)={input_to_output_partition_str}" )?; if self.preserve_order { @@ -580,7 +648,6 @@ impl ExecutionPlan for RepartitionExec { partition ); - let lazy_state = Arc::clone(&self.state); let input = Arc::clone(&self.input); let partitioning = self.partitioning().clone(); let metrics = self.metrics.clone(); @@ -592,30 +659,31 @@ impl ExecutionPlan for RepartitionExec { // Get existing ordering to use for merging let sort_exprs = self.sort_exprs().cloned().unwrap_or_default(); + let state = Arc::clone(&self.state); + if let Some(mut state) = state.try_lock() { + state.ensure_input_streams_initialized( + Arc::clone(&input), + metrics.clone(), + partitioning.partition_count(), + Arc::clone(&context), + )?; + } + let stream = futures::stream::once(async move { let num_input_partitions = input.output_partitioning().partition_count(); - let input_captured = Arc::clone(&input); - let metrics_captured = metrics.clone(); - let name_captured = name.clone(); - let context_captured = Arc::clone(&context); - let state = lazy_state - .get_or_init(|| async move { - Mutex::new(RepartitionExecState::new( - input_captured, - partitioning, - metrics_captured, - preserve_order, - name_captured, - context_captured, - )) - }) - .await; - // lock scope let (mut rx, reservation, abort_helper) = { // lock mutexes let mut state = state.lock(); + let state = state.consume_input_streams( + Arc::clone(&input), + metrics.clone(), + partitioning, + preserve_order, + name.clone(), + Arc::clone(&context), + )?; // now return stream for the specified *output* partition which will // read from the channel @@ -628,9 +696,7 @@ impl ExecutionPlan for RepartitionExec { }; trace!( - "Before returning stream in {}::execute for partition: {}", - name, - partition + "Before returning stream in {name}::execute for partition: {partition}" ); if preserve_order { @@ -652,7 +718,7 @@ impl ExecutionPlan for RepartitionExec { // input partitions to this partition: let fetch = None; let merge_reservation = - MemoryConsumer::new(format!("{}[Merge {partition}]", name)) + MemoryConsumer::new(format!("{name}[Merge {partition}]")) .register(context.memory_pool()); StreamingMergeBuilder::new() .with_streams(input_streams) @@ -684,7 +750,15 @@ impl ExecutionPlan for RepartitionExec { } fn statistics(&self) -> Result { - self.input.statistics() + self.input.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + if partition.is_none() { + self.input.partition_statistics(None) + } else { + Ok(Statistics::new_unknown(&self.schema())) + } } fn cardinality_effect(&self) -> CardinalityEffect { @@ -730,6 +804,25 @@ impl ExecutionPlan for RepartitionExec { new_partitioning, )?))) } + + fn gather_filters_for_pushdown( + &self, + parent_filters: Vec>, + _config: &ConfigOptions, + ) -> Result { + Ok(FilterDescription::new_with_child_count(1) + .all_parent_filters_supported(parent_filters)) + } + + fn handle_child_pushdown_result( + &self, + child_pushdown_result: ChildPushdownResult, + _config: &ConfigOptions, + ) -> Result>> { + Ok(FilterPushdownPropagation::transparent( + child_pushdown_result, + )) + } } impl RepartitionExec { @@ -825,24 +918,17 @@ impl RepartitionExec { /// /// txs hold the output sending channels for each output partition async fn pull_from_input( - input: Arc, - partition: usize, + mut stream: SendableRecordBatchStream, mut output_channels: HashMap< usize, (DistributionSender, SharedMemoryReservation), >, partitioning: Partitioning, metrics: RepartitionMetrics, - context: Arc, ) -> Result<()> { let mut partitioner = BatchPartitioner::try_new(partitioning, metrics.repartition_time.clone())?; - // execute the child operator - let timer = metrics.fetch_time.timer(); - let mut stream = input.execute(partition, context)?; - timer.done(); - // While there are still outputs to send to, keep pulling inputs let mut batches_until_yield = partitioner.num_partitions(); while !output_channels.is_empty() { @@ -1090,6 +1176,7 @@ mod tests { use datafusion_common_runtime::JoinSet; use datafusion_execution::runtime_env::RuntimeEnvBuilder; use insta::assert_snapshot; + use itertools::Itertools; #[tokio::test] async fn one_to_many_round_robin() -> Result<()> { @@ -1270,15 +1357,9 @@ mod tests { let partitioning = Partitioning::RoundRobinBatch(1); let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap(); - // Note: this should pass (the stream can be created) but the - // error when the input is executed should get passed back - let output_stream = exec.execute(0, task_ctx).unwrap(); - // Expect that an error is returned - let result_string = crate::common::collect(output_stream) - .await - .unwrap_err() - .to_string(); + let result_string = exec.execute(0, task_ctx).err().unwrap().to_string(); + assert!( result_string.contains("ErrorExec, unsurprisingly, errored in partition 0"), "actual: {result_string}" @@ -1468,7 +1549,14 @@ mod tests { }); let batches_with_drop = crate::common::collect(output_stream1).await.unwrap(); - assert_eq!(batches_without_drop, batches_with_drop); + fn sort(batch: Vec) -> Vec { + batch + .into_iter() + .sorted_by_key(|b| format!("{b:?}")) + .collect() + } + + assert_eq!(sort(batches_without_drop), sort(batches_with_drop)); } fn str_batches_to_vec(batches: &[RecordBatch]) -> Vec<&str> { diff --git a/datafusion/physical-plan/src/sorts/partial_sort.rs b/datafusion/physical-plan/src/sorts/partial_sort.rs index 320fa21c8665..78f898a2d77a 100644 --- a/datafusion/physical-plan/src/sorts/partial_sort.rs +++ b/datafusion/physical-plan/src/sorts/partial_sort.rs @@ -296,10 +296,7 @@ impl ExecutionPlan for PartialSortExec { let input = self.input.execute(partition, Arc::clone(&context))?; - trace!( - "End PartialSortExec's input.execute for partition: {}", - partition - ); + trace!("End PartialSortExec's input.execute for partition: {partition}"); // Make sure common prefix length is larger than 0 // Otherwise, we should use SortExec. @@ -321,7 +318,11 @@ impl ExecutionPlan for PartialSortExec { } fn statistics(&self) -> Result { - self.input.statistics() + self.input.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + self.input.partition_statistics(partition) } } diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index 8c0c6a7e8ea9..683983d9e697 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -44,15 +44,10 @@ use crate::{ Statistics, }; -use arrow::array::{ - Array, RecordBatch, RecordBatchOptions, StringViewArray, UInt32Array, -}; -use arrow::compute::{concat_batches, lexsort_to_indices, take_arrays, SortColumn}; -use arrow::datatypes::{DataType, SchemaRef}; -use arrow::row::{RowConverter, Rows, SortField}; -use datafusion_common::{ - exec_datafusion_err, internal_datafusion_err, internal_err, DataFusionError, Result, -}; +use arrow::array::{Array, RecordBatch, RecordBatchOptions, StringViewArray}; +use arrow::compute::{concat_batches, lexsort_to_indices, take_arrays}; +use arrow::datatypes::SchemaRef; +use datafusion_common::{internal_datafusion_err, internal_err, DataFusionError, Result}; use datafusion_execution::disk_manager::RefCountedTempFile; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion_execution::runtime_env::RuntimeEnv; @@ -89,8 +84,9 @@ impl ExternalSorterMetrics { /// 1. get a non-empty new batch from input /// /// 2. check with the memory manager there is sufficient space to -/// buffer the batch in memory 2.1 if memory sufficient, buffer -/// batch in memory, go to 1. +/// buffer the batch in memory. +/// +/// 2.1 if memory is sufficient, buffer batch in memory, go to 1. /// /// 2.2 if no more memory is available, sort all buffered batches and /// spill to file. buffer the next batch in memory, go to 1. @@ -205,8 +201,6 @@ struct ExternalSorter { schema: SchemaRef, /// Sort expressions expr: Arc<[PhysicalSortExpr]>, - /// RowConverter corresponding to the sort expressions - sort_keys_row_converter: Arc, /// The target number of rows for output batches batch_size: usize, /// If the in size of buffered memory batches is below this size, @@ -274,22 +268,6 @@ impl ExternalSorter { MemoryConsumer::new(format!("ExternalSorterMerge[{partition_id}]")) .register(&runtime.memory_pool); - // Construct RowConverter for sort keys - let sort_fields = expr - .iter() - .map(|e| { - let data_type = e - .expr - .data_type(&schema) - .map_err(|e| e.context("Resolving sort expression data type"))?; - Ok(SortField::new_with_options(data_type, e.options)) - }) - .collect::>>()?; - - let converter = RowConverter::new(sort_fields).map_err(|e| { - exec_datafusion_err!("Failed to create RowConverter: {:?}", e) - })?; - let spill_manager = SpillManager::new( Arc::clone(&runtime), metrics.spill_metrics.clone(), @@ -302,7 +280,6 @@ impl ExternalSorter { in_progress_spill_file: None, finished_spill_files: vec![], expr: expr.into(), - sort_keys_row_converter: Arc::new(converter), metrics, reservation, spill_manager, @@ -727,22 +704,10 @@ impl ExternalSorter { let schema = batch.schema(); let expressions: LexOrdering = self.expr.iter().cloned().collect(); - let row_converter = Arc::clone(&self.sort_keys_row_converter); let stream = futures::stream::once(async move { let _timer = metrics.elapsed_compute().timer(); - let sort_columns = expressions - .iter() - .map(|expr| expr.evaluate_to_sort_column(&batch)) - .collect::>>()?; - - let sorted = if is_multi_column_with_lists(&sort_columns) { - // lex_sort_to_indices doesn't support List with more than one column - // https://github.com/apache/arrow-rs/issues/5454 - sort_batch_row_based(&batch, &expressions, row_converter, None)? - } else { - sort_batch(&batch, &expressions, None)? - }; + let sorted = sort_batch(&batch, &expressions, None)?; metrics.record_output(sorted.num_rows()); drop(batch); @@ -833,45 +798,6 @@ impl Debug for ExternalSorter { } } -/// Converts rows into a sorted array of indices based on their order. -/// This function returns the indices that represent the sorted order of the rows. -fn rows_to_indices(rows: Rows, limit: Option) -> Result { - let mut sort: Vec<_> = rows.iter().enumerate().collect(); - sort.sort_unstable_by(|(_, a), (_, b)| a.cmp(b)); - - let mut len = rows.num_rows(); - if let Some(limit) = limit { - len = limit.min(len); - } - let indices = - UInt32Array::from_iter_values(sort.iter().take(len).map(|(i, _)| *i as u32)); - Ok(indices) -} - -/// Sorts a `RecordBatch` by converting its sort columns into Arrow Row Format for faster comparison. -fn sort_batch_row_based( - batch: &RecordBatch, - expressions: &LexOrdering, - row_converter: Arc, - fetch: Option, -) -> Result { - let sort_columns = expressions - .iter() - .map(|expr| expr.evaluate_to_sort_column(batch).map(|col| col.values)) - .collect::>>()?; - let rows = row_converter.convert_columns(&sort_columns)?; - let indices = rows_to_indices(rows, fetch)?; - let columns = take_arrays(batch.columns(), &indices, None)?; - - let options = RecordBatchOptions::new().with_row_count(Some(indices.len())); - - Ok(RecordBatch::try_new_with_options( - batch.schema(), - columns, - &options, - )?) -} - pub fn sort_batch( batch: &RecordBatch, expressions: &LexOrdering, @@ -882,14 +808,7 @@ pub fn sort_batch( .map(|expr| expr.evaluate_to_sort_column(batch)) .collect::>>()?; - let indices = if is_multi_column_with_lists(&sort_columns) { - // lex_sort_to_indices doesn't support List with more than one column - // https://github.com/apache/arrow-rs/issues/5454 - lexsort_to_indices_multi_columns(sort_columns, fetch)? - } else { - lexsort_to_indices(&sort_columns, fetch)? - }; - + let indices = lexsort_to_indices(&sort_columns, fetch)?; let mut columns = take_arrays(batch.columns(), &indices, None)?; // The columns may be larger than the unsorted columns in `batch` especially for variable length @@ -908,50 +827,6 @@ pub fn sort_batch( )?) } -#[inline] -fn is_multi_column_with_lists(sort_columns: &[SortColumn]) -> bool { - sort_columns.iter().any(|c| { - matches!( - c.values.data_type(), - DataType::List(_) | DataType::LargeList(_) | DataType::FixedSizeList(_, _) - ) - }) -} - -pub(crate) fn lexsort_to_indices_multi_columns( - sort_columns: Vec, - limit: Option, -) -> Result { - let (fields, columns) = sort_columns.into_iter().fold( - (vec![], vec![]), - |(mut fields, mut columns), sort_column| { - fields.push(SortField::new_with_options( - sort_column.values.data_type().clone(), - sort_column.options.unwrap_or_default(), - )); - columns.push(sort_column.values); - (fields, columns) - }, - ); - - // Note: row converter is reused through `sort_batch_row_based()`, this function - // is not used during normal sort execution, but it's kept temporarily because - // it's inside a public interface `sort_batch()`. - let converter = RowConverter::new(fields)?; - let rows = converter.convert_columns(&columns)?; - let mut sort: Vec<_> = rows.iter().enumerate().collect(); - sort.sort_unstable_by(|(_, a), (_, b)| a.cmp(b)); - - let mut len = rows.num_rows(); - if let Some(limit) = limit { - len = limit.min(len); - } - let indices = - UInt32Array::from_iter_values(sort.iter().take(len).map(|(i, _)| *i as u32)); - - Ok(indices) -} - /// Sort execution plan. /// /// Support sorting datasets that are larger than the memory allotted @@ -1222,7 +1097,7 @@ impl ExecutionPlan for SortExec { let execution_options = &context.session_config().options().execution; - trace!("End SortExec's input.execute for partition: {}", partition); + trace!("End SortExec's input.execute for partition: {partition}"); let requirement = &LexRequirement::from(self.expr.clone()); @@ -1296,7 +1171,24 @@ impl ExecutionPlan for SortExec { } fn statistics(&self) -> Result { - Statistics::with_fetch(self.input.statistics()?, self.schema(), self.fetch, 0, 1) + self.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + if !self.preserve_partitioning() { + return self.input.partition_statistics(None)?.with_fetch( + self.schema(), + self.fetch, + 0, + 1, + ); + } + self.input.partition_statistics(partition)?.with_fetch( + self.schema(), + self.fetch, + 0, + 1, + ) } fn with_fetch(&self, limit: Option) -> Option> { @@ -1642,15 +1534,13 @@ mod tests { let err = result.unwrap_err(); assert!( matches!(err, DataFusionError::Context(..)), - "Assertion failed: expected a Context error, but got: {:?}", - err + "Assertion failed: expected a Context error, but got: {err:?}" ); // Assert that the context error is wrapping a resources exhausted error. assert!( matches!(err.find_root(), DataFusionError::ResourcesExhausted(_)), - "Assertion failed: expected a ResourcesExhausted error, but got: {:?}", - err + "Assertion failed: expected a ResourcesExhausted error, but got: {err:?}" ); Ok(()) diff --git a/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs b/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs index b987dff36441..6930473360f0 100644 --- a/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs +++ b/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs @@ -267,10 +267,7 @@ impl ExecutionPlan for SortPreservingMergeExec { partition: usize, context: Arc, ) -> Result { - trace!( - "Start SortPreservingMergeExec::execute for partition: {}", - partition - ); + trace!("Start SortPreservingMergeExec::execute for partition: {partition}"); if 0 != partition { return internal_err!( "SortPreservingMergeExec invalid partition {partition}" @@ -279,8 +276,7 @@ impl ExecutionPlan for SortPreservingMergeExec { let input_partitions = self.input.output_partitioning().partition_count(); trace!( - "Number of input partitions of SortPreservingMergeExec::execute: {}", - input_partitions + "Number of input partitions of SortPreservingMergeExec::execute: {input_partitions}" ); let schema = self.schema(); @@ -343,7 +339,11 @@ impl ExecutionPlan for SortPreservingMergeExec { } fn statistics(&self) -> Result { - self.input.statistics() + self.input.partition_statistics(None) + } + + fn partition_statistics(&self, _partition: Option) -> Result { + self.input.partition_statistics(None) } fn supports_limit_pushdown(&self) -> bool { diff --git a/datafusion/physical-plan/src/streaming.rs b/datafusion/physical-plan/src/streaming.rs index 18c472a7e187..6274995d04da 100644 --- a/datafusion/physical-plan/src/streaming.rs +++ b/datafusion/physical-plan/src/streaming.rs @@ -302,7 +302,7 @@ impl ExecutionPlan for StreamingTableExec { let new_projections = new_projections_for_columns( projection, &streaming_table_projections - .unwrap_or((0..self.schema().fields().len()).collect()), + .unwrap_or_else(|| (0..self.schema().fields().len()).collect()), ); let mut lex_orderings = vec![]; diff --git a/datafusion/physical-plan/src/test.rs b/datafusion/physical-plan/src/test.rs index a2dc1d778436..4d5244e0e1d4 100644 --- a/datafusion/physical-plan/src/test.rs +++ b/datafusion/physical-plan/src/test.rs @@ -87,9 +87,7 @@ impl DisplayAs for TestMemoryExec { let output_ordering = self .sort_information .first() - .map(|output_ordering| { - format!(", output_ordering={}", output_ordering) - }) + .map(|output_ordering| format!(", output_ordering={output_ordering}")) .unwrap_or_default(); let eq_properties = self.eq_properties(); @@ -97,12 +95,12 @@ impl DisplayAs for TestMemoryExec { let constraints = if constraints.is_empty() { String::new() } else { - format!(", {}", constraints) + format!(", {constraints}") }; let limit = self .fetch - .map_or(String::new(), |limit| format!(", fetch={}", limit)); + .map_or(String::new(), |limit| format!(", fetch={limit}")); if self.show_sizes { write!( f, @@ -170,7 +168,15 @@ impl ExecutionPlan for TestMemoryExec { } fn statistics(&self) -> Result { - self.statistics() + self.statistics_inner() + } + + fn partition_statistics(&self, partition: Option) -> Result { + if partition.is_some() { + Ok(Statistics::new_unknown(&self.schema)) + } else { + self.statistics_inner() + } } fn fetch(&self) -> Option { @@ -214,7 +220,7 @@ impl TestMemoryExec { ) } - fn statistics(&self) -> Result { + fn statistics_inner(&self) -> Result { Ok(common::compute_record_batch_statistics( &self.partitions, &self.schema, @@ -450,7 +456,7 @@ pub fn make_partition_utf8(sz: i32) -> RecordBatch { let seq_start = 0; let seq_end = sz; let values = (seq_start..seq_end) - .map(|i| format!("test_long_string_that_is_roughly_42_bytes_{}", i)) + .map(|i| format!("test_long_string_that_is_roughly_42_bytes_{i}")) .collect::>(); let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Utf8, true)])); let mut string_array = arrow::array::StringArray::from(values); diff --git a/datafusion/physical-plan/src/test/exec.rs b/datafusion/physical-plan/src/test/exec.rs index d0a0d25779cc..12ffca871f07 100644 --- a/datafusion/physical-plan/src/test/exec.rs +++ b/datafusion/physical-plan/src/test/exec.rs @@ -255,6 +255,13 @@ impl ExecutionPlan for MockExec { // Panics if one of the batches is an error fn statistics(&self) -> Result { + self.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + if partition.is_some() { + return Ok(Statistics::new_unknown(&self.schema)); + } let data: Result> = self .data .iter() @@ -405,6 +412,13 @@ impl ExecutionPlan for BarrierExec { } fn statistics(&self) -> Result { + self.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + if partition.is_some() { + return Ok(Statistics::new_unknown(&self.schema)); + } Ok(common::compute_record_batch_statistics( &self.data, &self.schema, @@ -590,6 +604,14 @@ impl ExecutionPlan for StatisticsExec { fn statistics(&self) -> Result { Ok(self.stats.clone()) } + + fn partition_statistics(&self, partition: Option) -> Result { + Ok(if partition.is_some() { + Statistics::new_unknown(&self.schema) + } else { + self.stats.clone() + }) + } } /// Execution plan that emits streams that block forever. diff --git a/datafusion/physical-plan/src/tree_node.rs b/datafusion/physical-plan/src/tree_node.rs index 69b0a165315e..78ba984ed1a5 100644 --- a/datafusion/physical-plan/src/tree_node.rs +++ b/datafusion/physical-plan/src/tree_node.rs @@ -94,7 +94,7 @@ impl PlanContext { impl Display for PlanContext { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { let node_string = displayable(self.plan.as_ref()).one_line(); - write!(f, "Node plan: {}", node_string)?; + write!(f, "Node plan: {node_string}")?; write!(f, "Node data: {}", self.data)?; write!(f, "") } diff --git a/datafusion/physical-plan/src/union.rs b/datafusion/physical-plan/src/union.rs index 2b666093f29e..930fe793d1d4 100644 --- a/datafusion/physical-plan/src/union.rs +++ b/datafusion/physical-plan/src/union.rs @@ -248,7 +248,7 @@ impl ExecutionPlan for UnionExec { } } - warn!("Error in Union: Partition {} not found", partition); + warn!("Error in Union: Partition {partition} not found"); exec_err!("Partition {partition} not found in Union") } @@ -258,16 +258,36 @@ impl ExecutionPlan for UnionExec { } fn statistics(&self) -> Result { - let stats = self - .inputs - .iter() - .map(|stat| stat.statistics()) - .collect::>>()?; + self.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + if let Some(partition_idx) = partition { + // For a specific partition, find which input it belongs to + let mut remaining_idx = partition_idx; + for input in &self.inputs { + let input_partition_count = input.output_partitioning().partition_count(); + if remaining_idx < input_partition_count { + // This partition belongs to this input + return input.partition_statistics(Some(remaining_idx)); + } + remaining_idx -= input_partition_count; + } + // If we get here, the partition index is out of bounds + Ok(Statistics::new_unknown(&self.schema())) + } else { + // Collect statistics from all inputs + let stats = self + .inputs + .iter() + .map(|input_exec| input_exec.partition_statistics(None)) + .collect::>>()?; - Ok(stats - .into_iter() - .reduce(stats_union) - .unwrap_or_else(|| Statistics::new_unknown(&self.schema()))) + Ok(stats + .into_iter() + .reduce(stats_union) + .unwrap_or_else(|| Statistics::new_unknown(&self.schema()))) + } } fn benefits_from_input_partitioning(&self) -> Vec { @@ -461,7 +481,7 @@ impl ExecutionPlan for InterleaveExec { ))); } - warn!("Error in InterleaveExec: Partition {} not found", partition); + warn!("Error in InterleaveExec: Partition {partition} not found"); exec_err!("Partition {partition} not found in InterleaveExec") } @@ -471,10 +491,17 @@ impl ExecutionPlan for InterleaveExec { } fn statistics(&self) -> Result { + self.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + if partition.is_some() { + return Ok(Statistics::new_unknown(&self.schema())); + } let stats = self .inputs .iter() - .map(|stat| stat.statistics()) + .map(|stat| stat.partition_statistics(None)) .collect::>>()?; Ok(stats @@ -513,7 +540,12 @@ fn union_schema(inputs: &[Arc]) -> SchemaRef { let fields = (0..first_schema.fields().len()) .map(|i| { - inputs + // We take the name from the left side of the union to match how names are coerced during logical planning, + // which also uses the left side names. + let base_field = first_schema.field(i).clone(); + + // Coerce metadata and nullability across all inputs + let merged_field = inputs .iter() .enumerate() .map(|(input_idx, input)| { @@ -535,6 +567,9 @@ fn union_schema(inputs: &[Arc]) -> SchemaRef { // We can unwrap this because if inputs was empty, this would've already panic'ed when we // indexed into inputs[0]. .unwrap() + .with_name(base_field.name()); + + merged_field }) .collect::>(); @@ -897,7 +932,7 @@ mod tests { // Check whether orderings are same. let lhs_orderings = lhs.oeq_class(); let rhs_orderings = rhs.oeq_class(); - assert_eq!(lhs_orderings.len(), rhs_orderings.len(), "{}", err_msg); + assert_eq!(lhs_orderings.len(), rhs_orderings.len(), "{err_msg}"); for rhs_ordering in rhs_orderings.iter() { assert!(lhs_orderings.contains(rhs_ordering), "{}", err_msg); } diff --git a/datafusion/physical-plan/src/values.rs b/datafusion/physical-plan/src/values.rs index 6cb64bcb5d86..fb27ccf30179 100644 --- a/datafusion/physical-plan/src/values.rs +++ b/datafusion/physical-plan/src/values.rs @@ -308,8 +308,10 @@ mod tests { data, )?; + #[allow(deprecated)] + let stats = values.statistics()?; assert_eq!( - values.statistics()?, + stats, Statistics { num_rows: Precision::Exact(rows), total_byte_size: Precision::Exact(8), // not important diff --git a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs index 92138bf6a7a1..6751f9b20240 100644 --- a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs +++ b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs @@ -226,6 +226,23 @@ impl BoundedWindowAggExec { .unwrap_or_else(Vec::new) } } + + fn statistics_helper(&self, statistics: Statistics) -> Result { + let win_cols = self.window_expr.len(); + let input_cols = self.input.schema().fields().len(); + // TODO stats: some windowing function will maintain invariants such as min, max... + let mut column_statistics = Vec::with_capacity(win_cols + input_cols); + // copy stats of the input to the beginning of the schema. + column_statistics.extend(statistics.column_statistics); + for _ in 0..win_cols { + column_statistics.push(ColumnStatistics::new_unknown()) + } + Ok(Statistics { + num_rows: statistics.num_rows, + column_statistics, + total_byte_size: Precision::Absent, + }) + } } impl DisplayAs for BoundedWindowAggExec { @@ -261,7 +278,7 @@ impl DisplayAs for BoundedWindowAggExec { writeln!(f, "select_list={}", g.join(", "))?; let mode = &self.input_order_mode; - writeln!(f, "mode={:?}", mode)?; + writeln!(f, "mode={mode:?}")?; } } Ok(()) @@ -343,21 +360,12 @@ impl ExecutionPlan for BoundedWindowAggExec { } fn statistics(&self) -> Result { - let input_stat = self.input.statistics()?; - let win_cols = self.window_expr.len(); - let input_cols = self.input.schema().fields().len(); - // TODO stats: some windowing function will maintain invariants such as min, max... - let mut column_statistics = Vec::with_capacity(win_cols + input_cols); - // copy stats of the input to the beginning of the schema. - column_statistics.extend(input_stat.column_statistics); - for _ in 0..win_cols { - column_statistics.push(ColumnStatistics::new_unknown()) - } - Ok(Statistics { - num_rows: input_stat.num_rows, - column_statistics, - total_byte_size: Precision::Absent, - }) + self.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + let input_stat = self.input.partition_statistics(partition)?; + self.statistics_helper(input_stat) } } @@ -1349,8 +1357,7 @@ mod tests { WindowFrameBound::Following(ScalarValue::UInt64(Some(n_future_range as u64))), ); let fn_name = format!( - "{}({:?}) PARTITION BY: [{:?}], ORDER BY: [{:?}]", - window_fn, args, partitionby_exprs, orderby_exprs + "{window_fn}({args:?}) PARTITION BY: [{partitionby_exprs:?}], ORDER BY: [{orderby_exprs:?}]" ); let input_order_mode = InputOrderMode::Linear; Ok(Arc::new(BoundedWindowAggExec::try_new( diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index d38bf2a186a8..d2b7e0a49e95 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -30,8 +30,8 @@ use crate::{ InputOrderMode, PhysicalExpr, }; -use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; -use arrow_schema::SortOptions; +use arrow::datatypes::{Schema, SchemaRef}; +use arrow_schema::{FieldRef, SortOptions}; use datafusion_common::{exec_err, Result}; use datafusion_expr::{ PartitionEvaluator, ReversedUDWF, SetMonotonicity, WindowFrame, @@ -65,16 +65,16 @@ pub fn schema_add_window_field( window_fn: &WindowFunctionDefinition, fn_name: &str, ) -> Result> { - let data_types = args + let fields = args .iter() - .map(|e| Arc::clone(e).as_ref().data_type(schema)) + .map(|e| Arc::clone(e).as_ref().return_field(schema)) .collect::>>()?; let nullability = args .iter() .map(|e| Arc::clone(e).as_ref().nullable(schema)) .collect::>>()?; - let window_expr_return_type = - window_fn.return_type(&data_types, &nullability, fn_name)?; + let window_expr_return_field = + window_fn.return_field(&fields, &nullability, fn_name)?; let mut window_fields = schema .fields() .iter() @@ -84,11 +84,10 @@ pub fn schema_add_window_field( if let WindowFunctionDefinition::AggregateUDF(_) = window_fn { Ok(Arc::new(Schema::new(window_fields))) } else { - window_fields.extend_from_slice(&[Field::new( - fn_name, - window_expr_return_type, - false, - )]); + window_fields.extend_from_slice(&[window_expr_return_field + .as_ref() + .clone() + .with_name(fn_name)]); Ok(Arc::new(Schema::new(window_fields))) } } @@ -165,15 +164,15 @@ pub fn create_udwf_window_expr( ignore_nulls: bool, ) -> Result> { // need to get the types into an owned vec for some reason - let input_types: Vec<_> = args + let input_fields: Vec<_> = args .iter() - .map(|arg| arg.data_type(input_schema)) + .map(|arg| arg.return_field(input_schema)) .collect::>()?; let udwf_expr = Arc::new(WindowUDFExpr { fun: Arc::clone(fun), args: args.to_vec(), - input_types, + input_fields, name, is_reversed: false, ignore_nulls, @@ -202,8 +201,8 @@ pub struct WindowUDFExpr { args: Vec>, /// Display name name: String, - /// Types of input expressions - input_types: Vec, + /// Fields of input expressions + input_fields: Vec, /// This is set to `true` only if the user-defined window function /// expression supports evaluation in reverse order, and the /// evaluation order is reversed. @@ -223,21 +222,21 @@ impl StandardWindowFunctionExpr for WindowUDFExpr { self } - fn field(&self) -> Result { + fn field(&self) -> Result { self.fun - .field(WindowUDFFieldArgs::new(&self.input_types, &self.name)) + .field(WindowUDFFieldArgs::new(&self.input_fields, &self.name)) } fn expressions(&self) -> Vec> { self.fun - .expressions(ExpressionArgs::new(&self.args, &self.input_types)) + .expressions(ExpressionArgs::new(&self.args, &self.input_fields)) } fn create_evaluator(&self) -> Result> { self.fun .partition_evaluator_factory(PartitionEvaluatorArgs::new( &self.args, - &self.input_types, + &self.input_fields, self.is_reversed, self.ignore_nulls, )) @@ -255,7 +254,7 @@ impl StandardWindowFunctionExpr for WindowUDFExpr { fun, args: self.args.clone(), name: self.name.clone(), - input_types: self.input_types.clone(), + input_fields: self.input_fields.clone(), is_reversed: !self.is_reversed, ignore_nulls: self.ignore_nulls, })), @@ -641,6 +640,7 @@ mod tests { use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec}; use arrow::compute::SortOptions; + use arrow_schema::{DataType, Field}; use datafusion_execution::TaskContext; use datafusion_functions_aggregate::count::count_udaf; diff --git a/datafusion/physical-plan/src/windows/window_agg_exec.rs b/datafusion/physical-plan/src/windows/window_agg_exec.rs index 3c42d3032ed5..4c76e2230875 100644 --- a/datafusion/physical-plan/src/windows/window_agg_exec.rs +++ b/datafusion/physical-plan/src/windows/window_agg_exec.rs @@ -156,6 +156,24 @@ impl WindowAggExec { .unwrap_or_else(Vec::new) } } + + fn statistics_inner(&self) -> Result { + let input_stat = self.input.partition_statistics(None)?; + let win_cols = self.window_expr.len(); + let input_cols = self.input.schema().fields().len(); + // TODO stats: some windowing function will maintain invariants such as min, max... + let mut column_statistics = Vec::with_capacity(win_cols + input_cols); + // copy stats of the input to the beginning of the schema. + column_statistics.extend(input_stat.column_statistics); + for _ in 0..win_cols { + column_statistics.push(ColumnStatistics::new_unknown()) + } + Ok(Statistics { + num_rows: input_stat.num_rows, + column_statistics, + total_byte_size: Precision::Absent, + }) + } } impl DisplayAs for WindowAggExec { @@ -271,21 +289,15 @@ impl ExecutionPlan for WindowAggExec { } fn statistics(&self) -> Result { - let input_stat = self.input.statistics()?; - let win_cols = self.window_expr.len(); - let input_cols = self.input.schema().fields().len(); - // TODO stats: some windowing function will maintain invariants such as min, max... - let mut column_statistics = Vec::with_capacity(win_cols + input_cols); - // copy stats of the input to the beginning of the schema. - column_statistics.extend(input_stat.column_statistics); - for _ in 0..win_cols { - column_statistics.push(ColumnStatistics::new_unknown()) + self.statistics_inner() + } + + fn partition_statistics(&self, partition: Option) -> Result { + if partition.is_none() { + self.statistics_inner() + } else { + Ok(Statistics::new_unknown(&self.schema())) } - Ok(Statistics { - num_rows: input_stat.num_rows, - column_statistics, - total_byte_size: Precision::Absent, - }) } } diff --git a/datafusion/physical-plan/src/work_table.rs b/datafusion/physical-plan/src/work_table.rs index 126a7d0bba29..eea1b9958633 100644 --- a/datafusion/physical-plan/src/work_table.rs +++ b/datafusion/physical-plan/src/work_table.rs @@ -227,6 +227,10 @@ impl ExecutionPlan for WorkTableExec { fn statistics(&self) -> Result { Ok(Statistics::new_unknown(&self.schema())) } + + fn partition_statistics(&self, _partition: Option) -> Result { + Ok(Statistics::new_unknown(&self.schema())) + } } #[cfg(test)] diff --git a/datafusion/proto-common/proto/datafusion_common.proto b/datafusion/proto-common/proto/datafusion_common.proto index 82f1e91d9c9b..35f41155fa05 100644 --- a/datafusion/proto-common/proto/datafusion_common.proto +++ b/datafusion/proto-common/proto/datafusion_common.proto @@ -108,7 +108,6 @@ message Field { // for complex data types like structs, unions repeated Field children = 4; map metadata = 5; - bool dict_ordered = 6; } message Timestamp{ diff --git a/datafusion/proto-common/src/common.rs b/datafusion/proto-common/src/common.rs index 61711dcf8e08..9af63e3b0736 100644 --- a/datafusion/proto-common/src/common.rs +++ b/datafusion/proto-common/src/common.rs @@ -17,6 +17,7 @@ use datafusion_common::{internal_datafusion_err, DataFusionError}; +/// Return a `DataFusionError::Internal` with the given message pub fn proto_error>(message: S) -> DataFusionError { internal_datafusion_err!("{}", message.into()) } diff --git a/datafusion/proto-common/src/generated/pbjson.rs b/datafusion/proto-common/src/generated/pbjson.rs index b44b05e9ca29..1ac35742c73a 100644 --- a/datafusion/proto-common/src/generated/pbjson.rs +++ b/datafusion/proto-common/src/generated/pbjson.rs @@ -3107,9 +3107,6 @@ impl serde::Serialize for Field { if !self.metadata.is_empty() { len += 1; } - if self.dict_ordered { - len += 1; - } let mut struct_ser = serializer.serialize_struct("datafusion_common.Field", len)?; if !self.name.is_empty() { struct_ser.serialize_field("name", &self.name)?; @@ -3126,9 +3123,6 @@ impl serde::Serialize for Field { if !self.metadata.is_empty() { struct_ser.serialize_field("metadata", &self.metadata)?; } - if self.dict_ordered { - struct_ser.serialize_field("dictOrdered", &self.dict_ordered)?; - } struct_ser.end() } } @@ -3145,8 +3139,6 @@ impl<'de> serde::Deserialize<'de> for Field { "nullable", "children", "metadata", - "dict_ordered", - "dictOrdered", ]; #[allow(clippy::enum_variant_names)] @@ -3156,7 +3148,6 @@ impl<'de> serde::Deserialize<'de> for Field { Nullable, Children, Metadata, - DictOrdered, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -3183,7 +3174,6 @@ impl<'de> serde::Deserialize<'de> for Field { "nullable" => Ok(GeneratedField::Nullable), "children" => Ok(GeneratedField::Children), "metadata" => Ok(GeneratedField::Metadata), - "dictOrdered" | "dict_ordered" => Ok(GeneratedField::DictOrdered), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -3208,7 +3198,6 @@ impl<'de> serde::Deserialize<'de> for Field { let mut nullable__ = None; let mut children__ = None; let mut metadata__ = None; - let mut dict_ordered__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Name => { @@ -3243,12 +3232,6 @@ impl<'de> serde::Deserialize<'de> for Field { map_.next_value::>()? ); } - GeneratedField::DictOrdered => { - if dict_ordered__.is_some() { - return Err(serde::de::Error::duplicate_field("dictOrdered")); - } - dict_ordered__ = Some(map_.next_value()?); - } } } Ok(Field { @@ -3257,7 +3240,6 @@ impl<'de> serde::Deserialize<'de> for Field { nullable: nullable__.unwrap_or_default(), children: children__.unwrap_or_default(), metadata: metadata__.unwrap_or_default(), - dict_ordered: dict_ordered__.unwrap_or_default(), }) } } diff --git a/datafusion/proto-common/src/generated/prost.rs b/datafusion/proto-common/src/generated/prost.rs index e029327d481d..a55714f190c5 100644 --- a/datafusion/proto-common/src/generated/prost.rs +++ b/datafusion/proto-common/src/generated/prost.rs @@ -106,8 +106,6 @@ pub struct Field { ::prost::alloc::string::String, ::prost::alloc::string::String, >, - #[prost(bool, tag = "6")] - pub dict_ordered: bool, } #[derive(Clone, PartialEq, ::prost::Message)] pub struct Timestamp { diff --git a/datafusion/proto-common/src/to_proto/mod.rs b/datafusion/proto-common/src/to_proto/mod.rs index 28927cad03b4..b6cbe5759cfc 100644 --- a/datafusion/proto-common/src/to_proto/mod.rs +++ b/datafusion/proto-common/src/to_proto/mod.rs @@ -97,7 +97,6 @@ impl TryFrom<&Field> for protobuf::Field { nullable: field.is_nullable(), children: Vec::new(), metadata: field.metadata().clone(), - dict_ordered: field.dict_is_ordered().unwrap_or(false), }) } } diff --git a/datafusion/proto/Cargo.toml b/datafusion/proto/Cargo.toml index 92e697ad2d9c..a1eeabdf87f4 100644 --- a/datafusion/proto/Cargo.toml +++ b/datafusion/proto/Cargo.toml @@ -60,5 +60,4 @@ datafusion-functions = { workspace = true, default-features = true } datafusion-functions-aggregate = { workspace = true } datafusion-functions-window-common = { workspace = true } doc-comment = { workspace = true } -strum = { version = "0.27.1", features = ["derive"] } tokio = { workspace = true, features = ["rt-multi-thread"] } diff --git a/datafusion/proto/README.md b/datafusion/proto/README.md index f51e4664d5d9..f8930779db89 100644 --- a/datafusion/proto/README.md +++ b/datafusion/proto/README.md @@ -19,11 +19,11 @@ # `datafusion-proto`: Apache DataFusion Protobuf Serialization / Deserialization -This crate contains code to convert Apache [DataFusion] plans to and from +This crate contains code to convert [Apache DataFusion] plans to and from bytes, which can be useful for sending plans over the network, for example when building a distributed query engine. See [API Docs] for details and examples. -[datafusion]: https://datafusion.apache.org +[apache datafusion]: https://datafusion.apache.org [api docs]: http://docs.rs/datafusion-proto/latest diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 908b95ab56a4..4c8b6c588d94 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -90,7 +90,7 @@ message ListingTableScanNode { ProjectionColumns projection = 4; datafusion_common.Schema schema = 5; repeated LogicalExprNode filters = 6; - repeated string table_partition_cols = 7; + repeated PartitionColumn table_partition_cols = 7; bool collect_stat = 8; uint32 target_partitions = 9; oneof FileFormatType { @@ -1217,6 +1217,7 @@ message CoalesceBatchesExecNode { message CoalescePartitionsExecNode { PhysicalPlanNode input = 1; + optional uint32 fetch = 2; } message PhysicalHashRepartition { diff --git a/datafusion/proto/src/generated/datafusion_proto_common.rs b/datafusion/proto/src/generated/datafusion_proto_common.rs index e029327d481d..a55714f190c5 100644 --- a/datafusion/proto/src/generated/datafusion_proto_common.rs +++ b/datafusion/proto/src/generated/datafusion_proto_common.rs @@ -106,8 +106,6 @@ pub struct Field { ::prost::alloc::string::String, ::prost::alloc::string::String, >, - #[prost(bool, tag = "6")] - pub dict_ordered: bool, } #[derive(Clone, PartialEq, ::prost::Message)] pub struct Timestamp { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 6166b6ec4796..932422944508 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -2050,10 +2050,16 @@ impl serde::Serialize for CoalescePartitionsExecNode { if self.input.is_some() { len += 1; } + if self.fetch.is_some() { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.CoalescePartitionsExecNode", len)?; if let Some(v) = self.input.as_ref() { struct_ser.serialize_field("input", v)?; } + if let Some(v) = self.fetch.as_ref() { + struct_ser.serialize_field("fetch", v)?; + } struct_ser.end() } } @@ -2065,11 +2071,13 @@ impl<'de> serde::Deserialize<'de> for CoalescePartitionsExecNode { { const FIELDS: &[&str] = &[ "input", + "fetch", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { Input, + Fetch, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -2092,6 +2100,7 @@ impl<'de> serde::Deserialize<'de> for CoalescePartitionsExecNode { { match value { "input" => Ok(GeneratedField::Input), + "fetch" => Ok(GeneratedField::Fetch), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -2112,6 +2121,7 @@ impl<'de> serde::Deserialize<'de> for CoalescePartitionsExecNode { V: serde::de::MapAccess<'de>, { let mut input__ = None; + let mut fetch__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Input => { @@ -2120,10 +2130,19 @@ impl<'de> serde::Deserialize<'de> for CoalescePartitionsExecNode { } input__ = map_.next_value()?; } + GeneratedField::Fetch => { + if fetch__.is_some() { + return Err(serde::de::Error::duplicate_field("fetch")); + } + fetch__ = + map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| x.0) + ; + } } } Ok(CoalescePartitionsExecNode { input: input__, + fetch: fetch__, }) } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index d2165dad4850..c2f4e93cef6a 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -115,8 +115,8 @@ pub struct ListingTableScanNode { pub schema: ::core::option::Option, #[prost(message, repeated, tag = "6")] pub filters: ::prost::alloc::vec::Vec, - #[prost(string, repeated, tag = "7")] - pub table_partition_cols: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, + #[prost(message, repeated, tag = "7")] + pub table_partition_cols: ::prost::alloc::vec::Vec, #[prost(bool, tag = "8")] pub collect_stat: bool, #[prost(uint32, tag = "9")] @@ -1824,6 +1824,8 @@ pub struct CoalesceBatchesExecNode { pub struct CoalescePartitionsExecNode { #[prost(message, optional, boxed, tag = "1")] pub input: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(uint32, optional, tag = "2")] + pub fetch: ::core::option::Option, } #[derive(Clone, PartialEq, ::prost::Message)] pub struct PhysicalHashRepartition { diff --git a/datafusion/proto/src/logical_plan/file_formats.rs b/datafusion/proto/src/logical_plan/file_formats.rs index 5c33277dc9f7..d3f6511ec98f 100644 --- a/datafusion/proto/src/logical_plan/file_formats.rs +++ b/datafusion/proto/src/logical_plan/file_formats.rs @@ -205,10 +205,7 @@ impl LogicalExtensionCodec for CsvLogicalExtensionCodec { _ctx: &SessionContext, ) -> datafusion_common::Result> { let proto = CsvOptionsProto::decode(buf).map_err(|e| { - DataFusionError::Execution(format!( - "Failed to decode CsvOptionsProto: {:?}", - e - )) + DataFusionError::Execution(format!("Failed to decode CsvOptionsProto: {e:?}")) })?; let options: CsvOptions = (&proto).into(); Ok(Arc::new(CsvFormatFactory { @@ -233,7 +230,7 @@ impl LogicalExtensionCodec for CsvLogicalExtensionCodec { }); proto.encode(buf).map_err(|e| { - DataFusionError::Execution(format!("Failed to encode CsvOptions: {:?}", e)) + DataFusionError::Execution(format!("Failed to encode CsvOptions: {e:?}")) })?; Ok(()) @@ -316,8 +313,7 @@ impl LogicalExtensionCodec for JsonLogicalExtensionCodec { ) -> datafusion_common::Result> { let proto = JsonOptionsProto::decode(buf).map_err(|e| { DataFusionError::Execution(format!( - "Failed to decode JsonOptionsProto: {:?}", - e + "Failed to decode JsonOptionsProto: {e:?}" )) })?; let options: JsonOptions = (&proto).into(); @@ -346,7 +342,7 @@ impl LogicalExtensionCodec for JsonLogicalExtensionCodec { }); proto.encode(buf).map_err(|e| { - DataFusionError::Execution(format!("Failed to encode JsonOptions: {:?}", e)) + DataFusionError::Execution(format!("Failed to encode JsonOptions: {e:?}")) })?; Ok(()) @@ -632,8 +628,7 @@ impl LogicalExtensionCodec for ParquetLogicalExtensionCodec { ) -> datafusion_common::Result> { let proto = TableParquetOptionsProto::decode(buf).map_err(|e| { DataFusionError::Execution(format!( - "Failed to decode TableParquetOptionsProto: {:?}", - e + "Failed to decode TableParquetOptionsProto: {e:?}" )) })?; let options: TableParquetOptions = (&proto).into(); @@ -663,8 +658,7 @@ impl LogicalExtensionCodec for ParquetLogicalExtensionCodec { proto.encode(buf).map_err(|e| { DataFusionError::Execution(format!( - "Failed to encode TableParquetOptionsProto: {:?}", - e + "Failed to encode TableParquetOptionsProto: {e:?}" )) })?; diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index cac2f9db1645..1b5527c14a49 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -268,7 +268,7 @@ pub fn parse_expr( ExprType::Column(column) => Ok(Expr::Column(column.into())), ExprType::Literal(literal) => { let scalar_value: ScalarValue = literal.try_into()?; - Ok(Expr::Literal(scalar_value)) + Ok(Expr::Literal(scalar_value, None)) } ExprType::WindowExpr(expr) => { let window_function = expr @@ -296,11 +296,13 @@ pub fn parse_expr( window_expr_node::WindowFunction::Udaf(udaf_name) => { let udaf_function = match &expr.fun_definition { Some(buf) => codec.try_decode_udaf(udaf_name, buf)?, - None => registry.udaf(udaf_name)?, + None => registry + .udaf(udaf_name) + .or_else(|_| codec.try_decode_udaf(udaf_name, &[]))?, }; let args = parse_exprs(&expr.exprs, registry, codec)?; - Expr::WindowFunction(WindowFunction::new( + Expr::from(WindowFunction::new( expr::WindowFunctionDefinition::AggregateUDF(udaf_function), args, )) @@ -313,11 +315,13 @@ pub fn parse_expr( window_expr_node::WindowFunction::Udwf(udwf_name) => { let udwf_function = match &expr.fun_definition { Some(buf) => codec.try_decode_udwf(udwf_name, buf)?, - None => registry.udwf(udwf_name)?, + None => registry + .udwf(udwf_name) + .or_else(|_| codec.try_decode_udwf(udwf_name, &[]))?, }; let args = parse_exprs(&expr.exprs, registry, codec)?; - Expr::WindowFunction(WindowFunction::new( + Expr::from(WindowFunction::new( expr::WindowFunctionDefinition::WindowUDF(udwf_function), args, )) @@ -540,7 +544,9 @@ pub fn parse_expr( }) => { let scalar_fn = match fun_definition { Some(buf) => codec.try_decode_udf(fun_name, buf)?, - None => registry.udf(fun_name.as_str())?, + None => registry + .udf(fun_name.as_str()) + .or_else(|_| codec.try_decode_udf(fun_name, &[]))?, }; Ok(Expr::ScalarFunction(expr::ScalarFunction::new_udf( scalar_fn, @@ -550,7 +556,9 @@ pub fn parse_expr( ExprType::AggregateUdfExpr(pb) => { let agg_fn = match &pb.fun_definition { Some(buf) => codec.try_decode_udaf(&pb.fun_name, buf)?, - None => registry.udaf(&pb.fun_name)?, + None => registry + .udaf(&pb.fun_name) + .or_else(|_| codec.try_decode_udaf(&pb.fun_name, &[]))?, }; Ok(Expr::AggregateFunction(expr::AggregateFunction::new_udf( diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index c65569ef1cfb..d934b24dc341 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -33,7 +33,7 @@ use crate::{ }; use crate::protobuf::{proto_error, ToProtoError}; -use arrow::datatypes::{DataType, Schema, SchemaRef}; +use arrow::datatypes::{DataType, Schema, SchemaBuilder, SchemaRef}; use datafusion::datasource::cte_worktable::CteWorkTable; #[cfg(feature = "avro")] use datafusion::datasource::file_format::avro::AvroFormat; @@ -355,10 +355,7 @@ impl AsLogicalPlan for LogicalPlanNode { .as_ref() .map(|expr| from_proto::parse_expr(expr, ctx, extension_codec)) .transpose()? - .ok_or_else(|| { - DataFusionError::Internal("expression required".to_string()) - })?; - // .try_into()?; + .ok_or_else(|| proto_error("expression required"))?; LogicalPlanBuilder::from(input).filter(expr)?.build() } LogicalPlanType::Window(window) => { @@ -458,23 +455,25 @@ impl AsLogicalPlan for LogicalPlanNode { .map(ListingTableUrl::parse) .collect::, _>>()?; + let partition_columns = scan + .table_partition_cols + .iter() + .map(|col| { + let Some(arrow_type) = col.arrow_type.as_ref() else { + return Err(proto_error( + "Missing Arrow type in partition columns", + )); + }; + let arrow_type = DataType::try_from(arrow_type).map_err(|e| { + proto_error(format!("Received an unknown ArrowType: {e}")) + })?; + Ok((col.name.clone(), arrow_type)) + }) + .collect::>>()?; + let options = ListingOptions::new(file_format) .with_file_extension(&scan.file_extension) - .with_table_partition_cols( - scan.table_partition_cols - .iter() - .map(|col| { - ( - col.clone(), - schema - .field_with_name(col) - .unwrap() - .data_type() - .clone(), - ) - }) - .collect(), - ) + .with_table_partition_cols(partition_columns) .with_collect_stat(scan.collect_stat) .with_target_partitions(scan.target_partitions as usize) .with_file_sort_order(all_sort_orders); @@ -1046,7 +1045,6 @@ impl AsLogicalPlan for LogicalPlanNode { }) } }; - let schema: protobuf::Schema = schema.as_ref().try_into()?; let filters: Vec = serialize_exprs(filters, extension_codec)?; @@ -1099,6 +1097,21 @@ impl AsLogicalPlan for LogicalPlanNode { let options = listing_table.options(); + let mut builder = SchemaBuilder::from(schema.as_ref()); + for (idx, field) in schema.fields().iter().enumerate().rev() { + if options + .table_partition_cols + .iter() + .any(|(name, _)| name == field.name()) + { + builder.remove(idx); + } + } + + let schema = builder.finish(); + + let schema: protobuf::Schema = (&schema).try_into()?; + let mut exprs_vec: Vec = vec![]; for order in &options.file_sort_order { let expr_vec = SortExprNodeCollection { @@ -1107,6 +1120,23 @@ impl AsLogicalPlan for LogicalPlanNode { exprs_vec.push(expr_vec); } + let partition_columns = options + .table_partition_cols + .iter() + .map(|(name, arrow_type)| { + let arrow_type = protobuf::ArrowType::try_from(arrow_type) + .map_err(|e| { + proto_error(format!( + "Received an unknown ArrowType: {e}" + )) + })?; + Ok(protobuf::PartitionColumn { + name: name.clone(), + arrow_type: Some(arrow_type), + }) + }) + .collect::>>()?; + Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::ListingScan( protobuf::ListingTableScanNode { @@ -1114,11 +1144,7 @@ impl AsLogicalPlan for LogicalPlanNode { table_name: Some(table_name.clone().into()), collect_stat: options.collect_stat, file_extension: options.file_extension.clone(), - table_partition_cols: options - .table_partition_cols - .iter() - .map(|x| x.0.clone()) - .collect::>(), + table_partition_cols: partition_columns, paths: listing_table .table_paths() .iter() @@ -1133,6 +1159,7 @@ impl AsLogicalPlan for LogicalPlanNode { )), }) } else if let Some(view_table) = source.downcast_ref::() { + let schema: protobuf::Schema = schema.as_ref().try_into()?; Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::ViewScan(Box::new( protobuf::ViewTableScanNode { @@ -1167,6 +1194,7 @@ impl AsLogicalPlan for LogicalPlanNode { )), }) } else { + let schema: protobuf::Schema = schema.as_ref().try_into()?; let mut bytes = vec![]; extension_codec .try_encode_table_provider(table_name, provider, &mut bytes) diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 841c31fa035f..7f089b1c8467 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -217,7 +217,7 @@ pub fn serialize_expr( expr_type: Some(ExprType::Alias(alias)), } } - Expr::Literal(value) => { + Expr::Literal(value, _) => { let pb_value: protobuf::ScalarValue = value.try_into()?; protobuf::LogicalExprNode { expr_type: Some(ExprType::Literal(pb_value)), @@ -302,18 +302,19 @@ pub fn serialize_expr( expr_type: Some(ExprType::SimilarTo(pb)), } } - Expr::WindowFunction(expr::WindowFunction { - ref fun, - params: - expr::WindowFunctionParams { - ref args, - ref partition_by, - ref order_by, - ref window_frame, - // TODO: support null treatment in proto - null_treatment: _, - }, - }) => { + Expr::WindowFunction(window_fun) => { + let expr::WindowFunction { + ref fun, + params: + expr::WindowFunctionParams { + ref args, + ref partition_by, + ref order_by, + ref window_frame, + // TODO: support null treatment in proto + null_treatment: _, + }, + } = window_fun.as_ref(); let (window_function, fun_definition) = match fun { WindowFunctionDefinition::AggregateUDF(aggr_udf) => { let mut buf = Vec::new(); diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index a886fc242545..5024bb558a65 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -20,6 +20,7 @@ use std::sync::Arc; use arrow::compute::SortOptions; +use arrow::datatypes::Field; use chrono::{TimeZone, Utc}; use datafusion_expr::dml::InsertOp; use object_store::path::Path; @@ -151,13 +152,13 @@ pub fn parse_physical_window_expr( protobuf::physical_window_expr_node::WindowFunction::UserDefinedAggrFunction(udaf_name) => { WindowFunctionDefinition::AggregateUDF(match &proto.fun_definition { Some(buf) => codec.try_decode_udaf(udaf_name, buf)?, - None => registry.udaf(udaf_name)? + None => registry.udaf(udaf_name).or_else(|_| codec.try_decode_udaf(udaf_name, &[]))?, }) } protobuf::physical_window_expr_node::WindowFunction::UserDefinedWindowFunction(udwf_name) => { WindowFunctionDefinition::WindowUDF(match &proto.fun_definition { Some(buf) => codec.try_decode_udwf(udwf_name, buf)?, - None => registry.udwf(udwf_name)? + None => registry.udwf(udwf_name).or_else(|_| codec.try_decode_udwf(udwf_name, &[]))? }) } } @@ -354,7 +355,9 @@ pub fn parse_physical_expr( ExprType::ScalarUdf(e) => { let udf = match &e.fun_definition { Some(buf) => codec.try_decode_udf(&e.name, buf)?, - None => registry.udf(e.name.as_str())?, + None => registry + .udf(e.name.as_str()) + .or_else(|_| codec.try_decode_udf(&e.name, &[]))?, }; let scalar_fun_def = Arc::clone(&udf); @@ -365,7 +368,7 @@ pub fn parse_physical_expr( e.name.as_str(), scalar_fun_def, args, - convert_required!(e.return_type)?, + Field::new("f", convert_required!(e.return_type)?, true).into(), ) .with_nullable(e.nullable), ) diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 90d071ab23f5..cc80a0a94fc2 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -792,7 +792,10 @@ impl protobuf::PhysicalPlanNode { ) -> Result> { let input: Arc = into_physical_plan(&merge.input, registry, runtime, extension_codec)?; - Ok(Arc::new(CoalescePartitionsExec::new(input))) + Ok(Arc::new( + CoalescePartitionsExec::new(input) + .with_fetch(merge.fetch.map(|f| f as usize)), + )) } fn try_into_repartition_physical_plan( @@ -1047,7 +1050,12 @@ impl protobuf::PhysicalPlanNode { let agg_udf = match &agg_node.fun_definition { Some(buf) => extension_codec .try_decode_udaf(udaf_name, buf)?, - None => registry.udaf(udaf_name)?, + None => { + registry.udaf(udaf_name).or_else(|_| { + extension_codec + .try_decode_udaf(udaf_name, &[]) + })? + } }; AggregateExprBuilder::new(agg_udf, input_phy_expr) @@ -2354,6 +2362,7 @@ impl protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::Merge(Box::new( protobuf::CoalescePartitionsExecNode { input: Some(Box::new(input)), + fetch: exec.fetch().map(|f| f as u32), }, ))), }) diff --git a/datafusion/proto/tests/cases/mod.rs b/datafusion/proto/tests/cases/mod.rs index 92d961fc7556..4c7da2768e74 100644 --- a/datafusion/proto/tests/cases/mod.rs +++ b/datafusion/proto/tests/cases/mod.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use arrow::datatypes::{DataType, Field}; +use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion::logical_expr::ColumnarValue; use datafusion_common::plan_err; use datafusion_expr::function::AccumulatorArgs; @@ -166,8 +166,11 @@ impl WindowUDFImpl for CustomUDWF { Ok(Box::new(CustomUDWFEvaluator {})) } - fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { - Ok(Field::new(field_args.name(), DataType::UInt64, false)) + fn field( + &self, + field_args: WindowUDFFieldArgs, + ) -> datafusion_common::Result { + Ok(Field::new(field_args.name(), DataType::UInt64, false).into()) } } diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 9fa1f74ae188..993cc6f87ca3 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -19,12 +19,15 @@ use arrow::array::{ ArrayRef, FixedSizeListArray, Int32Builder, MapArray, MapBuilder, StringBuilder, }; use arrow::datatypes::{ - DataType, Field, Fields, Int32Type, IntervalDayTimeType, IntervalMonthDayNanoType, - IntervalUnit, Schema, SchemaRef, TimeUnit, UnionFields, UnionMode, - DECIMAL256_MAX_PRECISION, + DataType, Field, FieldRef, Fields, Int32Type, IntervalDayTimeType, + IntervalMonthDayNanoType, IntervalUnit, Schema, SchemaRef, TimeUnit, UnionFields, + UnionMode, DECIMAL256_MAX_PRECISION, }; use arrow::util::pretty::pretty_format_batches; -use datafusion::datasource::file_format::json::JsonFormatFactory; +use datafusion::datasource::file_format::json::{JsonFormat, JsonFormatFactory}; +use datafusion::datasource::listing::{ + ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl, +}; use datafusion::optimizer::eliminate_nested_union::EliminateNestedUnion; use datafusion::optimizer::Optimizer; use datafusion_common::parsers::CompressionTypeVariant; @@ -110,15 +113,21 @@ fn roundtrip_json_test(proto: &protobuf::LogicalExprNode) { #[cfg(not(feature = "json"))] fn roundtrip_json_test(_proto: &protobuf::LogicalExprNode) {} -// Given a DataFusion logical Expr, convert it to protobuf and back, using debug formatting to test -// equality. fn roundtrip_expr_test(initial_struct: Expr, ctx: SessionContext) { let extension_codec = DefaultLogicalExtensionCodec {}; - let proto: protobuf::LogicalExprNode = - serialize_expr(&initial_struct, &extension_codec) - .unwrap_or_else(|e| panic!("Error serializing expression: {:?}", e)); - let round_trip: Expr = - from_proto::parse_expr(&proto, &ctx, &extension_codec).unwrap(); + roundtrip_expr_test_with_codec(initial_struct, ctx, &extension_codec); +} + +// Given a DataFusion logical Expr, convert it to protobuf and back, using debug formatting to test +// equality. +fn roundtrip_expr_test_with_codec( + initial_struct: Expr, + ctx: SessionContext, + codec: &dyn LogicalExtensionCodec, +) { + let proto: protobuf::LogicalExprNode = serialize_expr(&initial_struct, codec) + .unwrap_or_else(|e| panic!("Error serializing expression: {e:?}")); + let round_trip: Expr = from_proto::parse_expr(&proto, &ctx, codec).unwrap(); assert_eq!(format!("{:?}", &initial_struct), format!("{round_trip:?}")); @@ -970,8 +979,8 @@ async fn roundtrip_expr_api() -> Result<()> { stddev_pop(lit(2.2)), approx_distinct(lit(2)), approx_median(lit(2)), - approx_percentile_cont(lit(2), lit(0.5), None), - approx_percentile_cont(lit(2), lit(0.5), Some(lit(50))), + approx_percentile_cont(lit(2).sort(true, false), lit(0.5), None), + approx_percentile_cont(lit(2).sort(true, false), lit(0.5), Some(lit(50))), approx_percentile_cont_with_weight(lit(2), lit(1), lit(0.5)), grouping(lit(1)), bit_and(lit(2)), @@ -1959,7 +1968,7 @@ fn roundtrip_case_with_null() { let test_expr = Expr::Case(Case::new( Some(Box::new(lit(1.0_f32))), vec![(Box::new(lit(2.0_f32)), Box::new(lit(3.0_f32)))], - Some(Box::new(Expr::Literal(ScalarValue::Null))), + Some(Box::new(Expr::Literal(ScalarValue::Null, None))), )); let ctx = SessionContext::new(); @@ -1968,7 +1977,7 @@ fn roundtrip_case_with_null() { #[test] fn roundtrip_null_literal() { - let test_expr = Expr::Literal(ScalarValue::Null); + let test_expr = Expr::Literal(ScalarValue::Null, None); let ctx = SessionContext::new(); roundtrip_expr_test(test_expr, ctx); @@ -2182,8 +2191,7 @@ fn roundtrip_aggregate_udf() { roundtrip_expr_test(test_expr, ctx); } -#[test] -fn roundtrip_scalar_udf() { +fn dummy_udf() -> ScalarUDF { let scalar_fn = Arc::new(|args: &[ColumnarValue]| { let ColumnarValue::Array(array) = &args[0] else { panic!("should be array") @@ -2191,13 +2199,18 @@ fn roundtrip_scalar_udf() { Ok(ColumnarValue::from(Arc::new(array.clone()) as ArrayRef)) }); - let udf = create_udf( + create_udf( "dummy", vec![DataType::Utf8], DataType::Utf8, Volatility::Immutable, scalar_fn, - ); + ) +} + +#[test] +fn roundtrip_scalar_udf() { + let udf = dummy_udf(); let test_expr = Expr::ScalarFunction(ScalarFunction::new_udf( Arc::new(udf.clone()), @@ -2207,7 +2220,57 @@ fn roundtrip_scalar_udf() { let ctx = SessionContext::new(); ctx.register_udf(udf); - roundtrip_expr_test(test_expr, ctx); + roundtrip_expr_test(test_expr.clone(), ctx); + + // Now test loading the UDF without registering it in the context, but rather creating it in the + // extension codec. + #[derive(Debug)] + struct DummyUDFExtensionCodec; + + impl LogicalExtensionCodec for DummyUDFExtensionCodec { + fn try_decode( + &self, + _buf: &[u8], + _inputs: &[LogicalPlan], + _ctx: &SessionContext, + ) -> Result { + not_impl_err!("LogicalExtensionCodec is not provided") + } + + fn try_encode(&self, _node: &Extension, _buf: &mut Vec) -> Result<()> { + not_impl_err!("LogicalExtensionCodec is not provided") + } + + fn try_decode_table_provider( + &self, + _buf: &[u8], + _table_ref: &TableReference, + _schema: SchemaRef, + _ctx: &SessionContext, + ) -> Result> { + not_impl_err!("LogicalExtensionCodec is not provided") + } + + fn try_encode_table_provider( + &self, + _table_ref: &TableReference, + _node: Arc, + _buf: &mut Vec, + ) -> Result<()> { + not_impl_err!("LogicalExtensionCodec is not provided") + } + + fn try_decode_udf(&self, name: &str, _buf: &[u8]) -> Result> { + if name == "dummy" { + Ok(Arc::new(dummy_udf())) + } else { + Err(DataFusionError::Internal(format!("UDF {name} not found"))) + } + } + } + + let ctx = SessionContext::new(); + roundtrip_expr_test_with_codec(test_expr, ctx, &DummyUDFExtensionCodec) } #[test] @@ -2296,7 +2359,7 @@ fn roundtrip_window() { let ctx = SessionContext::new(); // 1. without window_frame - let test_expr1 = Expr::WindowFunction(expr::WindowFunction::new( + let test_expr1 = Expr::from(expr::WindowFunction::new( WindowFunctionDefinition::WindowUDF(rank_udwf()), vec![], )) @@ -2307,7 +2370,7 @@ fn roundtrip_window() { .unwrap(); // 2. with default window_frame - let test_expr2 = Expr::WindowFunction(expr::WindowFunction::new( + let test_expr2 = Expr::from(expr::WindowFunction::new( WindowFunctionDefinition::WindowUDF(rank_udwf()), vec![], )) @@ -2324,7 +2387,7 @@ fn roundtrip_window() { WindowFrameBound::Following(ScalarValue::UInt64(Some(2))), ); - let test_expr3 = Expr::WindowFunction(expr::WindowFunction::new( + let test_expr3 = Expr::from(expr::WindowFunction::new( WindowFunctionDefinition::WindowUDF(rank_udwf()), vec![], )) @@ -2341,7 +2404,7 @@ fn roundtrip_window() { WindowFrameBound::Following(ScalarValue::UInt64(Some(2))), ); - let test_expr4 = Expr::WindowFunction(expr::WindowFunction::new( + let test_expr4 = Expr::from(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("col1")], )) @@ -2391,7 +2454,7 @@ fn roundtrip_window() { Arc::new(vec![DataType::Float64, DataType::UInt32]), ); - let test_expr5 = Expr::WindowFunction(expr::WindowFunction::new( + let test_expr5 = Expr::from(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(Arc::new(dummy_agg.clone())), vec![col("col1")], )) @@ -2453,14 +2516,18 @@ fn roundtrip_window() { make_partition_evaluator() } - fn field(&self, field_args: WindowUDFFieldArgs) -> Result { - if let Some(return_type) = field_args.get_input_type(0) { - Ok(Field::new(field_args.name(), return_type, true)) + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + if let Some(return_field) = field_args.get_input_field(0) { + Ok(return_field + .as_ref() + .clone() + .with_name(field_args.name()) + .into()) } else { plan_err!( "dummy_udwf expects 1 argument, got {}: {:?}", - field_args.input_types().len(), - field_args.input_types() + field_args.input_fields().len(), + field_args.input_fields() ) } } @@ -2472,7 +2539,7 @@ fn roundtrip_window() { let dummy_window_udf = WindowUDF::from(SimpleWindowUDF::new()); - let test_expr6 = Expr::WindowFunction(expr::WindowFunction::new( + let test_expr6 = Expr::from(expr::WindowFunction::new( WindowFunctionDefinition::WindowUDF(Arc::new(dummy_window_udf.clone())), vec![col("col1")], )) @@ -2482,7 +2549,7 @@ fn roundtrip_window() { .build() .unwrap(); - let text_expr7 = Expr::WindowFunction(expr::WindowFunction::new( + let text_expr7 = Expr::from(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(avg_udaf()), vec![col("col1")], )) @@ -2559,3 +2626,33 @@ async fn roundtrip_union_query() -> Result<()> { ); Ok(()) } + +#[tokio::test] +async fn roundtrip_custom_listing_tables_schema() -> Result<()> { + let ctx = SessionContext::new(); + // Make sure during round-trip, constraint information is preserved + let file_format = JsonFormat::default(); + let table_partition_cols = vec![("part".to_owned(), DataType::Int64)]; + let data = "../core/tests/data/partitioned_table_json"; + let listing_table_url = ListingTableUrl::parse(data)?; + let listing_options = ListingOptions::new(Arc::new(file_format)) + .with_table_partition_cols(table_partition_cols); + + let config = ListingTableConfig::new(listing_table_url) + .with_listing_options(listing_options) + .infer_schema(&ctx.state()) + .await?; + + ctx.register_table("hive_style", Arc::new(ListingTable::try_new(config)?))?; + + let plan = ctx + .sql("SELECT part, value FROM hive_style LIMIT 1") + .await? + .logical_plan() + .clone(); + + let bytes = logical_plan_to_bytes(&plan)?; + let new_plan = logical_plan_from_bytes(&bytes, &ctx)?; + assert_eq!(plan, new_plan); + Ok(()) +} diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index be90497a6e21..3329e531aea7 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -66,6 +66,7 @@ use datafusion::physical_plan::aggregates::{ AggregateExec, AggregateMode, PhysicalGroupBy, }; use datafusion::physical_plan::analyze::AnalyzeExec; +use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion::physical_plan::empty::EmptyExec; use datafusion::physical_plan::expressions::{ binary, cast, col, in_list, like, lit, BinaryExpr, Column, NotExpr, PhysicalSortExpr, @@ -504,7 +505,7 @@ fn rountrip_aggregate_with_approx_pencentile_cont() -> Result<()> { vec![col("b", &schema)?, lit(0.5)], ) .schema(Arc::clone(&schema)) - .alias("APPROX_PERCENTILE_CONT(b, 0.5)") + .alias("APPROX_PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY b)") .build() .map(Arc::new)?]; @@ -594,7 +595,7 @@ fn roundtrip_aggregate_udaf() -> Result<()> { Signature::exact(vec![DataType::Int64], Volatility::Immutable), return_type, accumulator, - vec![Field::new("value", DataType::Int64, true)], + vec![Field::new("value", DataType::Int64, true).into()], )); let ctx = SessionContext::new(); @@ -709,7 +710,7 @@ fn roundtrip_sort_preserve_partitioning() -> Result<()> { } #[test] -fn roundtrip_coalesce_with_fetch() -> Result<()> { +fn roundtrip_coalesce_batches_with_fetch() -> Result<()> { let field_a = Field::new("a", DataType::Boolean, false); let field_b = Field::new("b", DataType::Int64, false); let schema = Arc::new(Schema::new(vec![field_a, field_b])); @@ -725,6 +726,22 @@ fn roundtrip_coalesce_with_fetch() -> Result<()> { )) } +#[test] +fn roundtrip_coalesce_partitions_with_fetch() -> Result<()> { + let field_a = Field::new("a", DataType::Boolean, false); + let field_b = Field::new("b", DataType::Int64, false); + let schema = Arc::new(Schema::new(vec![field_a, field_b])); + + roundtrip_test(Arc::new(CoalescePartitionsExec::new(Arc::new( + EmptyExec::new(schema.clone()), + ))))?; + + roundtrip_test(Arc::new( + CoalescePartitionsExec::new(Arc::new(EmptyExec::new(schema))) + .with_fetch(Some(10)), + )) +} + #[test] fn roundtrip_parquet_exec_with_pruning_predicate() -> Result<()> { let file_schema = @@ -739,9 +756,7 @@ fn roundtrip_parquet_exec_with_pruning_predicate() -> Result<()> { let mut options = TableParquetOptions::new(); options.global.pushdown_filters = true; - let file_source = Arc::new( - ParquetSource::new(options).with_predicate(Arc::clone(&file_schema), predicate), - ); + let file_source = Arc::new(ParquetSource::new(options).with_predicate(Arc::clone(&file_schema), predicate)); let scan_config = FileScanConfigBuilder::new( ObjectStoreUrl::local_filesystem(), @@ -800,10 +815,8 @@ fn roundtrip_parquet_exec_with_custom_predicate_expr() -> Result<()> { inner: Arc::new(Column::new("col", 1)), }); - let file_source = Arc::new( - ParquetSource::default() - .with_predicate(Arc::clone(&file_schema), custom_predicate_expr), - ); + let file_source = + Arc::new(ParquetSource::default().with_predicate(Arc::clone(&file_schema), custom_predicate_expr)); let scan_config = FileScanConfigBuilder::new( ObjectStoreUrl::local_filesystem(), @@ -968,7 +981,7 @@ fn roundtrip_scalar_udf() -> Result<()> { "dummy", fun_def, vec![col("a", &schema)?], - DataType::Int64, + Field::new("f", DataType::Int64, true).into(), ); let project = @@ -1096,7 +1109,7 @@ fn roundtrip_scalar_udf_extension_codec() -> Result<()> { "regex_udf", Arc::new(ScalarUDF::from(MyRegexUdf::new(".*".to_string()))), vec![col("text", &schema)?], - DataType::Int64, + Field::new("f", DataType::Int64, true).into(), )); let filter = Arc::new(FilterExec::try_new( @@ -1198,7 +1211,7 @@ fn roundtrip_aggregate_udf_extension_codec() -> Result<()> { "regex_udf", Arc::new(ScalarUDF::from(MyRegexUdf::new(".*".to_string()))), vec![col("text", &schema)?], - DataType::Int64, + Field::new("f", DataType::Int64, true).into(), )); let udaf = Arc::new(AggregateUDF::from(MyAggregateUDF::new( diff --git a/datafusion/proto/tests/cases/serialize.rs b/datafusion/proto/tests/cases/serialize.rs index d15e62909f7e..c9ef4377d43b 100644 --- a/datafusion/proto/tests/cases/serialize.rs +++ b/datafusion/proto/tests/cases/serialize.rs @@ -83,7 +83,7 @@ fn udf_roundtrip_with_registry() { #[test] #[should_panic( - expected = "No function registry provided to deserialize, so can not deserialize User Defined Function 'dummy'" + expected = "LogicalExtensionCodec is not provided for scalar function dummy" )] fn udf_roundtrip_without_registry() { let ctx = context_with_udf(); @@ -256,7 +256,7 @@ fn test_expression_serialization_roundtrip() { use datafusion_proto::logical_plan::from_proto::parse_expr; let ctx = SessionContext::new(); - let lit = Expr::Literal(ScalarValue::Utf8(None)); + let lit = Expr::Literal(ScalarValue::Utf8(None), None); for function in string::functions() { // default to 4 args (though some exprs like substr have error checking) let num_args = 4; diff --git a/datafusion/spark/Cargo.toml b/datafusion/spark/Cargo.toml new file mode 100644 index 000000000000..1ded8c40aa4b --- /dev/null +++ b/datafusion/spark/Cargo.toml @@ -0,0 +1,46 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "datafusion-spark" +description = "DataFusion expressions that emulate Apache Spark's behavior" +version = { workspace = true } +homepage = { workspace = true } +repository = { workspace = true } +authors = { workspace = true } +readme = { workspace = true } +license = { workspace = true } +edition = { workspace = true } + +[package.metadata.docs.rs] +all-features = true + +[lints] +workspace = true + +[lib] +name = "datafusion_spark" + +[dependencies] +arrow = { workspace = true } +datafusion-catalog = { workspace = true } +datafusion-common = { workspace = true } +datafusion-execution = { workspace = true } +datafusion-expr = { workspace = true } +datafusion-functions = { workspace = true } +datafusion-macros = { workspace = true } +log = { workspace = true } diff --git a/datafusion/spark/LICENSE.txt b/datafusion/spark/LICENSE.txt new file mode 120000 index 000000000000..1ef648f64b34 --- /dev/null +++ b/datafusion/spark/LICENSE.txt @@ -0,0 +1 @@ +../../LICENSE.txt \ No newline at end of file diff --git a/datafusion/spark/NOTICE.txt b/datafusion/spark/NOTICE.txt new file mode 120000 index 000000000000..fb051c92b10b --- /dev/null +++ b/datafusion/spark/NOTICE.txt @@ -0,0 +1 @@ +../../NOTICE.txt \ No newline at end of file diff --git a/datafusion/spark/README.md b/datafusion/spark/README.md new file mode 100644 index 000000000000..c92ada0ab477 --- /dev/null +++ b/datafusion/spark/README.md @@ -0,0 +1,40 @@ + + +# datafusion-spark: Spark-compatible Expressions + +This crate provides Apache Spark-compatible expressions for use with DataFusion. + +## Testing Guide + +When testing functions by directly invoking them (e.g., `test_scalar_function!()`), input coercion (from the `signature` +or `coerce_types`) is not applied. + +Therefore, direct invocation tests should only be used to verify that the function is correctly implemented. + +Please be sure to add additional tests beyond direct invocation. +For more detailed testing guidelines, refer to +the [Spark SQLLogicTest README](../sqllogictest/test_files/spark/README.md). + +## Implementation References + +When implementing Spark-compatible functions, you can check if there are existing implementations in +the [Sail](https://github.com/lakehq/sail) or [Comet](https://github.com/apache/datafusion-comet) projects first. +If you do port functionality from these sources, make sure to port over the corresponding tests too, to ensure +correctness and compatibility. diff --git a/datafusion/spark/src/function/aggregate/mod.rs b/datafusion/spark/src/function/aggregate/mod.rs new file mode 100644 index 000000000000..0856e2872d4f --- /dev/null +++ b/datafusion/spark/src/function/aggregate/mod.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_expr::AggregateUDF; +use std::sync::Arc; + +pub mod expr_fn {} + +pub fn functions() -> Vec> { + vec![] +} diff --git a/datafusion/spark/src/function/array/mod.rs b/datafusion/spark/src/function/array/mod.rs new file mode 100644 index 000000000000..a87df9a2c87a --- /dev/null +++ b/datafusion/spark/src/function/array/mod.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_expr::ScalarUDF; +use std::sync::Arc; + +pub mod expr_fn {} + +pub fn functions() -> Vec> { + vec![] +} diff --git a/datafusion/spark/src/function/bitwise/mod.rs b/datafusion/spark/src/function/bitwise/mod.rs new file mode 100644 index 000000000000..a87df9a2c87a --- /dev/null +++ b/datafusion/spark/src/function/bitwise/mod.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_expr::ScalarUDF; +use std::sync::Arc; + +pub mod expr_fn {} + +pub fn functions() -> Vec> { + vec![] +} diff --git a/datafusion/spark/src/function/collection/mod.rs b/datafusion/spark/src/function/collection/mod.rs new file mode 100644 index 000000000000..a87df9a2c87a --- /dev/null +++ b/datafusion/spark/src/function/collection/mod.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_expr::ScalarUDF; +use std::sync::Arc; + +pub mod expr_fn {} + +pub fn functions() -> Vec> { + vec![] +} diff --git a/datafusion/spark/src/function/conditional/mod.rs b/datafusion/spark/src/function/conditional/mod.rs new file mode 100644 index 000000000000..a87df9a2c87a --- /dev/null +++ b/datafusion/spark/src/function/conditional/mod.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_expr::ScalarUDF; +use std::sync::Arc; + +pub mod expr_fn {} + +pub fn functions() -> Vec> { + vec![] +} diff --git a/datafusion/spark/src/function/conversion/mod.rs b/datafusion/spark/src/function/conversion/mod.rs new file mode 100644 index 000000000000..a87df9a2c87a --- /dev/null +++ b/datafusion/spark/src/function/conversion/mod.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_expr::ScalarUDF; +use std::sync::Arc; + +pub mod expr_fn {} + +pub fn functions() -> Vec> { + vec![] +} diff --git a/datafusion/spark/src/function/csv/mod.rs b/datafusion/spark/src/function/csv/mod.rs new file mode 100644 index 000000000000..a87df9a2c87a --- /dev/null +++ b/datafusion/spark/src/function/csv/mod.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_expr::ScalarUDF; +use std::sync::Arc; + +pub mod expr_fn {} + +pub fn functions() -> Vec> { + vec![] +} diff --git a/datafusion/spark/src/function/datetime/mod.rs b/datafusion/spark/src/function/datetime/mod.rs new file mode 100644 index 000000000000..a87df9a2c87a --- /dev/null +++ b/datafusion/spark/src/function/datetime/mod.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_expr::ScalarUDF; +use std::sync::Arc; + +pub mod expr_fn {} + +pub fn functions() -> Vec> { + vec![] +} diff --git a/datafusion/spark/src/function/error_utils.rs b/datafusion/spark/src/function/error_utils.rs new file mode 100644 index 000000000000..b972d64ed3e9 --- /dev/null +++ b/datafusion/spark/src/function/error_utils.rs @@ -0,0 +1,71 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// TODO: https://github.com/apache/spark/tree/master/common/utils/src/main/resources/error + +use arrow::datatypes::DataType; +use datafusion_common::{exec_datafusion_err, internal_datafusion_err, DataFusionError}; + +pub fn invalid_arg_count_exec_err( + function_name: &str, + required_range: (i32, i32), + provided: usize, +) -> DataFusionError { + let (min_required, max_required) = required_range; + let required = if min_required == max_required { + format!( + "{min_required} argument{}", + if min_required == 1 { "" } else { "s" } + ) + } else { + format!("{min_required} to {max_required} arguments") + }; + exec_datafusion_err!( + "Spark `{function_name}` function requires {required}, got {provided}" + ) +} + +pub fn unsupported_data_type_exec_err( + function_name: &str, + required: &str, + provided: &DataType, +) -> DataFusionError { + exec_datafusion_err!("Unsupported Data Type: Spark `{function_name}` function expects {required}, got {provided}") +} + +pub fn unsupported_data_types_exec_err( + function_name: &str, + required: &str, + provided: &[DataType], +) -> DataFusionError { + exec_datafusion_err!( + "Unsupported Data Type: Spark `{function_name}` function expects {required}, got {}", + provided + .iter() + .map(|dt| format!("{dt}")) + .collect::>() + .join(", ") + ) +} + +pub fn generic_exec_err(function_name: &str, message: &str) -> DataFusionError { + exec_datafusion_err!("Spark `{function_name}` function: {message}") +} + +pub fn generic_internal_err(function_name: &str, message: &str) -> DataFusionError { + internal_datafusion_err!("Spark `{function_name}` function: {message}") +} diff --git a/datafusion/spark/src/function/generator/mod.rs b/datafusion/spark/src/function/generator/mod.rs new file mode 100644 index 000000000000..a87df9a2c87a --- /dev/null +++ b/datafusion/spark/src/function/generator/mod.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_expr::ScalarUDF; +use std::sync::Arc; + +pub mod expr_fn {} + +pub fn functions() -> Vec> { + vec![] +} diff --git a/datafusion/spark/src/function/hash/mod.rs b/datafusion/spark/src/function/hash/mod.rs new file mode 100644 index 000000000000..f31918e6a46b --- /dev/null +++ b/datafusion/spark/src/function/hash/mod.rs @@ -0,0 +1,33 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub mod sha2; + +use datafusion_expr::ScalarUDF; +use datafusion_functions::make_udf_function; +use std::sync::Arc; + +make_udf_function!(sha2::SparkSha2, sha2); + +pub mod expr_fn { + use datafusion_functions::export_functions; + export_functions!((sha2, "sha2(expr, bitLength) - Returns a checksum of SHA-2 family as a hex string of expr. SHA-224, SHA-256, SHA-384, and SHA-512 are supported. Bit length of 0 is equivalent to 256.", arg1 arg2)); +} + +pub fn functions() -> Vec> { + vec![sha2()] +} diff --git a/datafusion/spark/src/function/hash/sha2.rs b/datafusion/spark/src/function/hash/sha2.rs new file mode 100644 index 000000000000..b4b29ef33478 --- /dev/null +++ b/datafusion/spark/src/function/hash/sha2.rs @@ -0,0 +1,217 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate datafusion_functions; + +use crate::function::error_utils::{ + invalid_arg_count_exec_err, unsupported_data_type_exec_err, +}; +use crate::function::math::hex::spark_hex; +use arrow::array::{ArrayRef, AsArray, StringArray}; +use arrow::datatypes::{DataType, UInt32Type}; +use datafusion_common::{exec_err, internal_datafusion_err, Result, ScalarValue}; +use datafusion_expr::Signature; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Volatility}; +pub use datafusion_functions::crypto::basic::{sha224, sha256, sha384, sha512}; +use std::any::Any; +use std::sync::Arc; + +/// +#[derive(Debug)] +pub struct SparkSha2 { + signature: Signature, + aliases: Vec, +} + +impl Default for SparkSha2 { + fn default() -> Self { + Self::new() + } +} + +impl SparkSha2 { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + aliases: vec![], + } + } +} + +impl ScalarUDFImpl for SparkSha2 { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "sha2" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + if arg_types[1].is_null() { + return Ok(DataType::Null); + } + Ok(match arg_types[0] { + DataType::Utf8View + | DataType::LargeUtf8 + | DataType::Utf8 + | DataType::Binary + | DataType::BinaryView + | DataType::LargeBinary => DataType::Utf8, + DataType::Null => DataType::Null, + _ => { + return exec_err!( + "{} function can only accept strings or binary arrays.", + self.name() + ) + } + }) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let args: [ColumnarValue; 2] = args.args.try_into().map_err(|_| { + internal_datafusion_err!("Expected 2 arguments for function sha2") + })?; + + sha2(args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + if arg_types.len() != 2 { + return Err(invalid_arg_count_exec_err( + self.name(), + (2, 2), + arg_types.len(), + )); + } + let expr_type = match &arg_types[0] { + DataType::Utf8View + | DataType::LargeUtf8 + | DataType::Utf8 + | DataType::Binary + | DataType::BinaryView + | DataType::LargeBinary + | DataType::Null => Ok(arg_types[0].clone()), + _ => Err(unsupported_data_type_exec_err( + self.name(), + "String, Binary", + &arg_types[0], + )), + }?; + let bit_length_type = if arg_types[1].is_numeric() { + Ok(DataType::UInt32) + } else if arg_types[1].is_null() { + Ok(DataType::Null) + } else { + Err(unsupported_data_type_exec_err( + self.name(), + "Numeric Type", + &arg_types[1], + )) + }?; + + Ok(vec![expr_type, bit_length_type]) + } +} + +pub fn sha2(args: [ColumnarValue; 2]) -> Result { + match args { + [ColumnarValue::Scalar(ScalarValue::Utf8(expr_arg)), ColumnarValue::Scalar(ScalarValue::UInt32(Some(bit_length_arg)))] => { + match bit_length_arg { + 0 | 256 => sha256(&[ColumnarValue::from(ScalarValue::Utf8(expr_arg))]), + 224 => sha224(&[ColumnarValue::from(ScalarValue::Utf8(expr_arg))]), + 384 => sha384(&[ColumnarValue::from(ScalarValue::Utf8(expr_arg))]), + 512 => sha512(&[ColumnarValue::from(ScalarValue::Utf8(expr_arg))]), + _ => exec_err!( + "sha2 function only supports 224, 256, 384, and 512 bit lengths." + ), + } + .map(|hashed| spark_hex(&[hashed]).unwrap()) + } + [ColumnarValue::Array(expr_arg), ColumnarValue::Scalar(ScalarValue::UInt32(Some(bit_length_arg)))] => { + match bit_length_arg { + 0 | 256 => sha256(&[ColumnarValue::from(expr_arg)]), + 224 => sha224(&[ColumnarValue::from(expr_arg)]), + 384 => sha384(&[ColumnarValue::from(expr_arg)]), + 512 => sha512(&[ColumnarValue::from(expr_arg)]), + _ => exec_err!( + "sha2 function only supports 224, 256, 384, and 512 bit lengths." + ), + } + .map(|hashed| spark_hex(&[hashed]).unwrap()) + } + [ColumnarValue::Scalar(ScalarValue::Utf8(expr_arg)), ColumnarValue::Array(bit_length_arg)] => + { + let arr: StringArray = bit_length_arg + .as_primitive::() + .iter() + .map(|bit_length| { + match sha2([ + ColumnarValue::Scalar(ScalarValue::Utf8(expr_arg.clone())), + ColumnarValue::Scalar(ScalarValue::UInt32(bit_length)), + ]) + .unwrap() + { + ColumnarValue::Scalar(ScalarValue::Utf8(str)) => str, + ColumnarValue::Array(arr) => arr + .as_string::() + .iter() + .map(|str| str.unwrap().to_string()) + .next(), // first element + _ => unreachable!(), + } + }) + .collect(); + Ok(ColumnarValue::Array(Arc::new(arr) as ArrayRef)) + } + [ColumnarValue::Array(expr_arg), ColumnarValue::Array(bit_length_arg)] => { + let expr_iter = expr_arg.as_string::().iter(); + let bit_length_iter = bit_length_arg.as_primitive::().iter(); + let arr: StringArray = expr_iter + .zip(bit_length_iter) + .map(|(expr, bit_length)| { + match sha2([ + ColumnarValue::Scalar(ScalarValue::Utf8(Some( + expr.unwrap().to_string(), + ))), + ColumnarValue::Scalar(ScalarValue::UInt32(bit_length)), + ]) + .unwrap() + { + ColumnarValue::Scalar(ScalarValue::Utf8(str)) => str, + ColumnarValue::Array(arr) => arr + .as_string::() + .iter() + .map(|str| str.unwrap().to_string()) + .next(), // first element + _ => unreachable!(), + } + }) + .collect(); + Ok(ColumnarValue::Array(Arc::new(arr) as ArrayRef)) + } + _ => exec_err!("Unsupported argument types for sha2 function"), + } +} diff --git a/datafusion/spark/src/function/json/mod.rs b/datafusion/spark/src/function/json/mod.rs new file mode 100644 index 000000000000..a87df9a2c87a --- /dev/null +++ b/datafusion/spark/src/function/json/mod.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_expr::ScalarUDF; +use std::sync::Arc; + +pub mod expr_fn {} + +pub fn functions() -> Vec> { + vec![] +} diff --git a/datafusion/spark/src/function/lambda/mod.rs b/datafusion/spark/src/function/lambda/mod.rs new file mode 100644 index 000000000000..a87df9a2c87a --- /dev/null +++ b/datafusion/spark/src/function/lambda/mod.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_expr::ScalarUDF; +use std::sync::Arc; + +pub mod expr_fn {} + +pub fn functions() -> Vec> { + vec![] +} diff --git a/datafusion/spark/src/function/map/mod.rs b/datafusion/spark/src/function/map/mod.rs new file mode 100644 index 000000000000..a87df9a2c87a --- /dev/null +++ b/datafusion/spark/src/function/map/mod.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_expr::ScalarUDF; +use std::sync::Arc; + +pub mod expr_fn {} + +pub fn functions() -> Vec> { + vec![] +} diff --git a/datafusion/spark/src/function/math/expm1.rs b/datafusion/spark/src/function/math/expm1.rs new file mode 100644 index 000000000000..3a3a0c3835d3 --- /dev/null +++ b/datafusion/spark/src/function/math/expm1.rs @@ -0,0 +1,145 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::function::error_utils::{ + invalid_arg_count_exec_err, unsupported_data_type_exec_err, +}; +use arrow::array::{ArrayRef, AsArray}; +use arrow::datatypes::{DataType, Float64Type}; +use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use std::any::Any; +use std::sync::Arc; + +/// +#[derive(Debug)] +pub struct SparkExpm1 { + signature: Signature, + aliases: Vec, +} + +impl Default for SparkExpm1 { + fn default() -> Self { + Self::new() + } +} + +impl SparkExpm1 { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + aliases: vec![], + } + } +} + +impl ScalarUDFImpl for SparkExpm1 { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "expm1" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + if args.args.len() != 1 { + return Err(invalid_arg_count_exec_err("expm1", (1, 1), args.args.len())); + } + match &args.args[0] { + ColumnarValue::Scalar(ScalarValue::Float64(value)) => Ok( + ColumnarValue::Scalar(ScalarValue::Float64(value.map(|x| x.exp_m1()))), + ), + ColumnarValue::Array(array) => match array.data_type() { + DataType::Float64 => Ok(ColumnarValue::Array(Arc::new( + array + .as_primitive::() + .unary::<_, Float64Type>(|x| x.exp_m1()), + ) + as ArrayRef)), + other => Err(unsupported_data_type_exec_err( + "expm1", + format!("{}", DataType::Float64).as_str(), + other, + )), + }, + other => Err(unsupported_data_type_exec_err( + "expm1", + format!("{}", DataType::Float64).as_str(), + &other.data_type(), + )), + } + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + if arg_types.len() != 1 { + return Err(invalid_arg_count_exec_err("expm1", (1, 1), arg_types.len())); + } + if arg_types[0].is_numeric() { + Ok(vec![DataType::Float64]) + } else { + Err(unsupported_data_type_exec_err( + "expm1", + "Numeric Type", + &arg_types[0], + )) + } + } +} + +#[cfg(test)] +mod tests { + use crate::function::math::expm1::SparkExpm1; + use crate::function::utils::test::test_scalar_function; + use arrow::array::{Array, Float64Array}; + use arrow::datatypes::DataType::Float64; + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + macro_rules! test_expm1_float64_invoke { + ($INPUT:expr, $EXPECTED:expr) => { + test_scalar_function!( + SparkExpm1::new(), + vec![ColumnarValue::Scalar(ScalarValue::Float64($INPUT))], + $EXPECTED, + f64, + Float64, + Float64Array + ); + }; + } + + #[test] + fn test_expm1_invoke() -> Result<()> { + test_expm1_float64_invoke!(Some(0f64), Ok(Some(0.0f64))); + Ok(()) + } +} diff --git a/datafusion/spark/src/function/math/hex.rs b/datafusion/spark/src/function/math/hex.rs new file mode 100644 index 000000000000..74ec7641b38f --- /dev/null +++ b/datafusion/spark/src/function/math/hex.rs @@ -0,0 +1,404 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use crate::function::error_utils::{ + invalid_arg_count_exec_err, unsupported_data_type_exec_err, +}; +use arrow::array::{Array, StringArray}; +use arrow::datatypes::DataType; +use arrow::{ + array::{as_dictionary_array, as_largestring_array, as_string_array}, + datatypes::Int32Type, +}; +use datafusion_common::{ + cast::{as_binary_array, as_fixed_size_binary_array, as_int64_array}, + exec_err, DataFusionError, +}; +use datafusion_expr::Signature; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Volatility}; +use std::fmt::Write; + +/// +#[derive(Debug)] +pub struct SparkHex { + signature: Signature, + aliases: Vec, +} + +impl Default for SparkHex { + fn default() -> Self { + Self::new() + } +} + +impl SparkHex { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + aliases: vec![], + } + } +} + +impl ScalarUDFImpl for SparkHex { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "hex" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type( + &self, + _arg_types: &[DataType], + ) -> datafusion_common::Result { + Ok(DataType::Utf8) + } + + fn invoke_with_args( + &self, + args: ScalarFunctionArgs, + ) -> datafusion_common::Result { + spark_hex(&args.args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn coerce_types( + &self, + arg_types: &[DataType], + ) -> datafusion_common::Result> { + if arg_types.len() != 1 { + return Err(invalid_arg_count_exec_err("hex", (1, 1), arg_types.len())); + } + match &arg_types[0] { + DataType::Int64 + | DataType::Utf8 + | DataType::LargeUtf8 + | DataType::Binary + | DataType::LargeBinary => Ok(vec![arg_types[0].clone()]), + DataType::Dictionary(key_type, value_type) => match value_type.as_ref() { + DataType::Int64 + | DataType::Utf8 + | DataType::LargeUtf8 + | DataType::Binary + | DataType::LargeBinary => Ok(vec![arg_types[0].clone()]), + other => { + if other.is_numeric() { + Ok(vec![DataType::Dictionary( + key_type.clone(), + Box::new(DataType::Int64), + )]) + } else { + Err(unsupported_data_type_exec_err( + "hex", + "Numeric, String, or Binary", + &arg_types[0], + )) + } + } + }, + other => { + if other.is_numeric() { + Ok(vec![DataType::Int64]) + } else { + Err(unsupported_data_type_exec_err( + "hex", + "Numeric, String, or Binary", + &arg_types[0], + )) + } + } + } + } +} + +fn hex_int64(num: i64) -> String { + format!("{num:X}") +} + +#[inline(always)] +fn hex_encode>(data: T, lower_case: bool) -> String { + let mut s = String::with_capacity(data.as_ref().len() * 2); + if lower_case { + for b in data.as_ref() { + // Writing to a string never errors, so we can unwrap here. + write!(&mut s, "{b:02x}").unwrap(); + } + } else { + for b in data.as_ref() { + // Writing to a string never errors, so we can unwrap here. + write!(&mut s, "{b:02X}").unwrap(); + } + } + s +} + +#[inline(always)] +fn hex_bytes>(bytes: T) -> Result { + let hex_string = hex_encode(bytes, false); + Ok(hex_string) +} + +/// Spark-compatible `hex` function +pub fn spark_hex(args: &[ColumnarValue]) -> Result { + if args.len() != 1 { + return Err(DataFusionError::Internal( + "hex expects exactly one argument".to_string(), + )); + } + + let input = match &args[0] { + ColumnarValue::Scalar(value) => ColumnarValue::Array(value.to_array()?), + ColumnarValue::Array(_) => args[0].clone(), + }; + + match &input { + ColumnarValue::Array(array) => match array.data_type() { + DataType::Int64 => { + let array = as_int64_array(array)?; + + let hexed_array: StringArray = + array.iter().map(|v| v.map(hex_int64)).collect(); + + Ok(ColumnarValue::Array(Arc::new(hexed_array))) + } + DataType::Utf8 => { + let array = as_string_array(array); + + let hexed: StringArray = array + .iter() + .map(|v| v.map(hex_bytes).transpose()) + .collect::>()?; + + Ok(ColumnarValue::Array(Arc::new(hexed))) + } + DataType::LargeUtf8 => { + let array = as_largestring_array(array); + + let hexed: StringArray = array + .iter() + .map(|v| v.map(hex_bytes).transpose()) + .collect::>()?; + + Ok(ColumnarValue::Array(Arc::new(hexed))) + } + DataType::Binary => { + let array = as_binary_array(array)?; + + let hexed: StringArray = array + .iter() + .map(|v| v.map(hex_bytes).transpose()) + .collect::>()?; + + Ok(ColumnarValue::Array(Arc::new(hexed))) + } + DataType::FixedSizeBinary(_) => { + let array = as_fixed_size_binary_array(array)?; + + let hexed: StringArray = array + .iter() + .map(|v| v.map(hex_bytes).transpose()) + .collect::>()?; + + Ok(ColumnarValue::Array(Arc::new(hexed))) + } + DataType::Dictionary(_, value_type) => { + let dict = as_dictionary_array::(&array); + + let values = match **value_type { + DataType::Int64 => as_int64_array(dict.values())? + .iter() + .map(|v| v.map(hex_int64)) + .collect::>(), + DataType::Utf8 => as_string_array(dict.values()) + .iter() + .map(|v| v.map(hex_bytes).transpose()) + .collect::>()?, + DataType::Binary => as_binary_array(dict.values())? + .iter() + .map(|v| v.map(hex_bytes).transpose()) + .collect::>()?, + _ => exec_err!( + "hex got an unexpected argument type: {:?}", + array.data_type() + )?, + }; + + let new_values: Vec> = dict + .keys() + .iter() + .map(|key| key.map(|k| values[k as usize].clone()).unwrap_or(None)) + .collect(); + + let string_array_values = StringArray::from(new_values); + + Ok(ColumnarValue::Array(Arc::new(string_array_values))) + } + _ => exec_err!( + "hex got an unexpected argument type: {:?}", + array.data_type() + ), + }, + _ => exec_err!("native hex does not support scalar values at this time"), + } +} + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use arrow::array::{Int64Array, StringArray}; + use arrow::{ + array::{ + as_string_array, BinaryDictionaryBuilder, PrimitiveDictionaryBuilder, + StringBuilder, StringDictionaryBuilder, + }, + datatypes::{Int32Type, Int64Type}, + }; + use datafusion_expr::ColumnarValue; + + #[test] + fn test_dictionary_hex_utf8() { + let mut input_builder = StringDictionaryBuilder::::new(); + input_builder.append_value("hi"); + input_builder.append_value("bye"); + input_builder.append_null(); + input_builder.append_value("rust"); + let input = input_builder.finish(); + + let mut string_builder = StringBuilder::new(); + string_builder.append_value("6869"); + string_builder.append_value("627965"); + string_builder.append_null(); + string_builder.append_value("72757374"); + let expected = string_builder.finish(); + + let columnar_value = ColumnarValue::Array(Arc::new(input)); + let result = super::spark_hex(&[columnar_value]).unwrap(); + + let result = match result { + ColumnarValue::Array(array) => array, + _ => panic!("Expected array"), + }; + + let result = as_string_array(&result); + + assert_eq!(result, &expected); + } + + #[test] + fn test_dictionary_hex_int64() { + let mut input_builder = PrimitiveDictionaryBuilder::::new(); + input_builder.append_value(1); + input_builder.append_value(2); + input_builder.append_null(); + input_builder.append_value(3); + let input = input_builder.finish(); + + let mut string_builder = StringBuilder::new(); + string_builder.append_value("1"); + string_builder.append_value("2"); + string_builder.append_null(); + string_builder.append_value("3"); + let expected = string_builder.finish(); + + let columnar_value = ColumnarValue::Array(Arc::new(input)); + let result = super::spark_hex(&[columnar_value]).unwrap(); + + let result = match result { + ColumnarValue::Array(array) => array, + _ => panic!("Expected array"), + }; + + let result = as_string_array(&result); + + assert_eq!(result, &expected); + } + + #[test] + fn test_dictionary_hex_binary() { + let mut input_builder = BinaryDictionaryBuilder::::new(); + input_builder.append_value("1"); + input_builder.append_value("j"); + input_builder.append_null(); + input_builder.append_value("3"); + let input = input_builder.finish(); + + let mut expected_builder = StringBuilder::new(); + expected_builder.append_value("31"); + expected_builder.append_value("6A"); + expected_builder.append_null(); + expected_builder.append_value("33"); + let expected = expected_builder.finish(); + + let columnar_value = ColumnarValue::Array(Arc::new(input)); + let result = super::spark_hex(&[columnar_value]).unwrap(); + + let result = match result { + ColumnarValue::Array(array) => array, + _ => panic!("Expected array"), + }; + + let result = as_string_array(&result); + + assert_eq!(result, &expected); + } + + #[test] + fn test_hex_int64() { + let num = 1234; + let hexed = super::hex_int64(num); + assert_eq!(hexed, "4D2".to_string()); + + let num = -1; + let hexed = super::hex_int64(num); + assert_eq!(hexed, "FFFFFFFFFFFFFFFF".to_string()); + } + + #[test] + fn test_spark_hex_int64() { + let int_array = Int64Array::from(vec![Some(1), Some(2), None, Some(3)]); + let columnar_value = ColumnarValue::Array(Arc::new(int_array)); + + let result = super::spark_hex(&[columnar_value]).unwrap(); + let result = match result { + ColumnarValue::Array(array) => array, + _ => panic!("Expected array"), + }; + + let string_array = as_string_array(&result); + let expected_array = StringArray::from(vec![ + Some("1".to_string()), + Some("2".to_string()), + None, + Some("3".to_string()), + ]); + + assert_eq!(string_array, &expected_array); + } +} diff --git a/datafusion/spark/src/function/math/mod.rs b/datafusion/spark/src/function/math/mod.rs new file mode 100644 index 000000000000..80bcdc39a41d --- /dev/null +++ b/datafusion/spark/src/function/math/mod.rs @@ -0,0 +1,37 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub mod expm1; +pub mod hex; + +use datafusion_expr::ScalarUDF; +use datafusion_functions::make_udf_function; +use std::sync::Arc; + +make_udf_function!(expm1::SparkExpm1, expm1); +make_udf_function!(hex::SparkHex, hex); + +pub mod expr_fn { + use datafusion_functions::export_functions; + + export_functions!((expm1, "Returns exp(expr) - 1 as a Float64.", arg1)); + export_functions!((hex, "Computes hex value of the given column.", arg1)); +} + +pub fn functions() -> Vec> { + vec![expm1(), hex()] +} diff --git a/datafusion/spark/src/function/misc/mod.rs b/datafusion/spark/src/function/misc/mod.rs new file mode 100644 index 000000000000..a87df9a2c87a --- /dev/null +++ b/datafusion/spark/src/function/misc/mod.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_expr::ScalarUDF; +use std::sync::Arc; + +pub mod expr_fn {} + +pub fn functions() -> Vec> { + vec![] +} diff --git a/datafusion/spark/src/function/mod.rs b/datafusion/spark/src/function/mod.rs new file mode 100644 index 000000000000..dfdd94a040a9 --- /dev/null +++ b/datafusion/spark/src/function/mod.rs @@ -0,0 +1,41 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub mod aggregate; +pub mod array; +pub mod bitwise; +pub mod collection; +pub mod conditional; +pub mod conversion; +pub mod csv; +pub mod datetime; +pub mod error_utils; +pub mod generator; +pub mod hash; +pub mod json; +pub mod lambda; +pub mod map; +pub mod math; +pub mod misc; +pub mod predicate; +pub mod string; +pub mod r#struct; +pub mod table; +pub mod url; +pub mod utils; +pub mod window; +pub mod xml; diff --git a/datafusion/spark/src/function/predicate/mod.rs b/datafusion/spark/src/function/predicate/mod.rs new file mode 100644 index 000000000000..a87df9a2c87a --- /dev/null +++ b/datafusion/spark/src/function/predicate/mod.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_expr::ScalarUDF; +use std::sync::Arc; + +pub mod expr_fn {} + +pub fn functions() -> Vec> { + vec![] +} diff --git a/datafusion/spark/src/function/string/ascii.rs b/datafusion/spark/src/function/string/ascii.rs new file mode 100644 index 000000000000..c05aa214ccc0 --- /dev/null +++ b/datafusion/spark/src/function/string/ascii.rs @@ -0,0 +1,174 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayAccessor, ArrayIter, ArrayRef, AsArray, Int32Array}; +use arrow::datatypes::DataType; +use arrow::error::ArrowError; +use datafusion_common::{internal_err, plan_err, Result}; +use datafusion_expr::ColumnarValue; +use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; +use datafusion_functions::utils::make_scalar_function; +use std::any::Any; +use std::sync::Arc; + +/// +#[derive(Debug)] +pub struct SparkAscii { + signature: Signature, + aliases: Vec, +} + +impl Default for SparkAscii { + fn default() -> Self { + Self::new() + } +} + +impl SparkAscii { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + aliases: vec![], + } + } +} + +impl ScalarUDFImpl for SparkAscii { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "ascii" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Int32) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(ascii, vec![])(&args.args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + if arg_types.len() != 1 { + return plan_err!( + "The {} function requires 1 argument, but got {}.", + self.name(), + arg_types.len() + ); + } + Ok(vec![DataType::Utf8]) + } +} + +fn calculate_ascii<'a, V>(array: V) -> Result +where + V: ArrayAccessor, +{ + let iter = ArrayIter::new(array); + let result = iter + .map(|string| { + string.map(|s| { + let mut chars = s.chars(); + chars.next().map_or(0, |v| v as i32) + }) + }) + .collect::(); + + Ok(Arc::new(result) as ArrayRef) +} + +/// Returns the numeric code of the first character of the argument. +pub fn ascii(args: &[ArrayRef]) -> Result { + match args[0].data_type() { + DataType::Utf8 => { + let string_array = args[0].as_string::(); + Ok(calculate_ascii(string_array)?) + } + DataType::LargeUtf8 => { + let string_array = args[0].as_string::(); + Ok(calculate_ascii(string_array)?) + } + DataType::Utf8View => { + let string_array = args[0].as_string_view(); + Ok(calculate_ascii(string_array)?) + } + _ => internal_err!("Unsupported data type"), + } +} + +#[cfg(test)] +mod tests { + use crate::function::string::ascii::SparkAscii; + use crate::function::utils::test::test_scalar_function; + use arrow::array::{Array, Int32Array}; + use arrow::datatypes::DataType::Int32; + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + macro_rules! test_ascii_string_invoke { + ($INPUT:expr, $EXPECTED:expr) => { + test_scalar_function!( + SparkAscii::new(), + vec![ColumnarValue::Scalar(ScalarValue::Utf8($INPUT))], + $EXPECTED, + i32, + Int32, + Int32Array + ); + + test_scalar_function!( + SparkAscii::new(), + vec![ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT))], + $EXPECTED, + i32, + Int32, + Int32Array + ); + + test_scalar_function!( + SparkAscii::new(), + vec![ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT))], + $EXPECTED, + i32, + Int32, + Int32Array + ); + }; + } + + #[test] + fn test_ascii_invoke() -> Result<()> { + test_ascii_string_invoke!(Some(String::from("x")), Ok(Some(120))); + test_ascii_string_invoke!(Some(String::from("a")), Ok(Some(97))); + test_ascii_string_invoke!(Some(String::from("")), Ok(Some(0))); + test_ascii_string_invoke!(Some(String::from("\n")), Ok(Some(10))); + test_ascii_string_invoke!(Some(String::from("\t")), Ok(Some(9))); + test_ascii_string_invoke!(None, Ok(None)); + + Ok(()) + } +} diff --git a/datafusion/spark/src/function/string/char.rs b/datafusion/spark/src/function/string/char.rs new file mode 100644 index 000000000000..dd6cdc83b30d --- /dev/null +++ b/datafusion/spark/src/function/string/char.rs @@ -0,0 +1,130 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{any::Any, sync::Arc}; + +use arrow::{ + array::{ArrayRef, StringArray}, + datatypes::{ + DataType, + DataType::{Int64, Utf8}, + }, +}; + +use datafusion_common::{cast::as_int64_array, exec_err, Result, ScalarValue}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; + +/// Spark-compatible `char` expression +/// +#[derive(Debug)] +pub struct SparkChar { + signature: Signature, +} + +impl Default for SparkChar { + fn default() -> Self { + Self::new() + } +} + +impl SparkChar { + pub fn new() -> Self { + Self { + signature: Signature::uniform(1, vec![Int64], Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for SparkChar { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "char" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + spark_chr(&args.args) + } +} + +/// Returns the ASCII character having the binary equivalent to the input expression. +/// E.g., chr(65) = 'A'. +/// Compatible with Apache Spark's Chr function +fn spark_chr(args: &[ColumnarValue]) -> Result { + let array = args[0].clone(); + match array { + ColumnarValue::Array(array) => { + let array = chr(&[array])?; + Ok(ColumnarValue::Array(array)) + } + ColumnarValue::Scalar(ScalarValue::Int64(Some(value))) => { + if value < 0 { + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some( + "".to_string(), + )))) + } else { + match core::char::from_u32((value % 256) as u32) { + Some(ch) => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some( + ch.to_string(), + )))), + None => { + exec_err!("requested character was incompatible for encoding.") + } + } + } + } + _ => exec_err!("The argument must be an Int64 array or scalar."), + } +} + +fn chr(args: &[ArrayRef]) -> Result { + let integer_array = as_int64_array(&args[0])?; + + // first map is the iterator, second is for the `Option<_>` + let result = integer_array + .iter() + .map(|integer: Option| { + integer + .map(|integer| { + if integer < 0 { + return Ok("".to_string()); // Return empty string for negative integers + } + match core::char::from_u32((integer % 256) as u32) { + Some(ch) => Ok(ch.to_string()), + None => { + exec_err!("requested character not compatible for encoding.") + } + } + }) + .transpose() + }) + .collect::>()?; + + Ok(Arc::new(result) as ArrayRef) +} diff --git a/datafusion/spark/src/function/string/mod.rs b/datafusion/spark/src/function/string/mod.rs new file mode 100644 index 000000000000..9d5fabe832e9 --- /dev/null +++ b/datafusion/spark/src/function/string/mod.rs @@ -0,0 +1,45 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub mod ascii; +pub mod char; + +use datafusion_expr::ScalarUDF; +use datafusion_functions::make_udf_function; +use std::sync::Arc; + +make_udf_function!(ascii::SparkAscii, ascii); +make_udf_function!(char::SparkChar, char); + +pub mod expr_fn { + use datafusion_functions::export_functions; + + export_functions!(( + ascii, + "Returns the ASCII code point of the first character of string.", + arg1 + )); + export_functions!(( + char, + "Returns the ASCII character having the binary equivalent to col. If col is larger than 256 the result is equivalent to char(col % 256).", + arg1 + )); +} + +pub fn functions() -> Vec> { + vec![ascii(), char()] +} diff --git a/datafusion/spark/src/function/struct/mod.rs b/datafusion/spark/src/function/struct/mod.rs new file mode 100644 index 000000000000..a87df9a2c87a --- /dev/null +++ b/datafusion/spark/src/function/struct/mod.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_expr::ScalarUDF; +use std::sync::Arc; + +pub mod expr_fn {} + +pub fn functions() -> Vec> { + vec![] +} diff --git a/datafusion/spark/src/function/table/mod.rs b/datafusion/spark/src/function/table/mod.rs new file mode 100644 index 000000000000..aba7b7ceb78e --- /dev/null +++ b/datafusion/spark/src/function/table/mod.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_catalog::TableFunction; +use std::sync::Arc; + +pub mod expr_fn {} + +pub fn functions() -> Vec> { + vec![] +} diff --git a/datafusion/spark/src/function/url/mod.rs b/datafusion/spark/src/function/url/mod.rs new file mode 100644 index 000000000000..a87df9a2c87a --- /dev/null +++ b/datafusion/spark/src/function/url/mod.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_expr::ScalarUDF; +use std::sync::Arc; + +pub mod expr_fn {} + +pub fn functions() -> Vec> { + vec![] +} diff --git a/datafusion/spark/src/function/utils.rs b/datafusion/spark/src/function/utils.rs new file mode 100644 index 000000000000..85af4bb927ca --- /dev/null +++ b/datafusion/spark/src/function/utils.rs @@ -0,0 +1,117 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#[cfg(test)] +pub mod test { + /// $FUNC ScalarUDFImpl to test + /// $ARGS arguments (vec) to pass to function + /// $EXPECTED a Result + /// $EXPECTED_TYPE is the expected value type + /// $EXPECTED_DATA_TYPE is the expected result type + /// $ARRAY_TYPE is the column type after function applied + macro_rules! test_scalar_function { + ($FUNC:expr, $ARGS:expr, $EXPECTED:expr, $EXPECTED_TYPE:ty, $EXPECTED_DATA_TYPE:expr, $ARRAY_TYPE:ident) => { + let expected: datafusion_common::Result> = $EXPECTED; + let func = $FUNC; + + let arg_fields: Vec = $ARGS + .iter() + .enumerate() + .map(|(idx, arg)| { + + let nullable = match arg { + datafusion_expr::ColumnarValue::Scalar(scalar) => scalar.is_null(), + datafusion_expr::ColumnarValue::Array(a) => a.null_count() > 0, + }; + + std::sync::Arc::new(arrow::datatypes::Field::new(format!("arg_{idx}"), arg.data_type(), nullable)) + }) + .collect::>(); + + let cardinality = $ARGS + .iter() + .fold(Option::::None, |acc, arg| match arg { + datafusion_expr::ColumnarValue::Scalar(_) => acc, + datafusion_expr::ColumnarValue::Array(a) => Some(a.len()), + }) + .unwrap_or(1); + + let scalar_arguments = $ARGS.iter().map(|arg| match arg { + datafusion_expr::ColumnarValue::Scalar(scalar) => Some(scalar.clone()), + datafusion_expr::ColumnarValue::Array(_) => None, + }).collect::>(); + let scalar_arguments_refs = scalar_arguments.iter().map(|arg| arg.as_ref()).collect::>(); + + + let return_field = func.return_field_from_args(datafusion_expr::ReturnFieldArgs { + arg_fields: &arg_fields, + scalar_arguments: &scalar_arguments_refs + }); + + match expected { + Ok(expected) => { + let return_field = return_field.unwrap(); + assert_eq!(return_field.data_type(), &$EXPECTED_DATA_TYPE); + + let result = func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{ + args: $ARGS, + number_rows: cardinality, + return_field, + arg_fields: arg_fields.clone(), + }); + assert_eq!(result.is_ok(), true, "function returned an error: {}", result.unwrap_err()); + + let result = result.unwrap().to_array(cardinality).expect("Failed to convert to array"); + let result = result.as_any().downcast_ref::<$ARRAY_TYPE>().expect("Failed to convert to type"); + assert_eq!(result.data_type(), &$EXPECTED_DATA_TYPE); + + // value is correct + match expected { + Some(v) => assert_eq!(result.value(0), v), + None => assert!(result.is_null(0)), + }; + } + Err(expected_error) => { + if return_field.is_err() { + match return_field { + Ok(_) => assert!(false, "expected error"), + Err(error) => { datafusion_common::assert_contains!(expected_error.strip_backtrace(), error.strip_backtrace()); } + } + } + else { + let return_field = return_field.unwrap(); + + // invoke is expected error - cannot use .expect_err() due to Debug not being implemented + match func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{ + args: $ARGS, + number_rows: cardinality, + return_field, + arg_fields, + }) { + Ok(_) => assert!(false, "expected error"), + Err(error) => { + assert!(expected_error.strip_backtrace().starts_with(&error.strip_backtrace())); + } + } + } + } + }; + }; + } + + pub(crate) use test_scalar_function; +} diff --git a/datafusion/spark/src/function/window/mod.rs b/datafusion/spark/src/function/window/mod.rs new file mode 100644 index 000000000000..97ab4a9e3542 --- /dev/null +++ b/datafusion/spark/src/function/window/mod.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_expr::WindowUDF; +use std::sync::Arc; + +pub mod expr_fn {} + +pub fn functions() -> Vec> { + vec![] +} diff --git a/datafusion/spark/src/function/xml/mod.rs b/datafusion/spark/src/function/xml/mod.rs new file mode 100644 index 000000000000..a87df9a2c87a --- /dev/null +++ b/datafusion/spark/src/function/xml/mod.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_expr::ScalarUDF; +use std::sync::Arc; + +pub mod expr_fn {} + +pub fn functions() -> Vec> { + vec![] +} diff --git a/datafusion/spark/src/lib.rs b/datafusion/spark/src/lib.rs new file mode 100644 index 000000000000..1fe5b6ecac8f --- /dev/null +++ b/datafusion/spark/src/lib.rs @@ -0,0 +1,150 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#![doc( + html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", + html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" +)] +#![cfg_attr(docsrs, feature(doc_auto_cfg))] +// Make cheap clones clear: https://github.com/apache/datafusion/issues/11143 +#![deny(clippy::clone_on_ref_ptr)] + +//! Spark Expression packages for [DataFusion]. +//! +//! This crate contains a collection of various Spark expression packages for DataFusion, +//! implemented using the extension API. +//! +//! [DataFusion]: https://crates.io/crates/datafusion +//! +//! # Available Packages +//! See the list of [modules](#modules) in this crate for available packages. +//! +//! # Using A Package +//! You can register all functions in all packages using the [`register_all`] function. +//! +//! Each package also exports an `expr_fn` submodule to help create [`Expr`]s that invoke +//! functions using a fluent style. For example: +//! +//![`Expr`]: datafusion_expr::Expr + +pub mod function; + +use datafusion_catalog::TableFunction; +use datafusion_common::Result; +use datafusion_execution::FunctionRegistry; +use datafusion_expr::{AggregateUDF, ScalarUDF, WindowUDF}; +use log::debug; +use std::sync::Arc; + +/// Fluent-style API for creating `Expr`s +#[allow(unused)] +pub mod expr_fn { + pub use super::function::aggregate::expr_fn::*; + pub use super::function::array::expr_fn::*; + pub use super::function::bitwise::expr_fn::*; + pub use super::function::collection::expr_fn::*; + pub use super::function::conditional::expr_fn::*; + pub use super::function::conversion::expr_fn::*; + pub use super::function::csv::expr_fn::*; + pub use super::function::datetime::expr_fn::*; + pub use super::function::generator::expr_fn::*; + pub use super::function::hash::expr_fn::*; + pub use super::function::json::expr_fn::*; + pub use super::function::lambda::expr_fn::*; + pub use super::function::map::expr_fn::*; + pub use super::function::math::expr_fn::*; + pub use super::function::misc::expr_fn::*; + pub use super::function::predicate::expr_fn::*; + pub use super::function::r#struct::expr_fn::*; + pub use super::function::string::expr_fn::*; + pub use super::function::table::expr_fn::*; + pub use super::function::url::expr_fn::*; + pub use super::function::window::expr_fn::*; + pub use super::function::xml::expr_fn::*; +} + +/// Returns all default scalar functions +pub fn all_default_scalar_functions() -> Vec> { + function::array::functions() + .into_iter() + .chain(function::bitwise::functions()) + .chain(function::collection::functions()) + .chain(function::conditional::functions()) + .chain(function::conversion::functions()) + .chain(function::csv::functions()) + .chain(function::datetime::functions()) + .chain(function::generator::functions()) + .chain(function::hash::functions()) + .chain(function::json::functions()) + .chain(function::lambda::functions()) + .chain(function::map::functions()) + .chain(function::math::functions()) + .chain(function::misc::functions()) + .chain(function::predicate::functions()) + .chain(function::string::functions()) + .chain(function::r#struct::functions()) + .chain(function::url::functions()) + .chain(function::xml::functions()) + .collect::>() +} + +/// Returns all default aggregate functions +pub fn all_default_aggregate_functions() -> Vec> { + function::aggregate::functions() +} + +/// Returns all default window functions +pub fn all_default_window_functions() -> Vec> { + function::window::functions() +} + +/// Returns all default table functions +pub fn all_default_table_functions() -> Vec> { + function::table::functions() +} + +/// Registers all enabled packages with a [`FunctionRegistry`] +pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { + let scalar_functions: Vec> = all_default_scalar_functions(); + scalar_functions.into_iter().try_for_each(|udf| { + let existing_udf = registry.register_udf(udf)?; + if let Some(existing_udf) = existing_udf { + debug!("Overwrite existing UDF: {}", existing_udf.name()); + } + Ok(()) as Result<()> + })?; + + let aggregate_functions: Vec> = all_default_aggregate_functions(); + aggregate_functions.into_iter().try_for_each(|udf| { + let existing_udaf = registry.register_udaf(udf)?; + if let Some(existing_udaf) = existing_udaf { + debug!("Overwrite existing UDAF: {}", existing_udaf.name()); + } + Ok(()) as Result<()> + })?; + + let window_functions: Vec> = all_default_window_functions(); + window_functions.into_iter().try_for_each(|udf| { + let existing_udwf = registry.register_udwf(udf)?; + if let Some(existing_udwf) = existing_udwf { + debug!("Overwrite existing UDWF: {}", existing_udwf.name()); + } + Ok(()) as Result<()> + })?; + + Ok(()) +} diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 436f4388d8a3..071e95940035 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -74,7 +74,7 @@ fn find_closest_match(candidates: Vec, target: &str) -> Option { }) } -/// Arguments to for a function call extracted from the SQL AST +/// Arguments for a function call extracted from the SQL AST #[derive(Debug)] struct FunctionArgs { /// Function name @@ -91,6 +91,8 @@ struct FunctionArgs { null_treatment: Option, /// DISTINCT distinct: bool, + /// WITHIN GROUP clause, if any + within_group: Vec, } impl FunctionArgs { @@ -115,6 +117,7 @@ impl FunctionArgs { filter, null_treatment, distinct: false, + within_group, }); }; @@ -144,6 +147,9 @@ impl FunctionArgs { } FunctionArgumentClause::OrderBy(oby) => { if order_by.is_some() { + if !within_group.is_empty() { + return plan_err!("ORDER BY clause is only permitted in WITHIN GROUP clause when a WITHIN GROUP is used"); + } return not_impl_err!("Calling {name}: Duplicated ORDER BY clause in function arguments"); } order_by = Some(oby); @@ -176,8 +182,10 @@ impl FunctionArgs { } } - if !within_group.is_empty() { - return not_impl_err!("WITHIN GROUP is not supported yet: {within_group:?}"); + if within_group.len() > 1 { + return not_impl_err!( + "Only a single ordering expression is permitted in a WITHIN GROUP clause" + ); } let order_by = order_by.unwrap_or_default(); @@ -190,6 +198,7 @@ impl FunctionArgs { filter, null_treatment, distinct, + within_group, }) } } @@ -210,8 +219,14 @@ impl SqlToRel<'_, S> { filter, null_treatment, distinct, + within_group, } = function_args; + if over.is_some() && !within_group.is_empty() { + return plan_err!("OVER and WITHIN GROUP clause are can not be used together. \ + OVER is for window function, whereas WITHIN GROUP is for ordered set aggregate function"); + } + // If function is a window function (it has an OVER clause), // it shouldn't have ordering requirement as function argument // required ordering should be defined in OVER clause. @@ -346,7 +361,7 @@ impl SqlToRel<'_, S> { null_treatment, } = window_expr; - return Expr::WindowFunction(expr::WindowFunction::new(func_def, args)) + return Expr::from(expr::WindowFunction::new(func_def, args)) .partition_by(partition_by) .order_by(order_by) .window_frame(window_frame) @@ -356,15 +371,49 @@ impl SqlToRel<'_, S> { } else { // User defined aggregate functions (UDAF) have precedence in case it has the same name as a scalar built-in function if let Some(fm) = self.context_provider.get_aggregate_meta(&name) { - let order_by = self.order_by_to_sort_expr( - order_by, - schema, - planner_context, - true, - None, - )?; - let order_by = (!order_by.is_empty()).then_some(order_by); - let args = self.function_args_to_expr(args, schema, planner_context)?; + if fm.is_ordered_set_aggregate() && within_group.is_empty() { + return plan_err!("WITHIN GROUP clause is required when calling ordered set aggregate function({})", fm.name()); + } + + if null_treatment.is_some() && !fm.supports_null_handling_clause() { + return plan_err!( + "[IGNORE | RESPECT] NULLS are not permitted for {}", + fm.name() + ); + } + + let mut args = + self.function_args_to_expr(args, schema, planner_context)?; + + let order_by = if fm.is_ordered_set_aggregate() { + let within_group = self.order_by_to_sort_expr( + within_group, + schema, + planner_context, + false, + None, + )?; + + // add target column expression in within group clause to function arguments + if !within_group.is_empty() { + args = within_group + .iter() + .map(|sort| sort.expr.clone()) + .chain(args) + .collect::>(); + } + (!within_group.is_empty()).then_some(within_group) + } else { + let order_by = self.order_by_to_sort_expr( + order_by, + schema, + planner_context, + true, + None, + )?; + (!order_by.is_empty()).then_some(order_by) + }; + let filter: Option> = filter .map(|e| self.sql_expr_to_logical_expr(*e, schema, planner_context)) .transpose()? @@ -408,17 +457,12 @@ impl SqlToRel<'_, S> { if let Some(suggested_func_name) = suggest_valid_function(&name, is_function_window, self.context_provider) { - plan_err!("Invalid function '{name}'.\nDid you mean '{suggested_func_name}'?") - .map_err(|e| { - let span = Span::try_from_sqlparser_span(sql_parser_span); - let mut diagnostic = - Diagnostic::new_error(format!("Invalid function '{name}'"), span); - diagnostic.add_note( - format!("Possible function '{}'", suggested_func_name), - None, - ); - e.with_diagnostic(diagnostic) - }) + let span = Span::try_from_sqlparser_span(sql_parser_span); + let mut diagnostic = + Diagnostic::new_error(format!("Invalid function '{name}'"), span); + diagnostic + .add_note(format!("Possible function '{suggested_func_name}'"), None); + plan_err!("Invalid function '{name}'.\nDid you mean '{suggested_func_name}'?"; diagnostic=diagnostic) } else { internal_err!("No functions registered with this context.") } diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index d29ccdc6a7e9..e92869873731 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -215,7 +215,7 @@ impl SqlToRel<'_, S> { } SQLExpr::Extract { field, expr, .. } => { let mut extract_args = vec![ - Expr::Literal(ScalarValue::from(format!("{field}"))), + Expr::Literal(ScalarValue::from(format!("{field}")), None), self.sql_expr_to_logical_expr(*expr, schema, planner_context)?, ]; @@ -644,7 +644,9 @@ impl SqlToRel<'_, S> { values: Vec, ) -> Result { match values.first() { - Some(SQLExpr::Identifier(_)) | Some(SQLExpr::Value(_)) => { + Some(SQLExpr::Identifier(_)) + | Some(SQLExpr::Value(_)) + | Some(SQLExpr::CompoundIdentifier(_)) => { self.parse_struct(schema, planner_context, values, vec![]) } None => not_impl_err!("Empty tuple not supported yet"), diff --git a/datafusion/sql/src/expr/order_by.rs b/datafusion/sql/src/expr/order_by.rs index cce3f3004809..d357c3753e13 100644 --- a/datafusion/sql/src/expr/order_by.rs +++ b/datafusion/sql/src/expr/order_by.rs @@ -41,13 +41,13 @@ impl SqlToRel<'_, S> { /// If false, interpret numeric literals as constant values. pub(crate) fn order_by_to_sort_expr( &self, - exprs: Vec, + order_by_exprs: Vec, input_schema: &DFSchema, planner_context: &mut PlannerContext, literal_to_column: bool, additional_schema: Option<&DFSchema>, ) -> Result> { - if exprs.is_empty() { + if order_by_exprs.is_empty() { return Ok(vec![]); } @@ -61,13 +61,23 @@ impl SqlToRel<'_, S> { None => input_schema, }; - let mut expr_vec = vec![]; - for e in exprs { + let mut sort_expr_vec = Vec::with_capacity(order_by_exprs.len()); + + let make_sort_expr = + |expr: Expr, asc: Option, nulls_first: Option| { + let asc = asc.unwrap_or(true); + // When asc is true, by default nulls last to be consistent with postgres + // postgres rule: https://www.postgresql.org/docs/current/queries-order.html + let nulls_first = nulls_first.unwrap_or(!asc); + Sort::new(expr, asc, nulls_first) + }; + + for order_by_expr in order_by_exprs { let OrderByExpr { expr, options: OrderByOptions { asc, nulls_first }, with_fill, - } = e; + } = order_by_expr; if let Some(with_fill) = with_fill { return not_impl_err!("ORDER BY WITH FILL is not supported: {with_fill}"); @@ -102,15 +112,9 @@ impl SqlToRel<'_, S> { self.sql_expr_to_logical_expr(e, order_by_schema, planner_context)? } }; - let asc = asc.unwrap_or(true); - expr_vec.push(Sort::new( - expr, - asc, - // When asc is true, by default nulls last to be consistent with postgres - // postgres rule: https://www.postgresql.org/docs/current/queries-order.html - nulls_first.unwrap_or(!asc), - )) + sort_expr_vec.push(make_sort_expr(expr, asc, nulls_first)); } - Ok(expr_vec) + + Ok(sort_expr_vec) } } diff --git a/datafusion/sql/src/expr/subquery.rs b/datafusion/sql/src/expr/subquery.rs index 225c5d74c2ab..602d39233d58 100644 --- a/datafusion/sql/src/expr/subquery.rs +++ b/datafusion/sql/src/expr/subquery.rs @@ -138,15 +138,9 @@ impl SqlToRel<'_, S> { if sub_plan.schema().fields().len() > 1 { let sub_schema = sub_plan.schema(); let field_names = sub_schema.field_names(); - - plan_err!("{}: {}", error_message, field_names.join(", ")).map_err(|err| { - let diagnostic = self.build_multi_column_diagnostic( - spans, - error_message, - help_message, - ); - err.with_diagnostic(diagnostic) - }) + let diagnostic = + self.build_multi_column_diagnostic(spans, error_message, help_message); + plan_err!("{}: {}", error_message, field_names.join(", "); diagnostic=diagnostic) } else { Ok(()) } diff --git a/datafusion/sql/src/expr/substring.rs b/datafusion/sql/src/expr/substring.rs index 59c78bc713cc..8f6e77e035c1 100644 --- a/datafusion/sql/src/expr/substring.rs +++ b/datafusion/sql/src/expr/substring.rs @@ -51,7 +51,7 @@ impl SqlToRel<'_, S> { (None, Some(for_expr)) => { let arg = self.sql_expr_to_logical_expr(*expr, schema, planner_context)?; - let from_logic = Expr::Literal(ScalarValue::Int64(Some(1))); + let from_logic = Expr::Literal(ScalarValue::Int64(Some(1)), None); let for_logic = self.sql_expr_to_logical_expr(*for_expr, schema, planner_context)?; vec![arg, from_logic, for_logic] diff --git a/datafusion/sql/src/expr/unary_op.rs b/datafusion/sql/src/expr/unary_op.rs index 626b79d6c3b6..e0c94543f601 100644 --- a/datafusion/sql/src/expr/unary_op.rs +++ b/datafusion/sql/src/expr/unary_op.rs @@ -45,16 +45,18 @@ impl SqlToRel<'_, S> { { Ok(operand) } else { - plan_err!("Unary operator '+' only supports numeric, interval and timestamp types").map_err(|e| { - let span = operand.spans().and_then(|s| s.first()); - let mut diagnostic = Diagnostic::new_error( - format!("+ cannot be used with {data_type}"), - span - ); - diagnostic.add_note("+ can only be used with numbers, intervals, and timestamps", None); - diagnostic.add_help(format!("perhaps you need to cast {operand}"), None); - e.with_diagnostic(diagnostic) - }) + let span = operand.spans().and_then(|s| s.first()); + let mut diagnostic = Diagnostic::new_error( + format!("+ cannot be used with {data_type}"), + span, + ); + diagnostic.add_note( + "+ can only be used with numbers, intervals, and timestamps", + None, + ); + diagnostic + .add_help(format!("perhaps you need to cast {operand}"), None); + plan_err!("Unary operator '+' only supports numeric, interval and timestamp types"; diagnostic=diagnostic) } } UnaryOperator::Minus => { diff --git a/datafusion/sql/src/expr/value.rs b/datafusion/sql/src/expr/value.rs index be4a45a25750..7075a1afd9dd 100644 --- a/datafusion/sql/src/expr/value.rs +++ b/datafusion/sql/src/expr/value.rs @@ -50,7 +50,7 @@ impl SqlToRel<'_, S> { match value { Value::Number(n, _) => self.parse_sql_number(&n, false), Value::SingleQuotedString(s) | Value::DoubleQuotedString(s) => Ok(lit(s)), - Value::Null => Ok(Expr::Literal(ScalarValue::Null)), + Value::Null => Ok(Expr::Literal(ScalarValue::Null, None)), Value::Boolean(n) => Ok(lit(n)), Value::Placeholder(param) => { Self::create_placeholder_expr(param, param_data_types) @@ -131,10 +131,7 @@ impl SqlToRel<'_, S> { // Check if the placeholder is in the parameter list let param_type = param_data_types.get(idx); // Data type of the parameter - debug!( - "type of param {} param_data_types[idx]: {:?}", - param, param_type - ); + debug!("type of param {param} param_data_types[idx]: {param_type:?}"); Ok(Expr::Placeholder(Placeholder::new( param, @@ -383,11 +380,10 @@ fn parse_decimal(unsigned_number: &str, negative: bool) -> Result { int_val ) })?; - Ok(Expr::Literal(ScalarValue::Decimal128( - Some(val), - precision as u8, - scale as i8, - ))) + Ok(Expr::Literal( + ScalarValue::Decimal128(Some(val), precision as u8, scale as i8), + None, + )) } else if precision <= DECIMAL256_MAX_PRECISION as u64 { let val = bigint_to_i256(&int_val).ok_or_else(|| { // Failures are unexpected here as we have already checked the precision @@ -396,11 +392,10 @@ fn parse_decimal(unsigned_number: &str, negative: bool) -> Result { int_val ) })?; - Ok(Expr::Literal(ScalarValue::Decimal256( - Some(val), - precision as u8, - scale as i8, - ))) + Ok(Expr::Literal( + ScalarValue::Decimal256(Some(val), precision as u8, scale as i8), + None, + )) } else { not_impl_err!( "Decimal precision {} exceeds the maximum supported precision: {}", @@ -486,10 +481,13 @@ mod tests { ]; for (input, expect) in cases { let output = parse_decimal(input, true).unwrap(); - assert_eq!(output, Expr::Literal(expect.arithmetic_negate().unwrap())); + assert_eq!( + output, + Expr::Literal(expect.arithmetic_negate().unwrap(), None) + ); let output = parse_decimal(input, false).unwrap(); - assert_eq!(output, Expr::Literal(expect)); + assert_eq!(output, Expr::Literal(expect, None)); } // scale < i8::MIN diff --git a/datafusion/sql/src/parser.rs b/datafusion/sql/src/parser.rs index 822b651eae86..9731eebad167 100644 --- a/datafusion/sql/src/parser.rs +++ b/datafusion/sql/src/parser.rs @@ -20,9 +20,9 @@ //! This parser implements DataFusion specific statements such as //! `CREATE EXTERNAL TABLE` -use std::collections::VecDeque; -use std::fmt; - +use datafusion_common::config::SqlParserOptions; +use datafusion_common::DataFusionError; +use datafusion_common::{sql_err, Diagnostic, Span}; use sqlparser::ast::{ExprWithAlias, OrderByOptions}; use sqlparser::tokenizer::TokenWithSpan; use sqlparser::{ @@ -34,15 +34,22 @@ use sqlparser::{ parser::{Parser, ParserError}, tokenizer::{Token, Tokenizer, Word}, }; +use std::collections::VecDeque; +use std::fmt; // Use `Parser::expected` instead, if possible macro_rules! parser_err { - ($MSG:expr) => { - Err(ParserError::ParserError($MSG.to_string())) - }; + ($MSG:expr $(; diagnostic = $DIAG:expr)?) => {{ + + let err = DataFusionError::from(ParserError::ParserError($MSG.to_string())); + $( + let err = err.with_diagnostic($DIAG); + )? + Err(err) + }}; } -fn parse_file_type(s: &str) -> Result { +fn parse_file_type(s: &str) -> Result { Ok(s.to_uppercase()) } @@ -140,7 +147,7 @@ impl fmt::Display for CopyToStatement { write!(f, "COPY {source} TO {target}")?; if let Some(file_type) = stored_as { - write!(f, " STORED AS {}", file_type)?; + write!(f, " STORED AS {file_type}")?; } if !partitioned_by.is_empty() { write!(f, " PARTITIONED BY ({})", partitioned_by.join(", "))?; @@ -266,11 +273,9 @@ impl fmt::Display for Statement { } } -fn ensure_not_set(field: &Option, name: &str) -> Result<(), ParserError> { +fn ensure_not_set(field: &Option, name: &str) -> Result<(), DataFusionError> { if field.is_some() { - return Err(ParserError::ParserError(format!( - "{name} specified more than once", - ))); + parser_err!(format!("{name} specified more than once",))? } Ok(()) } @@ -285,6 +290,7 @@ fn ensure_not_set(field: &Option, name: &str) -> Result<(), ParserError> { /// [`Statement`] for a list of this special syntax pub struct DFParser<'a> { pub parser: Parser<'a>, + options: SqlParserOptions, } /// Same as `sqlparser` @@ -356,21 +362,28 @@ impl<'a> DFParserBuilder<'a> { self } - pub fn build(self) -> Result, ParserError> { + pub fn build(self) -> Result, DataFusionError> { let mut tokenizer = Tokenizer::new(self.dialect, self.sql); - let tokens = tokenizer.tokenize_with_location()?; + // Convert TokenizerError -> ParserError + let tokens = tokenizer + .tokenize_with_location() + .map_err(ParserError::from)?; Ok(DFParser { parser: Parser::new(self.dialect) .with_tokens_with_locations(tokens) .with_recursion_limit(self.recursion_limit), + options: SqlParserOptions { + recursion_limit: self.recursion_limit, + ..Default::default() + }, }) } } impl<'a> DFParser<'a> { #[deprecated(since = "46.0.0", note = "DFParserBuilder")] - pub fn new(sql: &'a str) -> Result { + pub fn new(sql: &'a str) -> Result { DFParserBuilder::new(sql).build() } @@ -378,13 +391,13 @@ impl<'a> DFParser<'a> { pub fn new_with_dialect( sql: &'a str, dialect: &'a dyn Dialect, - ) -> Result { + ) -> Result { DFParserBuilder::new(sql).with_dialect(dialect).build() } /// Parse a sql string into one or [`Statement`]s using the /// [`GenericDialect`]. - pub fn parse_sql(sql: &'a str) -> Result, ParserError> { + pub fn parse_sql(sql: &'a str) -> Result, DataFusionError> { let mut parser = DFParserBuilder::new(sql).build()?; parser.parse_statements() @@ -395,7 +408,7 @@ impl<'a> DFParser<'a> { pub fn parse_sql_with_dialect( sql: &str, dialect: &dyn Dialect, - ) -> Result, ParserError> { + ) -> Result, DataFusionError> { let mut parser = DFParserBuilder::new(sql).with_dialect(dialect).build()?; parser.parse_statements() } @@ -403,14 +416,14 @@ impl<'a> DFParser<'a> { pub fn parse_sql_into_expr_with_dialect( sql: &str, dialect: &dyn Dialect, - ) -> Result { + ) -> Result { let mut parser = DFParserBuilder::new(sql).with_dialect(dialect).build()?; parser.parse_expr() } /// Parse a sql string into one or [`Statement`]s - pub fn parse_statements(&mut self) -> Result, ParserError> { + pub fn parse_statements(&mut self) -> Result, DataFusionError> { let mut stmts = VecDeque::new(); let mut expecting_statement_delimiter = false; loop { @@ -438,12 +451,22 @@ impl<'a> DFParser<'a> { &self, expected: &str, found: TokenWithSpan, - ) -> Result { - parser_err!(format!("Expected {expected}, found: {found}")) + ) -> Result { + let sql_parser_span = found.span; + let span = Span::try_from_sqlparser_span(sql_parser_span); + let diagnostic = Diagnostic::new_error( + format!("Expected: {expected}, found: {found}{}", found.span.start), + span, + ); + parser_err!( + format!("Expected: {expected}, found: {found}{}", found.span.start); + diagnostic= + diagnostic + ) } /// Parse a new expression - pub fn parse_statement(&mut self) -> Result { + pub fn parse_statement(&mut self) -> Result { match self.parser.peek_token().token { Token::Word(w) => { match w.keyword { @@ -455,9 +478,7 @@ impl<'a> DFParser<'a> { if let Token::Word(w) = self.parser.peek_nth_token(1).token { // use native parser for COPY INTO if w.keyword == Keyword::INTO { - return Ok(Statement::Statement(Box::from( - self.parser.parse_statement()?, - ))); + return self.parse_and_handle_statement(); } } self.parser.next_token(); // COPY @@ -469,36 +490,49 @@ impl<'a> DFParser<'a> { } _ => { // use sqlparser-rs parser - Ok(Statement::Statement(Box::from( - self.parser.parse_statement()?, - ))) + self.parse_and_handle_statement() } } } _ => { // use the native parser - Ok(Statement::Statement(Box::from( - self.parser.parse_statement()?, - ))) + self.parse_and_handle_statement() } } } - pub fn parse_expr(&mut self) -> Result { + pub fn parse_expr(&mut self) -> Result { if let Token::Word(w) = self.parser.peek_token().token { match w.keyword { Keyword::CREATE | Keyword::COPY | Keyword::EXPLAIN => { - return parser_err!("Unsupported command in expression"); + return parser_err!("Unsupported command in expression")?; } _ => {} } } - self.parser.parse_expr_with_alias() + Ok(self.parser.parse_expr_with_alias()?) + } + + /// Helper method to parse a statement and handle errors consistently, especially for recursion limits + fn parse_and_handle_statement(&mut self) -> Result { + self.parser + .parse_statement() + .map(|stmt| Statement::Statement(Box::from(stmt))) + .map_err(|e| match e { + ParserError::RecursionLimitExceeded => DataFusionError::SQL( + ParserError::RecursionLimitExceeded, + Some(format!( + " (current limit: {})", + self.options.recursion_limit + )), + ), + other => DataFusionError::SQL(other, None), + }) } /// Parse a SQL `COPY TO` statement - pub fn parse_copy(&mut self) -> Result { + pub fn parse_copy(&mut self) -> Result { // parse as a query let source = if self.parser.consume_token(&Token::LParen) { let query = self.parser.parse_query()?; @@ -541,7 +575,7 @@ impl<'a> DFParser<'a> { Keyword::WITH => { self.parser.expect_keyword(Keyword::HEADER)?; self.parser.expect_keyword(Keyword::ROW)?; - return parser_err!("WITH HEADER ROW clause is no longer in use. Please use the OPTIONS clause with 'format.has_header' set appropriately, e.g., OPTIONS ('format.has_header' 'true')"); + return parser_err!("WITH HEADER ROW clause is no longer in use. Please use the OPTIONS clause with 'format.has_header' set appropriately, e.g., OPTIONS ('format.has_header' 'true')")?; } Keyword::PARTITIONED => { self.parser.expect_keyword(Keyword::BY)?; @@ -561,17 +595,13 @@ impl<'a> DFParser<'a> { if token == Token::EOF || token == Token::SemiColon { break; } else { - return Err(ParserError::ParserError(format!( - "Unexpected token {token}" - ))); + return self.expected("end of statement or ;", token)?; } } } let Some(target) = builder.target else { - return Err(ParserError::ParserError( - "Missing TO clause in COPY statement".into(), - )); + return parser_err!("Missing TO clause in COPY statement")?; }; Ok(Statement::CopyTo(CopyToStatement { @@ -589,7 +619,7 @@ impl<'a> DFParser<'a> { /// because it allows keywords as well as other non words /// /// [`parse_literal_string`]: sqlparser::parser::Parser::parse_literal_string - pub fn parse_option_key(&mut self) -> Result { + pub fn parse_option_key(&mut self) -> Result { let next_token = self.parser.next_token(); match next_token.token { Token::Word(Word { value, .. }) => { @@ -602,7 +632,7 @@ impl<'a> DFParser<'a> { // Unquoted namespaced keys have to conform to the syntax // "[\.]*". If we have a key that breaks this // pattern, error out: - return self.parser.expected("key name", next_token); + return self.expected("key name", next_token); } } Ok(parts.join(".")) @@ -610,7 +640,7 @@ impl<'a> DFParser<'a> { Token::SingleQuotedString(s) => Ok(s), Token::DoubleQuotedString(s) => Ok(s), Token::EscapedStringLiteral(s) => Ok(s), - _ => self.parser.expected("key name", next_token), + _ => self.expected("key name", next_token), } } @@ -620,7 +650,7 @@ impl<'a> DFParser<'a> { /// word or keyword in this location. /// /// [`parse_value`]: sqlparser::parser::Parser::parse_value - pub fn parse_option_value(&mut self) -> Result { + pub fn parse_option_value(&mut self) -> Result { let next_token = self.parser.next_token(); match next_token.token { // e.g. things like "snappy" or "gzip" that may be keywords @@ -629,12 +659,12 @@ impl<'a> DFParser<'a> { Token::DoubleQuotedString(s) => Ok(Value::DoubleQuotedString(s)), Token::EscapedStringLiteral(s) => Ok(Value::EscapedStringLiteral(s)), Token::Number(n, l) => Ok(Value::Number(n, l)), - _ => self.parser.expected("string or numeric value", next_token), + _ => self.expected("string or numeric value", next_token), } } /// Parse a SQL `EXPLAIN` - pub fn parse_explain(&mut self) -> Result { + pub fn parse_explain(&mut self) -> Result { let analyze = self.parser.parse_keyword(Keyword::ANALYZE); let verbose = self.parser.parse_keyword(Keyword::VERBOSE); let format = self.parse_explain_format()?; @@ -649,7 +679,7 @@ impl<'a> DFParser<'a> { })) } - pub fn parse_explain_format(&mut self) -> Result, ParserError> { + pub fn parse_explain_format(&mut self) -> Result, DataFusionError> { if !self.parser.parse_keyword(Keyword::FORMAT) { return Ok(None); } @@ -659,15 +689,13 @@ impl<'a> DFParser<'a> { Token::Word(w) => Ok(w.value), Token::SingleQuotedString(w) => Ok(w), Token::DoubleQuotedString(w) => Ok(w), - _ => self - .parser - .expected("an explain format such as TREE", next_token), + _ => self.expected("an explain format such as TREE", next_token), }?; Ok(Some(format)) } /// Parse a SQL `CREATE` statement handling `CREATE EXTERNAL TABLE` - pub fn parse_create(&mut self) -> Result { + pub fn parse_create(&mut self) -> Result { if self.parser.parse_keyword(Keyword::EXTERNAL) { self.parse_create_external_table(false) } else if self.parser.parse_keyword(Keyword::UNBOUNDED) { @@ -678,7 +706,7 @@ impl<'a> DFParser<'a> { } } - fn parse_partitions(&mut self) -> Result, ParserError> { + fn parse_partitions(&mut self) -> Result, DataFusionError> { let mut partitions: Vec = vec![]; if !self.parser.consume_token(&Token::LParen) || self.parser.consume_token(&Token::RParen) @@ -708,7 +736,7 @@ impl<'a> DFParser<'a> { } /// Parse the ordering clause of a `CREATE EXTERNAL TABLE` SQL statement - pub fn parse_order_by_exprs(&mut self) -> Result, ParserError> { + pub fn parse_order_by_exprs(&mut self) -> Result, DataFusionError> { let mut values = vec![]; self.parser.expect_token(&Token::LParen)?; loop { @@ -721,7 +749,7 @@ impl<'a> DFParser<'a> { } /// Parse an ORDER BY sub-expression optionally followed by ASC or DESC. - pub fn parse_order_by_expr(&mut self) -> Result { + pub fn parse_order_by_expr(&mut self) -> Result { let expr = self.parser.parse_expr()?; let asc = if self.parser.parse_keyword(Keyword::ASC) { @@ -753,7 +781,7 @@ impl<'a> DFParser<'a> { // This is a copy of the equivalent implementation in sqlparser. fn parse_columns( &mut self, - ) -> Result<(Vec, Vec), ParserError> { + ) -> Result<(Vec, Vec), DataFusionError> { let mut columns = vec![]; let mut constraints = vec![]; if !self.parser.consume_token(&Token::LParen) @@ -789,7 +817,7 @@ impl<'a> DFParser<'a> { Ok((columns, constraints)) } - fn parse_column_def(&mut self) -> Result { + fn parse_column_def(&mut self) -> Result { let name = self.parser.parse_identifier()?; let data_type = self.parser.parse_data_type()?; let mut options = vec![]; @@ -820,7 +848,7 @@ impl<'a> DFParser<'a> { fn parse_create_external_table( &mut self, unbounded: bool, - ) -> Result { + ) -> Result { let temporary = self .parser .parse_one_of_keywords(&[Keyword::TEMP, Keyword::TEMPORARY]) @@ -868,15 +896,15 @@ impl<'a> DFParser<'a> { } else { self.parser.expect_keyword(Keyword::HEADER)?; self.parser.expect_keyword(Keyword::ROW)?; - return parser_err!("WITH HEADER ROW clause is no longer in use. Please use the OPTIONS clause with 'format.has_header' set appropriately, e.g., OPTIONS (format.has_header true)"); + return parser_err!("WITH HEADER ROW clause is no longer in use. Please use the OPTIONS clause with 'format.has_header' set appropriately, e.g., OPTIONS (format.has_header true)")?; } } Keyword::DELIMITER => { - return parser_err!("DELIMITER clause is no longer in use. Please use the OPTIONS clause with 'format.delimiter' set appropriately, e.g., OPTIONS (format.delimiter ',')"); + return parser_err!("DELIMITER clause is no longer in use. Please use the OPTIONS clause with 'format.delimiter' set appropriately, e.g., OPTIONS (format.delimiter ',')")?; } Keyword::COMPRESSION => { self.parser.expect_keyword(Keyword::TYPE)?; - return parser_err!("COMPRESSION TYPE clause is no longer in use. Please use the OPTIONS clause with 'format.compression' set appropriately, e.g., OPTIONS (format.compression gzip)"); + return parser_err!("COMPRESSION TYPE clause is no longer in use. Please use the OPTIONS clause with 'format.compression' set appropriately, e.g., OPTIONS (format.compression gzip)")?; } Keyword::PARTITIONED => { self.parser.expect_keyword(Keyword::BY)?; @@ -899,7 +927,7 @@ impl<'a> DFParser<'a> { columns.extend(cols); if !cons.is_empty() { - return Err(ParserError::ParserError( + return sql_err!(ParserError::ParserError( "Constraints on Partition Columns are not supported" .to_string(), )); @@ -919,21 +947,19 @@ impl<'a> DFParser<'a> { if token == Token::EOF || token == Token::SemiColon { break; } else { - return Err(ParserError::ParserError(format!( - "Unexpected token {token}" - ))); + return self.expected("end of statement or ;", token)?; } } } // Validations: location and file_type are required if builder.file_type.is_none() { - return Err(ParserError::ParserError( + return sql_err!(ParserError::ParserError( "Missing STORED AS clause in CREATE EXTERNAL TABLE statement".into(), )); } if builder.location.is_none() { - return Err(ParserError::ParserError( + return sql_err!(ParserError::ParserError( "Missing LOCATION clause in CREATE EXTERNAL TABLE statement".into(), )); } @@ -955,7 +981,7 @@ impl<'a> DFParser<'a> { } /// Parses the set of valid formats - fn parse_file_format(&mut self) -> Result { + fn parse_file_format(&mut self) -> Result { let token = self.parser.next_token(); match &token.token { Token::Word(w) => parse_file_type(&w.value), @@ -967,7 +993,7 @@ impl<'a> DFParser<'a> { /// /// This method supports keywords as key names as well as multiple /// value types such as Numbers as well as Strings. - fn parse_value_options(&mut self) -> Result, ParserError> { + fn parse_value_options(&mut self) -> Result, DataFusionError> { let mut options = vec![]; self.parser.expect_token(&Token::LParen)?; @@ -999,7 +1025,7 @@ mod tests { use sqlparser::dialect::SnowflakeDialect; use sqlparser::tokenizer::Span; - fn expect_parse_ok(sql: &str, expected: Statement) -> Result<(), ParserError> { + fn expect_parse_ok(sql: &str, expected: Statement) -> Result<(), DataFusionError> { let statements = DFParser::parse_sql(sql)?; assert_eq!( statements.len(), @@ -1041,7 +1067,7 @@ mod tests { } #[test] - fn create_external_table() -> Result<(), ParserError> { + fn create_external_table() -> Result<(), DataFusionError> { // positive case let sql = "CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV LOCATION 'foo.csv'"; let display = None; @@ -1262,13 +1288,13 @@ mod tests { "CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV PARTITIONED BY (p1 int, c1) LOCATION 'foo.csv'"; expect_parse_error( sql, - "sql parser error: Expected: a data type name, found: )", + "SQL error: ParserError(\"Expected: a data type name, found: ) at Line: 1, Column: 73\")", ); // negative case: mixed column defs and column names in `PARTITIONED BY` clause let sql = "CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV PARTITIONED BY (c1, p1 int) LOCATION 'foo.csv'"; - expect_parse_error(sql, "sql parser error: Expected ',' or ')' after partition definition, found: int"); + expect_parse_error(sql, "SQL error: ParserError(\"Expected: ',' or ')' after partition definition, found: int at Line: 1, Column: 70\")"); // positive case: additional options (one entry) can be specified let sql = @@ -1514,7 +1540,7 @@ mod tests { } #[test] - fn copy_to_table_to_table() -> Result<(), ParserError> { + fn copy_to_table_to_table() -> Result<(), DataFusionError> { // positive case let sql = "COPY foo TO bar STORED AS CSV"; let expected = Statement::CopyTo(CopyToStatement { @@ -1530,7 +1556,7 @@ mod tests { } #[test] - fn skip_copy_into_snowflake() -> Result<(), ParserError> { + fn skip_copy_into_snowflake() -> Result<(), DataFusionError> { let sql = "COPY INTO foo FROM @~/staged FILE_FORMAT = (FORMAT_NAME = 'mycsv');"; let dialect = Box::new(SnowflakeDialect); let statements = DFParser::parse_sql_with_dialect(sql, dialect.as_ref())?; @@ -1547,7 +1573,7 @@ mod tests { } #[test] - fn explain_copy_to_table_to_table() -> Result<(), ParserError> { + fn explain_copy_to_table_to_table() -> Result<(), DataFusionError> { let cases = vec![ ("EXPLAIN COPY foo TO bar STORED AS PARQUET", false, false), ( @@ -1588,7 +1614,7 @@ mod tests { } #[test] - fn copy_to_query_to_table() -> Result<(), ParserError> { + fn copy_to_query_to_table() -> Result<(), DataFusionError> { let statement = verified_stmt("SELECT 1"); // unwrap the various layers @@ -1621,7 +1647,7 @@ mod tests { } #[test] - fn copy_to_options() -> Result<(), ParserError> { + fn copy_to_options() -> Result<(), DataFusionError> { let sql = "COPY foo TO bar STORED AS CSV OPTIONS ('row_group_size' '55')"; let expected = Statement::CopyTo(CopyToStatement { source: object_name("foo"), @@ -1638,7 +1664,7 @@ mod tests { } #[test] - fn copy_to_partitioned_by() -> Result<(), ParserError> { + fn copy_to_partitioned_by() -> Result<(), DataFusionError> { let sql = "COPY foo TO bar STORED AS CSV PARTITIONED BY (a) OPTIONS ('row_group_size' '55')"; let expected = Statement::CopyTo(CopyToStatement { source: object_name("foo"), @@ -1655,7 +1681,7 @@ mod tests { } #[test] - fn copy_to_multi_options() -> Result<(), ParserError> { + fn copy_to_multi_options() -> Result<(), DataFusionError> { // order of options is preserved let sql = "COPY foo TO bar STORED AS parquet OPTIONS ('format.row_group_size' 55, 'format.compression' snappy, 'execution.keep_partition_by_columns' true)"; @@ -1754,7 +1780,7 @@ mod tests { assert_contains!( err.to_string(), - "sql parser error: recursion limit exceeded" + "SQL error: RecursionLimitExceeded (current limit: 1)" ); } } diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 3325c98aa74b..5a1f3cdf69c3 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -72,7 +72,7 @@ impl ParserOptions { parse_float_as_decimal: false, enable_ident_normalization: true, support_varchar_with_length: true, - map_varchar_to_utf8view: false, + map_varchar_to_utf8view: true, enable_options_value_normalization: false, collect_spans: false, } @@ -816,7 +816,7 @@ impl std::fmt::Display for IdentTaker { if !first { write!(f, ".")?; } - write!(f, "{}", ident)?; + write!(f, "{ident}")?; first = false; } diff --git a/datafusion/sql/src/query.rs b/datafusion/sql/src/query.rs index ea641320c01b..f42a3ad138c4 100644 --- a/datafusion/sql/src/query.rs +++ b/datafusion/sql/src/query.rs @@ -22,14 +22,15 @@ use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use crate::stack::StackGuard; use datafusion_common::{not_impl_err, Constraints, DFSchema, Result}; use datafusion_expr::expr::Sort; -use datafusion_expr::select_expr::SelectExpr; + use datafusion_expr::{ - CreateMemoryTable, DdlStatement, Distinct, LogicalPlan, LogicalPlanBuilder, + CreateMemoryTable, DdlStatement, Distinct, Expr, LogicalPlan, LogicalPlanBuilder, }; use sqlparser::ast::{ - Expr as SQLExpr, Offset as SQLOffset, OrderBy, OrderByExpr, OrderByKind, Query, - SelectInto, SetExpr, + Expr as SQLExpr, Ident, Offset as SQLOffset, OrderBy, OrderByExpr, OrderByKind, + Query, SelectInto, SetExpr, }; +use sqlparser::tokenizer::Span; impl SqlToRel<'_, S> { /// Generate a logical plan from an SQL query/subquery @@ -158,7 +159,7 @@ fn to_order_by_exprs(order_by: Option) -> Result> { /// Returns the order by expressions from the query with the select expressions. pub(crate) fn to_order_by_exprs_with_select( order_by: Option, - _select_exprs: Option<&Vec>, // TODO: ORDER BY ALL + select_exprs: Option<&Vec>, ) -> Result> { let Some(OrderBy { kind, interpolate }) = order_by else { // If no order by, return an empty array. @@ -168,7 +169,30 @@ pub(crate) fn to_order_by_exprs_with_select( return not_impl_err!("ORDER BY INTERPOLATE is not supported"); } match kind { - OrderByKind::All(_) => not_impl_err!("ORDER BY ALL is not supported"), + OrderByKind::All(order_by_options) => { + let Some(exprs) = select_exprs else { + return Ok(vec![]); + }; + let order_by_exprs = exprs + .iter() + .map(|select_expr| match select_expr { + Expr::Column(column) => Ok(OrderByExpr { + expr: SQLExpr::Identifier(Ident { + value: column.name.clone(), + quote_style: None, + span: Span::empty(), + }), + options: order_by_options.clone(), + with_fill: None, + }), + // TODO: Support other types of expressions + _ => not_impl_err!( + "ORDER BY ALL is not supported for non-column expressions" + ), + }) + .collect::>>()?; + Ok(order_by_exprs) + } OrderByKind::Expressions(order_by_exprs) => Ok(order_by_exprs), } } diff --git a/datafusion/sql/src/relation/mod.rs b/datafusion/sql/src/relation/mod.rs index b080d211b413..1c5a8ff4d252 100644 --- a/datafusion/sql/src/relation/mod.rs +++ b/datafusion/sql/src/relation/mod.rs @@ -92,7 +92,7 @@ impl SqlToRel<'_, S> { .build(), (None, Err(e)) => { let e = e.with_diagnostic(Diagnostic::new_error( - format!("table '{}' not found", table_ref), + format!("table '{table_ref}' not found"), Span::try_from_sqlparser_span(relation_span), )); Err(e) diff --git a/datafusion/sql/src/resolve.rs b/datafusion/sql/src/resolve.rs index 96012a92c09a..9e909f66fa97 100644 --- a/datafusion/sql/src/resolve.rs +++ b/datafusion/sql/src/resolve.rs @@ -78,7 +78,7 @@ impl Visitor for RelationVisitor { if !with.recursive { // This is a bit hackish as the CTE will be visited again as part of visiting `q`, // but thankfully `insert_relation` is idempotent. - cte.visit(self); + let _ = cte.visit(self); } self.ctes_in_scope .push(ObjectName::from(vec![cte.alias.name.clone()])); @@ -143,7 +143,7 @@ fn visit_statement(statement: &DFStatement, visitor: &mut RelationVisitor) { visitor.insert_relation(table_name); } CopyToSource::Query(query) => { - query.visit(visitor); + let _ = query.visit(visitor); } }, DFStatement::Explain(explain) => visit_statement(&explain.statement, visitor), diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index 33994b60b735..9fad274b51c0 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -94,13 +94,13 @@ impl SqlToRel<'_, S> { planner_context, )?; - let order_by = - to_order_by_exprs_with_select(query_order_by, Some(&select_exprs))?; - // Having and group by clause may reference aliases defined in select projection let projected_plan = self.project(base_plan.clone(), select_exprs)?; let select_exprs = projected_plan.expressions(); + let order_by = + to_order_by_exprs_with_select(query_order_by, Some(&select_exprs))?; + // Place the fields of the base plan at the front so that when there are references // with the same name, the fields of the base plan will be searched first. // See https://github.com/apache/datafusion/issues/9162 @@ -885,7 +885,7 @@ impl SqlToRel<'_, S> { | SelectItem::UnnamedExpr(expr) = proj { let mut err = None; - visit_expressions_mut(expr, |expr| { + let _ = visit_expressions_mut(expr, |expr| { if let SQLExpr::Function(f) = expr { if let Some(WindowType::NamedWindow(ident)) = &f.over { let normalized_ident = diff --git a/datafusion/sql/src/set_expr.rs b/datafusion/sql/src/set_expr.rs index 272d6f874b4d..5b65e1c045bd 100644 --- a/datafusion/sql/src/set_expr.rs +++ b/datafusion/sql/src/set_expr.rs @@ -95,26 +95,22 @@ impl SqlToRel<'_, S> { if left_plan.schema().fields().len() == right_plan.schema().fields().len() { return Ok(()); } - - plan_err!("{} queries have different number of columns", op).map_err(|err| { - err.with_diagnostic( - Diagnostic::new_error( - format!("{} queries have different number of columns", op), - set_expr_span, - ) - .with_note( - format!("this side has {} fields", left_plan.schema().fields().len()), - left_span, - ) - .with_note( - format!( - "this side has {} fields", - right_plan.schema().fields().len() - ), - right_span, - ), - ) - }) + let diagnostic = Diagnostic::new_error( + format!("{op} queries have different number of columns"), + set_expr_span, + ) + .with_note( + format!("this side has {} fields", left_plan.schema().fields().len()), + left_span, + ) + .with_note( + format!( + "this side has {} fields", + right_plan.schema().fields().len() + ), + right_span, + ); + plan_err!("{} queries have different number of columns", op; diagnostic =diagnostic) } pub(super) fn set_operation_to_plan( diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index 1f1c235fee6f..dafb0346485e 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -696,7 +696,7 @@ impl SqlToRel<'_, S> { statement, } => { // Convert parser data types to DataFusion data types - let data_types: Vec = data_types + let mut data_types: Vec = data_types .into_iter() .map(|t| self.convert_data_type(&t)) .collect::>()?; @@ -710,6 +710,19 @@ impl SqlToRel<'_, S> { *statement, &mut planner_context, )?; + + if data_types.is_empty() { + let map_types = plan.get_parameter_types()?; + let param_types: Vec<_> = (1..=map_types.len()) + .filter_map(|i| { + let key = format!("${i}"); + map_types.get(&key).and_then(|opt| opt.clone()) + }) + .collect(); + data_types.extend(param_types.iter().cloned()); + planner_context.with_prepare_param_data_types(param_types); + } + Ok(LogicalPlan::Statement(PlanStatement::Prepare(Prepare { name: ident_to_string(&name), data_types, @@ -1609,7 +1622,7 @@ impl SqlToRel<'_, S> { // If config does not belong to any namespace, assume it is // a format option and apply the format prefix for backwards // compatibility. - let renamed_key = format!("format.{}", key); + let renamed_key = format!("format.{key}"); options_map.insert(renamed_key.to_lowercase(), value_string); } else { options_map.insert(key.to_lowercase(), value_string); @@ -1794,7 +1807,10 @@ impl SqlToRel<'_, S> { // Do a table lookup to verify the table exists let table_ref = self.object_name_to_table_reference(table_name.clone())?; let table_source = self.context_provider.get_table_source(table_ref.clone())?; - let schema = table_source.schema().to_dfschema_ref()?; + let schema = DFSchema::try_from_qualified_schema( + table_ref.clone(), + &table_source.schema(), + )?; let scan = LogicalPlanBuilder::scan(table_ref.clone(), Arc::clone(&table_source), None)? .build()?; @@ -2049,7 +2065,7 @@ impl SqlToRel<'_, S> { .cloned() .unwrap_or_else(|| { // If there is no default for the column, then the default is NULL - Expr::Literal(ScalarValue::Null) + Expr::Literal(ScalarValue::Null, None) }) .cast_to(target_field.data_type(), &DFSchema::empty())?, }; diff --git a/datafusion/sql/src/unparser/ast.rs b/datafusion/sql/src/unparser/ast.rs index a9a5b325d37e..03926da5a4f6 100644 --- a/datafusion/sql/src/unparser/ast.rs +++ b/datafusion/sql/src/unparser/ast.rs @@ -219,7 +219,7 @@ impl SelectBuilder { value: &ast::Expr, ) -> &mut Self { if let Some(selection) = &mut self.selection { - visit_expressions_mut(selection, |expr| { + let _ = visit_expressions_mut(selection, |expr| { if expr == existing_expr { *expr = value.clone(); } @@ -409,6 +409,7 @@ pub struct RelationBuilder { #[allow(dead_code)] #[derive(Clone)] +#[allow(clippy::large_enum_variant)] enum TableFactorBuilder { Table(TableRelationBuilder), Derived(DerivedRelationBuilder), @@ -735,9 +736,9 @@ impl fmt::Display for BuilderError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { Self::UninitializedField(ref field) => { - write!(f, "`{}` must be initialized", field) + write!(f, "`{field}` must be initialized") } - Self::ValidationError(ref error) => write!(f, "{}", error), + Self::ValidationError(ref error) => write!(f, "{error}"), } } } diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 89e9e237488a..cce14894acaf 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -187,19 +187,20 @@ impl Unparser<'_> { Expr::Cast(Cast { expr, data_type }) => { Ok(self.cast_to_sql(expr, data_type)?) } - Expr::Literal(value) => Ok(self.scalar_to_sql(value)?), + Expr::Literal(value, _) => Ok(self.scalar_to_sql(value)?), Expr::Alias(Alias { expr, name: _, .. }) => self.expr_to_sql_inner(expr), - Expr::WindowFunction(WindowFunction { - fun, - params: - WindowFunctionParams { - args, - partition_by, - order_by, - window_frame, - .. - }, - }) => { + Expr::WindowFunction(window_fun) => { + let WindowFunction { + fun, + params: + WindowFunctionParams { + args, + partition_by, + order_by, + window_frame, + .. + }, + } = window_fun.as_ref(); let func_name = fun.name(); let args = self.function_args_to_sql(args)?; @@ -312,6 +313,7 @@ impl Unparser<'_> { distinct, args, filter, + order_by, .. } = &agg.params; @@ -320,6 +322,16 @@ impl Unparser<'_> { Some(filter) => Some(Box::new(self.expr_to_sql_inner(filter)?)), None => None, }; + let within_group = if agg.func.is_ordered_set_aggregate() { + order_by + .as_ref() + .unwrap_or(&Vec::new()) + .iter() + .map(|sort_expr| self.sort_to_sql(sort_expr)) + .collect::>>()? + } else { + Vec::new() + }; Ok(ast::Expr::Function(Function { name: ObjectName::from(vec![Ident { value: func_name.to_string(), @@ -335,7 +347,7 @@ impl Unparser<'_> { filter, null_treatment: None, over: None, - within_group: vec![], + within_group, parameters: ast::FunctionArguments::None, uses_odbc_syntax: false, })) @@ -590,7 +602,7 @@ impl Unparser<'_> { .chunks_exact(2) .map(|chunk| { let key = match &chunk[0] { - Expr::Literal(ScalarValue::Utf8(Some(s))) => self.new_ident_quoted_if_needs(s.to_string()), + Expr::Literal(ScalarValue::Utf8(Some(s)), _) => self.new_ident_quoted_if_needs(s.to_string()), _ => return internal_err!("named_struct expects even arguments to be strings, but received: {:?}", &chunk[0]) }; @@ -619,7 +631,7 @@ impl Unparser<'_> { }; let field = match &args[1] { - Expr::Literal(lit) => self.new_ident_quoted_if_needs(lit.to_string()), + Expr::Literal(lit, _) => self.new_ident_quoted_if_needs(lit.to_string()), _ => { return internal_err!( "get_field expects second argument to be a string, but received: {:?}", @@ -1122,16 +1134,16 @@ impl Unparser<'_> { ScalarValue::Float16(None) => Ok(ast::Expr::value(ast::Value::Null)), ScalarValue::Float32(Some(f)) => { let f_val = match f.fract() { - 0.0 => format!("{:.1}", f), - _ => format!("{}", f), + 0.0 => format!("{f:.1}"), + _ => format!("{f}"), }; Ok(ast::Expr::value(ast::Value::Number(f_val, false))) } ScalarValue::Float32(None) => Ok(ast::Expr::value(ast::Value::Null)), ScalarValue::Float64(Some(f)) => { let f_val = match f.fract() { - 0.0 => format!("{:.1}", f), - _ => format!("{}", f), + 0.0 => format!("{f:.1}"), + _ => format!("{f}"), }; Ok(ast::Expr::value(ast::Value::Number(f_val, false))) } @@ -1899,87 +1911,87 @@ mod tests { r#"a LIKE 'foo' ESCAPE 'o'"#, ), ( - Expr::Literal(ScalarValue::Date64(Some(0))), + Expr::Literal(ScalarValue::Date64(Some(0)), None), r#"CAST('1970-01-01 00:00:00' AS DATETIME)"#, ), ( - Expr::Literal(ScalarValue::Date64(Some(10000))), + Expr::Literal(ScalarValue::Date64(Some(10000)), None), r#"CAST('1970-01-01 00:00:10' AS DATETIME)"#, ), ( - Expr::Literal(ScalarValue::Date64(Some(-10000))), + Expr::Literal(ScalarValue::Date64(Some(-10000)), None), r#"CAST('1969-12-31 23:59:50' AS DATETIME)"#, ), ( - Expr::Literal(ScalarValue::Date32(Some(0))), + Expr::Literal(ScalarValue::Date32(Some(0)), None), r#"CAST('1970-01-01' AS DATE)"#, ), ( - Expr::Literal(ScalarValue::Date32(Some(10))), + Expr::Literal(ScalarValue::Date32(Some(10)), None), r#"CAST('1970-01-11' AS DATE)"#, ), ( - Expr::Literal(ScalarValue::Date32(Some(-1))), + Expr::Literal(ScalarValue::Date32(Some(-1)), None), r#"CAST('1969-12-31' AS DATE)"#, ), ( - Expr::Literal(ScalarValue::TimestampSecond(Some(10001), None)), + Expr::Literal(ScalarValue::TimestampSecond(Some(10001), None), None), r#"CAST('1970-01-01 02:46:41' AS TIMESTAMP)"#, ), ( - Expr::Literal(ScalarValue::TimestampSecond( - Some(10001), - Some("+08:00".into()), - )), + Expr::Literal( + ScalarValue::TimestampSecond(Some(10001), Some("+08:00".into())), + None, + ), r#"CAST('1970-01-01 10:46:41 +08:00' AS TIMESTAMP)"#, ), ( - Expr::Literal(ScalarValue::TimestampMillisecond(Some(10001), None)), + Expr::Literal(ScalarValue::TimestampMillisecond(Some(10001), None), None), r#"CAST('1970-01-01 00:00:10.001' AS TIMESTAMP)"#, ), ( - Expr::Literal(ScalarValue::TimestampMillisecond( - Some(10001), - Some("+08:00".into()), - )), + Expr::Literal( + ScalarValue::TimestampMillisecond(Some(10001), Some("+08:00".into())), + None, + ), r#"CAST('1970-01-01 08:00:10.001 +08:00' AS TIMESTAMP)"#, ), ( - Expr::Literal(ScalarValue::TimestampMicrosecond(Some(10001), None)), + Expr::Literal(ScalarValue::TimestampMicrosecond(Some(10001), None), None), r#"CAST('1970-01-01 00:00:00.010001' AS TIMESTAMP)"#, ), ( - Expr::Literal(ScalarValue::TimestampMicrosecond( - Some(10001), - Some("+08:00".into()), - )), + Expr::Literal( + ScalarValue::TimestampMicrosecond(Some(10001), Some("+08:00".into())), + None, + ), r#"CAST('1970-01-01 08:00:00.010001 +08:00' AS TIMESTAMP)"#, ), ( - Expr::Literal(ScalarValue::TimestampNanosecond(Some(10001), None)), + Expr::Literal(ScalarValue::TimestampNanosecond(Some(10001), None), None), r#"CAST('1970-01-01 00:00:00.000010001' AS TIMESTAMP)"#, ), ( - Expr::Literal(ScalarValue::TimestampNanosecond( - Some(10001), - Some("+08:00".into()), - )), + Expr::Literal( + ScalarValue::TimestampNanosecond(Some(10001), Some("+08:00".into())), + None, + ), r#"CAST('1970-01-01 08:00:00.000010001 +08:00' AS TIMESTAMP)"#, ), ( - Expr::Literal(ScalarValue::Time32Second(Some(10001))), + Expr::Literal(ScalarValue::Time32Second(Some(10001)), None), r#"CAST('02:46:41' AS TIME)"#, ), ( - Expr::Literal(ScalarValue::Time32Millisecond(Some(10001))), + Expr::Literal(ScalarValue::Time32Millisecond(Some(10001)), None), r#"CAST('00:00:10.001' AS TIME)"#, ), ( - Expr::Literal(ScalarValue::Time64Microsecond(Some(10001))), + Expr::Literal(ScalarValue::Time64Microsecond(Some(10001)), None), r#"CAST('00:00:00.010001' AS TIME)"#, ), ( - Expr::Literal(ScalarValue::Time64Nanosecond(Some(10001))), + Expr::Literal(ScalarValue::Time64Nanosecond(Some(10001)), None), r#"CAST('00:00:00.000010001' AS TIME)"#, ), (sum(col("a")), r#"sum(a)"#), @@ -2008,7 +2020,7 @@ mod tests { "count(*) FILTER (WHERE true)", ), ( - Expr::WindowFunction(WindowFunction { + Expr::from(WindowFunction { fun: WindowFunctionDefinition::WindowUDF(row_number_udwf()), params: WindowFunctionParams { args: vec![col("col")], @@ -2022,7 +2034,7 @@ mod tests { ), ( #[expect(deprecated)] - Expr::WindowFunction(WindowFunction { + Expr::from(WindowFunction { fun: WindowFunctionDefinition::AggregateUDF(count_udaf()), params: WindowFunctionParams { args: vec![Expr::Wildcard { @@ -2124,19 +2136,17 @@ mod tests { (col("need quoted").eq(lit(1)), r#"("need quoted" = 1)"#), // See test_interval_scalar_to_expr for interval literals ( - (col("a") + col("b")).gt(Expr::Literal(ScalarValue::Decimal128( - Some(100123), - 28, - 3, - ))), + (col("a") + col("b")).gt(Expr::Literal( + ScalarValue::Decimal128(Some(100123), 28, 3), + None, + )), r#"((a + b) > 100.123)"#, ), ( - (col("a") + col("b")).gt(Expr::Literal(ScalarValue::Decimal256( - Some(100123.into()), - 28, - 3, - ))), + (col("a") + col("b")).gt(Expr::Literal( + ScalarValue::Decimal256(Some(100123.into()), 28, 3), + None, + )), r#"((a + b) > 100.123)"#, ), ( @@ -2172,28 +2182,39 @@ mod tests { "MAP {'a': 1, 'b': 2}", ), ( - Expr::Literal(ScalarValue::Dictionary( - Box::new(DataType::Int32), - Box::new(ScalarValue::Utf8(Some("foo".into()))), - )), + Expr::Literal( + ScalarValue::Dictionary( + Box::new(DataType::Int32), + Box::new(ScalarValue::Utf8(Some("foo".into()))), + ), + None, + ), "'foo'", ), ( - Expr::Literal(ScalarValue::List(Arc::new( - ListArray::from_iter_primitive::(vec![Some(vec![ + Expr::Literal( + ScalarValue::List(Arc::new(ListArray::from_iter_primitive::< + Int32Type, + _, + _, + >(vec![Some(vec![ Some(1), Some(2), Some(3), - ])]), - ))), + ])]))), + None, + ), "[1, 2, 3]", ), ( - Expr::Literal(ScalarValue::LargeList(Arc::new( - LargeListArray::from_iter_primitive::(vec![Some( - vec![Some(1), Some(2), Some(3)], - )]), - ))), + Expr::Literal( + ScalarValue::LargeList(Arc::new( + LargeListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2), Some(3)]), + ]), + )), + None, + ), "[1, 2, 3]", ), ( @@ -2217,7 +2238,7 @@ mod tests { for (expr, expected) in tests { let ast = expr_to_sql(&expr)?; - let actual = format!("{}", ast); + let actual = format!("{ast}"); assert_eq!(actual, expected); } @@ -2235,7 +2256,7 @@ mod tests { let expr = col("a").gt(lit(4)); let ast = unparser.expr_to_sql(&expr)?; - let actual = format!("{}", ast); + let actual = format!("{ast}"); let expected = r#"('a' > 4)"#; assert_eq!(actual, expected); @@ -2251,7 +2272,7 @@ mod tests { let expr = col("a").gt(lit(4)); let ast = unparser.expr_to_sql(&expr)?; - let actual = format!("{}", ast); + let actual = format!("{ast}"); let expected = r#"(a > 4)"#; assert_eq!(actual, expected); @@ -2275,7 +2296,7 @@ mod tests { }); let ast = unparser.expr_to_sql(&expr)?; - let actual = format!("{}", ast); + let actual = format!("{ast}"); let expected = format!(r#"CAST(a AS {identifier})"#); assert_eq!(actual, expected); @@ -2300,7 +2321,7 @@ mod tests { }); let ast = unparser.expr_to_sql(&expr)?; - let actual = format!("{}", ast); + let actual = format!("{ast}"); let expected = format!(r#"CAST(a AS {identifier})"#); assert_eq!(actual, expected); @@ -2322,7 +2343,7 @@ mod tests { let unparser = Unparser::new(&dialect); let ast = unparser.sort_to_sql(&expr)?; - let actual = format!("{}", ast); + let actual = format!("{ast}"); assert_eq!(actual, expected); } @@ -2498,11 +2519,17 @@ mod tests { #[test] fn test_float_scalar_to_expr() { let tests = [ - (Expr::Literal(ScalarValue::Float64(Some(3f64))), "3.0"), - (Expr::Literal(ScalarValue::Float64(Some(3.1f64))), "3.1"), - (Expr::Literal(ScalarValue::Float32(Some(-2f32))), "-2.0"), + (Expr::Literal(ScalarValue::Float64(Some(3f64)), None), "3.0"), + ( + Expr::Literal(ScalarValue::Float64(Some(3.1f64)), None), + "3.1", + ), + ( + Expr::Literal(ScalarValue::Float32(Some(-2f32)), None), + "-2.0", + ), ( - Expr::Literal(ScalarValue::Float32(Some(-2.989f32))), + Expr::Literal(ScalarValue::Float32(Some(-2.989f32)), None), "-2.989", ), ]; @@ -2522,18 +2549,20 @@ mod tests { let tests = [ ( Expr::Cast(Cast { - expr: Box::new(Expr::Literal(ScalarValue::Utf8(Some( - "blah".to_string(), - )))), + expr: Box::new(Expr::Literal( + ScalarValue::Utf8(Some("blah".to_string())), + None, + )), data_type: DataType::Binary, }), "'blah'", ), ( Expr::Cast(Cast { - expr: Box::new(Expr::Literal(ScalarValue::Utf8(Some( - "blah".to_string(), - )))), + expr: Box::new(Expr::Literal( + ScalarValue::Utf8(Some("blah".to_string())), + None, + )), data_type: DataType::BinaryView, }), "'blah'", @@ -2572,7 +2601,7 @@ mod tests { }); let ast = unparser.expr_to_sql(&expr)?; - let actual = format!("{}", ast); + let actual = format!("{ast}"); let expected = format!(r#"CAST(a AS {identifier})"#); assert_eq!(actual, expected); @@ -2625,10 +2654,13 @@ mod tests { let expr = ScalarUDF::new_from_impl( datafusion_functions::datetime::date_part::DatePartFunc::new(), ) - .call(vec![Expr::Literal(ScalarValue::new_utf8(unit)), col("x")]); + .call(vec![ + Expr::Literal(ScalarValue::new_utf8(unit), None), + col("x"), + ]); let ast = unparser.expr_to_sql(&expr)?; - let actual = format!("{}", ast); + let actual = format!("{ast}"); assert_eq!(actual, expected); } @@ -2655,7 +2687,7 @@ mod tests { }); let ast = unparser.expr_to_sql(&expr)?; - let actual = format!("{}", ast); + let actual = format!("{ast}"); let expected = format!(r#"CAST(a AS {identifier})"#); assert_eq!(actual, expected); @@ -2683,7 +2715,7 @@ mod tests { }); let ast = unparser.expr_to_sql(&expr)?; - let actual = format!("{}", ast); + let actual = format!("{ast}"); let expected = format!(r#"CAST(a AS {identifier})"#); assert_eq!(actual, expected); @@ -2722,7 +2754,7 @@ mod tests { }); let ast = unparser.expr_to_sql(&expr)?; - let actual = format!("{}", ast); + let actual = format!("{ast}"); let expected = format!(r#"CAST(a AS {identifier})"#); assert_eq!(actual, expected); @@ -2745,13 +2777,13 @@ mod tests { (&mysql_dialect, "DATETIME"), ] { let unparser = Unparser::new(dialect); - let expr = Expr::Literal(ScalarValue::TimestampMillisecond( - Some(1738285549123), + let expr = Expr::Literal( + ScalarValue::TimestampMillisecond(Some(1738285549123), None), None, - )); + ); let ast = unparser.expr_to_sql(&expr)?; - let actual = format!("{}", ast); + let actual = format!("{ast}"); let expected = format!(r#"CAST('2025-01-31 01:05:49.123' AS {identifier})"#); assert_eq!(actual, expected); @@ -2778,7 +2810,7 @@ mod tests { }); let ast = unparser.expr_to_sql(&expr)?; - let actual = format!("{}", ast); + let actual = format!("{ast}"); let expected = format!(r#"CAST(a AS {identifier})"#); assert_eq!(actual, expected); @@ -2804,7 +2836,7 @@ mod tests { }); let ast = unparser.expr_to_sql(&expr)?; - let actual = format!("{}", ast); + let actual = format!("{ast}"); let expected = expected.to_string(); assert_eq!(actual, expected); @@ -2816,9 +2848,10 @@ mod tests { fn test_cast_value_to_dict_expr() { let tests = [( Expr::Cast(Cast { - expr: Box::new(Expr::Literal(ScalarValue::Utf8(Some( - "variation".to_string(), - )))), + expr: Box::new(Expr::Literal( + ScalarValue::Utf8(Some("variation".to_string())), + None, + )), data_type: DataType::Dictionary(Box::new(Int8), Box::new(DataType::Utf8)), }), "'variation'", @@ -2856,12 +2889,12 @@ mod tests { expr: Box::new(col("a")), data_type: DataType::Float64, }), - Expr::Literal(ScalarValue::Int64(Some(2))), + Expr::Literal(ScalarValue::Int64(Some(2)), None), ], }); let ast = unparser.expr_to_sql(&expr)?; - let actual = format!("{}", ast); + let actual = format!("{ast}"); let expected = format!(r#"round(CAST("a" AS {identifier}), 2)"#); assert_eq!(actual, expected); @@ -2891,7 +2924,7 @@ mod tests { let func = WindowFunctionDefinition::WindowUDF(rank_udwf()); let mut window_func = WindowFunction::new(func, vec![]); window_func.params.order_by = vec![Sort::new(col("a"), true, true)]; - let expr = Expr::WindowFunction(window_func); + let expr = Expr::from(window_func); let ast = unparser.expr_to_sql(&expr)?; let actual = ast.to_string(); @@ -2996,7 +3029,7 @@ mod tests { datafusion_functions::datetime::date_trunc::DateTruncFunc::new(), )), args: vec![ - Expr::Literal(ScalarValue::Utf8(Some(precision.to_string()))), + Expr::Literal(ScalarValue::Utf8(Some(precision.to_string())), None), col("date_col"), ], }); @@ -3041,7 +3074,7 @@ mod tests { let expr = cast(col("a"), DataType::Utf8View); let ast = unparser.expr_to_sql(&expr)?; - let actual = format!("{}", ast); + let actual = format!("{ast}"); let expected = r#"CAST(a AS CHAR)"#.to_string(); assert_eq!(actual, expected); @@ -3049,7 +3082,7 @@ mod tests { let expr = col("a").eq(lit(ScalarValue::Utf8View(Some("hello".to_string())))); let ast = unparser.expr_to_sql(&expr)?; - let actual = format!("{}", ast); + let actual = format!("{ast}"); let expected = r#"(a = 'hello')"#.to_string(); assert_eq!(actual, expected); @@ -3057,7 +3090,7 @@ mod tests { let expr = col("a").is_not_null(); let ast = unparser.expr_to_sql(&expr)?; - let actual = format!("{}", ast); + let actual = format!("{ast}"); let expected = r#"a IS NOT NULL"#.to_string(); assert_eq!(actual, expected); @@ -3065,7 +3098,7 @@ mod tests { let expr = col("a").is_null(); let ast = unparser.expr_to_sql(&expr)?; - let actual = format!("{}", ast); + let actual = format!("{ast}"); let expected = r#"a IS NULL"#.to_string(); assert_eq!(actual, expected); diff --git a/datafusion/sql/src/unparser/extension_unparser.rs b/datafusion/sql/src/unparser/extension_unparser.rs index f7deabe7c902..b778130ca5a2 100644 --- a/datafusion/sql/src/unparser/extension_unparser.rs +++ b/datafusion/sql/src/unparser/extension_unparser.rs @@ -64,6 +64,7 @@ pub enum UnparseWithinStatementResult { } /// The result of unparsing a custom logical node to a statement. +#[allow(clippy::large_enum_variant)] pub enum UnparseToStatementResult { /// If the custom logical node was successfully unparsed to a statement. Modified(Statement), diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index ebbf32f39523..ecb2691b4539 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -50,7 +50,7 @@ use datafusion_expr::{ UserDefinedLogicalNode, }; use sqlparser::ast::{self, Ident, OrderByKind, SetExpr, TableAliasColumnDef}; -use std::sync::Arc; +use std::{sync::Arc, vec}; /// Convert a DataFusion [`LogicalPlan`] to [`ast::Statement`] /// @@ -369,12 +369,13 @@ impl Unparser<'_> { plan: &LogicalPlan, relation: &mut RelationBuilder, lateral: bool, + columns: Vec, ) -> Result<()> { - if self.dialect.requires_derived_table_alias() { + if self.dialect.requires_derived_table_alias() || !columns.is_empty() { self.derive( plan, relation, - Some(self.new_table_alias(alias.to_string(), vec![])), + Some(self.new_table_alias(alias.to_string(), columns)), lateral, ) } else { @@ -452,6 +453,18 @@ impl Unparser<'_> { } } + // If it's a unnest projection, we should provide the table column alias + // to provide a column name for the unnest relation. + let columns = if unnest_input_type.is_some() { + p.expr + .iter() + .map(|e| { + self.new_ident_quoted_if_needs(e.schema_name().to_string()) + }) + .collect() + } else { + vec![] + }; // Projection can be top-level plan for derived table if select.already_projected() { return self.derive_with_dialect_alias( @@ -461,6 +474,7 @@ impl Unparser<'_> { unnest_input_type .filter(|t| matches!(t, UnnestInputType::OuterReference)) .is_some(), + columns, ); } self.reconstruct_select_statement(plan, p, select)?; @@ -494,6 +508,7 @@ impl Unparser<'_> { plan, relation, false, + vec![], ); } if let Some(fetch) = &limit.fetch { @@ -532,6 +547,7 @@ impl Unparser<'_> { plan, relation, false, + vec![], ); } let Some(query_ref) = query else { @@ -603,6 +619,7 @@ impl Unparser<'_> { plan, relation, false, + vec![], ); } @@ -893,6 +910,7 @@ impl Unparser<'_> { plan, relation, false, + vec![], ); } @@ -1010,6 +1028,7 @@ impl Unparser<'_> { subquery.subquery.as_ref(), relation, true, + vec![], ) } } @@ -1032,8 +1051,7 @@ impl Unparser<'_> { if let Expr::Alias(Alias { expr, .. }) = expr { if let Expr::Column(Column { name, .. }) = expr.as_ref() { if let Some(prefix) = name.strip_prefix(UNNEST_PLACEHOLDER) { - if prefix.starts_with(&format!("({}(", OUTER_REFERENCE_COLUMN_PREFIX)) - { + if prefix.starts_with(&format!("({OUTER_REFERENCE_COLUMN_PREFIX}(")) { return Some(UnnestInputType::OuterReference); } return Some(UnnestInputType::Scalar); @@ -1120,6 +1138,7 @@ impl Unparser<'_> { if project_vec.is_empty() { builder = builder.project(vec![Expr::Literal( ScalarValue::Int64(Some(1)), + None, )])?; } else { let project_columns = project_vec diff --git a/datafusion/sql/src/unparser/utils.rs b/datafusion/sql/src/unparser/utils.rs index 7bc2bcd38a66..89fa392c183f 100644 --- a/datafusion/sql/src/unparser/utils.rs +++ b/datafusion/sql/src/unparser/utils.rs @@ -314,7 +314,6 @@ pub(crate) fn unproject_sort_expr( ))); } } - return Ok(Transformed::no(Expr::Column(col))); } Ok(Transformed::no(Expr::Column(col))) @@ -423,7 +422,7 @@ pub(crate) fn date_part_to_sql( match (style, date_part_args.len()) { (DateFieldExtractStyle::Extract, 2) => { let date_expr = unparser.expr_to_sql(&date_part_args[1])?; - if let Expr::Literal(ScalarValue::Utf8(Some(field))) = &date_part_args[0] { + if let Expr::Literal(ScalarValue::Utf8(Some(field)), _) = &date_part_args[0] { let field = match field.to_lowercase().as_str() { "year" => ast::DateTimeField::Year, "month" => ast::DateTimeField::Month, @@ -444,7 +443,7 @@ pub(crate) fn date_part_to_sql( (DateFieldExtractStyle::Strftime, 2) => { let column = unparser.expr_to_sql(&date_part_args[1])?; - if let Expr::Literal(ScalarValue::Utf8(Some(field))) = &date_part_args[0] { + if let Expr::Literal(ScalarValue::Utf8(Some(field)), _) = &date_part_args[0] { let field = match field.to_lowercase().as_str() { "year" => "%Y", "month" => "%m", @@ -532,7 +531,7 @@ pub(crate) fn sqlite_from_unixtime_to_sql( "datetime", &[ from_unixtime_args[0].clone(), - Expr::Literal(ScalarValue::Utf8(Some("unixepoch".to_string()))), + Expr::Literal(ScalarValue::Utf8(Some("unixepoch".to_string())), None), ], )?)) } @@ -555,7 +554,7 @@ pub(crate) fn sqlite_date_trunc_to_sql( ); } - if let Expr::Literal(ScalarValue::Utf8(Some(unit))) = &date_trunc_args[0] { + if let Expr::Literal(ScalarValue::Utf8(Some(unit)), _) = &date_trunc_args[0] { let format = match unit.to_lowercase().as_str() { "year" => "%Y", "month" => "%Y-%m", @@ -569,7 +568,7 @@ pub(crate) fn sqlite_date_trunc_to_sql( return Ok(Some(unparser.scalar_function_to_sql( "strftime", &[ - Expr::Literal(ScalarValue::Utf8(Some(format.to_string()))), + Expr::Literal(ScalarValue::Utf8(Some(format.to_string())), None), date_trunc_args[1].clone(), ], )?)); diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index bc2a94cd44ff..52832e1324be 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -158,20 +158,19 @@ fn check_column_satisfies_expr( purpose: CheckColumnsSatisfyExprsPurpose, ) -> Result<()> { if !columns.contains(expr) { + let diagnostic = Diagnostic::new_error( + purpose.diagnostic_message(expr), + expr.spans().and_then(|spans| spans.first()), + ) + .with_help(format!("Either add '{expr}' to GROUP BY clause, or use an aggregare function like ANY_VALUE({expr})"), None); + return plan_err!( "{}: While expanding wildcard, column \"{}\" must appear in the GROUP BY clause or must be part of an aggregate function, currently only \"{}\" appears in the SELECT clause satisfies this requirement", purpose.message_prefix(), expr, - expr_vec_fmt!(columns) - ) - .map_err(|err| { - let diagnostic = Diagnostic::new_error( - purpose.diagnostic_message(expr), - expr.spans().and_then(|spans| spans.first()), - ) - .with_help(format!("Either add '{expr}' to GROUP BY clause, or use an aggregare function like ANY_VALUE({expr})"), None); - err.with_diagnostic(diagnostic) - }); + expr_vec_fmt!(columns); + diagnostic=diagnostic + ); } Ok(()) } @@ -199,7 +198,7 @@ pub(crate) fn resolve_positions_to_exprs( match expr { // sql_expr_to_logical_expr maps number to i64 // https://github.com/apache/datafusion/blob/8d175c759e17190980f270b5894348dc4cff9bbf/datafusion/src/sql/planner.rs#L882-L887 - Expr::Literal(ScalarValue::Int64(Some(position))) + Expr::Literal(ScalarValue::Int64(Some(position)), _) if position > 0_i64 && position <= select_exprs.len() as i64 => { let index = (position - 1) as usize; @@ -209,7 +208,7 @@ pub(crate) fn resolve_positions_to_exprs( _ => select_expr.clone(), }) } - Expr::Literal(ScalarValue::Int64(Some(position))) => plan_err!( + Expr::Literal(ScalarValue::Int64(Some(position)), _) => plan_err!( "Cannot find column with position {} in SELECT clause. Valid columns: 1 to {}", position, select_exprs.len() ), @@ -242,15 +241,21 @@ pub fn window_expr_common_partition_keys(window_exprs: &[Expr]) -> Result<&[Expr let all_partition_keys = window_exprs .iter() .map(|expr| match expr { - Expr::WindowFunction(WindowFunction { - params: WindowFunctionParams { partition_by, .. }, - .. - }) => Ok(partition_by), - Expr::Alias(Alias { expr, .. }) => match expr.as_ref() { - Expr::WindowFunction(WindowFunction { + Expr::WindowFunction(window_fun) => { + let WindowFunction { params: WindowFunctionParams { partition_by, .. }, .. - }) => Ok(partition_by), + } = window_fun.as_ref(); + Ok(partition_by) + } + Expr::Alias(Alias { expr, .. }) => match expr.as_ref() { + Expr::WindowFunction(window_fun) => { + let WindowFunction { + params: WindowFunctionParams { partition_by, .. }, + .. + } = window_fun.as_ref(); + Ok(partition_by) + } expr => exec_err!("Impossibly got non-window expr {expr:?}"), }, expr => exec_err!("Impossibly got non-window expr {expr:?}"), @@ -399,9 +404,9 @@ impl RecursiveUnnestRewriter<'_> { // Full context, we are trying to plan the execution as InnerProjection->Unnest->OuterProjection // inside unnest execution, each column inside the inner projection // will be transformed into new columns. Thus we need to keep track of these placeholding column names - let placeholder_name = format!("{UNNEST_PLACEHOLDER}({})", inner_expr_name); + let placeholder_name = format!("{UNNEST_PLACEHOLDER}({inner_expr_name})"); let post_unnest_name = - format!("{UNNEST_PLACEHOLDER}({},depth={})", inner_expr_name, level); + format!("{UNNEST_PLACEHOLDER}({inner_expr_name},depth={level})"); // This is due to the fact that unnest transformation should keep the original // column name as is, to comply with group by and order by let placeholder_column = Column::from_name(placeholder_name.clone()); @@ -681,7 +686,7 @@ mod tests { "{}=>[{}]", i.0, vec.iter() - .map(|i| format!("{}", i)) + .map(|i| format!("{i}")) .collect::>() .join(", ") ), diff --git a/datafusion/sql/tests/cases/diagnostic.rs b/datafusion/sql/tests/cases/diagnostic.rs index d08fe787948a..b3fc5dea9eff 100644 --- a/datafusion/sql/tests/cases/diagnostic.rs +++ b/datafusion/sql/tests/cases/diagnostic.rs @@ -20,16 +20,17 @@ use insta::assert_snapshot; use std::{collections::HashMap, sync::Arc}; use datafusion_common::{Diagnostic, Location, Result, Span}; -use datafusion_sql::planner::{ParserOptions, SqlToRel}; +use datafusion_sql::{ + parser::{DFParser, DFParserBuilder}, + planner::{ParserOptions, SqlToRel}, +}; use regex::Regex; -use sqlparser::{dialect::GenericDialect, parser::Parser}; use crate::{MockContextProvider, MockSessionState}; fn do_query(sql: &'static str) -> Diagnostic { - let dialect = GenericDialect {}; - let statement = Parser::new(&dialect) - .try_with_sql(sql) + let statement = DFParserBuilder::new(sql) + .build() .expect("unable to create parser") .parse_statement() .expect("unable to parse query"); @@ -41,7 +42,7 @@ fn do_query(sql: &'static str) -> Diagnostic { .with_scalar_function(Arc::new(string::concat().as_ref().clone())); let context = MockContextProvider { state }; let sql_to_rel = SqlToRel::new_with_options(&context, options); - match sql_to_rel.sql_statement_to_plan(statement) { + match sql_to_rel.statement_to_plan(statement) { Ok(_) => panic!("expected error"), Err(err) => match err.diagnostic() { Some(diag) => diag.clone(), @@ -366,3 +367,25 @@ fn test_unary_op_plus_with_non_column() -> Result<()> { assert_eq!(diag.span, None); Ok(()) } + +#[test] +fn test_syntax_error() -> Result<()> { + // create a table with a column of type varchar + let query = "CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV PARTITIONED BY (c1, p1 /*int*/int/*int*/) LOCATION 'foo.csv'"; + let spans = get_spans(query); + match DFParser::parse_sql(query) { + Ok(_) => panic!("expected error"), + Err(err) => match err.diagnostic() { + Some(diag) => { + let diag = diag.clone(); + assert_snapshot!(diag.message, @"Expected: ',' or ')' after partition definition, found: int at Line: 1, Column: 77"); + println!("{spans:?}"); + assert_eq!(diag.span, Some(spans["int"])); + Ok(()) + } + None => { + panic!("expected diagnostic") + } + }, + } +} diff --git a/datafusion/sql/tests/cases/mod.rs b/datafusion/sql/tests/cases/mod.rs index b3eedcdc41e3..426d188f633c 100644 --- a/datafusion/sql/tests/cases/mod.rs +++ b/datafusion/sql/tests/cases/mod.rs @@ -17,4 +17,5 @@ mod collection; mod diagnostic; +mod params; mod plan_to_sql; diff --git a/datafusion/sql/tests/cases/params.rs b/datafusion/sql/tests/cases/params.rs new file mode 100644 index 000000000000..b3cc49c31071 --- /dev/null +++ b/datafusion/sql/tests/cases/params.rs @@ -0,0 +1,886 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan; +use arrow::datatypes::DataType; +use datafusion_common::{assert_contains, ParamValues, ScalarValue}; +use datafusion_expr::{LogicalPlan, Prepare, Statement}; +use insta::assert_snapshot; +use std::collections::HashMap; + +pub struct ParameterTest<'a> { + pub sql: &'a str, + pub expected_types: Vec<(&'a str, Option)>, + pub param_values: Vec, +} + +impl ParameterTest<'_> { + pub fn run(&self) -> String { + let plan = logical_plan(self.sql).unwrap(); + + let actual_types = plan.get_parameter_types().unwrap(); + let expected_types: HashMap> = self + .expected_types + .iter() + .map(|(k, v)| (k.to_string(), v.clone())) + .collect(); + + assert_eq!(actual_types, expected_types); + + let plan_with_params = plan + .clone() + .with_param_values(self.param_values.clone()) + .unwrap(); + + format!("** Initial Plan:\n{plan}\n** Final Plan:\n{plan_with_params}") + } +} + +fn generate_prepare_stmt_and_data_types(sql: &str) -> (LogicalPlan, String) { + let plan = logical_plan(sql).unwrap(); + let data_types = match &plan { + LogicalPlan::Statement(Statement::Prepare(Prepare { data_types, .. })) => { + format!("{data_types:?}") + } + _ => panic!("Expected a Prepare statement"), + }; + (plan, data_types) +} + +#[test] +fn test_prepare_statement_to_plan_panic_param_format() { + // param is not number following the $ sign + // panic due to error returned from the parser + let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = $foo"; + + assert_snapshot!( + logical_plan(sql).unwrap_err().strip_backtrace(), + @r###" + Error during planning: Invalid placeholder, not a number: $foo + "### + ); +} + +#[test] +fn test_prepare_statement_to_plan_panic_param_zero() { + // param is zero following the $ sign + // panic due to error returned from the parser + let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = $0"; + + assert_snapshot!( + logical_plan(sql).unwrap_err().strip_backtrace(), + @r###" + Error during planning: Invalid placeholder, zero is not a valid index: $0 + "### + ); +} + +#[test] +fn test_prepare_statement_to_plan_panic_prepare_wrong_syntax() { + // param is not number following the $ sign + // panic due to error returned from the parser + let sql = "PREPARE AS SELECT id, age FROM person WHERE age = $foo"; + assert!(logical_plan(sql) + .unwrap_err() + .strip_backtrace() + .contains("Expected: AS, found: SELECT")) +} + +#[test] +fn test_prepare_statement_to_plan_panic_no_relation_and_constant_param() { + let sql = "PREPARE my_plan(INT) AS SELECT id + $1"; + + let plan = logical_plan(sql).unwrap_err().strip_backtrace(); + assert_snapshot!( + plan, + @r"Schema error: No field named id." + ); +} + +#[test] +fn test_prepare_statement_should_infer_types() { + // only provide 1 data type while using 2 params + let sql = "PREPARE my_plan(INT) AS SELECT 1 + $1 + $2"; + let plan = logical_plan(sql).unwrap(); + let actual_types = plan.get_parameter_types().unwrap(); + let expected_types = HashMap::from([ + ("$1".to_string(), Some(DataType::Int32)), + ("$2".to_string(), Some(DataType::Int64)), + ]); + assert_eq!(actual_types, expected_types); +} + +#[test] +fn test_non_prepare_statement_should_infer_types() { + // Non prepared statements (like SELECT) should also have their parameter types inferred + let sql = "SELECT 1 + $1"; + let plan = logical_plan(sql).unwrap(); + let actual_types = plan.get_parameter_types().unwrap(); + let expected_types = HashMap::from([ + // constant 1 is inferred to be int64 + ("$1".to_string(), Some(DataType::Int64)), + ]); + assert_eq!(actual_types, expected_types); +} + +#[test] +#[should_panic( + expected = "Expected: [NOT] NULL | TRUE | FALSE | DISTINCT | [form] NORMALIZED FROM after IS, found: $1" +)] +fn test_prepare_statement_to_plan_panic_is_param() { + let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age is $1"; + logical_plan(sql).unwrap(); +} + +#[test] +fn test_prepare_statement_to_plan_no_param() { + // no embedded parameter but still declare it + let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = 10"; + let (plan, dt) = generate_prepare_stmt_and_data_types(sql); + assert_snapshot!( + plan, + @r#" + Prepare: "my_plan" [Int32] + Projection: person.id, person.age + Filter: person.age = Int64(10) + TableScan: person + "# + ); + assert_snapshot!(dt, @r#"[Int32]"#); + + /////////////////// + // replace params with values + let param_values = vec![ScalarValue::Int32(Some(10))]; + let plan_with_params = plan.with_param_values(param_values).unwrap(); + assert_snapshot!( + plan_with_params, + @r" + Projection: person.id, person.age + Filter: person.age = Int64(10) + TableScan: person + " + ); + + ////////////////////////////////////////// + // no embedded parameter and no declare it + let sql = "PREPARE my_plan AS SELECT id, age FROM person WHERE age = 10"; + let (plan, dt) = generate_prepare_stmt_and_data_types(sql); + assert_snapshot!( + plan, + @r#" + Prepare: "my_plan" [] + Projection: person.id, person.age + Filter: person.age = Int64(10) + TableScan: person + "# + ); + assert_snapshot!(dt, @r#"[]"#); + + /////////////////// + // replace params with values + let param_values: Vec = vec![]; + let plan_with_params = plan.with_param_values(param_values).unwrap(); + assert_snapshot!( + plan_with_params, + @r" + Projection: person.id, person.age + Filter: person.age = Int64(10) + TableScan: person + " + ); +} + +#[test] +fn test_prepare_statement_to_plan_one_param_no_value_panic() { + // no embedded parameter but still declare it + let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = 10"; + let plan = logical_plan(sql).unwrap(); + // declare 1 param but provide 0 + let param_values: Vec = vec![]; + + assert_snapshot!( + plan.with_param_values(param_values) + .unwrap_err() + .strip_backtrace(), + @r###" + Error during planning: Expected 1 parameters, got 0 + "###); +} + +#[test] +fn test_prepare_statement_to_plan_one_param_one_value_different_type_panic() { + // no embedded parameter but still declare it + let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = 10"; + let plan = logical_plan(sql).unwrap(); + // declare 1 param but provide 0 + let param_values = vec![ScalarValue::Float64(Some(20.0))]; + + assert_snapshot!( + plan.with_param_values(param_values) + .unwrap_err() + .strip_backtrace(), + @r###" + Error during planning: Expected parameter of type Int32, got Float64 at index 0 + "### + ); +} + +#[test] +fn test_prepare_statement_to_plan_no_param_on_value_panic() { + // no embedded parameter but still declare it + let sql = "PREPARE my_plan AS SELECT id, age FROM person WHERE age = 10"; + let plan = logical_plan(sql).unwrap(); + // declare 1 param but provide 0 + let param_values = vec![ScalarValue::Int32(Some(10))]; + + assert_snapshot!( + plan.with_param_values(param_values) + .unwrap_err() + .strip_backtrace(), + @r###" + Error during planning: Expected 0 parameters, got 1 + "### + ); +} + +#[test] +fn test_prepare_statement_to_plan_params_as_constants() { + let sql = "PREPARE my_plan(INT) AS SELECT $1"; + let (plan, dt) = generate_prepare_stmt_and_data_types(sql); + assert_snapshot!( + plan, + @r#" + Prepare: "my_plan" [Int32] + Projection: $1 + EmptyRelation + "# + ); + assert_snapshot!(dt, @r#"[Int32]"#); + + /////////////////// + // replace params with values + let param_values = vec![ScalarValue::Int32(Some(10))]; + let plan_with_params = plan.with_param_values(param_values).unwrap(); + assert_snapshot!( + plan_with_params, + @r" + Projection: Int32(10) AS $1 + EmptyRelation + " + ); + + /////////////////////////////////////// + let sql = "PREPARE my_plan(INT) AS SELECT 1 + $1"; + let (plan, dt) = generate_prepare_stmt_and_data_types(sql); + assert_snapshot!( + plan, + @r#" + Prepare: "my_plan" [Int32] + Projection: Int64(1) + $1 + EmptyRelation + "# + ); + assert_snapshot!(dt, @r#"[Int32]"#); + + /////////////////// + // replace params with values + let param_values = vec![ScalarValue::Int32(Some(10))]; + let plan_with_params = plan.with_param_values(param_values).unwrap(); + assert_snapshot!( + plan_with_params, + @r" + Projection: Int64(1) + Int32(10) AS Int64(1) + $1 + EmptyRelation + " + ); + + /////////////////////////////////////// + let sql = "PREPARE my_plan(INT, DOUBLE) AS SELECT 1 + $1 + $2"; + let (plan, dt) = generate_prepare_stmt_and_data_types(sql); + assert_snapshot!( + plan, + @r#" + Prepare: "my_plan" [Int32, Float64] + Projection: Int64(1) + $1 + $2 + EmptyRelation + "# + ); + assert_snapshot!(dt, @r#"[Int32, Float64]"#); + + /////////////////// + // replace params with values + let param_values = vec![ + ScalarValue::Int32(Some(10)), + ScalarValue::Float64(Some(10.0)), + ]; + let plan_with_params = plan.with_param_values(param_values).unwrap(); + assert_snapshot!( + plan_with_params, + @r" + Projection: Int64(1) + Int32(10) + Float64(10) AS Int64(1) + $1 + $2 + EmptyRelation + " + ); +} + +#[test] +fn test_infer_types_from_join() { + let test = ParameterTest { + sql: + "SELECT id, order_id FROM person JOIN orders ON id = customer_id and age = $1", + expected_types: vec![("$1", Some(DataType::Int32))], + param_values: vec![ScalarValue::Int32(Some(10))], + }; + + assert_snapshot!( + test.run(), + @r" + ** Initial Plan: + Projection: person.id, orders.order_id + Inner Join: Filter: person.id = orders.customer_id AND person.age = $1 + TableScan: person + TableScan: orders + ** Final Plan: + Projection: person.id, orders.order_id + Inner Join: Filter: person.id = orders.customer_id AND person.age = Int32(10) + TableScan: person + TableScan: orders + " + ); +} + +#[test] +fn test_prepare_statement_infer_types_from_join() { + let test = ParameterTest { + sql: "PREPARE my_plan AS SELECT id, order_id FROM person JOIN orders ON id = customer_id and age = $1", + expected_types: vec![("$1", Some(DataType::Int32))], + param_values: vec![ScalarValue::Int32(Some(10))] + }; + + assert_snapshot!( + test.run(), + @r#" + ** Initial Plan: + Prepare: "my_plan" [Int32] + Projection: person.id, orders.order_id + Inner Join: Filter: person.id = orders.customer_id AND person.age = $1 + TableScan: person + TableScan: orders + ** Final Plan: + Projection: person.id, orders.order_id + Inner Join: Filter: person.id = orders.customer_id AND person.age = Int32(10) + TableScan: person + TableScan: orders + "# + ); +} + +#[test] +fn test_infer_types_from_predicate() { + let test = ParameterTest { + sql: "SELECT id, age FROM person WHERE age = $1", + expected_types: vec![("$1", Some(DataType::Int32))], + param_values: vec![ScalarValue::Int32(Some(10))], + }; + + assert_snapshot!( + test.run(), + @r" + ** Initial Plan: + Projection: person.id, person.age + Filter: person.age = $1 + TableScan: person + ** Final Plan: + Projection: person.id, person.age + Filter: person.age = Int32(10) + TableScan: person + " + ); +} + +#[test] +fn test_prepare_statement_infer_types_from_predicate() { + let test = ParameterTest { + sql: "PREPARE my_plan AS SELECT id, age FROM person WHERE age = $1", + expected_types: vec![("$1", Some(DataType::Int32))], + param_values: vec![ScalarValue::Int32(Some(10))], + }; + assert_snapshot!( + test.run(), + @r#" + ** Initial Plan: + Prepare: "my_plan" [Int32] + Projection: person.id, person.age + Filter: person.age = $1 + TableScan: person + ** Final Plan: + Projection: person.id, person.age + Filter: person.age = Int32(10) + TableScan: person + "# + ); +} + +#[test] +fn test_infer_types_from_between_predicate() { + let test = ParameterTest { + sql: "SELECT id, age FROM person WHERE age BETWEEN $1 AND $2", + expected_types: vec![ + ("$1", Some(DataType::Int32)), + ("$2", Some(DataType::Int32)), + ], + param_values: vec![ScalarValue::Int32(Some(10)), ScalarValue::Int32(Some(30))], + }; + + assert_snapshot!( + test.run(), + @r" + ** Initial Plan: + Projection: person.id, person.age + Filter: person.age BETWEEN $1 AND $2 + TableScan: person + ** Final Plan: + Projection: person.id, person.age + Filter: person.age BETWEEN Int32(10) AND Int32(30) + TableScan: person + " + ); +} + +#[test] +fn test_prepare_statement_infer_types_from_between_predicate() { + let test = ParameterTest { + sql: "PREPARE my_plan AS SELECT id, age FROM person WHERE age BETWEEN $1 AND $2", + expected_types: vec![ + ("$1", Some(DataType::Int32)), + ("$2", Some(DataType::Int32)), + ], + param_values: vec![ScalarValue::Int32(Some(10)), ScalarValue::Int32(Some(30))], + }; + assert_snapshot!( + test.run(), + @r#" + ** Initial Plan: + Prepare: "my_plan" [Int32, Int32] + Projection: person.id, person.age + Filter: person.age BETWEEN $1 AND $2 + TableScan: person + ** Final Plan: + Projection: person.id, person.age + Filter: person.age BETWEEN Int32(10) AND Int32(30) + TableScan: person + "# + ); +} + +#[test] +fn test_infer_types_subquery() { + let test = ParameterTest { + sql: "SELECT id, age FROM person WHERE age = (select max(age) from person where id = $1)", + expected_types: vec![("$1", Some(DataType::UInt32))], + param_values: vec![ScalarValue::UInt32(Some(10))] + }; + + assert_snapshot!( + test.run(), + @r" + ** Initial Plan: + Projection: person.id, person.age + Filter: person.age = () + Subquery: + Projection: max(person.age) + Aggregate: groupBy=[[]], aggr=[[max(person.age)]] + Filter: person.id = $1 + TableScan: person + TableScan: person + ** Final Plan: + Projection: person.id, person.age + Filter: person.age = () + Subquery: + Projection: max(person.age) + Aggregate: groupBy=[[]], aggr=[[max(person.age)]] + Filter: person.id = UInt32(10) + TableScan: person + TableScan: person + " + ); +} + +#[test] +fn test_prepare_statement_infer_types_subquery() { + let test = ParameterTest { + sql: "PREPARE my_plan AS SELECT id, age FROM person WHERE age = (select max(age) from person where id = $1)", + expected_types: vec![("$1", Some(DataType::UInt32))], + param_values: vec![ScalarValue::UInt32(Some(10))] + }; + + assert_snapshot!( + test.run(), + @r#" + ** Initial Plan: + Prepare: "my_plan" [UInt32] + Projection: person.id, person.age + Filter: person.age = () + Subquery: + Projection: max(person.age) + Aggregate: groupBy=[[]], aggr=[[max(person.age)]] + Filter: person.id = $1 + TableScan: person + TableScan: person + ** Final Plan: + Projection: person.id, person.age + Filter: person.age = () + Subquery: + Projection: max(person.age) + Aggregate: groupBy=[[]], aggr=[[max(person.age)]] + Filter: person.id = UInt32(10) + TableScan: person + TableScan: person + "# + ); +} + +#[test] +fn test_update_infer() { + let test = ParameterTest { + sql: "update person set age=$1 where id=$2", + expected_types: vec![ + ("$1", Some(DataType::Int32)), + ("$2", Some(DataType::UInt32)), + ], + param_values: vec![ScalarValue::Int32(Some(42)), ScalarValue::UInt32(Some(1))], + }; + + assert_snapshot!( + test.run(), + @r" + ** Initial Plan: + Dml: op=[Update] table=[person] + Projection: person.id AS id, person.first_name AS first_name, person.last_name AS last_name, $1 AS age, person.state AS state, person.salary AS salary, person.birth_date AS birth_date, person.😀 AS 😀 + Filter: person.id = $2 + TableScan: person + ** Final Plan: + Dml: op=[Update] table=[person] + Projection: person.id AS id, person.first_name AS first_name, person.last_name AS last_name, Int32(42) AS age, person.state AS state, person.salary AS salary, person.birth_date AS birth_date, person.😀 AS 😀 + Filter: person.id = UInt32(1) + TableScan: person + " + ); +} + +#[test] +fn test_prepare_statement_update_infer() { + let test = ParameterTest { + sql: "PREPARE my_plan AS update person set age=$1 where id=$2", + expected_types: vec![ + ("$1", Some(DataType::Int32)), + ("$2", Some(DataType::UInt32)), + ], + param_values: vec![ScalarValue::Int32(Some(42)), ScalarValue::UInt32(Some(1))], + }; + + assert_snapshot!( + test.run(), + @r#" + ** Initial Plan: + Prepare: "my_plan" [Int32, UInt32] + Dml: op=[Update] table=[person] + Projection: person.id AS id, person.first_name AS first_name, person.last_name AS last_name, $1 AS age, person.state AS state, person.salary AS salary, person.birth_date AS birth_date, person.😀 AS 😀 + Filter: person.id = $2 + TableScan: person + ** Final Plan: + Dml: op=[Update] table=[person] + Projection: person.id AS id, person.first_name AS first_name, person.last_name AS last_name, Int32(42) AS age, person.state AS state, person.salary AS salary, person.birth_date AS birth_date, person.😀 AS 😀 + Filter: person.id = UInt32(1) + TableScan: person + "# + ); +} + +#[test] +fn test_insert_infer() { + let test = ParameterTest { + sql: "insert into person (id, first_name, last_name) values ($1, $2, $3)", + expected_types: vec![ + ("$1", Some(DataType::UInt32)), + ("$2", Some(DataType::Utf8)), + ("$3", Some(DataType::Utf8)), + ], + param_values: vec![ + ScalarValue::UInt32(Some(1)), + ScalarValue::from("Alan"), + ScalarValue::from("Turing"), + ], + }; + + assert_snapshot!( + test.run(), + @r#" + ** Initial Plan: + Dml: op=[Insert Into] table=[person] + Projection: column1 AS id, column2 AS first_name, column3 AS last_name, CAST(NULL AS Int32) AS age, CAST(NULL AS Utf8) AS state, CAST(NULL AS Float64) AS salary, CAST(NULL AS Timestamp(Nanosecond, None)) AS birth_date, CAST(NULL AS Int32) AS 😀 + Values: ($1, $2, $3) + ** Final Plan: + Dml: op=[Insert Into] table=[person] + Projection: column1 AS id, column2 AS first_name, column3 AS last_name, CAST(NULL AS Int32) AS age, CAST(NULL AS Utf8) AS state, CAST(NULL AS Float64) AS salary, CAST(NULL AS Timestamp(Nanosecond, None)) AS birth_date, CAST(NULL AS Int32) AS 😀 + Values: (UInt32(1) AS $1, Utf8("Alan") AS $2, Utf8("Turing") AS $3) + "# + ); +} + +#[test] +fn test_prepare_statement_insert_infer() { + let test = ParameterTest { + sql: "PREPARE my_plan AS insert into person (id, first_name, last_name) values ($1, $2, $3)", + expected_types: vec![ + ("$1", Some(DataType::UInt32)), + ("$2", Some(DataType::Utf8)), + ("$3", Some(DataType::Utf8)), + ], + param_values: vec![ + ScalarValue::UInt32(Some(1)), + ScalarValue::from("Alan"), + ScalarValue::from("Turing"), + ] + }; + assert_snapshot!( + test.run(), + @r#" + ** Initial Plan: + Prepare: "my_plan" [UInt32, Utf8, Utf8] + Dml: op=[Insert Into] table=[person] + Projection: column1 AS id, column2 AS first_name, column3 AS last_name, CAST(NULL AS Int32) AS age, CAST(NULL AS Utf8) AS state, CAST(NULL AS Float64) AS salary, CAST(NULL AS Timestamp(Nanosecond, None)) AS birth_date, CAST(NULL AS Int32) AS 😀 + Values: ($1, $2, $3) + ** Final Plan: + Dml: op=[Insert Into] table=[person] + Projection: column1 AS id, column2 AS first_name, column3 AS last_name, CAST(NULL AS Int32) AS age, CAST(NULL AS Utf8) AS state, CAST(NULL AS Float64) AS salary, CAST(NULL AS Timestamp(Nanosecond, None)) AS birth_date, CAST(NULL AS Int32) AS 😀 + Values: (UInt32(1) AS $1, Utf8("Alan") AS $2, Utf8("Turing") AS $3) + "# + ); +} + +#[test] +fn test_prepare_statement_to_plan_one_param() { + let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = $1"; + let (plan, dt) = generate_prepare_stmt_and_data_types(sql); + assert_snapshot!( + plan, + @r#" + Prepare: "my_plan" [Int32] + Projection: person.id, person.age + Filter: person.age = $1 + TableScan: person + "# + ); + assert_snapshot!(dt, @r#"[Int32]"#); + + /////////////////// + // replace params with values + let param_values = vec![ScalarValue::Int32(Some(10))]; + + let plan_with_params = plan.with_param_values(param_values).unwrap(); + assert_snapshot!( + plan_with_params, + @r" + Projection: person.id, person.age + Filter: person.age = Int32(10) + TableScan: person + " + ); +} + +#[test] +fn test_prepare_statement_to_plan_data_type() { + let sql = "PREPARE my_plan(DOUBLE) AS SELECT id, age FROM person WHERE age = $1"; + + let (plan, dt) = generate_prepare_stmt_and_data_types(sql); + assert_snapshot!( + plan, + // age is defined as Int32 but prepare statement declares it as DOUBLE/Float64 + // Prepare statement and its logical plan should be created successfully + @r#" + Prepare: "my_plan" [Float64] + Projection: person.id, person.age + Filter: person.age = $1 + TableScan: person + "# + ); + assert_snapshot!(dt, @r#"[Float64]"#); + + /////////////////// + // replace params with values still succeed and use Float64 + let param_values = vec![ScalarValue::Float64(Some(10.0))]; + + let plan_with_params = plan.with_param_values(param_values).unwrap(); + assert_snapshot!( + plan_with_params, + @r" + Projection: person.id, person.age + Filter: person.age = Float64(10) + TableScan: person + " + ); +} + +#[test] +fn test_prepare_statement_to_plan_multi_params() { + let sql = "PREPARE my_plan(INT, STRING, DOUBLE, INT, DOUBLE, STRING) AS + SELECT id, age, $6 + FROM person + WHERE age IN ($1, $4) AND salary > $3 and salary < $5 OR first_name < $2"; + let (plan, dt) = generate_prepare_stmt_and_data_types(sql); + assert_snapshot!( + plan, + @r#" + Prepare: "my_plan" [Int32, Utf8, Float64, Int32, Float64, Utf8] + Projection: person.id, person.age, $6 + Filter: person.age IN ([$1, $4]) AND person.salary > $3 AND person.salary < $5 OR person.first_name < $2 + TableScan: person + "# + ); + assert_snapshot!(dt, @r#"[Int32, Utf8, Float64, Int32, Float64, Utf8]"#); + + /////////////////// + // replace params with values + let param_values = vec![ + ScalarValue::Int32(Some(10)), + ScalarValue::from("abc"), + ScalarValue::Float64(Some(100.0)), + ScalarValue::Int32(Some(20)), + ScalarValue::Float64(Some(200.0)), + ScalarValue::from("xyz"), + ]; + + let plan_with_params = plan.with_param_values(param_values).unwrap(); + assert_snapshot!( + plan_with_params, + @r#" + Projection: person.id, person.age, Utf8("xyz") AS $6 + Filter: person.age IN ([Int32(10), Int32(20)]) AND person.salary > Float64(100) AND person.salary < Float64(200) OR person.first_name < Utf8("abc") + TableScan: person + "# + ); +} + +#[test] +fn test_prepare_statement_to_plan_having() { + let sql = "PREPARE my_plan(INT, DOUBLE, DOUBLE, DOUBLE) AS + SELECT id, sum(age) + FROM person \ + WHERE salary > $2 + GROUP BY id + HAVING sum(age) < $1 AND sum(age) > 10 OR sum(age) in ($3, $4)\ + "; + let (plan, dt) = generate_prepare_stmt_and_data_types(sql); + assert_snapshot!( + plan, + @r#" + Prepare: "my_plan" [Int32, Float64, Float64, Float64] + Projection: person.id, sum(person.age) + Filter: sum(person.age) < $1 AND sum(person.age) > Int64(10) OR sum(person.age) IN ([$3, $4]) + Aggregate: groupBy=[[person.id]], aggr=[[sum(person.age)]] + Filter: person.salary > $2 + TableScan: person + "# + ); + assert_snapshot!(dt, @r#"[Int32, Float64, Float64, Float64]"#); + + /////////////////// + // replace params with values + let param_values = vec![ + ScalarValue::Int32(Some(10)), + ScalarValue::Float64(Some(100.0)), + ScalarValue::Float64(Some(200.0)), + ScalarValue::Float64(Some(300.0)), + ]; + + let plan_with_params = plan.with_param_values(param_values).unwrap(); + assert_snapshot!( + plan_with_params, + @r#" + Projection: person.id, sum(person.age) + Filter: sum(person.age) < Int32(10) AND sum(person.age) > Int64(10) OR sum(person.age) IN ([Float64(200), Float64(300)]) + Aggregate: groupBy=[[person.id]], aggr=[[sum(person.age)]] + Filter: person.salary > Float64(100) + TableScan: person + "# + ); +} + +#[test] +fn test_prepare_statement_to_plan_limit() { + let sql = "PREPARE my_plan(BIGINT, BIGINT) AS + SELECT id FROM person \ + OFFSET $1 LIMIT $2"; + let (plan, dt) = generate_prepare_stmt_and_data_types(sql); + assert_snapshot!( + plan, + @r#" + Prepare: "my_plan" [Int64, Int64] + Limit: skip=$1, fetch=$2 + Projection: person.id + TableScan: person + "# + ); + assert_snapshot!(dt, @r#"[Int64, Int64]"#); + + // replace params with values + let param_values = vec![ScalarValue::Int64(Some(10)), ScalarValue::Int64(Some(200))]; + let plan_with_params = plan.with_param_values(param_values).unwrap(); + assert_snapshot!( + plan_with_params, + @r#" + Limit: skip=10, fetch=200 + Projection: person.id + TableScan: person + "# + ); +} + +#[test] +fn test_prepare_statement_unknown_list_param() { + let sql = "SELECT id from person where id = $2"; + let plan = logical_plan(sql).unwrap(); + let param_values = ParamValues::List(vec![]); + let err = plan.replace_params_with_values(¶m_values).unwrap_err(); + assert_contains!( + err.to_string(), + "Error during planning: No value found for placeholder with id $2" + ); +} + +#[test] +fn test_prepare_statement_unknown_hash_param() { + let sql = "SELECT id from person where id = $bar"; + let plan = logical_plan(sql).unwrap(); + let param_values = ParamValues::Map(HashMap::new()); + let err = plan.replace_params_with_values(¶m_values).unwrap_err(); + assert_contains!( + err.to_string(), + "Error during planning: No value found for placeholder with name $bar" + ); +} + +#[test] +fn test_prepare_statement_bad_list_idx() { + let sql = "SELECT id from person where id = $foo"; + let plan = logical_plan(sql).unwrap(); + let param_values = ParamValues::List(vec![]); + + let err = plan.replace_params_with_values(¶m_values).unwrap_err(); + assert_contains!(err.to_string(), "Error during planning: Failed to parse placeholder id: invalid digit found in string"); +} diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index f7f264d1c557..7f74c8a557f7 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -120,19 +120,17 @@ fn roundtrip_statement() -> Result<()> { "select ta.j1_id from j1 ta where ta.j1_id > 1;", "select ta.j1_id, tb.j2_string from j1 ta join j2 tb on (ta.j1_id = tb.j2_id);", "select ta.j1_id, tb.j2_string, tc.j3_string from j1 ta join j2 tb on (ta.j1_id = tb.j2_id) join j3 tc on (ta.j1_id = tc.j3_id);", - // Commented queries are failing since DF46 upgrade on atempt to convert unparsed ast back to plan. Initial plan to ast (sql) conversion is successful/correct. - // `Result::unwrap()` on an `Err` value: Collection([Internal("Not a compound identifier: [Ident { value: \"id\", quote_style: None, span: Span(Location(0,0)..Location(0,0)) }]"), Internal("Not a compound identifier: [Ident { value: \"first_name\", quote_style: None, span: Span(Location(0,0)..Location(0,0)) }]")]) - // "select * from (select id, first_name from person)", - // "select * from (select id, first_name from (select * from person))", - // "select id, count(*) as cnt from (select id from person) group by id", + "select * from (select id, first_name from person)", + "select * from (select id, first_name from (select * from person))", + "select id, count(*) as cnt from (select id from person) group by id", "select (id-1)/2, count(*) / (sum(id/10)-1) as agg_expr from (select (id-1) as id from person) group by id", "select CAST(id/2 as VARCHAR) NOT LIKE 'foo*' from person where NOT EXISTS (select ta.j1_id, tb.j2_string from j1 ta join j2 tb on (ta.j1_id = tb.j2_id))", r#"select "First Name" from person_quoted_cols"#, "select DISTINCT id FROM person", "select DISTINCT on (id) id, first_name from person", "select DISTINCT on (id) id, first_name from person order by id", - // r#"select id, count("First Name") as cnt from (select id, "First Name" from person_quoted_cols) group by id"#, - // "select id, count(*) as cnt from (select p1.id as id from person p1 inner join person p2 on p1.id=p2.id) group by id", + r#"select id, count("First Name") as cnt from (select id, "First Name" from person_quoted_cols) group by id"#, + "select id, count(*) as cnt from (select p1.id as id from person p1 inner join person p2 on p1.id=p2.id) group by id", "select id, count(*), first_name from person group by first_name, id", "select id, sum(age), first_name from person group by first_name, id", "select id, count(*), first_name @@ -147,16 +145,16 @@ fn roundtrip_statement() -> Result<()> { group by "Last Name", id having count_first_name>5 and count_first_name<10 order by count_first_name, "Last Name""#, - // r#"select p.id, count("First Name") as count_first_name, - // "Last Name", sum(qp.id/p.id - (select sum(id) from person_quoted_cols) ) / (select count(*) from person) - // from (select id, "First Name", "Last Name" from person_quoted_cols) qp - // inner join (select * from person) p - // on p.id = qp.id - // where p.id!=3 and "First Name"=='test' and qp.id in - // (select id from (select id, count(*) from person group by id having count(*) > 0)) - // group by "Last Name", p.id - // having count_first_name>5 and count_first_name<10 - // order by count_first_name, "Last Name""#, + r#"select p.id, count("First Name") as count_first_name, + "Last Name", sum(qp.id/p.id - (select sum(id) from person_quoted_cols) ) / (select count(*) from person) + from (select id, "First Name", "Last Name" from person_quoted_cols) qp + inner join (select * from person) p + on p.id = qp.id + where p.id!=3 and "First Name"=='test' and qp.id in + (select id from (select id, count(*) from person group by id having count(*) > 0)) + group by "Last Name", p.id + having count_first_name>5 and count_first_name<10 + order by count_first_name, "Last Name""#, r#"SELECT j1_string as string FROM j1 UNION ALL SELECT j2_string as string FROM j2"#, @@ -215,7 +213,7 @@ fn roundtrip_statement() -> Result<()> { "SELECT [1, 2, 3][1]", "SELECT left[1] FROM array", "SELECT {a:1, b:2}", - // SELECT s.a FROM (SELECT {a:1, b:2} AS s) + "SELECT s.a FROM (SELECT {a:1, b:2} AS s)", "SELECT MAP {'a': 1, 'b': 2}" ]; @@ -771,7 +769,7 @@ fn roundtrip_statement_with_dialect_27() -> Result<(), DataFusionError> { sql: "SELECT * FROM UNNEST([1,2,3])", parser_dialect: GenericDialect {}, unparser_dialect: UnparserDefaultDialect {}, - expected: @r#"SELECT "UNNEST(make_array(Int64(1),Int64(2),Int64(3)))" FROM (SELECT UNNEST([1, 2, 3]) AS "UNNEST(make_array(Int64(1),Int64(2),Int64(3)))")"#, + expected: @r#"SELECT "UNNEST(make_array(Int64(1),Int64(2),Int64(3)))" FROM (SELECT UNNEST([1, 2, 3]) AS "UNNEST(make_array(Int64(1),Int64(2),Int64(3)))") AS derived_projection ("UNNEST(make_array(Int64(1),Int64(2),Int64(3)))")"#, ); Ok(()) } @@ -793,7 +791,7 @@ fn roundtrip_statement_with_dialect_29() -> Result<(), DataFusionError> { sql: "SELECT * FROM UNNEST([1,2,3]), j1", parser_dialect: GenericDialect {}, unparser_dialect: UnparserDefaultDialect {}, - expected: @r#"SELECT "UNNEST(make_array(Int64(1),Int64(2),Int64(3)))", j1.j1_id, j1.j1_string FROM (SELECT UNNEST([1, 2, 3]) AS "UNNEST(make_array(Int64(1),Int64(2),Int64(3)))") CROSS JOIN j1"#, + expected: @r#"SELECT "UNNEST(make_array(Int64(1),Int64(2),Int64(3)))", j1.j1_id, j1.j1_string FROM (SELECT UNNEST([1, 2, 3]) AS "UNNEST(make_array(Int64(1),Int64(2),Int64(3)))") AS derived_projection ("UNNEST(make_array(Int64(1),Int64(2),Int64(3)))") CROSS JOIN j1"#, ); Ok(()) } @@ -1938,7 +1936,7 @@ fn test_complex_order_by_with_grouping() -> Result<()> { }, { assert_snapshot!( sql, - @r#"SELECT j1_id, j1_string, lochierarchy FROM (SELECT j1.j1_id, j1.j1_string, (grouping(j1.j1_id) + grouping(j1.j1_string)) AS lochierarchy, grouping(j1.j1_string), grouping(j1.j1_id) FROM j1 GROUP BY ROLLUP (j1.j1_id, j1.j1_string) ORDER BY (grouping(j1.j1_id) + grouping(j1.j1_string)) DESC NULLS FIRST, CASE WHEN ((grouping(j1.j1_id) + grouping(j1.j1_string)) = 0) THEN j1.j1_id END ASC NULLS LAST) LIMIT 100"# + @r#"SELECT j1.j1_id, j1.j1_string, lochierarchy FROM (SELECT j1.j1_id, j1.j1_string, (grouping(j1.j1_id) + grouping(j1.j1_string)) AS lochierarchy, grouping(j1.j1_string), grouping(j1.j1_id) FROM j1 GROUP BY ROLLUP (j1.j1_id, j1.j1_string) ORDER BY (grouping(j1.j1_id) + grouping(j1.j1_string)) DESC NULLS FIRST, CASE WHEN ((grouping(j1.j1_id) + grouping(j1.j1_string)) = 0) THEN j1.j1_id END ASC NULLS LAST) LIMIT 100"# ); }); diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 2804a1de0606..4be7953aefc0 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -17,21 +17,16 @@ use std::any::Any; #[cfg(test)] -use std::collections::HashMap; use std::sync::Arc; use std::vec; use arrow::datatypes::{TimeUnit::Nanosecond, *}; use common::MockContextProvider; -use datafusion_common::{ - assert_contains, DataFusionError, ParamValues, Result, ScalarValue, -}; +use datafusion_common::{assert_contains, DataFusionError, Result}; use datafusion_expr::{ - col, - logical_plan::{LogicalPlan, Prepare}, - test::function_stub::sum_udaf, - ColumnarValue, CreateIndex, DdlStatement, ScalarFunctionArgs, ScalarUDF, - ScalarUDFImpl, Signature, Statement, Volatility, + col, logical_plan::LogicalPlan, test::function_stub::sum_udaf, ColumnarValue, + CreateIndex, DdlStatement, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, + Volatility, }; use datafusion_functions::{string, unicode}; use datafusion_sql::{ @@ -761,7 +756,7 @@ fn plan_delete() { plan, @r#" Dml: op=[Delete] table=[person] - Filter: id = Int64(1) + Filter: person.id = Int64(1) TableScan: person "# ); @@ -776,7 +771,7 @@ fn plan_delete_quoted_identifier_case_sensitive() { plan, @r#" Dml: op=[Delete] table=[SomeCatalog.SomeSchema.UPPERCASE_test] - Filter: Id = Int64(1) + Filter: SomeCatalog.SomeSchema.UPPERCASE_test.Id = Int64(1) TableScan: SomeCatalog.SomeSchema.UPPERCASE_test "# ); @@ -3360,7 +3355,7 @@ fn parse_decimals_parser_options() -> ParserOptions { parse_float_as_decimal: true, enable_ident_normalization: false, support_varchar_with_length: false, - map_varchar_to_utf8view: false, + map_varchar_to_utf8view: true, enable_options_value_normalization: false, collect_spans: false, } @@ -3371,7 +3366,7 @@ fn ident_normalization_parser_options_no_ident_normalization() -> ParserOptions parse_float_as_decimal: true, enable_ident_normalization: false, support_varchar_with_length: false, - map_varchar_to_utf8view: false, + map_varchar_to_utf8view: true, enable_options_value_normalization: false, collect_spans: false, } @@ -3382,23 +3377,12 @@ fn ident_normalization_parser_options_ident_normalization() -> ParserOptions { parse_float_as_decimal: true, enable_ident_normalization: true, support_varchar_with_length: false, - map_varchar_to_utf8view: false, + map_varchar_to_utf8view: true, enable_options_value_normalization: false, collect_spans: false, } } -fn generate_prepare_stmt_and_data_types(sql: &str) -> (LogicalPlan, String) { - let plan = logical_plan(sql).unwrap(); - let data_types = match &plan { - LogicalPlan::Statement(Statement::Prepare(Prepare { data_types, .. })) => { - format!("{data_types:?}") - } - _ => panic!("Expected a Prepare statement"), - }; - (plan, data_types) -} - #[test] fn select_partially_qualified_column() { let sql = "SELECT person.first_name FROM public.person"; @@ -4330,712 +4314,6 @@ Projection: p1.id, p1.age, p2.id ); } -#[test] -fn test_prepare_statement_to_plan_panic_param_format() { - // param is not number following the $ sign - // panic due to error returned from the parser - let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = $foo"; - - assert_snapshot!( - logical_plan(sql).unwrap_err().strip_backtrace(), - @r###" - Error during planning: Invalid placeholder, not a number: $foo - "### - ); -} - -#[test] -fn test_prepare_statement_to_plan_panic_param_zero() { - // param is zero following the $ sign - // panic due to error returned from the parser - let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = $0"; - - assert_snapshot!( - logical_plan(sql).unwrap_err().strip_backtrace(), - @r###" - Error during planning: Invalid placeholder, zero is not a valid index: $0 - "### - ); -} - -#[test] -fn test_prepare_statement_to_plan_panic_prepare_wrong_syntax() { - // param is not number following the $ sign - // panic due to error returned from the parser - let sql = "PREPARE AS SELECT id, age FROM person WHERE age = $foo"; - assert!(logical_plan(sql) - .unwrap_err() - .strip_backtrace() - .contains("Expected: AS, found: SELECT")) -} - -#[test] -fn test_prepare_statement_to_plan_panic_no_relation_and_constant_param() { - let sql = "PREPARE my_plan(INT) AS SELECT id + $1"; - - let plan = logical_plan(sql).unwrap_err().strip_backtrace(); - assert_snapshot!( - plan, - @r"Schema error: No field named id." - ); -} - -#[test] -fn test_prepare_statement_should_infer_types() { - // only provide 1 data type while using 2 params - let sql = "PREPARE my_plan(INT) AS SELECT 1 + $1 + $2"; - let plan = logical_plan(sql).unwrap(); - let actual_types = plan.get_parameter_types().unwrap(); - let expected_types = HashMap::from([ - ("$1".to_string(), Some(DataType::Int32)), - ("$2".to_string(), Some(DataType::Int64)), - ]); - assert_eq!(actual_types, expected_types); -} - -#[test] -fn test_non_prepare_statement_should_infer_types() { - // Non prepared statements (like SELECT) should also have their parameter types inferred - let sql = "SELECT 1 + $1"; - let plan = logical_plan(sql).unwrap(); - let actual_types = plan.get_parameter_types().unwrap(); - let expected_types = HashMap::from([ - // constant 1 is inferred to be int64 - ("$1".to_string(), Some(DataType::Int64)), - ]); - assert_eq!(actual_types, expected_types); -} - -#[test] -#[should_panic( - expected = "Expected: [NOT] NULL | TRUE | FALSE | DISTINCT | [form] NORMALIZED FROM after IS, found: $1" -)] -fn test_prepare_statement_to_plan_panic_is_param() { - let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age is $1"; - logical_plan(sql).unwrap(); -} - -#[test] -fn test_prepare_statement_to_plan_no_param() { - // no embedded parameter but still declare it - let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = 10"; - let (plan, dt) = generate_prepare_stmt_and_data_types(sql); - assert_snapshot!( - plan, - @r#" - Prepare: "my_plan" [Int32] - Projection: person.id, person.age - Filter: person.age = Int64(10) - TableScan: person - "# - ); - assert_snapshot!(dt, @r#"[Int32]"#); - - /////////////////// - // replace params with values - let param_values = vec![ScalarValue::Int32(Some(10))]; - let plan_with_params = plan.with_param_values(param_values).unwrap(); - assert_snapshot!( - plan_with_params, - @r" - Projection: person.id, person.age - Filter: person.age = Int64(10) - TableScan: person - " - ); - - ////////////////////////////////////////// - // no embedded parameter and no declare it - let sql = "PREPARE my_plan AS SELECT id, age FROM person WHERE age = 10"; - let (plan, dt) = generate_prepare_stmt_and_data_types(sql); - assert_snapshot!( - plan, - @r#" - Prepare: "my_plan" [] - Projection: person.id, person.age - Filter: person.age = Int64(10) - TableScan: person - "# - ); - assert_snapshot!(dt, @r#"[]"#); - - /////////////////// - // replace params with values - let param_values: Vec = vec![]; - let plan_with_params = plan.with_param_values(param_values).unwrap(); - assert_snapshot!( - plan_with_params, - @r" - Projection: person.id, person.age - Filter: person.age = Int64(10) - TableScan: person - " - ); -} - -#[test] -fn test_prepare_statement_to_plan_one_param_no_value_panic() { - // no embedded parameter but still declare it - let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = 10"; - let plan = logical_plan(sql).unwrap(); - // declare 1 param but provide 0 - let param_values: Vec = vec![]; - - assert_snapshot!( - plan.with_param_values(param_values) - .unwrap_err() - .strip_backtrace(), - @r###" - Error during planning: Expected 1 parameters, got 0 - "###); -} - -#[test] -fn test_prepare_statement_to_plan_one_param_one_value_different_type_panic() { - // no embedded parameter but still declare it - let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = 10"; - let plan = logical_plan(sql).unwrap(); - // declare 1 param but provide 0 - let param_values = vec![ScalarValue::Float64(Some(20.0))]; - - assert_snapshot!( - plan.with_param_values(param_values) - .unwrap_err() - .strip_backtrace(), - @r###" - Error during planning: Expected parameter of type Int32, got Float64 at index 0 - "### - ); -} - -#[test] -fn test_prepare_statement_to_plan_no_param_on_value_panic() { - // no embedded parameter but still declare it - let sql = "PREPARE my_plan AS SELECT id, age FROM person WHERE age = 10"; - let plan = logical_plan(sql).unwrap(); - // declare 1 param but provide 0 - let param_values = vec![ScalarValue::Int32(Some(10))]; - - assert_snapshot!( - plan.with_param_values(param_values) - .unwrap_err() - .strip_backtrace(), - @r###" - Error during planning: Expected 0 parameters, got 1 - "### - ); -} - -#[test] -fn test_prepare_statement_to_plan_params_as_constants() { - let sql = "PREPARE my_plan(INT) AS SELECT $1"; - let (plan, dt) = generate_prepare_stmt_and_data_types(sql); - assert_snapshot!( - plan, - @r#" - Prepare: "my_plan" [Int32] - Projection: $1 - EmptyRelation - "# - ); - assert_snapshot!(dt, @r#"[Int32]"#); - - /////////////////// - // replace params with values - let param_values = vec![ScalarValue::Int32(Some(10))]; - let plan_with_params = plan.with_param_values(param_values).unwrap(); - assert_snapshot!( - plan_with_params, - @r" - Projection: Int32(10) AS $1 - EmptyRelation - " - ); - - /////////////////////////////////////// - let sql = "PREPARE my_plan(INT) AS SELECT 1 + $1"; - let (plan, dt) = generate_prepare_stmt_and_data_types(sql); - assert_snapshot!( - plan, - @r#" - Prepare: "my_plan" [Int32] - Projection: Int64(1) + $1 - EmptyRelation - "# - ); - assert_snapshot!(dt, @r#"[Int32]"#); - - /////////////////// - // replace params with values - let param_values = vec![ScalarValue::Int32(Some(10))]; - let plan_with_params = plan.with_param_values(param_values).unwrap(); - assert_snapshot!( - plan_with_params, - @r" - Projection: Int64(1) + Int32(10) AS Int64(1) + $1 - EmptyRelation - " - ); - - /////////////////////////////////////// - let sql = "PREPARE my_plan(INT, DOUBLE) AS SELECT 1 + $1 + $2"; - let (plan, dt) = generate_prepare_stmt_and_data_types(sql); - assert_snapshot!( - plan, - @r#" - Prepare: "my_plan" [Int32, Float64] - Projection: Int64(1) + $1 + $2 - EmptyRelation - "# - ); - assert_snapshot!(dt, @r#"[Int32, Float64]"#); - - /////////////////// - // replace params with values - let param_values = vec![ - ScalarValue::Int32(Some(10)), - ScalarValue::Float64(Some(10.0)), - ]; - let plan_with_params = plan.with_param_values(param_values).unwrap(); - assert_snapshot!( - plan_with_params, - @r" - Projection: Int64(1) + Int32(10) + Float64(10) AS Int64(1) + $1 + $2 - EmptyRelation - " - ); -} - -#[test] -fn test_infer_types_from_join() { - let sql = - "SELECT id, order_id FROM person JOIN orders ON id = customer_id and age = $1"; - - let plan = logical_plan(sql).unwrap(); - assert_snapshot!( - plan, - @r#" - Projection: person.id, orders.order_id - Inner Join: Filter: person.id = orders.customer_id AND person.age = $1 - TableScan: person - TableScan: orders - "# - ); - - let actual_types = plan.get_parameter_types().unwrap(); - let expected_types = HashMap::from([("$1".to_string(), Some(DataType::Int32))]); - assert_eq!(actual_types, expected_types); - - // replace params with values - let param_values = vec![ScalarValue::Int32(Some(10))]; - let plan_with_params = plan.with_param_values(param_values).unwrap(); - - assert_snapshot!( - plan_with_params, - @r" - Projection: person.id, orders.order_id - Inner Join: Filter: person.id = orders.customer_id AND person.age = Int32(10) - TableScan: person - TableScan: orders - " - ); -} - -#[test] -fn test_infer_types_from_predicate() { - let sql = "SELECT id, age FROM person WHERE age = $1"; - let plan = logical_plan(sql).unwrap(); - assert_snapshot!( - plan, - @r#" - Projection: person.id, person.age - Filter: person.age = $1 - TableScan: person - "# - ); - - let actual_types = plan.get_parameter_types().unwrap(); - let expected_types = HashMap::from([("$1".to_string(), Some(DataType::Int32))]); - assert_eq!(actual_types, expected_types); - - // replace params with values - let param_values = vec![ScalarValue::Int32(Some(10))]; - let plan_with_params = plan.with_param_values(param_values).unwrap(); - - assert_snapshot!( - plan_with_params, - @r" - Projection: person.id, person.age - Filter: person.age = Int32(10) - TableScan: person - " - ); -} - -#[test] -fn test_infer_types_from_between_predicate() { - let sql = "SELECT id, age FROM person WHERE age BETWEEN $1 AND $2"; - - let plan = logical_plan(sql).unwrap(); - assert_snapshot!( - plan, - @r#" - Projection: person.id, person.age - Filter: person.age BETWEEN $1 AND $2 - TableScan: person - "# - ); - - let actual_types = plan.get_parameter_types().unwrap(); - let expected_types = HashMap::from([ - ("$1".to_string(), Some(DataType::Int32)), - ("$2".to_string(), Some(DataType::Int32)), - ]); - assert_eq!(actual_types, expected_types); - - // replace params with values - let param_values = vec![ScalarValue::Int32(Some(10)), ScalarValue::Int32(Some(30))]; - let plan_with_params = plan.with_param_values(param_values).unwrap(); - - assert_snapshot!( - plan_with_params, - @r" - Projection: person.id, person.age - Filter: person.age BETWEEN Int32(10) AND Int32(30) - TableScan: person - " - ); -} - -#[test] -fn test_infer_types_subquery() { - let sql = "SELECT id, age FROM person WHERE age = (select max(age) from person where id = $1)"; - - let plan = logical_plan(sql).unwrap(); - assert_snapshot!( - plan, - @r#" - Projection: person.id, person.age - Filter: person.age = () - Subquery: - Projection: max(person.age) - Aggregate: groupBy=[[]], aggr=[[max(person.age)]] - Filter: person.id = $1 - TableScan: person - TableScan: person - "# - ); - - let actual_types = plan.get_parameter_types().unwrap(); - let expected_types = HashMap::from([("$1".to_string(), Some(DataType::UInt32))]); - assert_eq!(actual_types, expected_types); - - // replace params with values - let param_values = vec![ScalarValue::UInt32(Some(10))]; - let plan_with_params = plan.with_param_values(param_values).unwrap(); - - assert_snapshot!( - plan_with_params, - @r" - Projection: person.id, person.age - Filter: person.age = () - Subquery: - Projection: max(person.age) - Aggregate: groupBy=[[]], aggr=[[max(person.age)]] - Filter: person.id = UInt32(10) - TableScan: person - TableScan: person - " - ); -} - -#[test] -fn test_update_infer() { - let sql = "update person set age=$1 where id=$2"; - - let plan = logical_plan(sql).unwrap(); - assert_snapshot!( - plan, - @r#" - Dml: op=[Update] table=[person] - Projection: person.id AS id, person.first_name AS first_name, person.last_name AS last_name, $1 AS age, person.state AS state, person.salary AS salary, person.birth_date AS birth_date, person.😀 AS 😀 - Filter: person.id = $2 - TableScan: person - "# - ); - - let actual_types = plan.get_parameter_types().unwrap(); - let expected_types = HashMap::from([ - ("$1".to_string(), Some(DataType::Int32)), - ("$2".to_string(), Some(DataType::UInt32)), - ]); - assert_eq!(actual_types, expected_types); - - // replace params with values - let param_values = vec![ScalarValue::Int32(Some(42)), ScalarValue::UInt32(Some(1))]; - let plan_with_params = plan.with_param_values(param_values).unwrap(); - - assert_snapshot!( - plan_with_params, - @r" - Dml: op=[Update] table=[person] - Projection: person.id AS id, person.first_name AS first_name, person.last_name AS last_name, Int32(42) AS age, person.state AS state, person.salary AS salary, person.birth_date AS birth_date, person.😀 AS 😀 - Filter: person.id = UInt32(1) - TableScan: person - " - ); -} - -#[test] -fn test_insert_infer() { - let sql = "insert into person (id, first_name, last_name) values ($1, $2, $3)"; - let plan = logical_plan(sql).unwrap(); - assert_snapshot!( - plan, - @r#" - Dml: op=[Insert Into] table=[person] - Projection: column1 AS id, column2 AS first_name, column3 AS last_name, CAST(NULL AS Int32) AS age, CAST(NULL AS Utf8) AS state, CAST(NULL AS Float64) AS salary, CAST(NULL AS Timestamp(Nanosecond, None)) AS birth_date, CAST(NULL AS Int32) AS 😀 - Values: ($1, $2, $3) - "# - ); - - let actual_types = plan.get_parameter_types().unwrap(); - let expected_types = HashMap::from([ - ("$1".to_string(), Some(DataType::UInt32)), - ("$2".to_string(), Some(DataType::Utf8)), - ("$3".to_string(), Some(DataType::Utf8)), - ]); - assert_eq!(actual_types, expected_types); - - // replace params with values - let param_values = vec![ - ScalarValue::UInt32(Some(1)), - ScalarValue::from("Alan"), - ScalarValue::from("Turing"), - ]; - let plan_with_params = plan.with_param_values(param_values).unwrap(); - assert_snapshot!( - plan_with_params, - @r#" - Dml: op=[Insert Into] table=[person] - Projection: column1 AS id, column2 AS first_name, column3 AS last_name, CAST(NULL AS Int32) AS age, CAST(NULL AS Utf8) AS state, CAST(NULL AS Float64) AS salary, CAST(NULL AS Timestamp(Nanosecond, None)) AS birth_date, CAST(NULL AS Int32) AS 😀 - Values: (UInt32(1) AS $1, Utf8("Alan") AS $2, Utf8("Turing") AS $3) - "# - ); -} - -#[test] -fn test_prepare_statement_to_plan_one_param() { - let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = $1"; - let (plan, dt) = generate_prepare_stmt_and_data_types(sql); - assert_snapshot!( - plan, - @r#" - Prepare: "my_plan" [Int32] - Projection: person.id, person.age - Filter: person.age = $1 - TableScan: person - "# - ); - assert_snapshot!(dt, @r#"[Int32]"#); - - /////////////////// - // replace params with values - let param_values = vec![ScalarValue::Int32(Some(10))]; - - let plan_with_params = plan.with_param_values(param_values).unwrap(); - assert_snapshot!( - plan_with_params, - @r" - Projection: person.id, person.age - Filter: person.age = Int32(10) - TableScan: person - " - ); -} - -#[test] -fn test_prepare_statement_to_plan_data_type() { - let sql = "PREPARE my_plan(DOUBLE) AS SELECT id, age FROM person WHERE age = $1"; - - let (plan, dt) = generate_prepare_stmt_and_data_types(sql); - assert_snapshot!( - plan, - // age is defined as Int32 but prepare statement declares it as DOUBLE/Float64 - // Prepare statement and its logical plan should be created successfully - @r#" - Prepare: "my_plan" [Float64] - Projection: person.id, person.age - Filter: person.age = $1 - TableScan: person - "# - ); - assert_snapshot!(dt, @r#"[Float64]"#); - - /////////////////// - // replace params with values still succeed and use Float64 - let param_values = vec![ScalarValue::Float64(Some(10.0))]; - - let plan_with_params = plan.with_param_values(param_values).unwrap(); - assert_snapshot!( - plan_with_params, - @r" - Projection: person.id, person.age - Filter: person.age = Float64(10) - TableScan: person - " - ); -} - -#[test] -fn test_prepare_statement_to_plan_multi_params() { - let sql = "PREPARE my_plan(INT, STRING, DOUBLE, INT, DOUBLE, STRING) AS - SELECT id, age, $6 - FROM person - WHERE age IN ($1, $4) AND salary > $3 and salary < $5 OR first_name < $2"; - let (plan, dt) = generate_prepare_stmt_and_data_types(sql); - assert_snapshot!( - plan, - @r#" - Prepare: "my_plan" [Int32, Utf8, Float64, Int32, Float64, Utf8] - Projection: person.id, person.age, $6 - Filter: person.age IN ([$1, $4]) AND person.salary > $3 AND person.salary < $5 OR person.first_name < $2 - TableScan: person - "# - ); - assert_snapshot!(dt, @r#"[Int32, Utf8, Float64, Int32, Float64, Utf8]"#); - - /////////////////// - // replace params with values - let param_values = vec![ - ScalarValue::Int32(Some(10)), - ScalarValue::from("abc"), - ScalarValue::Float64(Some(100.0)), - ScalarValue::Int32(Some(20)), - ScalarValue::Float64(Some(200.0)), - ScalarValue::from("xyz"), - ]; - - let plan_with_params = plan.with_param_values(param_values).unwrap(); - assert_snapshot!( - plan_with_params, - @r#" - Projection: person.id, person.age, Utf8("xyz") AS $6 - Filter: person.age IN ([Int32(10), Int32(20)]) AND person.salary > Float64(100) AND person.salary < Float64(200) OR person.first_name < Utf8("abc") - TableScan: person - "# - ); -} - -#[test] -fn test_prepare_statement_to_plan_having() { - let sql = "PREPARE my_plan(INT, DOUBLE, DOUBLE, DOUBLE) AS - SELECT id, sum(age) - FROM person \ - WHERE salary > $2 - GROUP BY id - HAVING sum(age) < $1 AND sum(age) > 10 OR sum(age) in ($3, $4)\ - "; - let (plan, dt) = generate_prepare_stmt_and_data_types(sql); - assert_snapshot!( - plan, - @r#" - Prepare: "my_plan" [Int32, Float64, Float64, Float64] - Projection: person.id, sum(person.age) - Filter: sum(person.age) < $1 AND sum(person.age) > Int64(10) OR sum(person.age) IN ([$3, $4]) - Aggregate: groupBy=[[person.id]], aggr=[[sum(person.age)]] - Filter: person.salary > $2 - TableScan: person - "# - ); - assert_snapshot!(dt, @r#"[Int32, Float64, Float64, Float64]"#); - - /////////////////// - // replace params with values - let param_values = vec![ - ScalarValue::Int32(Some(10)), - ScalarValue::Float64(Some(100.0)), - ScalarValue::Float64(Some(200.0)), - ScalarValue::Float64(Some(300.0)), - ]; - - let plan_with_params = plan.with_param_values(param_values).unwrap(); - assert_snapshot!( - plan_with_params, - @r#" - Projection: person.id, sum(person.age) - Filter: sum(person.age) < Int32(10) AND sum(person.age) > Int64(10) OR sum(person.age) IN ([Float64(200), Float64(300)]) - Aggregate: groupBy=[[person.id]], aggr=[[sum(person.age)]] - Filter: person.salary > Float64(100) - TableScan: person - "# - ); -} - -#[test] -fn test_prepare_statement_to_plan_limit() { - let sql = "PREPARE my_plan(BIGINT, BIGINT) AS - SELECT id FROM person \ - OFFSET $1 LIMIT $2"; - let (plan, dt) = generate_prepare_stmt_and_data_types(sql); - assert_snapshot!( - plan, - @r#" - Prepare: "my_plan" [Int64, Int64] - Limit: skip=$1, fetch=$2 - Projection: person.id - TableScan: person - "# - ); - assert_snapshot!(dt, @r#"[Int64, Int64]"#); - - // replace params with values - let param_values = vec![ScalarValue::Int64(Some(10)), ScalarValue::Int64(Some(200))]; - let plan_with_params = plan.with_param_values(param_values).unwrap(); - assert_snapshot!( - plan_with_params, - @r#" - Limit: skip=10, fetch=200 - Projection: person.id - TableScan: person - "# - ); -} - -#[test] -fn test_prepare_statement_unknown_list_param() { - let sql = "SELECT id from person where id = $2"; - let plan = logical_plan(sql).unwrap(); - let param_values = ParamValues::List(vec![]); - let err = plan.replace_params_with_values(¶m_values).unwrap_err(); - assert_contains!( - err.to_string(), - "Error during planning: No value found for placeholder with id $2" - ); -} - -#[test] -fn test_prepare_statement_unknown_hash_param() { - let sql = "SELECT id from person where id = $bar"; - let plan = logical_plan(sql).unwrap(); - let param_values = ParamValues::Map(HashMap::new()); - let err = plan.replace_params_with_values(¶m_values).unwrap_err(); - assert_contains!( - err.to_string(), - "Error during planning: No value found for placeholder with name $bar" - ); -} - -#[test] -fn test_prepare_statement_bad_list_idx() { - let sql = "SELECT id from person where id = $foo"; - let plan = logical_plan(sql).unwrap(); - let param_values = ParamValues::List(vec![]); - - let err = plan.replace_params_with_values(¶m_values).unwrap_err(); - assert_contains!(err.to_string(), "Error during planning: Failed to parse placeholder id: invalid digit found in string"); -} - #[test] fn test_inner_join_with_cast_key() { let sql = "SELECT person.id, person.age diff --git a/datafusion/sqllogictest/Cargo.toml b/datafusion/sqllogictest/Cargo.toml index 16cd3d5b3aa4..c90e8daefe49 100644 --- a/datafusion/sqllogictest/Cargo.toml +++ b/datafusion/sqllogictest/Cargo.toml @@ -42,8 +42,10 @@ async-trait = { workspace = true } bigdecimal = { workspace = true } bytes = { workspace = true, optional = true } chrono = { workspace = true, optional = true } -clap = { version = "4.5.35", features = ["derive", "env"] } +clap = { version = "4.5.39", features = ["derive", "env"] } datafusion = { workspace = true, default-features = true, features = ["avro"] } +datafusion-spark = { workspace = true, default-features = true } +datafusion-substrait = { path = "../substrait" } futures = { workspace = true } half = { workspace = true, default-features = true } indicatif = "0.17" @@ -55,11 +57,11 @@ postgres-types = { version = "0.2.8", features = ["derive", "with-chrono-0_4"], rust_decimal = { version = "1.37.1", features = ["tokio-pg"] } # When updating the following dependency verify that sqlite test file regeneration works correctly # by running the regenerate_sqlite_files.sh script. -sqllogictest = "0.28.0" +sqllogictest = "0.28.2" sqlparser = { workspace = true } tempfile = { workspace = true } -testcontainers = { version = "0.23", features = ["default"], optional = true } -testcontainers-modules = { version = "0.11", features = ["postgres"], optional = true } +testcontainers = { version = "0.24", features = ["default"], optional = true } +testcontainers-modules = { version = "0.12", features = ["postgres"], optional = true } thiserror = "2.0.12" tokio = { workspace = true } tokio-postgres = { version = "0.7.12", optional = true } diff --git a/datafusion/sqllogictest/README.md b/datafusion/sqllogictest/README.md index 77162f4001ae..3fdb29c9d5cd 100644 --- a/datafusion/sqllogictest/README.md +++ b/datafusion/sqllogictest/README.md @@ -156,6 +156,14 @@ sqllogictests also supports `cargo test` style substring matches on file names t cargo test --test sqllogictests -- information ``` +Additionally, executing specific tests within a file is also supported. Tests are identified by line number within +the .slt file; for example, the following command will run the test in line `709` for file `information.slt` along +with any other preparatory statements: + +```shell +cargo test --test sqllogictests -- information:709 +``` + ## Running tests: Postgres compatibility Test files that start with prefix `pg_compat_` verify compatibility @@ -283,6 +291,27 @@ Tests that need to write temporary files should write (only) to this directory to ensure they do not interfere with others concurrently running tests. +## Running tests: Substrait round-trip mode + +This mode will run all the .slt test files in validation mode, adding a Substrait conversion round-trip for each +generated DataFusion logical plan (SQL statement → DF logical → Substrait → DF logical → DF physical → execute). + +Not all statements will be round-tripped, some statements like CREATE, INSERT, SET or EXPLAIN statements will be +issued as is, but any other statement will be round-tripped to/from Substrait. + +_WARNING_: as there are still a lot of failures in this mode (https://github.com/apache/datafusion/issues/16248), +it is not enforced in the CI, instead, it needs to be run manually with the following command: + +```shell +cargo test --test sqllogictests -- --substrait-round-trip +``` + +For focusing on one specific failing test, a file:line filter can be used: + +```shell +cargo test --test sqllogictests -- --substrait-round-trip binary.slt:23 +``` + ## `.slt` file format [`sqllogictest`] was originally written for SQLite to verify the diff --git a/datafusion/sqllogictest/bin/sqllogictests.rs b/datafusion/sqllogictest/bin/sqllogictests.rs index 21dfe2ee08f4..d5fce1a7cdb2 100644 --- a/datafusion/sqllogictest/bin/sqllogictests.rs +++ b/datafusion/sqllogictest/bin/sqllogictests.rs @@ -20,8 +20,9 @@ use datafusion::common::instant::Instant; use datafusion::common::utils::get_available_parallelism; use datafusion::common::{exec_err, DataFusionError, Result}; use datafusion_sqllogictest::{ - df_value_validator, read_dir_recursive, setup_scratch_dir, value_normalizer, - DataFusion, TestContext, + df_value_validator, read_dir_recursive, setup_scratch_dir, should_skip_file, + should_skip_record, value_normalizer, DataFusion, DataFusionSubstraitRoundTrip, + Filter, TestContext, }; use futures::stream::StreamExt; use indicatif::{ @@ -31,8 +32,8 @@ use itertools::Itertools; use log::Level::Info; use log::{info, log_enabled}; use sqllogictest::{ - parse_file, strict_column_validator, AsyncDB, Condition, Normalizer, Record, - Validator, + parse_file, strict_column_validator, AsyncDB, Condition, MakeConnection, Normalizer, + Record, Validator, }; #[cfg(feature = "postgres")] @@ -50,6 +51,7 @@ const TEST_DIRECTORY: &str = "test_files/"; const DATAFUSION_TESTING_TEST_DIRECTORY: &str = "../../datafusion-testing/data/"; const PG_COMPAT_FILE_PREFIX: &str = "pg_compat_"; const SQLITE_PREFIX: &str = "sqlite"; +const ERRS_PER_FILE_LIMIT: usize = 10; pub fn main() -> Result<()> { tokio::runtime::Builder::new_multi_thread() @@ -101,6 +103,7 @@ async fn run_tests() -> Result<()> { // to stdout and return OK so they can continue listing other tests. return Ok(()); } + options.warn_on_ignored(); #[cfg(feature = "postgres")] @@ -134,27 +137,49 @@ async fn run_tests() -> Result<()> { let m_clone = m.clone(); let m_style_clone = m_style.clone(); + let filters = options.filters.clone(); SpawnedTask::spawn(async move { - match (options.postgres_runner, options.complete) { - (false, false) => { - run_test_file(test_file, validator, m_clone, m_style_clone) - .await? + match ( + options.postgres_runner, + options.complete, + options.substrait_round_trip, + ) { + (_, _, true) => { + run_test_file_substrait_round_trip( + test_file, + validator, + m_clone, + m_style_clone, + filters.as_ref(), + ) + .await? } - (false, true) => { + (false, false, _) => { + run_test_file( + test_file, + validator, + m_clone, + m_style_clone, + filters.as_ref(), + ) + .await? + } + (false, true, _) => { run_complete_file(test_file, validator, m_clone, m_style_clone) .await? } - (true, false) => { + (true, false, _) => { run_test_file_with_postgres( test_file, validator, m_clone, m_style_clone, + filters.as_ref(), ) .await? } - (true, true) => { + (true, true, _) => { run_complete_file_with_postgres( test_file, validator, @@ -201,11 +226,51 @@ async fn run_tests() -> Result<()> { } } +async fn run_test_file_substrait_round_trip( + test_file: TestFile, + validator: Validator, + mp: MultiProgress, + mp_style: ProgressStyle, + filters: &[Filter], +) -> Result<()> { + let TestFile { + path, + relative_path, + } = test_file; + let Some(test_ctx) = TestContext::try_new_for_test_file(&relative_path).await else { + info!("Skipping: {}", path.display()); + return Ok(()); + }; + setup_scratch_dir(&relative_path)?; + + let count: u64 = get_record_count(&path, "DatafusionSubstraitRoundTrip".to_string()); + let pb = mp.add(ProgressBar::new(count)); + + pb.set_style(mp_style); + pb.set_message(format!("{:?}", &relative_path)); + + let mut runner = sqllogictest::Runner::new(|| async { + Ok(DataFusionSubstraitRoundTrip::new( + test_ctx.session_ctx().clone(), + relative_path.clone(), + pb.clone(), + )) + }); + runner.add_label("DatafusionSubstraitRoundTrip"); + runner.with_column_validator(strict_column_validator); + runner.with_normalizer(value_normalizer); + runner.with_validator(validator); + let res = run_file_in_runner(path, runner, filters).await; + pb.finish_and_clear(); + res +} + async fn run_test_file( test_file: TestFile, validator: Validator, mp: MultiProgress, mp_style: ProgressStyle, + filters: &[Filter], ) -> Result<()> { let TestFile { path, @@ -234,15 +299,49 @@ async fn run_test_file( runner.with_column_validator(strict_column_validator); runner.with_normalizer(value_normalizer); runner.with_validator(validator); + let result = run_file_in_runner(path, runner, filters).await; + pb.finish_and_clear(); + result +} - let res = runner - .run_file_async(path) - .await - .map_err(|e| DataFusionError::External(Box::new(e))); +async fn run_file_in_runner>( + path: PathBuf, + mut runner: sqllogictest::Runner, + filters: &[Filter], +) -> Result<()> { + let path = path.canonicalize()?; + let records = + parse_file(&path).map_err(|e| DataFusionError::External(Box::new(e)))?; + let mut errs = vec![]; + for record in records.into_iter() { + if let Record::Halt { .. } = record { + break; + } + if should_skip_record::(&record, filters) { + continue; + } + if let Err(err) = runner.run_async(record).await { + errs.push(format!("{err}")); + } + } - pb.finish_and_clear(); + if !errs.is_empty() { + let mut msg = format!("{} errors in file {}\n\n", errs.len(), path.display()); + for (i, err) in errs.iter().enumerate() { + if i >= ERRS_PER_FILE_LIMIT { + msg.push_str(&format!( + "... other {} errors in {} not shown ...\n\n", + errs.len() - ERRS_PER_FILE_LIMIT, + path.display() + )); + break; + } + msg.push_str(&format!("{}. {err}\n\n", i + 1)); + } + return Err(DataFusionError::External(msg.into())); + } - res + Ok(()) } fn get_record_count(path: &PathBuf, label: String) -> u64 { @@ -287,6 +386,7 @@ async fn run_test_file_with_postgres( validator: Validator, mp: MultiProgress, mp_style: ProgressStyle, + filters: &[Filter], ) -> Result<()> { use datafusion_sqllogictest::Postgres; let TestFile { @@ -308,14 +408,9 @@ async fn run_test_file_with_postgres( runner.with_column_validator(strict_column_validator); runner.with_normalizer(value_normalizer); runner.with_validator(validator); - runner - .run_file_async(path) - .await - .map_err(|e| DataFusionError::External(Box::new(e)))?; - + let result = run_file_in_runner(path, runner, filters).await; pb.finish_and_clear(); - - Ok(()) + result } #[cfg(not(feature = "postgres"))] @@ -324,6 +419,7 @@ async fn run_test_file_with_postgres( _validator: Validator, _mp: MultiProgress, _mp_style: ProgressStyle, + _filters: &[Filter], ) -> Result<()> { use datafusion::common::plan_err; plan_err!("Can not run with postgres as postgres feature is not enabled") @@ -537,14 +633,25 @@ struct Options { )] postgres_runner: bool, + #[clap( + long, + conflicts_with = "complete", + conflicts_with = "postgres_runner", + help = "Before executing each query, convert its logical plan to Substrait and from Substrait back to its logical plan" + )] + substrait_round_trip: bool, + #[clap(long, env = "INCLUDE_SQLITE", help = "Include sqlite files")] include_sqlite: bool, #[clap(long, env = "INCLUDE_TPCH", help = "Include tpch files")] include_tpch: bool, - #[clap(action, help = "test filter (substring match on filenames)")] - filters: Vec, + #[clap( + action, + help = "test filter (substring match on filenames with optional :{line_number} suffix)" + )] + filters: Vec, #[clap( long, @@ -597,15 +704,7 @@ impl Options { /// filter and that does a substring match on each input. returns /// true f this path should be run fn check_test_file(&self, path: &Path) -> bool { - if self.filters.is_empty() { - return true; - } - - // otherwise check if any filter matches - let path_string = path.to_string_lossy(); - self.filters - .iter() - .any(|filter| path_string.contains(filter)) + !should_skip_file(path, &self.filters) } /// Postgres runner executes only tests in files with specific names or in diff --git a/datafusion/sqllogictest/src/engines/conversion.rs b/datafusion/sqllogictest/src/engines/conversion.rs index 516ec69e0b07..92ab64059bbd 100644 --- a/datafusion/sqllogictest/src/engines/conversion.rs +++ b/datafusion/sqllogictest/src/engines/conversion.rs @@ -49,7 +49,7 @@ pub(crate) fn f16_to_str(value: f16) -> String { } else if value == f16::NEG_INFINITY { "-Infinity".to_string() } else { - big_decimal_to_str(BigDecimal::from_str(&value.to_string()).unwrap()) + big_decimal_to_str(BigDecimal::from_str(&value.to_string()).unwrap(), None) } } @@ -63,7 +63,7 @@ pub(crate) fn f32_to_str(value: f32) -> String { } else if value == f32::NEG_INFINITY { "-Infinity".to_string() } else { - big_decimal_to_str(BigDecimal::from_str(&value.to_string()).unwrap()) + big_decimal_to_str(BigDecimal::from_str(&value.to_string()).unwrap(), None) } } @@ -77,7 +77,21 @@ pub(crate) fn f64_to_str(value: f64) -> String { } else if value == f64::NEG_INFINITY { "-Infinity".to_string() } else { - big_decimal_to_str(BigDecimal::from_str(&value.to_string()).unwrap()) + big_decimal_to_str(BigDecimal::from_str(&value.to_string()).unwrap(), None) + } +} + +pub(crate) fn spark_f64_to_str(value: f64) -> String { + if value.is_nan() { + // The sign of NaN can be different depending on platform. + // So the string representation of NaN ignores the sign. + "NaN".to_string() + } else if value == f64::INFINITY { + "Infinity".to_string() + } else if value == f64::NEG_INFINITY { + "-Infinity".to_string() + } else { + big_decimal_to_str(BigDecimal::from_str(&value.to_string()).unwrap(), Some(15)) } } @@ -86,6 +100,7 @@ pub(crate) fn decimal_128_to_str(value: i128, scale: i8) -> String { big_decimal_to_str( BigDecimal::from_str(&Decimal128Type::format_decimal(value, precision, scale)) .unwrap(), + None, ) } @@ -94,17 +109,21 @@ pub(crate) fn decimal_256_to_str(value: i256, scale: i8) -> String { big_decimal_to_str( BigDecimal::from_str(&Decimal256Type::format_decimal(value, precision, scale)) .unwrap(), + None, ) } #[cfg(feature = "postgres")] pub(crate) fn decimal_to_str(value: Decimal) -> String { - big_decimal_to_str(BigDecimal::from_str(&value.to_string()).unwrap()) + big_decimal_to_str(BigDecimal::from_str(&value.to_string()).unwrap(), None) } -pub(crate) fn big_decimal_to_str(value: BigDecimal) -> String { +/// Converts a `BigDecimal` to its plain string representation, optionally rounding to a specified number of decimal places. +/// +/// If `round_digits` is `None`, the value is rounded to 12 decimal places by default. +pub(crate) fn big_decimal_to_str(value: BigDecimal, round_digits: Option) -> String { // Round the value to limit the number of decimal places - let value = value.round(12).normalized(); + let value = value.round(round_digits.unwrap_or(12)).normalized(); // Format the value to a string value.to_plain_string() } @@ -115,12 +134,12 @@ mod tests { use bigdecimal::{num_bigint::BigInt, BigDecimal}; macro_rules! assert_decimal_str_eq { - ($integer:expr, $scale:expr, $expected:expr) => { + ($integer:expr, $scale:expr, $round_digits:expr, $expected:expr) => { assert_eq!( - big_decimal_to_str(BigDecimal::from_bigint( - BigInt::from($integer), - $scale - )), + big_decimal_to_str( + BigDecimal::from_bigint(BigInt::from($integer), $scale), + $round_digits + ), $expected ); }; @@ -128,44 +147,51 @@ mod tests { #[test] fn test_big_decimal_to_str() { - assert_decimal_str_eq!(110, 3, "0.11"); - assert_decimal_str_eq!(11, 3, "0.011"); - assert_decimal_str_eq!(11, 2, "0.11"); - assert_decimal_str_eq!(11, 1, "1.1"); - assert_decimal_str_eq!(11, 0, "11"); - assert_decimal_str_eq!(11, -1, "110"); - assert_decimal_str_eq!(0, 0, "0"); + assert_decimal_str_eq!(110, 3, None, "0.11"); + assert_decimal_str_eq!(11, 3, None, "0.011"); + assert_decimal_str_eq!(11, 2, None, "0.11"); + assert_decimal_str_eq!(11, 1, None, "1.1"); + assert_decimal_str_eq!(11, 0, None, "11"); + assert_decimal_str_eq!(11, -1, None, "110"); + assert_decimal_str_eq!(0, 0, None, "0"); assert_decimal_str_eq!( 12345678901234567890123456789012345678_i128, 0, + None, "12345678901234567890123456789012345678" ); assert_decimal_str_eq!( 12345678901234567890123456789012345678_i128, 38, + None, "0.123456789012" ); // Negative cases - assert_decimal_str_eq!(-110, 3, "-0.11"); - assert_decimal_str_eq!(-11, 3, "-0.011"); - assert_decimal_str_eq!(-11, 2, "-0.11"); - assert_decimal_str_eq!(-11, 1, "-1.1"); - assert_decimal_str_eq!(-11, 0, "-11"); - assert_decimal_str_eq!(-11, -1, "-110"); + assert_decimal_str_eq!(-110, 3, None, "-0.11"); + assert_decimal_str_eq!(-11, 3, None, "-0.011"); + assert_decimal_str_eq!(-11, 2, None, "-0.11"); + assert_decimal_str_eq!(-11, 1, None, "-1.1"); + assert_decimal_str_eq!(-11, 0, None, "-11"); + assert_decimal_str_eq!(-11, -1, None, "-110"); assert_decimal_str_eq!( -12345678901234567890123456789012345678_i128, 0, + None, "-12345678901234567890123456789012345678" ); assert_decimal_str_eq!( -12345678901234567890123456789012345678_i128, 38, + None, "-0.123456789012" ); // Round to 12 decimal places // 1.0000000000011 -> 1.000000000001 - assert_decimal_str_eq!(10_i128.pow(13) + 11, 13, "1.000000000001"); + assert_decimal_str_eq!(10_i128.pow(13) + 11, 13, None, "1.000000000001"); + assert_decimal_str_eq!(10_i128.pow(13) + 11, 13, Some(12), "1.000000000001"); + + assert_decimal_str_eq!(10_i128.pow(13) + 11, 13, Some(13), "1.0000000000011"); } } diff --git a/datafusion/sqllogictest/src/engines/datafusion_engine/normalize.rs b/datafusion/sqllogictest/src/engines/datafusion_engine/normalize.rs index eeb34186ea20..0d832bb3062d 100644 --- a/datafusion/sqllogictest/src/engines/datafusion_engine/normalize.rs +++ b/datafusion/sqllogictest/src/engines/datafusion_engine/normalize.rs @@ -22,13 +22,16 @@ use arrow::array::{Array, AsArray}; use arrow::datatypes::Fields; use arrow::util::display::ArrayFormatter; use arrow::{array, array::ArrayRef, datatypes::DataType, record_batch::RecordBatch}; -use datafusion::common::format::DEFAULT_CLI_FORMAT_OPTIONS; use datafusion::common::DataFusionError; +use datafusion::config::ConfigField; use std::path::PathBuf; use std::sync::LazyLock; /// Converts `batches` to a result as expected by sqllogictest. -pub fn convert_batches(batches: Vec) -> Result>> { +pub fn convert_batches( + batches: Vec, + is_spark_path: bool, +) -> Result>> { if batches.is_empty() { Ok(vec![]) } else { @@ -46,7 +49,16 @@ pub fn convert_batches(batches: Vec) -> Result>> { ))); } - let new_rows = convert_batch(batch)? + // Convert a single batch to a `Vec>` for comparison, flatten expanded rows, and normalize each. + let new_rows = (0..batch.num_rows()) + .map(|row| { + batch + .columns() + .iter() + .map(|col| cell_to_string(col, row, is_spark_path)) + .collect::>>() + }) + .collect::>>>()? .into_iter() .flat_map(expand_row) .map(normalize_paths); @@ -162,19 +174,6 @@ static WORKSPACE_ROOT: LazyLock = LazyLock::new(|| { object_store::path::Path::parse(sanitized_workplace_root).unwrap() }); -/// Convert a single batch to a `Vec>` for comparison -fn convert_batch(batch: RecordBatch) -> Result>> { - (0..batch.num_rows()) - .map(|row| { - batch - .columns() - .iter() - .map(|col| cell_to_string(col, row)) - .collect::>>() - }) - .collect() -} - macro_rules! get_row_value { ($array_type:ty, $column: ident, $row: ident) => {{ let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); @@ -193,7 +192,7 @@ macro_rules! get_row_value { /// /// Floating numbers are rounded to have a consistent representation with the Postgres runner. /// -pub fn cell_to_string(col: &ArrayRef, row: usize) -> Result { +pub fn cell_to_string(col: &ArrayRef, row: usize, is_spark_path: bool) -> Result { if !col.is_valid(row) { // represent any null value with the string "NULL" Ok(NULL_STR.to_string()) @@ -210,7 +209,12 @@ pub fn cell_to_string(col: &ArrayRef, row: usize) -> Result { Ok(f32_to_str(get_row_value!(array::Float32Array, col, row))) } DataType::Float64 => { - Ok(f64_to_str(get_row_value!(array::Float64Array, col, row))) + let result = get_row_value!(array::Float64Array, col, row); + if is_spark_path { + Ok(spark_f64_to_str(result)) + } else { + Ok(f64_to_str(result)) + } } DataType::Decimal128(_, scale) => { let value = get_row_value!(array::Decimal128Array, col, row); @@ -236,12 +240,20 @@ pub fn cell_to_string(col: &ArrayRef, row: usize) -> Result { DataType::Dictionary(_, _) => { let dict = col.as_any_dictionary(); let key = dict.normalized_keys()[row]; - Ok(cell_to_string(dict.values(), key)?) + Ok(cell_to_string(dict.values(), key, is_spark_path)?) } _ => { - let f = - ArrayFormatter::try_new(col.as_ref(), &DEFAULT_CLI_FORMAT_OPTIONS); - Ok(f.unwrap().value(row).to_string()) + let mut datafusion_format_options = + datafusion::config::FormatOptions::default(); + + datafusion_format_options.set("null", "NULL").unwrap(); + + let arrow_format_options: arrow::util::display::FormatOptions = + (&datafusion_format_options).try_into().unwrap(); + + let f = ArrayFormatter::try_new(col.as_ref(), &arrow_format_options)?; + + Ok(f.value(row).to_string()) } } .map_err(DFSqlLogicTestError::Arrow) @@ -280,7 +292,9 @@ pub fn convert_schema_to_types(columns: &Fields) -> Vec { if key_type.is_integer() { // mapping dictionary string types to Text match value_type.as_ref() { - DataType::Utf8 | DataType::LargeUtf8 => DFColumnType::Text, + DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => { + DFColumnType::Text + } _ => DFColumnType::Another, } } else { diff --git a/datafusion/sqllogictest/src/engines/datafusion_engine/runner.rs b/datafusion/sqllogictest/src/engines/datafusion_engine/runner.rs index a3a29eda2ee9..a01ac7e2f985 100644 --- a/datafusion/sqllogictest/src/engines/datafusion_engine/runner.rs +++ b/datafusion/sqllogictest/src/engines/datafusion_engine/runner.rs @@ -31,6 +31,7 @@ use sqllogictest::DBOutput; use tokio::time::Instant; use crate::engines::output::{DFColumnType, DFOutput}; +use crate::is_spark_path; pub struct DataFusion { ctx: SessionContext, @@ -79,7 +80,7 @@ impl sqllogictest::AsyncDB for DataFusion { } let start = Instant::now(); - let result = run_query(&self.ctx, sql).await; + let result = run_query(&self.ctx, is_spark_path(&self.relative_path), sql).await; let duration = start.elapsed(); if duration.gt(&Duration::from_millis(500)) { @@ -115,7 +116,11 @@ impl sqllogictest::AsyncDB for DataFusion { async fn shutdown(&mut self) {} } -async fn run_query(ctx: &SessionContext, sql: impl Into) -> Result { +async fn run_query( + ctx: &SessionContext, + is_spark_path: bool, + sql: impl Into, +) -> Result { let df = ctx.sql(sql.into().as_str()).await?; let task_ctx = Arc::new(df.task_ctx()); let plan = df.create_physical_plan().await?; @@ -123,7 +128,7 @@ async fn run_query(ctx: &SessionContext, sql: impl Into) -> Result = collect(stream).await?; - let rows = normalize::convert_batches(results)?; + let rows = normalize::convert_batches(results, is_spark_path)?; if rows.is_empty() && types.is_empty() { Ok(DBOutput::StatementComplete(0)) diff --git a/datafusion/sqllogictest/src/engines/datafusion_substrait_roundtrip_engine/mod.rs b/datafusion/sqllogictest/src/engines/datafusion_substrait_roundtrip_engine/mod.rs new file mode 100644 index 000000000000..9ff077c67d8c --- /dev/null +++ b/datafusion/sqllogictest/src/engines/datafusion_substrait_roundtrip_engine/mod.rs @@ -0,0 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod runner; + +pub use runner::*; diff --git a/datafusion/sqllogictest/src/engines/datafusion_substrait_roundtrip_engine/runner.rs b/datafusion/sqllogictest/src/engines/datafusion_substrait_roundtrip_engine/runner.rs new file mode 100644 index 000000000000..9d3854755352 --- /dev/null +++ b/datafusion/sqllogictest/src/engines/datafusion_substrait_roundtrip_engine/runner.rs @@ -0,0 +1,154 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; +use std::{path::PathBuf, time::Duration}; + +use crate::engines::datafusion_engine::Result; +use crate::engines::output::{DFColumnType, DFOutput}; +use crate::{convert_batches, convert_schema_to_types, DFSqlLogicTestError}; +use arrow::record_batch::RecordBatch; +use async_trait::async_trait; +use datafusion::logical_expr::LogicalPlan; +use datafusion::physical_plan::common::collect; +use datafusion::physical_plan::execute_stream; +use datafusion::prelude::SessionContext; +use datafusion_substrait::logical_plan::consumer::from_substrait_plan; +use datafusion_substrait::logical_plan::producer::to_substrait_plan; +use indicatif::ProgressBar; +use log::Level::{Debug, Info}; +use log::{debug, log_enabled, warn}; +use sqllogictest::DBOutput; +use tokio::time::Instant; + +pub struct DataFusionSubstraitRoundTrip { + ctx: SessionContext, + relative_path: PathBuf, + pb: ProgressBar, +} + +impl DataFusionSubstraitRoundTrip { + pub fn new(ctx: SessionContext, relative_path: PathBuf, pb: ProgressBar) -> Self { + Self { + ctx, + relative_path, + pb, + } + } + + fn update_slow_count(&self) { + let msg = self.pb.message(); + let split: Vec<&str> = msg.split(" ").collect(); + let mut current_count = 0; + + if split.len() > 2 { + // third match will be current slow count + current_count = split[2].parse::().unwrap(); + } + + current_count += 1; + + self.pb + .set_message(format!("{} - {} took > 500 ms", split[0], current_count)); + } +} + +#[async_trait] +impl sqllogictest::AsyncDB for DataFusionSubstraitRoundTrip { + type Error = DFSqlLogicTestError; + type ColumnType = DFColumnType; + + async fn run(&mut self, sql: &str) -> Result { + if log_enabled!(Debug) { + debug!( + "[{}] Running query: \"{}\"", + self.relative_path.display(), + sql + ); + } + + let start = Instant::now(); + let result = run_query_substrait_round_trip(&self.ctx, sql).await; + let duration = start.elapsed(); + + if duration.gt(&Duration::from_millis(500)) { + self.update_slow_count(); + } + + self.pb.inc(1); + + if log_enabled!(Info) && duration.gt(&Duration::from_secs(2)) { + warn!( + "[{}] Running query took more than 2 sec ({duration:?}): \"{sql}\"", + self.relative_path.display() + ); + } + + result + } + + /// Engine name of current database. + fn engine_name(&self) -> &str { + "DataFusionSubstraitRoundTrip" + } + + /// `DataFusion` calls this function to perform sleep. + /// + /// The default implementation is `std::thread::sleep`, which is universal to any async runtime + /// but would block the current thread. If you are running in tokio runtime, you should override + /// this by `tokio::time::sleep`. + async fn sleep(dur: Duration) { + tokio::time::sleep(dur).await; + } + + async fn shutdown(&mut self) {} +} + +async fn run_query_substrait_round_trip( + ctx: &SessionContext, + sql: impl Into, +) -> Result { + let df = ctx.sql(sql.into().as_str()).await?; + let task_ctx = Arc::new(df.task_ctx()); + + let state = ctx.state(); + let round_tripped_plan = match df.logical_plan() { + // Substrait does not handle these plans + LogicalPlan::Ddl(_) + | LogicalPlan::Explain(_) + | LogicalPlan::Dml(_) + | LogicalPlan::Copy(_) + | LogicalPlan::Statement(_) => df.logical_plan().clone(), + // For any other plan, convert to Substrait + logical_plan => { + let plan = to_substrait_plan(logical_plan, &state)?; + from_substrait_plan(&state, &plan).await? + } + }; + + let physical_plan = state.create_physical_plan(&round_tripped_plan).await?; + let stream = execute_stream(physical_plan, task_ctx)?; + let types = convert_schema_to_types(stream.schema().fields()); + let results: Vec = collect(stream).await?; + let rows = convert_batches(results, false)?; + + if rows.is_empty() && types.is_empty() { + Ok(DBOutput::StatementComplete(0)) + } else { + Ok(DBOutput::Rows { types, rows }) + } +} diff --git a/datafusion/sqllogictest/src/engines/mod.rs b/datafusion/sqllogictest/src/engines/mod.rs index 3569dea70176..ef6335ddbed6 100644 --- a/datafusion/sqllogictest/src/engines/mod.rs +++ b/datafusion/sqllogictest/src/engines/mod.rs @@ -18,12 +18,14 @@ /// Implementation of sqllogictest for datafusion. mod conversion; mod datafusion_engine; +mod datafusion_substrait_roundtrip_engine; mod output; pub use datafusion_engine::convert_batches; pub use datafusion_engine::convert_schema_to_types; pub use datafusion_engine::DFSqlLogicTestError; pub use datafusion_engine::DataFusion; +pub use datafusion_substrait_roundtrip_engine::DataFusionSubstraitRoundTrip; pub use output::DFColumnType; pub use output::DFOutput; diff --git a/datafusion/sqllogictest/src/filters.rs b/datafusion/sqllogictest/src/filters.rs new file mode 100644 index 000000000000..44482236f7c5 --- /dev/null +++ b/datafusion/sqllogictest/src/filters.rs @@ -0,0 +1,170 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::sql::parser::{DFParserBuilder, Statement}; +use sqllogictest::{AsyncDB, Record}; +use sqlparser::ast::{SetExpr, Statement as SqlStatement}; +use sqlparser::dialect::dialect_from_str; +use std::path::Path; +use std::str::FromStr; + +/// Filter specification that determines whether a certain sqllogictest record in +/// a certain file should be filtered. In order for a [`Filter`] to match a test case: +/// +/// - The test must belong to a file whose absolute path contains the `file_substring` substring. +/// - If a `line_number` is specified, the test must be declared in that same line number. +/// +/// If a [`Filter`] matches a specific test case, then the record is executed, if there's +/// no match, the record is skipped. +/// +/// Filters can be parsed from strings of the form `:line_number`. For example, +/// `foo.slt:100` matches any test whose name contains `foo.slt` and the test starts on line +/// number 100. +#[derive(Debug, Clone)] +pub struct Filter { + file_substring: String, + line_number: Option, +} + +impl FromStr for Filter { + type Err = String; + + fn from_str(s: &str) -> Result { + let parts: Vec<&str> = s.rsplitn(2, ':').collect(); + if parts.len() == 2 { + match parts[0].parse::() { + Ok(line) => Ok(Filter { + file_substring: parts[1].to_string(), + line_number: Some(line), + }), + Err(_) => Err(format!("Cannot parse line number from '{s}'")), + } + } else { + Ok(Filter { + file_substring: s.to_string(), + line_number: None, + }) + } + } +} + +/// Given a list of [`Filter`]s, determines if the whole file in the provided +/// path can be skipped. +/// +/// - If there's at least 1 filter whose file name is a substring of the provided path, +/// it returns true. +/// - If the provided filter list is empty, it returns false. +pub fn should_skip_file(path: &Path, filters: &[Filter]) -> bool { + if filters.is_empty() { + return false; + } + + let path_string = path.to_string_lossy(); + for filter in filters { + if path_string.contains(&filter.file_substring) { + return false; + } + } + true +} + +/// Determines whether a certain sqllogictest record should be skipped given the provided +/// filters. +/// +/// If there's at least 1 matching filter, or the filter list is empty, it returns false. +/// +/// There are certain records that will never be skipped even if they are not matched +/// by any filters, like CREATE TABLE, INSERT INTO, DROP or SELECT * INTO statements, +/// as they populate tables necessary for other tests to work. +pub fn should_skip_record( + record: &Record, + filters: &[Filter], +) -> bool { + if filters.is_empty() { + return false; + } + + let (sql, loc) = match record { + Record::Statement { sql, loc, .. } => (sql, loc), + Record::Query { sql, loc, .. } => (sql, loc), + _ => return false, + }; + + let statement = if let Some(statement) = parse_or_none(sql, "Postgres") { + statement + } else if let Some(statement) = parse_or_none(sql, "generic") { + statement + } else { + return false; + }; + + if !statement_is_skippable(&statement) { + return false; + } + + for filter in filters { + if !loc.file().contains(&filter.file_substring) { + continue; + } + if let Some(line_num) = filter.line_number { + if loc.line() != line_num { + continue; + } + } + + // This filter matches both file name substring and the exact + // line number (if one was provided), so don't skip it. + return false; + } + + true +} + +fn statement_is_skippable(statement: &Statement) -> bool { + // Only SQL statements can be skipped. + let Statement::Statement(sql_stmt) = statement else { + return false; + }; + + // Cannot skip SELECT INTO statements, as they can also create tables + // that further test cases will use. + if let SqlStatement::Query(v) = sql_stmt.as_ref() { + if let SetExpr::Select(v) = v.body.as_ref() { + if v.into.is_some() { + return false; + } + } + } + + // Only SELECT and EXPLAIN statements can be skipped, as any other + // statement might be populating tables that future test cases will use. + matches!( + sql_stmt.as_ref(), + SqlStatement::Query(_) | SqlStatement::Explain { .. } + ) +} + +fn parse_or_none(sql: &str, dialect: &str) -> Option { + let Ok(Ok(Some(statement))) = DFParserBuilder::new(sql) + .with_dialect(dialect_from_str(dialect).unwrap().as_ref()) + .build() + .map(|mut v| v.parse_statements().map(|mut v| v.pop_front())) + else { + return None; + }; + Some(statement) +} diff --git a/datafusion/sqllogictest/src/lib.rs b/datafusion/sqllogictest/src/lib.rs index 1a208aa3fac2..3c786d6bdaac 100644 --- a/datafusion/sqllogictest/src/lib.rs +++ b/datafusion/sqllogictest/src/lib.rs @@ -34,12 +34,15 @@ pub use engines::DFColumnType; pub use engines::DFOutput; pub use engines::DFSqlLogicTestError; pub use engines::DataFusion; +pub use engines::DataFusionSubstraitRoundTrip; #[cfg(feature = "postgres")] pub use engines::Postgres; +mod filters; mod test_context; mod util; +pub use filters::*; pub use test_context::TestContext; pub use util::*; diff --git a/datafusion/sqllogictest/src/test_context.rs b/datafusion/sqllogictest/src/test_context.rs index ce819f186454..143e3ef1a89b 100644 --- a/datafusion/sqllogictest/src/test_context.rs +++ b/datafusion/sqllogictest/src/test_context.rs @@ -40,8 +40,11 @@ use datafusion::{ prelude::{CsvReadOptions, SessionContext}, }; +use crate::is_spark_path; use async_trait::async_trait; use datafusion::common::cast::as_float64_array; +use datafusion::execution::runtime_env::RuntimeEnv; +use datafusion::execution::SessionStateBuilder; use log::info; use tempfile::TempDir; @@ -70,8 +73,20 @@ impl TestContext { let config = SessionConfig::new() // hardcode target partitions so plans are deterministic .with_target_partitions(4); + let runtime = Arc::new(RuntimeEnv::default()); + let mut state = SessionStateBuilder::new() + .with_config(config) + .with_runtime_env(runtime) + .with_default_features() + .build(); + + if is_spark_path(relative_path) { + info!("Registering Spark functions"); + datafusion_spark::register_all(&mut state) + .expect("Can not register Spark functions"); + } - let mut test_ctx = TestContext::new(SessionContext::new_with_config(config)); + let mut test_ctx = TestContext::new(SessionContext::new_with_state(state)); let file_name = relative_path.file_name().unwrap().to_str().unwrap(); match file_name { @@ -122,6 +137,7 @@ impl TestContext { info!("Using default SessionContext"); } }; + Some(test_ctx) } @@ -223,14 +239,14 @@ pub async fn register_temp_table(ctx: &SessionContext) { self } - fn table_type(&self) -> TableType { - self.0 - } - fn schema(&self) -> SchemaRef { unimplemented!() } + fn table_type(&self) -> TableType { + self.0 + } + async fn scan( &self, _state: &dyn Session, @@ -410,10 +426,24 @@ fn create_example_udf() -> ScalarUDF { fn register_union_table(ctx: &SessionContext) { let union = UnionArray::try_new( - UnionFields::new(vec![3], vec![Field::new("int", DataType::Int32, false)]), - ScalarBuffer::from(vec![3, 3]), + UnionFields::new( + // typeids: 3 for int, 1 for string + vec![3, 1], + vec![ + Field::new("int", DataType::Int32, false), + Field::new("string", DataType::Utf8, false), + ], + ), + ScalarBuffer::from(vec![3, 1, 3]), None, - vec![Arc::new(Int32Array::from(vec![1, 2]))], + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(StringArray::from(vec![ + Some("foo"), + Some("bar"), + Some("baz"), + ])), + ], ) .unwrap(); diff --git a/datafusion/sqllogictest/src/util.rs b/datafusion/sqllogictest/src/util.rs index 5ae640cc98a9..695fe463fa67 100644 --- a/datafusion/sqllogictest/src/util.rs +++ b/datafusion/sqllogictest/src/util.rs @@ -106,3 +106,7 @@ pub fn df_value_validator( normalized_actual == normalized_expected } + +pub fn is_spark_path(relative_path: &Path) -> bool { + relative_path.starts_with("spark/") +} diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index a55ac079aa74..ed77435d6a85 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -132,37 +132,51 @@ statement error DataFusion error: Schema error: Schema contains duplicate unqual SELECT approx_distinct(c9) count_c9, approx_distinct(cast(c9 as varchar)) count_c9_str FROM aggregate_test_100 # csv_query_approx_percentile_cont_with_weight -statement error DataFusion error: Error during planning: Failed to coerce arguments to satisfy a call to 'approx_percentile_cont_with_weight' function: coercion from \[Utf8, Int8, Float64\] to the signature OneOf(.*) failed(.|\n)* -SELECT approx_percentile_cont_with_weight(c1, c2, 0.95) FROM aggregate_test_100 +statement error Failed to coerce arguments to satisfy a call to 'approx_percentile_cont_with_weight' function +SELECT approx_percentile_cont_with_weight(c2, 0.95) WITHIN GROUP (ORDER BY c1) FROM aggregate_test_100 -statement error DataFusion error: Error during planning: Failed to coerce arguments to satisfy a call to 'approx_percentile_cont_with_weight' function: coercion from \[Int16, Utf8, Float64\] to the signature OneOf(.*) failed(.|\n)* -SELECT approx_percentile_cont_with_weight(c3, c1, 0.95) FROM aggregate_test_100 +statement error Failed to coerce arguments to satisfy a call to 'approx_percentile_cont_with_weight' function +SELECT approx_percentile_cont_with_weight(c1, 0.95) WITHIN GROUP (ORDER BY c3) FROM aggregate_test_100 -statement error DataFusion error: Error during planning: Failed to coerce arguments to satisfy a call to 'approx_percentile_cont_with_weight' function: coercion from \[Int16, Int8, Utf8\] to the signature OneOf(.*) failed(.|\n)* -SELECT approx_percentile_cont_with_weight(c3, c2, c1) FROM aggregate_test_100 +statement error Failed to coerce arguments to satisfy a call to 'approx_percentile_cont_with_weight' function +SELECT approx_percentile_cont_with_weight(c2, c1) WITHIN GROUP (ORDER BY c3) FROM aggregate_test_100 # csv_query_approx_percentile_cont_with_histogram_bins statement error DataFusion error: This feature is not implemented: Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be UInt > 0 literal \(got data type Int64\)\. -SELECT c1, approx_percentile_cont(c3, 0.95, -1000) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 +SELECT c1, approx_percentile_cont(0.95, -1000) WITHIN GROUP (ORDER BY c3) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 -statement error DataFusion error: Error during planning: Failed to coerce arguments to satisfy a call to 'approx_percentile_cont' function: coercion from \[Int16, Float64, Utf8\] to the signature OneOf(.*) failed(.|\n)* -SELECT approx_percentile_cont(c3, 0.95, c1) FROM aggregate_test_100 +statement error Failed to coerce arguments to satisfy a call to 'approx_percentile_cont' function +SELECT approx_percentile_cont(0.95, c1) WITHIN GROUP (ORDER BY c3) FROM aggregate_test_100 statement error DataFusion error: Error during planning: Failed to coerce arguments to satisfy a call to 'approx_percentile_cont' function: coercion from \[Int16, Float64, Float64\] to the signature OneOf(.*) failed(.|\n)* -SELECT approx_percentile_cont(c3, 0.95, 111.1) FROM aggregate_test_100 +SELECT approx_percentile_cont(0.95, 111.1) WITHIN GROUP (ORDER BY c3) FROM aggregate_test_100 statement error DataFusion error: Error during planning: Failed to coerce arguments to satisfy a call to 'approx_percentile_cont' function: coercion from \[Float64, Float64, Float64\] to the signature OneOf(.*) failed(.|\n)* -SELECT approx_percentile_cont(c12, 0.95, 111.1) FROM aggregate_test_100 +SELECT approx_percentile_cont(0.95, 111.1) WITHIN GROUP (ORDER BY c12) FROM aggregate_test_100 statement error DataFusion error: This feature is not implemented: Percentile value for 'APPROX_PERCENTILE_CONT' must be a literal -SELECT approx_percentile_cont(c12, c12) FROM aggregate_test_100 +SELECT approx_percentile_cont(c12) WITHIN GROUP (ORDER BY c12) FROM aggregate_test_100 statement error DataFusion error: This feature is not implemented: Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be a literal -SELECT approx_percentile_cont(c12, 0.95, c5) FROM aggregate_test_100 +SELECT approx_percentile_cont(0.95, c5) WITHIN GROUP (ORDER BY c12) FROM aggregate_test_100 + +statement error DataFusion error: This feature is not implemented: Conflicting ordering requirements in aggregate functions is not supported +SELECT approx_percentile_cont(0.95) WITHIN GROUP (ORDER BY c5), approx_percentile_cont(0.2) WITHIN GROUP (ORDER BY c12) FROM aggregate_test_100 + +statement error DataFusion error: Error during planning: \[IGNORE | RESPECT\] NULLS are not permitted for approx_percentile_cont +SELECT approx_percentile_cont(0.95) WITHIN GROUP (ORDER BY c5) IGNORE NULLS FROM aggregate_test_100 + +statement error DataFusion error: Error during planning: \[IGNORE | RESPECT\] NULLS are not permitted for approx_percentile_cont +SELECT approx_percentile_cont(0.95) WITHIN GROUP (ORDER BY c5) RESPECT NULLS FROM aggregate_test_100 + +statement error DataFusion error: This feature is not implemented: Only a single ordering expression is permitted in a WITHIN GROUP clause +SELECT approx_percentile_cont(0.95) WITHIN GROUP (ORDER BY c5, c12) FROM aggregate_test_100 # Not supported over sliding windows -query error This feature is not implemented: Aggregate can not be used as a sliding accumulator because `retract_batch` is not implemented -SELECT approx_percentile_cont(c3, 0.5) OVER (ROWS BETWEEN 4 PRECEDING AND CURRENT ROW) +query error DataFusion error: Error during planning: OVER and WITHIN GROUP clause are can not be used together. OVER is for window function, whereas WITHIN GROUP is for ordered set aggregate function +SELECT approx_percentile_cont(0.5) +WITHIN GROUP (ORDER BY c3) +OVER (ROWS BETWEEN 4 PRECEDING AND CURRENT ROW) FROM aggregate_test_100 # array agg can use order by @@ -289,17 +303,19 @@ CREATE TABLE array_agg_distinct_list_table AS VALUES ('b', [1,0]), ('b', [1,0]), ('b', [1,0]), - ('b', [0,1]) + ('b', [0,1]), + (NULL, [0,1]), + ('b', NULL) ; # Apply array_sort to have deterministic result, higher dimension nested array also works but not for array sort, # so they are covered in `datafusion/functions-aggregate/src/array_agg.rs` query ?? select array_sort(c1), array_sort(c2) from ( - select array_agg(distinct column1) as c1, array_agg(distinct column2) as c2 from array_agg_distinct_list_table + select array_agg(distinct column1) as c1, array_agg(distinct column2) ignore nulls as c2 from array_agg_distinct_list_table ); ---- -[b, w] [[0, 1], [1, 0]] +[NULL, b, w] [[0, 1], [1, 0]] statement ok drop table array_agg_distinct_list_table; @@ -1276,173 +1292,173 @@ SELECT approx_distinct(c9) AS a, approx_distinct(c9) AS b FROM aggregate_test_10 #csv_query_approx_percentile_cont (c2) query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c2, 0.1) AS DOUBLE) / 1.0) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.1) WITHIN GROUP (ORDER BY c2) AS DOUBLE) / 1.0) < 0.05) AS q FROM aggregate_test_100 ---- true query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c2, 0.5) AS DOUBLE) / 3.0) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.5) WITHIN GROUP (ORDER BY c2) AS DOUBLE) / 3.0) < 0.05) AS q FROM aggregate_test_100 ---- true query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c2, 0.9) AS DOUBLE) / 5.0) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.9) WITHIN GROUP (ORDER BY c2) AS DOUBLE) / 5.0) < 0.05) AS q FROM aggregate_test_100 ---- true # csv_query_approx_percentile_cont (c3) query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c3, 0.1) AS DOUBLE) / -95.3) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.1) WITHIN GROUP (ORDER BY c3) AS DOUBLE) / -95.3) < 0.05) AS q FROM aggregate_test_100 ---- true query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c3, 0.5) AS DOUBLE) / 15.5) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.5) WITHIN GROUP (ORDER BY c3) AS DOUBLE) / 15.5) < 0.05) AS q FROM aggregate_test_100 ---- true query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c3, 0.9) AS DOUBLE) / 102.0) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.9) WITHIN GROUP (ORDER BY c3) AS DOUBLE) / 102.0) < 0.05) AS q FROM aggregate_test_100 ---- true # csv_query_approx_percentile_cont (c4) query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c4, 0.1) AS DOUBLE) / -22925.0) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.1) WITHIN GROUP (ORDER BY c4) AS DOUBLE) / -22925.0) < 0.05) AS q FROM aggregate_test_100 ---- true query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c4, 0.5) AS DOUBLE) / 4599.0) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.5) WITHIN GROUP (ORDER BY c4) AS DOUBLE) / 4599.0) < 0.05) AS q FROM aggregate_test_100 ---- true query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c4, 0.9) AS DOUBLE) / 25334.0) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.9) WITHIN GROUP (ORDER BY c4) AS DOUBLE) / 25334.0) < 0.05) AS q FROM aggregate_test_100 ---- true # csv_query_approx_percentile_cont (c5) query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c5, 0.1) AS DOUBLE) / -1882606710.0) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.1) WITHIN GROUP (ORDER BY c5) AS DOUBLE) / -1882606710.0) < 0.05) AS q FROM aggregate_test_100 ---- true query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c5, 0.5) AS DOUBLE) / 377164262.0) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.5) WITHIN GROUP (ORDER BY c5) AS DOUBLE) / 377164262.0) < 0.05) AS q FROM aggregate_test_100 ---- true query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c5, 0.9) AS DOUBLE) / 1991374996.0) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.9) WITHIN GROUP (ORDER BY c5) AS DOUBLE) / 1991374996.0) < 0.05) AS q FROM aggregate_test_100 ---- true # csv_query_approx_percentile_cont (c6) query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c6, 0.1) AS DOUBLE) / -7250000000000000000) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.1) WITHIN GROUP (ORDER BY c6) AS DOUBLE) / -7250000000000000000) < 0.05) AS q FROM aggregate_test_100 ---- true query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c6, 0.5) AS DOUBLE) / 1130000000000000000) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.5) WITHIN GROUP (ORDER BY c6) AS DOUBLE) / 1130000000000000000) < 0.05) AS q FROM aggregate_test_100 ---- true query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c6, 0.9) AS DOUBLE) / 7370000000000000000) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.9) WITHIN GROUP (ORDER BY c6) AS DOUBLE) / 7370000000000000000) < 0.05) AS q FROM aggregate_test_100 ---- true # csv_query_approx_percentile_cont (c7) query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c7, 0.1) AS DOUBLE) / 18.9) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.1) WITHIN GROUP (ORDER BY c7) AS DOUBLE) / 18.9) < 0.05) AS q FROM aggregate_test_100 ---- true query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c7, 0.5) AS DOUBLE) / 134.0) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.5) WITHIN GROUP (ORDER BY c7) AS DOUBLE) / 134.0) < 0.05) AS q FROM aggregate_test_100 ---- true query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c7, 0.9) AS DOUBLE) / 231.0) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.9) WITHIN GROUP (ORDER BY c7) AS DOUBLE) / 231.0) < 0.05) AS q FROM aggregate_test_100 ---- true # csv_query_approx_percentile_cont (c8) query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c8, 0.1) AS DOUBLE) / 2671.0) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.1) WITHIN GROUP (ORDER BY c8) AS DOUBLE) / 2671.0) < 0.05) AS q FROM aggregate_test_100 ---- true query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c8, 0.5) AS DOUBLE) / 30634.0) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.5) WITHIN GROUP (ORDER BY c8) AS DOUBLE) / 30634.0) < 0.05) AS q FROM aggregate_test_100 ---- true query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c8, 0.9) AS DOUBLE) / 57518.0) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.9) WITHIN GROUP (ORDER BY c8) AS DOUBLE) / 57518.0) < 0.05) AS q FROM aggregate_test_100 ---- true # csv_query_approx_percentile_cont (c9) query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c9, 0.1) AS DOUBLE) / 472608672.0) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.1) WITHIN GROUP (ORDER BY c9) AS DOUBLE) / 472608672.0) < 0.05) AS q FROM aggregate_test_100 ---- true query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c9, 0.5) AS DOUBLE) / 2365817608.0) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.5) WITHIN GROUP (ORDER BY c9) AS DOUBLE) / 2365817608.0) < 0.05) AS q FROM aggregate_test_100 ---- true query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c9, 0.9) AS DOUBLE) / 3776538487.0) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.9) WITHIN GROUP (ORDER BY c9) AS DOUBLE) / 3776538487.0) < 0.05) AS q FROM aggregate_test_100 ---- true # csv_query_approx_percentile_cont (c10) query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c10, 0.1) AS DOUBLE) / 1830000000000000000) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.1) WITHIN GROUP (ORDER BY c10) AS DOUBLE) / 1830000000000000000) < 0.05) AS q FROM aggregate_test_100 ---- true query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c10, 0.5) AS DOUBLE) / 9300000000000000000) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.5) WITHIN GROUP (ORDER BY c10) AS DOUBLE) / 9300000000000000000) < 0.05) AS q FROM aggregate_test_100 ---- true query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c10, 0.9) AS DOUBLE) / 16100000000000000000) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.9) WITHIN GROUP (ORDER BY c10) AS DOUBLE) / 16100000000000000000) < 0.05) AS q FROM aggregate_test_100 ---- true # csv_query_approx_percentile_cont (c11) query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c11, 0.1) AS DOUBLE) / 0.109) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.1) WITHIN GROUP (ORDER BY c11) AS DOUBLE) / 0.109) < 0.05) AS q FROM aggregate_test_100 ---- true query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c11, 0.5) AS DOUBLE) / 0.491) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.5) WITHIN GROUP (ORDER BY c11) AS DOUBLE) / 0.491) < 0.05) AS q FROM aggregate_test_100 ---- true query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c11, 0.9) AS DOUBLE) / 0.834) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.9) WITHIN GROUP (ORDER BY c11) AS DOUBLE) / 0.834) < 0.05) AS q FROM aggregate_test_100 ---- true # percentile_cont_with_nulls query I -SELECT APPROX_PERCENTILE_CONT(v, 0.5) FROM (VALUES (1), (2), (3), (NULL), (NULL), (NULL)) as t (v); +SELECT APPROX_PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY v) FROM (VALUES (1), (2), (3), (NULL), (NULL), (NULL)) as t (v); ---- 2 # percentile_cont_with_nulls_only query I -SELECT APPROX_PERCENTILE_CONT(v, 0.5) FROM (VALUES (CAST(NULL as INT))) as t (v); +SELECT APPROX_PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY v) FROM (VALUES (CAST(NULL as INT))) as t (v); ---- NULL @@ -1465,7 +1481,7 @@ NaN # ISSUE: https://github.com/apache/datafusion/issues/11870 query R -select APPROX_PERCENTILE_CONT(v2, 0.8) from tmp_percentile_cont; +select APPROX_PERCENTILE_CONT(0.8) WITHIN GROUP (ORDER BY v2) from tmp_percentile_cont; ---- NaN @@ -1473,10 +1489,10 @@ NaN # Note: `approx_percentile_cont_with_weight()` uses the same implementation as `approx_percentile_cont()` query R SELECT APPROX_PERCENTILE_CONT_WITH_WEIGHT( - v2, '+Inf'::Double, 0.9 ) +WITHIN GROUP (ORDER BY v2) FROM tmp_percentile_cont; ---- NaN @@ -1495,7 +1511,7 @@ INSERT INTO t1 VALUES (TRUE); # ISSUE: https://github.com/apache/datafusion/issues/12716 # This test verifies that approx_percentile_cont_with_weight does not panic when given 'NaN' and returns 'inf' query R -SELECT approx_percentile_cont_with_weight('NaN'::DOUBLE, 0, 0) FROM t1 WHERE t1.v1; +SELECT approx_percentile_cont_with_weight(0, 0) WITHIN GROUP (ORDER BY 'NaN'::DOUBLE) FROM t1 WHERE t1.v1; ---- Infinity @@ -1722,7 +1738,7 @@ b NULL NULL 7732.315789473684 # csv_query_approx_percentile_cont_with_weight query TI -SELECT c1, approx_percentile_cont(c3, 0.95) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 +SELECT c1, approx_percentile_cont(0.95) WITHIN GROUP (ORDER BY c3) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 ---- a 73 b 68 @@ -1730,9 +1746,18 @@ c 122 d 124 e 115 +query TI +SELECT c1, approx_percentile_cont(0.95) WITHIN GROUP (ORDER BY c3 DESC) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 +---- +a -101 +b -114 +c -109 +d -98 +e -93 + # csv_query_approx_percentile_cont_with_weight (2) query TI -SELECT c1, approx_percentile_cont_with_weight(c3, 1, 0.95) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 +SELECT c1, approx_percentile_cont_with_weight(1, 0.95) WITHIN GROUP (ORDER BY c3) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 ---- a 73 b 68 @@ -1740,9 +1765,18 @@ c 122 d 124 e 115 +query TI +SELECT c1, approx_percentile_cont_with_weight(1, 0.95) WITHIN GROUP (ORDER BY c3 DESC) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 +---- +a -101 +b -114 +c -109 +d -98 +e -93 + # csv_query_approx_percentile_cont_with_histogram_bins query TI -SELECT c1, approx_percentile_cont(c3, 0.95, 200) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 +SELECT c1, approx_percentile_cont(0.95, 200) WITHIN GROUP (ORDER BY c3) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 ---- a 73 b 68 @@ -1751,7 +1785,7 @@ d 124 e 115 query TI -SELECT c1, approx_percentile_cont_with_weight(c3, c2, 0.95) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 +SELECT c1, approx_percentile_cont_with_weight(c2, 0.95) WITHIN GROUP (ORDER BY c3) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 ---- a 74 b 68 @@ -3041,7 +3075,7 @@ SELECT COUNT(DISTINCT c1) FROM test # test_approx_percentile_cont_decimal_support query TI -SELECT c1, approx_percentile_cont(c2, cast(0.85 as decimal(10,2))) apc FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 +SELECT c1, approx_percentile_cont(cast(0.85 as decimal(10,2))) WITHIN GROUP (ORDER BY c2) apc FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 ---- a 4 b 5 @@ -3194,6 +3228,33 @@ select array_agg(column1) from t; statement ok drop table t; +# array_agg_ignore_nulls +statement ok +create table t as values (NULL, ''), (1, 'c'), (2, 'a'), (NULL, 'b'), (4, NULL), (NULL, NULL), (5, 'a'); + +query ? +select array_agg(column1) ignore nulls as c1 from t; +---- +[1, 2, 4, 5] + +query II +select count(*), array_length(array_agg(distinct column2) ignore nulls) from t; +---- +7 4 + +query ? +select array_agg(column2 order by column1) ignore nulls from t; +---- +[c, a, a, , b] + +query ? +select array_agg(DISTINCT column2 order by column2) ignore nulls from t; +---- +[, a, b, c] + +statement ok +drop table t; + # variance_single_value query RRRR select var(sq.column1), var_pop(sq.column1), stddev(sq.column1), stddev_pop(sq.column1) from (values (1.0)) as sq; @@ -4952,23 +5013,114 @@ set datafusion.sql_parser.dialect = 'Generic'; ## Multiple distinct aggregates and dictionaries statement ok -create table dict_test as values (1, arrow_cast('foo', 'Dictionary(Int32, Utf8)')), (2, arrow_cast('bar', 'Dictionary(Int32, Utf8)')); +create table dict_test as values (1, arrow_cast('foo', 'Dictionary(Int32, Utf8)')), (1, arrow_cast('foo', 'Dictionary(Int32, Utf8)')), (2, arrow_cast('bar', 'Dictionary(Int32, Utf8)')), (1, arrow_cast('bar', 'Dictionary(Int32, Utf8)')); query IT -select * from dict_test; +select * from dict_test order by column1, column2; ---- +1 bar +1 foo 1 foo 2 bar query II -select count(distinct column1), count(distinct column2) from dict_test group by column1; +select count(distinct column1), count(distinct column2) from dict_test group by column1 order by column1; ---- -1 1 +1 2 1 1 statement ok drop table dict_test; +# avg_duration + +statement ok +create table d as values + (arrow_cast(1, 'Duration(Second)'), arrow_cast(2, 'Duration(Millisecond)'), arrow_cast(3, 'Duration(Microsecond)'), arrow_cast(4, 'Duration(Nanosecond)'), 1), + (arrow_cast(11, 'Duration(Second)'), arrow_cast(22, 'Duration(Millisecond)'), arrow_cast(33, 'Duration(Microsecond)'), arrow_cast(44, 'Duration(Nanosecond)'), 1); + +query ???? +SELECT avg(column1), avg(column2), avg(column3), avg(column4) FROM d; +---- +0 days 0 hours 0 mins 6 secs 0 days 0 hours 0 mins 0.012 secs 0 days 0 hours 0 mins 0.000018 secs 0 days 0 hours 0 mins 0.000000024 secs + +query ????I +SELECT avg(column1), avg(column2), avg(column3), avg(column4), column5 FROM d GROUP BY column5; +---- +0 days 0 hours 0 mins 6 secs 0 days 0 hours 0 mins 0.012 secs 0 days 0 hours 0 mins 0.000018 secs 0 days 0 hours 0 mins 0.000000024 secs 1 + +statement ok +drop table d; + +statement ok +create table d as values + (arrow_cast(1, 'Duration(Second)'), arrow_cast(2, 'Duration(Millisecond)'), arrow_cast(3, 'Duration(Microsecond)'), arrow_cast(4, 'Duration(Nanosecond)'), 1), + (arrow_cast(11, 'Duration(Second)'), arrow_cast(22, 'Duration(Millisecond)'), arrow_cast(33, 'Duration(Microsecond)'), arrow_cast(44, 'Duration(Nanosecond)'), 1), + (arrow_cast(5, 'Duration(Second)'), arrow_cast(10, 'Duration(Millisecond)'), arrow_cast(15, 'Duration(Microsecond)'), arrow_cast(20, 'Duration(Nanosecond)'), 2), + (arrow_cast(25, 'Duration(Second)'), arrow_cast(50, 'Duration(Millisecond)'), arrow_cast(75, 'Duration(Microsecond)'), arrow_cast(100, 'Duration(Nanosecond)'), 2), + (NULL, NULL, NULL, NULL, 1), + (NULL, NULL, NULL, NULL, 2); + + +query I? rowsort +SELECT column5, avg(column1) FROM d GROUP BY column5; +---- +1 0 days 0 hours 0 mins 6 secs +2 0 days 0 hours 0 mins 15 secs + +query I?? rowsort +SELECT column5, column1, avg(column1) OVER (PARTITION BY column5 ORDER BY column1 ROWS BETWEEN 1 PRECEDING AND CURRENT ROW) as window_avg +FROM d WHERE column1 IS NOT NULL; +---- +1 0 days 0 hours 0 mins 1 secs 0 days 0 hours 0 mins 1 secs +1 0 days 0 hours 0 mins 11 secs 0 days 0 hours 0 mins 6 secs +2 0 days 0 hours 0 mins 25 secs 0 days 0 hours 0 mins 15 secs +2 0 days 0 hours 0 mins 5 secs 0 days 0 hours 0 mins 5 secs + +# Cumulative average window function +query I?? +SELECT column5, column1, avg(column1) OVER (ORDER BY column5, column1 ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) as cumulative_avg +FROM d WHERE column1 IS NOT NULL; +---- +1 0 days 0 hours 0 mins 1 secs 0 days 0 hours 0 mins 1 secs +1 0 days 0 hours 0 mins 11 secs 0 days 0 hours 0 mins 6 secs +2 0 days 0 hours 0 mins 5 secs 0 days 0 hours 0 mins 5 secs +2 0 days 0 hours 0 mins 25 secs 0 days 0 hours 0 mins 10 secs + +# Centered average window function +query I?? +SELECT column5, column1, avg(column1) OVER (ORDER BY column5 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as centered_avg +FROM d WHERE column1 IS NOT NULL; +---- +1 0 days 0 hours 0 mins 1 secs 0 days 0 hours 0 mins 6 secs +1 0 days 0 hours 0 mins 11 secs 0 days 0 hours 0 mins 5 secs +2 0 days 0 hours 0 mins 5 secs 0 days 0 hours 0 mins 13 secs +2 0 days 0 hours 0 mins 25 secs 0 days 0 hours 0 mins 15 secs + +statement ok +drop table d; + +statement ok +create table dn as values + (arrow_cast(10, 'Duration(Second)'), 'a', 1), + (arrow_cast(20, 'Duration(Second)'), 'a', 2), + (NULL, 'b', 1), + (arrow_cast(40, 'Duration(Second)'), 'b', 2), + (arrow_cast(50, 'Duration(Second)'), 'c', 1), + (NULL, 'c', 2); + +query T?I +SELECT column2, avg(column1), column3 FROM dn GROUP BY column2, column3 ORDER BY column2, column3; +---- +a 0 days 0 hours 0 mins 10 secs 1 +a 0 days 0 hours 0 mins 20 secs 2 +b NULL 1 +b 0 days 0 hours 0 mins 40 secs 2 +c 0 days 0 hours 0 mins 50 secs 1 +c NULL 2 + +statement ok +drop table dn; # Prepare the table with dictionary values for testing statement ok @@ -5159,13 +5311,13 @@ physical_plan 06)----------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c3], file_type=csv, has_header=true query I -SELECT DISTINCT c3 FROM aggregate_test_100 group by c3 limit 5; +SELECT DISTINCT c3 FROM aggregate_test_100 group by c3 order by c3 limit 5; ---- -1 --40 -29 --85 --82 +-117 +-111 +-107 +-106 +-101 query TT EXPLAIN SELECT c2, c3 FROM aggregate_test_100 group by c2, c3 limit 5 offset 4; @@ -5183,13 +5335,13 @@ physical_plan 06)----------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c2, c3], file_type=csv, has_header=true query II -SELECT c2, c3 FROM aggregate_test_100 group by c2, c3 limit 5 offset 4; +SELECT c2, c3 FROM aggregate_test_100 group by c2, c3 order by c2, c3 limit 5 offset 4; ---- -5 -82 -4 -111 -3 104 -3 13 -1 38 +1 -56 +1 -25 +1 -24 +1 -8 +1 -5 # The limit should only apply to the aggregations which group by c3 query TT @@ -5218,12 +5370,12 @@ physical_plan 13)------------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c2, c3], file_type=csv, has_header=true query I -SELECT DISTINCT c3 FROM aggregate_test_100 WHERE c3 between 10 and 20 group by c2, c3 limit 4; +SELECT DISTINCT c3 FROM aggregate_test_100 WHERE c3 between 10 and 20 group by c3 order by c3 limit 4; ---- -13 -17 12 +13 14 +17 # An aggregate expression causes the limit to not be pushed to the aggregation query TT @@ -5268,11 +5420,11 @@ physical_plan 11)--------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c2, c3], file_type=csv, has_header=true query II -SELECT DISTINCT c3, c2 FROM aggregate_test_100 group by c2, c3 limit 3 offset 10; +SELECT DISTINCT c3, c2 FROM aggregate_test_100 group by c3, c2 order by c3, c2 limit 3 offset 10; ---- -57 1 --54 4 -112 3 +-95 3 +-94 5 +-90 4 query TT EXPLAIN SELECT c2, c3 FROM aggregate_test_100 group by rollup(c2, c3) limit 3; @@ -6758,7 +6910,7 @@ group1 0.0003 # median with all nulls statement ok create table group_median_all_nulls( - a STRING NOT NULL, + a STRING NOT NULL, b INT ) AS VALUES ( 'group0', NULL), @@ -6796,3 +6948,84 @@ select c2, count(*) from test WHERE 1 = 1 group by c2; 5 1 6 1 +# Min/Max struct +query ?? rowsort +WITH t AS (SELECT i as c1, i + 1 as c2 FROM generate_series(1, 10) t(i)) +SELECT MIN(c), MAX(c) FROM (SELECT STRUCT(c1 AS 'a', c2 AS 'b') AS c FROM t) +---- +{a: 1, b: 2} {a: 10, b: 11} + +# Min/Max struct with NULL +query ?? rowsort +WITH t AS (SELECT i as c1, i + 1 as c2 FROM generate_series(1, 10) t(i)) +SELECT MIN(c), MAX(c) FROM (SELECT CASE WHEN c1 % 2 == 0 THEN STRUCT(c1 AS 'a', c2 AS 'b') ELSE NULL END AS c FROM t) +---- +{a: 2, b: 3} {a: 10, b: 11} + +# Min/Max struct with two recordbatch +query ?? rowsort +SELECT MIN(c), MAX(c) FROM (SELECT STRUCT(1 as 'a', 2 as 'b') AS c UNION SELECT STRUCT(3 as 'a', 4 as 'b') AS c ) +---- +{a: 1, b: 2} {a: 3, b: 4} + +# Min/Max struct empty +query ?? rowsort +SELECT MIN(c), MAX(c) FROM (SELECT * FROM (SELECT STRUCT(1 as 'a', 2 as 'b') AS c) LIMIT 0) +---- +NULL NULL + +# Min/Max group struct +query I?? rowsort +WITH t AS (SELECT i as c1, i + 1 as c2 FROM generate_series(1, 10) t(i)) +SELECT key, MIN(c), MAX(c) FROM (SELECT STRUCT(c1 AS 'a', c2 AS 'b') AS c, (c1 % 2) AS key FROM t) GROUP BY key +---- +0 {a: 2, b: 3} {a: 10, b: 11} +1 {a: 1, b: 2} {a: 9, b: 10} + +# Min/Max group struct with NULL +query I?? rowsort +WITH t AS (SELECT i as c1, i + 1 as c2 FROM generate_series(1, 10) t(i)) +SELECT key, MIN(c), MAX(c) FROM (SELECT CASE WHEN c1 % 2 == 0 THEN STRUCT(c1 AS 'a', c2 AS 'b') ELSE NULL END AS c, (c1 % 2) AS key FROM t) GROUP BY key +---- +0 {a: 2, b: 3} {a: 10, b: 11} +1 NULL NULL + +# Min/Max group struct with NULL +query I?? rowsort +WITH t AS (SELECT i as c1, i + 1 as c2 FROM generate_series(1, 10) t(i)) +SELECT key, MIN(c), MAX(c) FROM (SELECT CASE WHEN c1 % 3 == 0 THEN STRUCT(c1 AS 'a', c2 AS 'b') ELSE NULL END AS c, (c1 % 2) AS key FROM t) GROUP BY key +---- +0 {a: 6, b: 7} {a: 6, b: 7} +1 {a: 3, b: 4} {a: 9, b: 10} + +# Min/Max struct empty +query ?? rowsort +WITH t AS (SELECT i as c1, i + 1 as c2 FROM generate_series(1, 10) t(i)) +SELECT MIN(c), MAX(c) FROM (SELECT STRUCT(c1 AS 'a', c2 AS 'b') AS c, (c1 % 2) AS key FROM t LIMIT 0) GROUP BY key +---- + +# Min/Max aggregation on struct with a single field +query ?? +WITH t AS (SELECT i as c1 FROM generate_series(1, 10) t(i)) +SELECT MIN(c), MAX(c) FROM (SELECT STRUCT(c1 AS 'a') AS c FROM t); +---- +{a: 1} {a: 10} + +# Min/Max aggregation on struct with identical first fields but different last fields +query ?? +SELECT MIN(column1),MAX(column1) FROM ( +VALUES + (STRUCT(1 AS 'a',2 AS 'b', 3 AS 'c')), + (STRUCT(1 AS 'a',2 AS 'b', 4 AS 'c')) +); +---- +{a: 1, b: 2, c: 3} {a: 1, b: 2, c: 4} + +query TI +SELECT column1, COUNT(DISTINCT column2) FROM ( +VALUES + ('x', arrow_cast('NAN','Float64')), + ('x', arrow_cast('NAN','Float64')) +) GROUP BY 1 ORDER BY 1; +---- +x 1 diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index f165d3bf66ba..ac96daed0d44 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -1204,7 +1204,7 @@ select array_element([1, 2], NULL); ---- NULL -query I +query ? select array_element(NULL, 2); ---- NULL @@ -1435,6 +1435,12 @@ NULL 23 NULL 43 5 NULL +# array_element of empty array +query T +select coalesce(array_element([], 1), array_element(NULL, 1), 'ok'); +---- +ok + ## array_max # array_max scalar function #1 (with positive index) @@ -1448,7 +1454,7 @@ select array_max(make_array(5, 3, 4, NULL, 6, NULL)); ---- 6 -query I +query ? select array_max(make_array(NULL, NULL)); ---- NULL @@ -1512,7 +1518,7 @@ select array_max(arrow_cast(make_array(1, 2, 3), 'FixedSizeList(3, Int64)')), ar ---- 3 1 -query I +query ? select array_max(make_array()); ---- NULL @@ -2177,7 +2183,7 @@ select array_any_value(1), array_any_value('a'), array_any_value(NULL); # array_any_value scalar function #1 (with null and non-null elements) -query ITII +query IT?I select array_any_value(make_array(NULL, 1, 2, 3, 4, 5)), array_any_value(make_array(NULL, 'h', 'e', 'l', 'l', 'o')), array_any_value(make_array(NULL, NULL)), array_any_value(make_array(NULL, NULL, 1, 2, 3)); ---- 1 h NULL 1 @@ -2348,6 +2354,11 @@ NULL [NULL, 51, 52, 54, 55, 56, 57, 58, 59, 60] [61, 62, 63, 64, 65, 66, 67, 68, 69, 70] +# test with empty table +query ? +select array_sort(column1, 'DESC', 'NULLS FIRST') from arrays_values where false; +---- + # test with empty array query ? select array_sort([]); @@ -2435,11 +2446,15 @@ select array_append(null, 1); ---- [1] -query error +query ? select array_append(null, [2, 3]); +---- +[[2, 3]] -query error +query ? select array_append(null, [[4]]); +---- +[[[4]]] query ???? select @@ -2716,8 +2731,10 @@ select array_prepend(null, [[1,2,3]]); # DuckDB: [[]] # ClickHouse: [[]] # TODO: We may also return [[]] -query error +query ? select array_prepend([], []); +---- +[[]] query ? select array_prepend(null, null); @@ -3080,22 +3097,25 @@ select array_concat( ---- [1, 2, 3] -# Concatenating Mixed types (doesn't work) -query error DataFusion error: Error during planning: It is not possible to concatenate arrays of different types\. Expected: List\(Field \{ name: "item", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\), got: List\(Field \{ name: "item", data_type: LargeUtf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) +# Concatenating Mixed types +query ? select array_concat( [arrow_cast('1', 'Utf8'), arrow_cast('2', 'Utf8')], [arrow_cast('3', 'LargeUtf8')] ); +---- +[1, 2, 3] -# Concatenating Mixed types (doesn't work) -query error DataFusion error: Error during planning: It is not possible to concatenate arrays of different types\. Expected: List\(Field \{ name: "item", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\), got: List\(Field \{ name: "item", data_type: Utf8View, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) -select array_concat( - [arrow_cast('1', 'Utf8'), arrow_cast('2', 'Utf8')], - [arrow_cast('3', 'Utf8View')] -); +# Concatenating Mixed types +query ?T +select + array_concat([arrow_cast('1', 'Utf8'), arrow_cast('2', 'Utf8')], [arrow_cast('3', 'Utf8View')]), + arrow_typeof(array_concat([arrow_cast('1', 'Utf8'), arrow_cast('2', 'Utf8')], [arrow_cast('3', 'Utf8View')])); +---- +[1, 2, 3] List(Field { name: "item", data_type: Utf8View, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) # array_concat error -query error DataFusion error: Error during planning: The array_concat function can only accept list as the args\. +query error DataFusion error: Error during planning: Execution error: Function 'array_concat' user-defined coercion failed with "Error during planning: array_concat does not support type Int64" select array_concat(1, 2); # array_concat scalar function #1 @@ -3406,15 +3426,11 @@ SELECT array_position(arrow_cast([1, 1, 100, 1, 1], 'LargeList(Int32)'), 100) ---- 3 -query I +query error DataFusion error: Error during planning: Failed to coerce arguments to satisfy a call to 'array_position' function: coercion from SELECT array_position([1, 2, 3], 'foo') ----- -NULL -query I +query error DataFusion error: Error during planning: Failed to coerce arguments to satisfy a call to 'array_position' function: coercion from SELECT array_position([1, 2, 3], 'foo', 2) ----- -NULL # list_position scalar function #5 (function alias `array_position`) query III @@ -4376,7 +4392,8 @@ select array_union(arrow_cast([1, 2, 3, 4], 'LargeList(Int64)'), arrow_cast([5, statement ok CREATE TABLE arrays_with_repeating_elements_for_union AS VALUES - ([1], [2]), + ([0, 1, 1], []), + ([1, 1], [2]), ([2, 3], [3]), ([3], [3, 4]) ; @@ -4384,6 +4401,7 @@ AS VALUES query ? select array_union(column1, column2) from arrays_with_repeating_elements_for_union; ---- +[0, 1] [1, 2] [2, 3] [3, 4] @@ -4391,6 +4409,7 @@ select array_union(column1, column2) from arrays_with_repeating_elements_for_uni query ? select array_union(arrow_cast(column1, 'LargeList(Int64)'), arrow_cast(column2, 'LargeList(Int64)')) from arrays_with_repeating_elements_for_union; ---- +[0, 1] [1, 2] [2, 3] [3, 4] @@ -4413,12 +4432,10 @@ select array_union(arrow_cast([], 'LargeList(Int64)'), arrow_cast([], 'LargeList query ? select array_union([[null]], []); ---- -[[NULL]] +[[]] -query ? +query error DataFusion error: Error during planning: Failed to coerce arguments to satisfy a call to 'array_union' function: select array_union(arrow_cast([[null]], 'LargeList(List(Int64))'), arrow_cast([], 'LargeList(Int64)')); ----- -[[NULL]] # array_union scalar function #8 query ? @@ -5223,6 +5240,19 @@ NULL 10 NULL 10 NULL 10 +# array_length for fixed sized list + +query III +select array_length(arrow_cast(make_array(1, 2, 3, 4, 5), 'FixedSizeList(5, Int64)')), array_length(arrow_cast(make_array(1, 2, 3), 'FixedSizeList(3, Int64)')), array_length(arrow_cast(make_array([1, 2], [3, 4], [5, 6]), 'FixedSizeList(3, List(Int64))')); +---- +5 3 3 + +query III +select array_length(arrow_cast(make_array(1, 2, 3, 4, 5), 'FixedSizeList(5, Int64)'), 1), array_length(arrow_cast(make_array(1, 2, 3), 'FixedSizeList(3, Int64)'), 1), array_length(arrow_cast(make_array([1, 2], [3, 4], [5, 6]), 'FixedSizeList(3, List(Int64))'), 1); +---- +5 3 3 + + query RRR select array_distance([2], [3]), list_distance([1], [2]), list_distance([1], [-2]); ---- @@ -6002,7 +6032,7 @@ physical_plan 04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))] 05)--------ProjectionExec: expr=[] 06)----------CoalesceBatchesExec: target_batch_size=8192 -07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8)), 1, 32) IN ([Literal { value: Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278") }, Literal { value: Utf8View("a") }, Literal { value: Utf8View("b") }, Literal { value: Utf8View("c") }]) +07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8)), 1, 32) IN ([Literal { value: Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"), field: Field { name: "7f4b18de3cfeb9b4ac78c381ee2ad278", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("a"), field: Field { name: "a", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("b"), field: Field { name: "b", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("c"), field: Field { name: "c", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }]) 08)--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 09)----------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] @@ -6031,7 +6061,7 @@ physical_plan 04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))] 05)--------ProjectionExec: expr=[] 06)----------CoalesceBatchesExec: target_batch_size=8192 -07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8)), 1, 32) IN ([Literal { value: Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278") }, Literal { value: Utf8View("a") }, Literal { value: Utf8View("b") }, Literal { value: Utf8View("c") }]) +07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8)), 1, 32) IN ([Literal { value: Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"), field: Field { name: "7f4b18de3cfeb9b4ac78c381ee2ad278", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("a"), field: Field { name: "a", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("b"), field: Field { name: "b", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("c"), field: Field { name: "c", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }]) 08)--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 09)----------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] @@ -6060,7 +6090,7 @@ physical_plan 04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))] 05)--------ProjectionExec: expr=[] 06)----------CoalesceBatchesExec: target_batch_size=8192 -07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8)), 1, 32) IN ([Literal { value: Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278") }, Literal { value: Utf8View("a") }, Literal { value: Utf8View("b") }, Literal { value: Utf8View("c") }]) +07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8)), 1, 32) IN ([Literal { value: Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"), field: Field { name: "7f4b18de3cfeb9b4ac78c381ee2ad278", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("a"), field: Field { name: "a", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("b"), field: Field { name: "b", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("c"), field: Field { name: "c", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }]) 08)--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 09)----------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] @@ -6120,7 +6150,7 @@ physical_plan 04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))] 05)--------ProjectionExec: expr=[] 06)----------CoalesceBatchesExec: target_batch_size=8192 -07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8)), 1, 32) IN ([Literal { value: Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278") }, Literal { value: Utf8View("a") }, Literal { value: Utf8View("b") }, Literal { value: Utf8View("c") }]) +07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8)), 1, 32) IN ([Literal { value: Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"), field: Field { name: "7f4b18de3cfeb9b4ac78c381ee2ad278", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("a"), field: Field { name: "a", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("b"), field: Field { name: "b", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("c"), field: Field { name: "c", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }]) 08)--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 09)----------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] @@ -6130,7 +6160,7 @@ select count(*) from test WHERE array_has([needle], needle); ---- 100000 -# The optimizer does not currently eliminate the filter; +# The optimizer does not currently eliminate the filter; # Instead, it's rewritten as `IS NULL OR NOT NULL` due to SQL null semantics query TT explain with test AS (SELECT substr(md5(i::text)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i)) @@ -6141,7 +6171,7 @@ logical_plan 02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] 03)----SubqueryAlias: test 04)------SubqueryAlias: t -05)--------Projection: +05)--------Projection: 06)----------Filter: substr(CAST(md5(CAST(tmp_table.value AS Utf8)) AS Utf8), Int64(1), Int64(32)) IS NOT NULL OR Boolean(NULL) 07)------------TableScan: tmp_table projection=[value] physical_plan @@ -6427,12 +6457,12 @@ select array_intersect(arrow_cast([1, 1, 2, 2, 3, 3], 'LargeList(Int64)'), null) query ? select array_intersect(null, [1, 1, 2, 2, 3, 3]); ---- -NULL +[] query ? select array_intersect(null, arrow_cast([1, 1, 2, 2, 3, 3], 'LargeList(Int64)')); ---- -NULL +[] query ? select array_intersect([], null); @@ -6457,12 +6487,12 @@ select array_intersect(arrow_cast([], 'LargeList(Int64)'), null); query ? select array_intersect(null, []); ---- -NULL +[] query ? select array_intersect(null, arrow_cast([], 'LargeList(Int64)')); ---- -NULL +[] query ? select array_intersect(null, null); @@ -7285,12 +7315,10 @@ select array_concat(column1, [7]) from arrays_values_v2; # flatten -#TODO: https://github.com/apache/datafusion/issues/7142 -# follow DuckDB -#query ? -#select flatten(NULL); -#---- -#NULL +query ? +select flatten(NULL); +---- +NULL # flatten with scalar values #1 query ??? @@ -7298,21 +7326,21 @@ select flatten(make_array(1, 2, 1, 3, 2)), flatten(make_array([1], [2, 3], [null], make_array(4, null, 5))), flatten(make_array([[1.1]], [[2.2]], [[3.3], [4.4]])); ---- -[1, 2, 1, 3, 2] [1, 2, 3, NULL, 4, NULL, 5] [1.1, 2.2, 3.3, 4.4] +[1, 2, 1, 3, 2] [1, 2, 3, NULL, 4, NULL, 5] [[1.1], [2.2], [3.3], [4.4]] query ??? select flatten(arrow_cast(make_array(1, 2, 1, 3, 2), 'LargeList(Int64)')), flatten(arrow_cast(make_array([1], [2, 3], [null], make_array(4, null, 5)), 'LargeList(LargeList(Int64))')), flatten(arrow_cast(make_array([[1.1]], [[2.2]], [[3.3], [4.4]]), 'LargeList(LargeList(LargeList(Float64)))')); ---- -[1, 2, 1, 3, 2] [1, 2, 3, NULL, 4, NULL, 5] [1.1, 2.2, 3.3, 4.4] +[1, 2, 1, 3, 2] [1, 2, 3, NULL, 4, NULL, 5] [[1.1], [2.2], [3.3], [4.4]] query ??? select flatten(arrow_cast(make_array(1, 2, 1, 3, 2), 'FixedSizeList(5, Int64)')), flatten(arrow_cast(make_array([1], [2, 3], [null], make_array(4, null, 5)), 'FixedSizeList(4, List(Int64))')), flatten(arrow_cast(make_array([[1.1], [2.2]], [[3.3], [4.4]]), 'FixedSizeList(2, List(List(Float64)))')); ---- -[1, 2, 1, 3, 2] [1, 2, 3, NULL, 4, NULL, 5] [1.1, 2.2, 3.3, 4.4] +[1, 2, 1, 3, 2] [1, 2, 3, NULL, 4, NULL, 5] [[1.1], [2.2], [3.3], [4.4]] # flatten with column values query ???? @@ -7322,8 +7350,8 @@ select flatten(column1), flatten(column4) from flatten_table; ---- -[1, 2, 3] [1, 2, 3, 4, 5, 6] [1, 2, 3] [1.0, 2.1, 2.2, 3.2, 3.3, 3.4] -[1, 2, 3, 4, 5, 6] [8] [1, 2, 3] [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] +[1, 2, 3] [[1, 2, 3], [4, 5], [6]] [[[1]], [[2, 3]]] [1.0, 2.1, 2.2, 3.2, 3.3, 3.4] +[1, 2, 3, 4, 5, 6] [[8]] [[[1, 2]], [[3]]] [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] query ???? select flatten(column1), @@ -7332,8 +7360,8 @@ select flatten(column1), flatten(column4) from large_flatten_table; ---- -[1, 2, 3] [1, 2, 3, 4, 5, 6] [1, 2, 3] [1.0, 2.1, 2.2, 3.2, 3.3, 3.4] -[1, 2, 3, 4, 5, 6] [8] [1, 2, 3] [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] +[1, 2, 3] [[1, 2, 3], [4, 5], [6]] [[[1]], [[2, 3]]] [1.0, 2.1, 2.2, 3.2, 3.3, 3.4] +[1, 2, 3, 4, 5, 6] [[8]] [[[1, 2]], [[3]]] [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] query ???? select flatten(column1), @@ -7342,8 +7370,19 @@ select flatten(column1), flatten(column4) from fixed_size_flatten_table; ---- -[1, 2, 3] [1, 2, 3, 4, 5, 6] [1, 2, 3] [1.0, 2.1, 2.2, 3.2, 3.3, 3.4] -[1, 2, 3, 4, 5, 6] [8, 9, 10, 11, 12, 13] [1, 2, 3] [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] +[1, 2, 3] [[1, 2, 3], [4, 5], [6]] [[[1]], [[2, 3]]] [1.0, 2.1, 2.2, 3.2, 3.3, 3.4] +[1, 2, 3, 4, 5, 6] [[8], [9, 10], [11, 12, 13]] [[[1, 2]], [[3]]] [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] + +# flatten with different inner list type +query ?????? +select flatten(arrow_cast(make_array([1, 2], [3, 4]), 'List(FixedSizeList(2, Int64))')), + flatten(arrow_cast(make_array([[1, 2]], [[3, 4]]), 'List(FixedSizeList(1, List(Int64)))')), + flatten(arrow_cast(make_array([1, 2], [3, 4]), 'LargeList(List(Int64))')), + flatten(arrow_cast(make_array([[1, 2]], [[3, 4]]), 'LargeList(List(List(Int64)))')), + flatten(arrow_cast(make_array([1, 2], [3, 4]), 'LargeList(FixedSizeList(2, Int64))')), + flatten(arrow_cast(make_array([[1, 2]], [[3, 4]]), 'LargeList(FixedSizeList(1, List(Int64)))')) +---- +[1, 2, 3, 4] [[1, 2], [3, 4]] [1, 2, 3, 4] [[1, 2], [3, 4]] [1, 2, 3, 4] [[1, 2], [3, 4]] ## empty (aliases: `array_empty`, `list_empty`) # empty scalar function #1 diff --git a/datafusion/sqllogictest/test_files/array_query.slt b/datafusion/sqllogictest/test_files/array_query.slt index 8fde295e6051..65d4fa495e3b 100644 --- a/datafusion/sqllogictest/test_files/array_query.slt +++ b/datafusion/sqllogictest/test_files/array_query.slt @@ -108,11 +108,15 @@ SELECT * FROM data WHERE column2 is not distinct from null; # Aggregates ########### -query error Internal error: Min/Max accumulator not implemented for type List +query ? SELECT min(column1) FROM data; +---- +[1, 2, 3] -query error Internal error: Min/Max accumulator not implemented for type List +query ? SELECT max(column1) FROM data; +---- +[2, 3] query I SELECT count(column1) FROM data; diff --git a/datafusion/sqllogictest/test_files/avro.slt b/datafusion/sqllogictest/test_files/avro.slt index 1b4150b074cc..4573af1d59b1 100644 --- a/datafusion/sqllogictest/test_files/avro.slt +++ b/datafusion/sqllogictest/test_files/avro.slt @@ -15,6 +15,10 @@ # specific language governing permissions and limitations # under the License. +# Currently, the avro not support Utf8View type, so we disable the map_varchar_to_utf8view +# After https://github.com/apache/arrow-rs/issues/7262 released, we can remove this setting +statement ok +set datafusion.sql_parser.map_varchar_to_utf8view = false; statement ok CREATE EXTERNAL TABLE alltypes_plain ( @@ -253,3 +257,13 @@ physical_plan 04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))] 05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 06)----------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/avro/alltypes_plain.avro]]}, file_type=avro + +# test column projection order from avro file +query ITII +SELECT id, string_col, int_col, bigint_col FROM alltypes_plain ORDER BY id LIMIT 5 +---- +0 0 0 0 +1 1 1 10 +2 0 0 0 +3 1 1 10 +4 0 0 0 diff --git a/datafusion/sqllogictest/test_files/binary.slt b/datafusion/sqllogictest/test_files/binary.slt index 5c5f9d510e55..1077c32e46f3 100644 --- a/datafusion/sqllogictest/test_files/binary.slt +++ b/datafusion/sqllogictest/test_files/binary.slt @@ -147,8 +147,45 @@ query error DataFusion error: Error during planning: Cannot infer common argumen SELECT column1, column1 = arrow_cast(X'0102', 'FixedSizeBinary(2)') FROM t # Comparison to different sized Binary -query error DataFusion error: Error during planning: Cannot infer common argument type for comparison operation FixedSizeBinary\(3\) = Binary +query ?B SELECT column1, column1 = X'0102' FROM t +---- +000102 false +003102 false +NULL NULL +ff0102 false +000102 false + +query ?B +SELECT column1, column1 = X'000102' FROM t +---- +000102 true +003102 false +NULL NULL +ff0102 false +000102 true + +query ?B +SELECT arrow_cast(column1, 'FixedSizeBinary(3)'), arrow_cast(column1, 'FixedSizeBinary(3)') = arrow_cast(arrow_cast(X'000102', 'FixedSizeBinary(3)'), 'BinaryView') FROM t; +---- +000102 true +003102 false +NULL NULL +ff0102 false +000102 true + +# Plan should not have a cast of the column (should have casted the literal +# to FixedSizeBinary as that is much faster) + +query TT +explain SELECT column1, column1 = X'000102' FROM t +---- +logical_plan +01)Projection: t.column1, t.column1 = FixedSizeBinary(3, "0,1,2") AS t.column1 = Binary("0,1,2") +02)--TableScan: t projection=[column1] +physical_plan +01)ProjectionExec: expr=[column1@0 as column1, column1@0 = 000102 as t.column1 = Binary("0,1,2")] +02)--DataSourceExec: partitions=1, partition_sizes=[1] statement ok drop table t_source diff --git a/datafusion/sqllogictest/test_files/clickbench_extended.slt b/datafusion/sqllogictest/test_files/clickbench_extended.slt new file mode 100644 index 000000000000..ee3e33551ee3 --- /dev/null +++ b/datafusion/sqllogictest/test_files/clickbench_extended.slt @@ -0,0 +1,66 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +# DataFusion specific ClickBench "Extended" Queries +# See data provenance notes in clickbench.slt + +statement ok +CREATE EXTERNAL TABLE hits +STORED AS PARQUET +LOCATION '../core/tests/data/clickbench_hits_10.parquet'; + +# If you change any of these queries, please change the corresponding query in +# benchmarks/queries/clickbench/extended.sql and update the README. + +query III +SELECT COUNT(DISTINCT "SearchPhrase"), COUNT(DISTINCT "MobilePhone"), COUNT(DISTINCT "MobilePhoneModel") FROM hits; +---- +1 1 1 + +query III +SELECT COUNT(DISTINCT "HitColor"), COUNT(DISTINCT "BrowserCountry"), COUNT(DISTINCT "BrowserLanguage") FROM hits; +---- +1 1 1 + +query TIIII +SELECT "BrowserCountry", COUNT(DISTINCT "SocialNetwork"), COUNT(DISTINCT "HitColor"), COUNT(DISTINCT "BrowserLanguage"), COUNT(DISTINCT "SocialAction") FROM hits GROUP BY 1 ORDER BY 2 DESC LIMIT 10; +---- +� 1 1 1 1 + +query IIIRRRR +SELECT "SocialSourceNetworkID", "RegionID", COUNT(*), AVG("Age"), AVG("ParamPrice"), STDDEV("ParamPrice") as s, VAR("ParamPrice") FROM hits GROUP BY "SocialSourceNetworkID", "RegionID" HAVING s IS NOT NULL ORDER BY s DESC LIMIT 10; +---- +0 839 6 0 0 0 0 +0 197 2 0 0 0 0 + +query IIIIII +SELECT "ClientIP", "WatchID", COUNT(*) c, MIN("ResponseStartTiming") tmin, MEDIAN("ResponseStartTiming") tmed, MAX("ResponseStartTiming") tmax FROM hits WHERE "JavaEnable" = 0 GROUP BY "ClientIP", "WatchID" HAVING c > 1 ORDER BY tmed DESC LIMIT 10; +---- + +query IIIIII +SELECT "ClientIP", "WatchID", COUNT(*) c, MIN("ResponseStartTiming") tmin, APPROX_PERCENTILE_CONT(0.95) WITHIN GROUP (ORDER BY "ResponseStartTiming") tp95, MAX("ResponseStartTiming") tmax FROM 'hits' WHERE "JavaEnable" = 0 GROUP BY "ClientIP", "WatchID" HAVING c > 1 ORDER BY tp95 DESC LIMIT 10; +---- + +query I +SELECT COUNT(*) AS ShareCount FROM hits WHERE "IsMobile" = 1 AND "MobilePhoneModel" LIKE 'iPhone%' AND "SocialAction" = 'share' AND "SocialSourceNetworkID" IN (5, 12) AND "ClientTimeZone" BETWEEN -5 AND 5 AND regexp_match("Referer", '\/campaign\/(spring|summer)_promo') IS NOT NULL AND CASE WHEN split_part(split_part("URL", 'resolution=', 2), '&', 1) ~ '^\d+$' THEN split_part(split_part("URL", 'resolution=', 2), '&', 1)::INT ELSE 0 END > 1920 AND levenshtein(CAST("UTMSource" AS STRING), CAST("UTMCampaign" AS STRING)) < 3; +---- +0 + + +statement ok +drop table hits; diff --git a/datafusion/sqllogictest/test_files/coalesce.slt b/datafusion/sqllogictest/test_files/coalesce.slt index e7cf31dc690b..9740bade5e27 100644 --- a/datafusion/sqllogictest/test_files/coalesce.slt +++ b/datafusion/sqllogictest/test_files/coalesce.slt @@ -260,8 +260,8 @@ select arrow_typeof(coalesce(c, arrow_cast('b', 'Dictionary(Int32, Utf8)'))) from t; ---- -a Dictionary(Int32, Utf8) -b Dictionary(Int32, Utf8) +a Utf8View +b Utf8View statement ok drop table t; diff --git a/datafusion/sqllogictest/test_files/copy.slt b/datafusion/sqllogictest/test_files/copy.slt index 925f96bd4ac0..5eeb05e814ac 100644 --- a/datafusion/sqllogictest/test_files/copy.slt +++ b/datafusion/sqllogictest/test_files/copy.slt @@ -637,7 +637,7 @@ query error DataFusion error: SQL error: ParserError\("Expected: \), found: EOF" COPY (select col2, sum(col1) from source_table # Copy from table with non literal -query error DataFusion error: SQL error: ParserError\("Unexpected token \("\) +query error DataFusion error: SQL error: ParserError\("Expected: end of statement or ;, found: \( at Line: 1, Column: 44"\) COPY source_table to '/tmp/table.parquet' (row_group_size 55 + 102); # Copy using execution.keep_partition_by_columns with an invalid value diff --git a/datafusion/sqllogictest/test_files/create_external_table.slt b/datafusion/sqllogictest/test_files/create_external_table.slt index bb66aef2514c..03cb5edb5fcc 100644 --- a/datafusion/sqllogictest/test_files/create_external_table.slt +++ b/datafusion/sqllogictest/test_files/create_external_table.slt @@ -77,7 +77,7 @@ statement error DataFusion error: SQL error: ParserError\("Expected: HEADER, fou CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV WITH LOCATION 'foo.csv'; # Unrecognized random clause -statement error DataFusion error: SQL error: ParserError\("Unexpected token FOOBAR"\) +statement error DataFusion error: SQL error: ParserError\("Expected: end of statement or ;, found: FOOBAR at Line: 1, Column: 47"\) CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV FOOBAR BARBAR BARFOO LOCATION 'foo.csv'; # Missing partition column diff --git a/datafusion/sqllogictest/test_files/cte.slt b/datafusion/sqllogictest/test_files/cte.slt index e019af9775a4..32320a06f4fb 100644 --- a/datafusion/sqllogictest/test_files/cte.slt +++ b/datafusion/sqllogictest/test_files/cte.slt @@ -722,7 +722,7 @@ logical_plan 03)----Projection: Int64(1) AS val 04)------EmptyRelation 05)----Projection: Int64(2) AS val -06)------Cross Join: +06)------Cross Join: 07)--------Filter: recursive_cte.val < Int64(2) 08)----------TableScan: recursive_cte 09)--------SubqueryAlias: sub_cte diff --git a/datafusion/sqllogictest/test_files/ddl.slt b/datafusion/sqllogictest/test_files/ddl.slt index 088d0155a66f..1e95e426f3e0 100644 --- a/datafusion/sqllogictest/test_files/ddl.slt +++ b/datafusion/sqllogictest/test_files/ddl.slt @@ -819,7 +819,7 @@ show columns FROM table_with_pk; ---- datafusion public table_with_pk sn Int32 NO datafusion public table_with_pk ts Timestamp(Nanosecond, Some("+00:00")) NO -datafusion public table_with_pk currency Utf8 NO +datafusion public table_with_pk currency Utf8View NO datafusion public table_with_pk amount Float32 YES statement ok @@ -835,8 +835,8 @@ CREATE TABLE t1(c1 VARCHAR(10) NOT NULL, c2 VARCHAR); query TTT DESCRIBE t1; ---- -c1 Utf8 NO -c2 Utf8 YES +c1 Utf8View NO +c2 Utf8View YES statement ok set datafusion.sql_parser.map_varchar_to_utf8view = true; diff --git a/datafusion/sqllogictest/test_files/delete.slt b/datafusion/sqllogictest/test_files/delete.slt new file mode 100644 index 000000000000..258318f09423 --- /dev/null +++ b/datafusion/sqllogictest/test_files/delete.slt @@ -0,0 +1,107 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +########## +## Delete Tests +########## + +statement ok +create table t1(a int, b varchar, c double, d int); + +# Turn off the optimizer to make the logical plan closer to the initial one +statement ok +set datafusion.optimizer.max_passes = 0; + + +# Delete all +query TT +explain delete from t1; +---- +logical_plan +01)Dml: op=[Delete] table=[t1] +02)--TableScan: t1 +physical_plan_error This feature is not implemented: Unsupported logical plan: Dml(Delete) + + +# Filtered by existing columns +query TT +explain delete from t1 where a = 1 and b = 2 and c > 3 and d != 4; +---- +logical_plan +01)Dml: op=[Delete] table=[t1] +02)--Filter: CAST(t1.a AS Int64) = Int64(1) AND t1.b = CAST(Int64(2) AS Utf8View) AND t1.c > CAST(Int64(3) AS Float64) AND CAST(t1.d AS Int64) != Int64(4) +03)----TableScan: t1 +physical_plan_error This feature is not implemented: Unsupported logical plan: Dml(Delete) + + +# Filtered by existing columns, using qualified and unqualified names +query TT +explain delete from t1 where t1.a = 1 and b = 2 and t1.c > 3 and d != 4; +---- +logical_plan +01)Dml: op=[Delete] table=[t1] +02)--Filter: CAST(t1.a AS Int64) = Int64(1) AND t1.b = CAST(Int64(2) AS Utf8View) AND t1.c > CAST(Int64(3) AS Float64) AND CAST(t1.d AS Int64) != Int64(4) +03)----TableScan: t1 +physical_plan_error This feature is not implemented: Unsupported logical plan: Dml(Delete) + + +# Filtered by a mix of columns and literal predicates +query TT +explain delete from t1 where a = 1 and 1 = 1 and true; +---- +logical_plan +01)Dml: op=[Delete] table=[t1] +02)--Filter: CAST(t1.a AS Int64) = Int64(1) AND Int64(1) = Int64(1) AND Boolean(true) +03)----TableScan: t1 +physical_plan_error This feature is not implemented: Unsupported logical plan: Dml(Delete) + + +# Deleting by columns that do not exist returns an error +query error DataFusion error: Schema error: No field named e. Valid fields are t1.a, t1.b, t1.c, t1.d. +explain delete from t1 where e = 1; + + +# Filtering using subqueries + +statement ok +create table t2(a int, b varchar, c double, d int); + +query TT +explain delete from t1 where a = (select max(a) from t2 where t1.b = t2.b); +---- +logical_plan +01)Dml: op=[Delete] table=[t1] +02)--Filter: t1.a = () +03)----Subquery: +04)------Projection: max(t2.a) +05)--------Aggregate: groupBy=[[]], aggr=[[max(t2.a)]] +06)----------Filter: outer_ref(t1.b) = t2.b +07)------------TableScan: t2 +08)----TableScan: t1 +physical_plan_error This feature is not implemented: Physical plan does not support logical expression ScalarSubquery() + +query TT +explain delete from t1 where a in (select a from t2); +---- +logical_plan +01)Dml: op=[Delete] table=[t1] +02)--Filter: t1.a IN () +03)----Subquery: +04)------Projection: t2.a +05)--------TableScan: t2 +06)----TableScan: t1 +physical_plan_error This feature is not implemented: Physical plan does not support logical expression InSubquery(InSubquery { expr: Column(Column { relation: Some(Bare { table: "t1" }), name: "a" }), subquery: , negated: false }) diff --git a/datafusion/sqllogictest/test_files/dictionary.slt b/datafusion/sqllogictest/test_files/dictionary.slt index 1769f42c2d2a..d241e61f33ff 100644 --- a/datafusion/sqllogictest/test_files/dictionary.slt +++ b/datafusion/sqllogictest/test_files/dictionary.slt @@ -456,4 +456,4 @@ statement ok CREATE TABLE test0 AS VALUES ('foo',1), ('bar',2), ('foo',3); statement ok -COPY (SELECT arrow_cast(column1, 'Dictionary(Int32, Utf8)') AS column1, column2 FROM test0) TO 'test_files/scratch/copy/part_dict_test' STORED AS PARQUET PARTITIONED BY (column1); \ No newline at end of file +COPY (SELECT arrow_cast(column1, 'Dictionary(Int32, Utf8)') AS column1, column2 FROM test0) TO 'test_files/scratch/copy/part_dict_test' STORED AS PARQUET PARTITIONED BY (column1); diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt index deff793e5110..2df8a9dfbae4 100644 --- a/datafusion/sqllogictest/test_files/explain.slt +++ b/datafusion/sqllogictest/test_files/explain.slt @@ -183,6 +183,7 @@ logical_plan after replace_distinct_aggregate SAME TEXT AS ABOVE logical_plan after eliminate_join SAME TEXT AS ABOVE logical_plan after decorrelate_predicate_subquery SAME TEXT AS ABOVE logical_plan after scalar_subquery_to_join SAME TEXT AS ABOVE +logical_plan after decorrelate_lateral_join SAME TEXT AS ABOVE logical_plan after extract_equijoin_predicate SAME TEXT AS ABOVE logical_plan after eliminate_duplicated_expr SAME TEXT AS ABOVE logical_plan after eliminate_filter SAME TEXT AS ABOVE @@ -204,6 +205,7 @@ logical_plan after replace_distinct_aggregate SAME TEXT AS ABOVE logical_plan after eliminate_join SAME TEXT AS ABOVE logical_plan after decorrelate_predicate_subquery SAME TEXT AS ABOVE logical_plan after scalar_subquery_to_join SAME TEXT AS ABOVE +logical_plan after decorrelate_lateral_join SAME TEXT AS ABOVE logical_plan after extract_equijoin_predicate SAME TEXT AS ABOVE logical_plan after eliminate_duplicated_expr SAME TEXT AS ABOVE logical_plan after eliminate_filter SAME TEXT AS ABOVE @@ -229,6 +231,7 @@ physical_plan after OutputRequirements physical_plan after aggregate_statistics SAME TEXT AS ABOVE physical_plan after join_selection SAME TEXT AS ABOVE physical_plan after LimitedDistinctAggregation SAME TEXT AS ABOVE +physical_plan after FilterPushdown SAME TEXT AS ABOVE physical_plan after EnforceDistribution SAME TEXT AS ABOVE physical_plan after CombinePartialFinalAggregate SAME TEXT AS ABOVE physical_plan after EnforceSorting SAME TEXT AS ABOVE @@ -303,6 +306,7 @@ physical_plan after OutputRequirements physical_plan after aggregate_statistics SAME TEXT AS ABOVE physical_plan after join_selection SAME TEXT AS ABOVE physical_plan after LimitedDistinctAggregation SAME TEXT AS ABOVE +physical_plan after FilterPushdown SAME TEXT AS ABOVE physical_plan after EnforceDistribution SAME TEXT AS ABOVE physical_plan after CombinePartialFinalAggregate SAME TEXT AS ABOVE physical_plan after EnforceSorting SAME TEXT AS ABOVE @@ -343,6 +347,7 @@ physical_plan after OutputRequirements physical_plan after aggregate_statistics SAME TEXT AS ABOVE physical_plan after join_selection SAME TEXT AS ABOVE physical_plan after LimitedDistinctAggregation SAME TEXT AS ABOVE +physical_plan after FilterPushdown SAME TEXT AS ABOVE physical_plan after EnforceDistribution SAME TEXT AS ABOVE physical_plan after CombinePartialFinalAggregate SAME TEXT AS ABOVE physical_plan after EnforceSorting SAME TEXT AS ABOVE @@ -415,7 +420,7 @@ logical_plan 01)LeftSemi Join: 02)--TableScan: t1 projection=[a] 03)--SubqueryAlias: __correlated_sq_1 -04)----Projection: +04)----Projection: 05)------Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] 06)--------TableScan: t2 projection=[] physical_plan diff --git a/datafusion/sqllogictest/test_files/explain_tree.slt b/datafusion/sqllogictest/test_files/explain_tree.slt index 15bf61576571..22183195c3df 100644 --- a/datafusion/sqllogictest/test_files/explain_tree.slt +++ b/datafusion/sqllogictest/test_files/explain_tree.slt @@ -291,47 +291,40 @@ explain SELECT table1.string_col, table2.date_col FROM table1 JOIN table2 ON tab ---- physical_plan 01)┌───────────────────────────┐ -02)│ CoalesceBatchesExec │ +02)│ ProjectionExec │ 03)│ -------------------- │ -04)│ target_batch_size: │ -05)│ 8192 │ -06)└─────────────┬─────────────┘ -07)┌─────────────┴─────────────┐ -08)│ HashJoinExec │ -09)│ -------------------- │ -10)│ on: ├──────────────┐ -11)│ (int_col = int_col) │ │ -12)└─────────────┬─────────────┘ │ -13)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ -14)│ CoalesceBatchesExec ││ CoalesceBatchesExec │ -15)│ -------------------- ││ -------------------- │ -16)│ target_batch_size: ││ target_batch_size: │ -17)│ 8192 ││ 8192 │ -18)└─────────────┬─────────────┘└─────────────┬─────────────┘ -19)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ -20)│ RepartitionExec ││ RepartitionExec │ -21)│ -------------------- ││ -------------------- │ -22)│ partition_count(in->out): ││ partition_count(in->out): │ -23)│ 4 -> 4 ││ 4 -> 4 │ -24)│ ││ │ -25)│ partitioning_scheme: ││ partitioning_scheme: │ -26)│ Hash([int_col@0], 4) ││ Hash([int_col@0], 4) │ -27)└─────────────┬─────────────┘└─────────────┬─────────────┘ -28)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ -29)│ RepartitionExec ││ RepartitionExec │ -30)│ -------------------- ││ -------------------- │ -31)│ partition_count(in->out): ││ partition_count(in->out): │ -32)│ 1 -> 4 ││ 1 -> 4 │ -33)│ ││ │ -34)│ partitioning_scheme: ││ partitioning_scheme: │ -35)│ RoundRobinBatch(4) ││ RoundRobinBatch(4) │ -36)└─────────────┬─────────────┘└─────────────┬─────────────┘ -37)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ -38)│ DataSourceExec ││ DataSourceExec │ -39)│ -------------------- ││ -------------------- │ -40)│ files: 1 ││ files: 1 │ -41)│ format: csv ││ format: parquet │ -42)└───────────────────────────┘└───────────────────────────┘ +04)│ date_col: date_col │ +05)│ │ +06)│ string_col: │ +07)│ string_col │ +08)└─────────────┬─────────────┘ +09)┌─────────────┴─────────────┐ +10)│ CoalesceBatchesExec │ +11)│ -------------------- │ +12)│ target_batch_size: │ +13)│ 8192 │ +14)└─────────────┬─────────────┘ +15)┌─────────────┴─────────────┐ +16)│ HashJoinExec │ +17)│ -------------------- │ +18)│ on: ├──────────────┐ +19)│ (int_col = int_col) │ │ +20)└─────────────┬─────────────┘ │ +21)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ +22)│ DataSourceExec ││ RepartitionExec │ +23)│ -------------------- ││ -------------------- │ +24)│ files: 1 ││ partition_count(in->out): │ +25)│ format: parquet ││ 1 -> 4 │ +26)│ ││ │ +27)│ ││ partitioning_scheme: │ +28)│ ││ RoundRobinBatch(4) │ +29)└───────────────────────────┘└─────────────┬─────────────┘ +30)-----------------------------┌─────────────┴─────────────┐ +31)-----------------------------│ DataSourceExec │ +32)-----------------------------│ -------------------- │ +33)-----------------------------│ files: 1 │ +34)-----------------------------│ format: csv │ +35)-----------------------------└───────────────────────────┘ # 3 Joins query TT @@ -365,48 +358,41 @@ physical_plan 19)│ (int_col = int_col) │ │ 20)└─────────────┬─────────────┘ │ 21)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ -22)│ DataSourceExec ││ CoalesceBatchesExec │ +22)│ DataSourceExec ││ ProjectionExec │ 23)│ -------------------- ││ -------------------- │ -24)│ bytes: 1560 ││ target_batch_size: │ -25)│ format: memory ││ 8192 │ +24)│ bytes: 1560 ││ date_col: date_col │ +25)│ format: memory ││ int_col: int_col │ 26)│ rows: 1 ││ │ -27)└───────────────────────────┘└─────────────┬─────────────┘ -28)-----------------------------┌─────────────┴─────────────┐ -29)-----------------------------│ HashJoinExec │ -30)-----------------------------│ -------------------- │ -31)-----------------------------│ on: ├──────────────┐ -32)-----------------------------│ (int_col = int_col) │ │ -33)-----------------------------└─────────────┬─────────────┘ │ -34)-----------------------------┌─────────────┴─────────────┐┌─────────────┴─────────────┐ -35)-----------------------------│ CoalesceBatchesExec ││ CoalesceBatchesExec │ -36)-----------------------------│ -------------------- ││ -------------------- │ -37)-----------------------------│ target_batch_size: ││ target_batch_size: │ -38)-----------------------------│ 8192 ││ 8192 │ -39)-----------------------------└─────────────┬─────────────┘└─────────────┬─────────────┘ -40)-----------------------------┌─────────────┴─────────────┐┌─────────────┴─────────────┐ -41)-----------------------------│ RepartitionExec ││ RepartitionExec │ -42)-----------------------------│ -------------------- ││ -------------------- │ -43)-----------------------------│ partition_count(in->out): ││ partition_count(in->out): │ -44)-----------------------------│ 4 -> 4 ││ 4 -> 4 │ -45)-----------------------------│ ││ │ -46)-----------------------------│ partitioning_scheme: ││ partitioning_scheme: │ -47)-----------------------------│ Hash([int_col@0], 4) ││ Hash([int_col@0], 4) │ -48)-----------------------------└─────────────┬─────────────┘└─────────────┬─────────────┘ -49)-----------------------------┌─────────────┴─────────────┐┌─────────────┴─────────────┐ -50)-----------------------------│ RepartitionExec ││ RepartitionExec │ -51)-----------------------------│ -------------------- ││ -------------------- │ -52)-----------------------------│ partition_count(in->out): ││ partition_count(in->out): │ -53)-----------------------------│ 1 -> 4 ││ 1 -> 4 │ -54)-----------------------------│ ││ │ -55)-----------------------------│ partitioning_scheme: ││ partitioning_scheme: │ -56)-----------------------------│ RoundRobinBatch(4) ││ RoundRobinBatch(4) │ -57)-----------------------------└─────────────┬─────────────┘└─────────────┬─────────────┘ -58)-----------------------------┌─────────────┴─────────────┐┌─────────────┴─────────────┐ -59)-----------------------------│ DataSourceExec ││ DataSourceExec │ -60)-----------------------------│ -------------------- ││ -------------------- │ -61)-----------------------------│ files: 1 ││ files: 1 │ -62)-----------------------------│ format: csv ││ format: parquet │ -63)-----------------------------└───────────────────────────┘└───────────────────────────┘ +27)│ ││ string_col: │ +28)│ ││ string_col │ +29)└───────────────────────────┘└─────────────┬─────────────┘ +30)-----------------------------┌─────────────┴─────────────┐ +31)-----------------------------│ CoalesceBatchesExec │ +32)-----------------------------│ -------------------- │ +33)-----------------------------│ target_batch_size: │ +34)-----------------------------│ 8192 │ +35)-----------------------------└─────────────┬─────────────┘ +36)-----------------------------┌─────────────┴─────────────┐ +37)-----------------------------│ HashJoinExec │ +38)-----------------------------│ -------------------- │ +39)-----------------------------│ on: ├──────────────┐ +40)-----------------------------│ (int_col = int_col) │ │ +41)-----------------------------└─────────────┬─────────────┘ │ +42)-----------------------------┌─────────────┴─────────────┐┌─────────────┴─────────────┐ +43)-----------------------------│ DataSourceExec ││ RepartitionExec │ +44)-----------------------------│ -------------------- ││ -------------------- │ +45)-----------------------------│ files: 1 ││ partition_count(in->out): │ +46)-----------------------------│ format: parquet ││ 1 -> 4 │ +47)-----------------------------│ ││ │ +48)-----------------------------│ ││ partitioning_scheme: │ +49)-----------------------------│ ││ RoundRobinBatch(4) │ +50)-----------------------------└───────────────────────────┘└─────────────┬─────────────┘ +51)----------------------------------------------------------┌─────────────┴─────────────┐ +52)----------------------------------------------------------│ DataSourceExec │ +53)----------------------------------------------------------│ -------------------- │ +54)----------------------------------------------------------│ files: 1 │ +55)----------------------------------------------------------│ format: csv │ +56)----------------------------------------------------------└───────────────────────────┘ # Long Filter (demonstrate what happens with wrapping) query TT @@ -1029,20 +1015,11 @@ physical_plan 11)│ bigint_col │ 12)└─────────────┬─────────────┘ 13)┌─────────────┴─────────────┐ -14)│ RepartitionExec │ +14)│ DataSourceExec │ 15)│ -------------------- │ -16)│ partition_count(in->out): │ -17)│ 1 -> 4 │ -18)│ │ -19)│ partitioning_scheme: │ -20)│ RoundRobinBatch(4) │ -21)└─────────────┬─────────────┘ -22)┌─────────────┴─────────────┐ -23)│ DataSourceExec │ -24)│ -------------------- │ -25)│ files: 1 │ -26)│ format: parquet │ -27)└───────────────────────────┘ +16)│ files: 1 │ +17)│ format: parquet │ +18)└───────────────────────────┘ # Query with projection on memory @@ -1186,69 +1163,64 @@ explain select * from table1 inner join table2 on table1.int_col = table2.int_co ---- physical_plan 01)┌───────────────────────────┐ -02)│ CoalesceBatchesExec │ +02)│ ProjectionExec │ 03)│ -------------------- │ -04)│ target_batch_size: │ -05)│ 8192 │ -06)└─────────────┬─────────────┘ -07)┌─────────────┴─────────────┐ -08)│ HashJoinExec │ -09)│ -------------------- │ -10)│ on: │ -11)│ (int_col = int_col), (CAST├──────────────┐ -12)│ (table1.string_col AS │ │ -13)│ Utf8View) = │ │ -14)│ string_col) │ │ -15)└─────────────┬─────────────┘ │ -16)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ -17)│ CoalesceBatchesExec ││ CoalesceBatchesExec │ -18)│ -------------------- ││ -------------------- │ -19)│ target_batch_size: ││ target_batch_size: │ -20)│ 8192 ││ 8192 │ -21)└─────────────┬─────────────┘└─────────────┬─────────────┘ -22)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ -23)│ RepartitionExec ││ RepartitionExec │ -24)│ -------------------- ││ -------------------- │ -25)│ partition_count(in->out): ││ partition_count(in->out): │ -26)│ 4 -> 4 ││ 4 -> 4 │ -27)│ ││ │ -28)│ partitioning_scheme: ││ partitioning_scheme: │ -29)│ Hash([int_col@0, CAST ││ Hash([int_col@0, │ -30)│ (table1.string_col ││ string_col@1], │ -31)│ AS Utf8View)@4], 4) ││ 4) │ -32)└─────────────┬─────────────┘└─────────────┬─────────────┘ -33)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ -34)│ ProjectionExec ││ RepartitionExec │ -35)│ -------------------- ││ -------------------- │ -36)│ CAST(table1.string_col AS ││ partition_count(in->out): │ -37)│ Utf8View): ││ 1 -> 4 │ -38)│ CAST(string_col AS ││ │ -39)│ Utf8View) ││ partitioning_scheme: │ -40)│ ││ RoundRobinBatch(4) │ -41)│ bigint_col: ││ │ -42)│ bigint_col ││ │ -43)│ ││ │ -44)│ date_col: date_col ││ │ -45)│ int_col: int_col ││ │ -46)│ ││ │ -47)│ string_col: ││ │ -48)│ string_col ││ │ -49)└─────────────┬─────────────┘└─────────────┬─────────────┘ -50)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ -51)│ RepartitionExec ││ DataSourceExec │ -52)│ -------------------- ││ -------------------- │ -53)│ partition_count(in->out): ││ files: 1 │ -54)│ 1 -> 4 ││ format: parquet │ -55)│ ││ │ -56)│ partitioning_scheme: ││ │ -57)│ RoundRobinBatch(4) ││ │ -58)└─────────────┬─────────────┘└───────────────────────────┘ -59)┌─────────────┴─────────────┐ -60)│ DataSourceExec │ -61)│ -------------------- │ -62)│ files: 1 │ -63)│ format: csv │ -64)└───────────────────────────┘ +04)│ bigint_col: │ +05)│ bigint_col │ +06)│ │ +07)│ date_col: date_col │ +08)│ int_col: int_col │ +09)│ │ +10)│ string_col: │ +11)│ string_col │ +12)└─────────────┬─────────────┘ +13)┌─────────────┴─────────────┐ +14)│ CoalesceBatchesExec │ +15)│ -------------------- │ +16)│ target_batch_size: │ +17)│ 8192 │ +18)└─────────────┬─────────────┘ +19)┌─────────────┴─────────────┐ +20)│ HashJoinExec │ +21)│ -------------------- │ +22)│ on: │ +23)│ (int_col = int_col), ├──────────────┐ +24)│ (string_col = CAST │ │ +25)│ (table1.string_col AS │ │ +26)│ Utf8View)) │ │ +27)└─────────────┬─────────────┘ │ +28)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ +29)│ DataSourceExec ││ ProjectionExec │ +30)│ -------------------- ││ -------------------- │ +31)│ files: 1 ││ CAST(table1.string_col AS │ +32)│ format: parquet ││ Utf8View): │ +33)│ ││ CAST(string_col AS │ +34)│ ││ Utf8View) │ +35)│ ││ │ +36)│ ││ bigint_col: │ +37)│ ││ bigint_col │ +38)│ ││ │ +39)│ ││ date_col: date_col │ +40)│ ││ int_col: int_col │ +41)│ ││ │ +42)│ ││ string_col: │ +43)│ ││ string_col │ +44)└───────────────────────────┘└─────────────┬─────────────┘ +45)-----------------------------┌─────────────┴─────────────┐ +46)-----------------------------│ RepartitionExec │ +47)-----------------------------│ -------------------- │ +48)-----------------------------│ partition_count(in->out): │ +49)-----------------------------│ 1 -> 4 │ +50)-----------------------------│ │ +51)-----------------------------│ partitioning_scheme: │ +52)-----------------------------│ RoundRobinBatch(4) │ +53)-----------------------------└─────────────┬─────────────┘ +54)-----------------------------┌─────────────┴─────────────┐ +55)-----------------------------│ DataSourceExec │ +56)-----------------------------│ -------------------- │ +57)-----------------------------│ files: 1 │ +58)-----------------------------│ format: csv │ +59)-----------------------------└───────────────────────────┘ # Query with outer hash join. query TT @@ -1256,71 +1228,66 @@ explain select * from table1 left outer join table2 on table1.int_col = table2.i ---- physical_plan 01)┌───────────────────────────┐ -02)│ CoalesceBatchesExec │ +02)│ ProjectionExec │ 03)│ -------------------- │ -04)│ target_batch_size: │ -05)│ 8192 │ -06)└─────────────┬─────────────┘ -07)┌─────────────┴─────────────┐ -08)│ HashJoinExec │ -09)│ -------------------- │ -10)│ join_type: Left │ -11)│ │ -12)│ on: ├──────────────┐ -13)│ (int_col = int_col), (CAST│ │ -14)│ (table1.string_col AS │ │ -15)│ Utf8View) = │ │ -16)│ string_col) │ │ -17)└─────────────┬─────────────┘ │ -18)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ -19)│ CoalesceBatchesExec ││ CoalesceBatchesExec │ -20)│ -------------------- ││ -------------------- │ -21)│ target_batch_size: ││ target_batch_size: │ -22)│ 8192 ││ 8192 │ -23)└─────────────┬─────────────┘└─────────────┬─────────────┘ -24)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ -25)│ RepartitionExec ││ RepartitionExec │ -26)│ -------------------- ││ -------------------- │ -27)│ partition_count(in->out): ││ partition_count(in->out): │ -28)│ 4 -> 4 ││ 4 -> 4 │ -29)│ ││ │ -30)│ partitioning_scheme: ││ partitioning_scheme: │ -31)│ Hash([int_col@0, CAST ││ Hash([int_col@0, │ -32)│ (table1.string_col ││ string_col@1], │ -33)│ AS Utf8View)@4], 4) ││ 4) │ -34)└─────────────┬─────────────┘└─────────────┬─────────────┘ -35)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ -36)│ ProjectionExec ││ RepartitionExec │ -37)│ -------------------- ││ -------------------- │ -38)│ CAST(table1.string_col AS ││ partition_count(in->out): │ -39)│ Utf8View): ││ 1 -> 4 │ -40)│ CAST(string_col AS ││ │ -41)│ Utf8View) ││ partitioning_scheme: │ -42)│ ││ RoundRobinBatch(4) │ -43)│ bigint_col: ││ │ -44)│ bigint_col ││ │ -45)│ ││ │ -46)│ date_col: date_col ││ │ -47)│ int_col: int_col ││ │ -48)│ ││ │ -49)│ string_col: ││ │ -50)│ string_col ││ │ -51)└─────────────┬─────────────┘└─────────────┬─────────────┘ -52)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ -53)│ RepartitionExec ││ DataSourceExec │ -54)│ -------------------- ││ -------------------- │ -55)│ partition_count(in->out): ││ files: 1 │ -56)│ 1 -> 4 ││ format: parquet │ -57)│ ││ │ -58)│ partitioning_scheme: ││ │ -59)│ RoundRobinBatch(4) ││ │ -60)└─────────────┬─────────────┘└───────────────────────────┘ -61)┌─────────────┴─────────────┐ -62)│ DataSourceExec │ -63)│ -------------------- │ -64)│ files: 1 │ -65)│ format: csv │ -66)└───────────────────────────┘ +04)│ bigint_col: │ +05)│ bigint_col │ +06)│ │ +07)│ date_col: date_col │ +08)│ int_col: int_col │ +09)│ │ +10)│ string_col: │ +11)│ string_col │ +12)└─────────────┬─────────────┘ +13)┌─────────────┴─────────────┐ +14)│ CoalesceBatchesExec │ +15)│ -------------------- │ +16)│ target_batch_size: │ +17)│ 8192 │ +18)└─────────────┬─────────────┘ +19)┌─────────────┴─────────────┐ +20)│ HashJoinExec │ +21)│ -------------------- │ +22)│ join_type: Right │ +23)│ │ +24)│ on: ├──────────────┐ +25)│ (int_col = int_col), │ │ +26)│ (string_col = CAST │ │ +27)│ (table1.string_col AS │ │ +28)│ Utf8View)) │ │ +29)└─────────────┬─────────────┘ │ +30)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ +31)│ DataSourceExec ││ ProjectionExec │ +32)│ -------------------- ││ -------------------- │ +33)│ files: 1 ││ CAST(table1.string_col AS │ +34)│ format: parquet ││ Utf8View): │ +35)│ ││ CAST(string_col AS │ +36)│ ││ Utf8View) │ +37)│ ││ │ +38)│ ││ bigint_col: │ +39)│ ││ bigint_col │ +40)│ ││ │ +41)│ ││ date_col: date_col │ +42)│ ││ int_col: int_col │ +43)│ ││ │ +44)│ ││ string_col: │ +45)│ ││ string_col │ +46)└───────────────────────────┘└─────────────┬─────────────┘ +47)-----------------------------┌─────────────┴─────────────┐ +48)-----------------------------│ RepartitionExec │ +49)-----------------------------│ -------------------- │ +50)-----------------------------│ partition_count(in->out): │ +51)-----------------------------│ 1 -> 4 │ +52)-----------------------------│ │ +53)-----------------------------│ partitioning_scheme: │ +54)-----------------------------│ RoundRobinBatch(4) │ +55)-----------------------------└─────────────┬─────────────┘ +56)-----------------------------┌─────────────┴─────────────┐ +57)-----------------------------│ DataSourceExec │ +58)-----------------------------│ -------------------- │ +59)-----------------------------│ files: 1 │ +60)-----------------------------│ format: csv │ +61)-----------------------------└───────────────────────────┘ # Query with nested loop join. query TT @@ -1339,35 +1306,8 @@ physical_plan 10)│ format: csv ││ │ 11)└───────────────────────────┘└─────────────┬─────────────┘ 12)-----------------------------┌─────────────┴─────────────┐ -13)-----------------------------│ AggregateExec │ -14)-----------------------------│ -------------------- │ -15)-----------------------------│ aggr: count(1) │ -16)-----------------------------│ mode: Final │ -17)-----------------------------└─────────────┬─────────────┘ -18)-----------------------------┌─────────────┴─────────────┐ -19)-----------------------------│ CoalescePartitionsExec │ -20)-----------------------------└─────────────┬─────────────┘ -21)-----------------------------┌─────────────┴─────────────┐ -22)-----------------------------│ AggregateExec │ -23)-----------------------------│ -------------------- │ -24)-----------------------------│ aggr: count(1) │ -25)-----------------------------│ mode: Partial │ -26)-----------------------------└─────────────┬─────────────┘ -27)-----------------------------┌─────────────┴─────────────┐ -28)-----------------------------│ RepartitionExec │ -29)-----------------------------│ -------------------- │ -30)-----------------------------│ partition_count(in->out): │ -31)-----------------------------│ 1 -> 4 │ -32)-----------------------------│ │ -33)-----------------------------│ partitioning_scheme: │ -34)-----------------------------│ RoundRobinBatch(4) │ -35)-----------------------------└─────────────┬─────────────┘ -36)-----------------------------┌─────────────┴─────────────┐ -37)-----------------------------│ DataSourceExec │ -38)-----------------------------│ -------------------- │ -39)-----------------------------│ files: 1 │ -40)-----------------------------│ format: parquet │ -41)-----------------------------└───────────────────────────┘ +13)-----------------------------│ PlaceholderRowExec │ +14)-----------------------------└───────────────────────────┘ # Query with cross join. query TT @@ -1378,20 +1318,11 @@ physical_plan 02)│ CrossJoinExec ├──────────────┐ 03)└─────────────┬─────────────┘ │ 04)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ -05)│ DataSourceExec ││ RepartitionExec │ +05)│ DataSourceExec ││ DataSourceExec │ 06)│ -------------------- ││ -------------------- │ -07)│ files: 1 ││ partition_count(in->out): │ -08)│ format: csv ││ 1 -> 4 │ -09)│ ││ │ -10)│ ││ partitioning_scheme: │ -11)│ ││ RoundRobinBatch(4) │ -12)└───────────────────────────┘└─────────────┬─────────────┘ -13)-----------------------------┌─────────────┴─────────────┐ -14)-----------------------------│ DataSourceExec │ -15)-----------------------------│ -------------------- │ -16)-----------------------------│ files: 1 │ -17)-----------------------------│ format: parquet │ -18)-----------------------------└───────────────────────────┘ +07)│ files: 1 ││ files: 1 │ +08)│ format: csv ││ format: parquet │ +09)└───────────────────────────┘└───────────────────────────┘ # Query with sort merge join. diff --git a/datafusion/sqllogictest/test_files/filter_without_sort_exec.slt b/datafusion/sqllogictest/test_files/filter_without_sort_exec.slt index d96044fda8c0..a09d8ce26ddf 100644 --- a/datafusion/sqllogictest/test_files/filter_without_sort_exec.slt +++ b/datafusion/sqllogictest/test_files/filter_without_sort_exec.slt @@ -34,7 +34,7 @@ ORDER BY "date", "time"; ---- logical_plan 01)Sort: data.date ASC NULLS LAST, data.time ASC NULLS LAST -02)--Filter: data.ticker = Utf8("A") +02)--Filter: data.ticker = Utf8View("A") 03)----TableScan: data projection=[date, ticker, time] physical_plan 01)SortPreservingMergeExec: [date@0 ASC NULLS LAST, time@2 ASC NULLS LAST] @@ -51,7 +51,7 @@ ORDER BY "time" ---- logical_plan 01)Sort: data.time ASC NULLS LAST -02)--Filter: data.ticker = Utf8("A") AND CAST(data.time AS Date32) = data.date +02)--Filter: data.ticker = Utf8View("A") AND CAST(data.time AS Date32) = data.date 03)----TableScan: data projection=[date, ticker, time] physical_plan 01)SortPreservingMergeExec: [time@2 ASC NULLS LAST] @@ -68,7 +68,7 @@ ORDER BY "date" ---- logical_plan 01)Sort: data.date ASC NULLS LAST -02)--Filter: data.ticker = Utf8("A") AND CAST(data.time AS Date32) = data.date +02)--Filter: data.ticker = Utf8View("A") AND CAST(data.time AS Date32) = data.date 03)----TableScan: data projection=[date, ticker, time] physical_plan 01)SortPreservingMergeExec: [date@0 ASC NULLS LAST] @@ -85,7 +85,7 @@ ORDER BY "ticker" ---- logical_plan 01)Sort: data.ticker ASC NULLS LAST -02)--Filter: data.ticker = Utf8("A") AND CAST(data.time AS Date32) = data.date +02)--Filter: data.ticker = Utf8View("A") AND CAST(data.time AS Date32) = data.date 03)----TableScan: data projection=[date, ticker, time] physical_plan 01)CoalescePartitionsExec @@ -102,7 +102,7 @@ ORDER BY "time", "date"; ---- logical_plan 01)Sort: data.time ASC NULLS LAST, data.date ASC NULLS LAST -02)--Filter: data.ticker = Utf8("A") AND CAST(data.time AS Date32) = data.date +02)--Filter: data.ticker = Utf8View("A") AND CAST(data.time AS Date32) = data.date 03)----TableScan: data projection=[date, ticker, time] physical_plan 01)SortPreservingMergeExec: [time@2 ASC NULLS LAST, date@0 ASC NULLS LAST] @@ -120,7 +120,7 @@ ORDER BY "time" ---- logical_plan 01)Sort: data.time ASC NULLS LAST -02)--Filter: data.ticker = Utf8("A") AND CAST(data.time AS Date32) != data.date +02)--Filter: data.ticker = Utf8View("A") AND CAST(data.time AS Date32) != data.date 03)----TableScan: data projection=[date, ticker, time] # no relation between time & date @@ -132,7 +132,7 @@ ORDER BY "time" ---- logical_plan 01)Sort: data.time ASC NULLS LAST -02)--Filter: data.ticker = Utf8("A") +02)--Filter: data.ticker = Utf8View("A") 03)----TableScan: data projection=[date, ticker, time] # query diff --git a/datafusion/sqllogictest/test_files/float16.slt b/datafusion/sqllogictest/test_files/float16.slt new file mode 100644 index 000000000000..5e59c730f078 --- /dev/null +++ b/datafusion/sqllogictest/test_files/float16.slt @@ -0,0 +1,92 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Basic tests Tests for Float16 Type + +statement ok +create table floats as values (1.0), (2.0), (3.0), (NULL), ('Nan'); + +statement ok +create table float16s as select arrow_cast(column1, 'Float16') as column1 from floats; + +query RT +select column1, arrow_typeof(column1) as type from float16s; +---- +1 Float16 +2 Float16 +3 Float16 +NULL Float16 +NaN Float16 + +# Test coercions with arithmetic + +query RRRRRR +SELECT + column1 + 1::tinyint as column1_plus_int8, + column1 + 1::smallint as column1_plus_int16, + column1 + 1::int as column1_plus_int32, + column1 + 1::bigint as column1_plus_int64, + column1 + 1.0::float as column1_plus_float32, + column1 + 1.0 as column1_plus_float64 +FROM float16s; +---- +2 2 2 2 2 2 +3 3 3 3 3 3 +4 4 4 4 4 4 +NULL NULL NULL NULL NULL NULL +NaN NaN NaN NaN NaN NaN + +# Try coercing with literal NULL +query error +select column1 + NULL from float16s; +---- +DataFusion error: type_coercion +caused by +Error during planning: Cannot automatically convert Null to Float16 + + +# Test coercions with equality +query BBBBBB +SELECT + column1 = 1::tinyint as column1_equals_int8, + column1 = 1::smallint as column1_equals_int16, + column1 = 1::int as column1_equals_int32, + column1 = 1::bigint as column1_equals_int64, + column1 = 1.0::float as column1_equals_float32, + column1 = 1.0 as column1_equals_float64 +FROM float16s; +---- +true true true true true true +false false false false false false +false false false false false false +NULL NULL NULL NULL NULL NULL +false false false false false false + + +# Try coercing with literal NULL +query error +select column1 = NULL from float16s; +---- +DataFusion error: Error during planning: Cannot infer common argument type for comparison operation Float16 = Null + + +# Cleanup +statement ok +drop table floats; + +statement ok +drop table float16s; diff --git a/datafusion/sqllogictest/test_files/imdb.slt b/datafusion/sqllogictest/test_files/imdb.slt new file mode 100644 index 000000000000..c17f9c47c745 --- /dev/null +++ b/datafusion/sqllogictest/test_files/imdb.slt @@ -0,0 +1,4040 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file contains IMDB test queries against a small sample dataset. +# The test creates tables with sample data and runs all the IMDB benchmark queries. + +# company_type table +statement ok +CREATE TABLE company_type ( + id INT NOT NULL, + kind VARCHAR NOT NULL +); + +statement ok +INSERT INTO company_type VALUES + (1, 'production companies'), + (2, 'distributors'), + (3, 'special effects companies'), + (4, 'other companies'), + (5, 'miscellaneous companies'), + (6, 'film distributors'), + (7, 'theaters'), + (8, 'sales companies'), + (9, 'producers'), + (10, 'publishers'), + (11, 'visual effects companies'), + (12, 'makeup departments'), + (13, 'costume designers'), + (14, 'movie studios'), + (15, 'sound departments'), + (16, 'talent agencies'), + (17, 'casting companies'), + (18, 'film commissions'), + (19, 'production services'), + (20, 'digital effects studios'); + +# info_type table +statement ok +CREATE TABLE info_type ( + id INT NOT NULL, + info VARCHAR NOT NULL +); + +statement ok +INSERT INTO info_type VALUES + (1, 'runtimes'), + (2, 'color info'), + (3, 'genres'), + (4, 'languages'), + (5, 'certificates'), + (6, 'sound mix'), + (7, 'countries'), + (8, 'top 250 rank'), + (9, 'bottom 10 rank'), + (10, 'release dates'), + (11, 'filming locations'), + (12, 'production companies'), + (13, 'technical info'), + (14, 'trivia'), + (15, 'goofs'), + (16, 'martial-arts'), + (17, 'quotes'), + (18, 'movie connections'), + (19, 'plot description'), + (20, 'biography'), + (21, 'plot summary'), + (22, 'box office'), + (23, 'ratings'), + (24, 'taglines'), + (25, 'keywords'), + (26, 'soundtrack'), + (27, 'votes'), + (28, 'height'), + (30, 'mini biography'), + (31, 'budget'), + (32, 'rating'); + +# title table +statement ok +CREATE TABLE title ( + id INT NOT NULL, + title VARCHAR NOT NULL, + imdb_index VARCHAR, + kind_id INT NOT NULL, + production_year INT, + imdb_id INT, + phonetic_code VARCHAR, + episode_of_id INT, + season_nr INT, + episode_nr INT, + series_years VARCHAR, + md5sum VARCHAR +); + +statement ok +INSERT INTO title VALUES + (1, 'The Shawshank Redemption', NULL, 1, 1994, 111161, NULL, NULL, NULL, NULL, NULL, NULL), + (2, 'The Godfather', NULL, 1, 1985, 68646, NULL, NULL, NULL, NULL, NULL, NULL), + (3, 'The Dark Knight', NULL, 1, 2008, 468569, NULL, NULL, NULL, NULL, NULL, NULL), + (4, 'The Godfather Part II', NULL, 1, 2012, 71562, NULL, NULL, NULL, NULL, NULL, NULL), + (5, 'Pulp Fiction', NULL, 1, 1994, 110912, NULL, NULL, NULL, NULL, NULL, NULL), + (6, 'Schindler''s List', NULL, 1, 1993, 108052, NULL, NULL, NULL, NULL, NULL, NULL), + (7, 'The Lord of the Rings: The Return of the King', NULL, 1, 2003, 167260, NULL, NULL, NULL, NULL, NULL, NULL), + (8, '12 Angry Men', NULL, 1, 1957, 50083, NULL, NULL, NULL, NULL, NULL, NULL), + (9, 'Inception', NULL, 1, 2010, 1375666, NULL, NULL, NULL, NULL, NULL, NULL), + (10, 'Fight Club', NULL, 1, 1999, 137523, NULL, NULL, NULL, NULL, NULL, NULL), + (11, 'The Matrix', NULL, 1, 2014, 133093, NULL, NULL, NULL, NULL, NULL, NULL), + (12, 'Goodfellas', NULL, 1, 1990, 99685, NULL, NULL, NULL, NULL, NULL, NULL), + (13, 'Avengers: Endgame', NULL, 1, 2019, 4154796, NULL, NULL, NULL, NULL, NULL, NULL), + (14, 'Interstellar', NULL, 1, 2014, 816692, NULL, NULL, NULL, NULL, NULL, NULL), + (15, 'The Silence of the Lambs', NULL, 1, 1991, 102926, NULL, NULL, NULL, NULL, NULL, NULL), + (16, 'Saving Private Ryan', NULL, 1, 1998, 120815, NULL, NULL, NULL, NULL, NULL, NULL), + (17, 'The Green Mile', NULL, 1, 1999, 120689, NULL, NULL, NULL, NULL, NULL, NULL), + (18, 'Forrest Gump', NULL, 1, 1994, 109830, NULL, NULL, NULL, NULL, NULL, NULL), + (19, 'Joker', NULL, 1, 2019, 7286456, NULL, NULL, NULL, NULL, NULL, NULL), + (20, 'Parasite', NULL, 1, 2019, 6751668, NULL, NULL, NULL, NULL, NULL, NULL), + (21, 'The Iron Giant', NULL, 1, 1999, 129167, NULL, NULL, NULL, NULL, NULL, NULL), + (22, 'Spider-Man: Into the Spider-Verse', NULL, 1, 2018, 4633694, NULL, NULL, NULL, NULL, NULL, NULL), + (23, 'Iron Man', NULL, 1, 2008, 371746, NULL, NULL, NULL, NULL, NULL, NULL), + (24, 'Black Panther', NULL, 1, 2018, 1825683, NULL, NULL, NULL, NULL, NULL, NULL), + (25, 'Titanic', NULL, 1, 1997, 120338, NULL, NULL, NULL, NULL, NULL, NULL), + (26, 'Kung Fu Panda 2', NULL, 1, 2011, 0441773, NULL, NULL, NULL, NULL, NULL, NULL), + (27, 'Halloween', NULL, 1, 2008, 1311067, NULL, NULL, NULL, NULL, NULL, NULL), + (28, 'Breaking Bad', NULL, 2, 2003, 903254, NULL, NULL, NULL, NULL, NULL, NULL), + (29, 'Breaking Bad: The Final Season', NULL, 2, 2007, 903255, NULL, NULL, NULL, NULL, NULL, NULL), + (30, 'Amsterdam Detective', NULL, 2, 2005, 905001, NULL, NULL, NULL, NULL, NULL, NULL), + (31, 'Amsterdam Detective: Cold Case', NULL, 2, 2007, 905002, NULL, NULL, NULL, NULL, NULL, NULL), + (32, 'Saw IV', NULL, 1, 2007, 905003, NULL, NULL, NULL, NULL, NULL, NULL), + (33, 'Shrek 2', NULL, 1, 2004, 906001, NULL, NULL, NULL, NULL, NULL, NULL), + (35, 'Dark Blood', NULL, 1, 2005, 907001, NULL, NULL, NULL, NULL, NULL, NULL), + (36, 'The Nordic Murders', NULL, 1, 2008, 908002, NULL, NULL, NULL, NULL, NULL, NULL), + (37, 'Scandinavian Crime', NULL, 1, 2009, 909001, NULL, NULL, NULL, NULL, NULL, NULL), + (38, 'The Western Sequel', NULL, 1, 1998, NULL, NULL, NULL, NULL, NULL, NULL, NULL), + (39, 'Marvel Superhero Epic', NULL, 1, 2010, NULL, NULL, NULL, NULL, NULL, NULL, NULL), + (40, 'The Champion', NULL, 1, 2016, 999555, NULL, NULL, NULL, NULL, NULL, NULL), + (41, 'Champion Boxer', NULL, 1, 2018, 999556, NULL, NULL, NULL, NULL, NULL, NULL), + (42, 'Avatar', NULL, 5, 2010, 499549, NULL, NULL, NULL, NULL, NULL, NULL), + (43, 'The Godfather Connection', NULL, 1, 1985, 68647, NULL, NULL, NULL, NULL, NULL, NULL), + (44, 'Digital Connection', NULL, 1, 2005, 888999, NULL, NULL, NULL, NULL, NULL, NULL), + (45, 'Berlin Noir', NULL, 1, 2010, NULL, NULL, NULL, NULL, NULL, NULL, NULL), + (46, 'YouTube Documentary', NULL, 1, 2008, 777999, NULL, NULL, NULL, NULL, NULL, NULL), + (47, 'The Swedish Murder Case', NULL, 1, 2012, 666777, NULL, NULL, NULL, NULL, NULL, NULL), + (48, 'Nordic Noir', NULL, 1, 2015, 555666, NULL, NULL, NULL, NULL, NULL, NULL), + (49, 'Derek Jacobi Story', NULL, 1, 1982, 444555, NULL, NULL, NULL, NULL, NULL, NULL), + (50, 'Woman in Black', NULL, 1, 2010, 987654, NULL, NULL, NULL, NULL, NULL, NULL), + (51, 'Kung Fu Panda', NULL, 1, 2008, 441772, NULL, NULL, NULL, NULL, NULL, NULL), + (52, 'Bruno', NULL, 1, 2009, NULL, NULL, NULL, NULL, NULL, NULL, NULL), + (53, 'Character Series', NULL, 2, 2020, 999888, NULL, NULL, NULL, 55, NULL, NULL), + (54, 'Vampire Chronicles', NULL, 1, 2015, 999999, NULL, NULL, NULL, NULL, NULL, NULL), + (55, 'Alien Invasion', NULL, 1, 2020, 888888, NULL, NULL, NULL, NULL, NULL, NULL), + (56, 'Dragon Warriors', NULL, 1, 2015, 888889, NULL, NULL, NULL, NULL, NULL, NULL), + (57, 'One Piece: Grand Adventure', NULL, 1, 2007, 777777, NULL, NULL, NULL, NULL, NULL, NULL), + (58, 'Moscow Nights', NULL, 1, 2010, 777778, NULL, NULL, NULL, NULL, NULL, NULL), + (59, 'Money Talks', NULL, 1, 1998, 888888, NULL, NULL, NULL, NULL, NULL, NULL), + (60, 'Fox Novel Movie', NULL, 1, 2005, 777888, NULL, NULL, NULL, NULL, NULL, NULL), + (61, 'Bad Movie Sequel', NULL, 1, 2010, 888777, NULL, NULL, NULL, NULL, NULL, NULL); + +# movie_companies table +statement ok +CREATE TABLE movie_companies ( + id INT NOT NULL, + movie_id INT NOT NULL, + company_id INT NOT NULL, + company_type_id INT NOT NULL, + note VARCHAR +); + +statement ok +INSERT INTO movie_companies VALUES + (1, 1, 4, 1, '(presents) (co-production)'), + (2, 2, 5, 1, '(presents)'), + (3, 3, 6, 1, '(co-production)'), + (4, 4, 7, 1, '(as Metro-Goldwyn-Mayer Pictures)'), + (5, 5, 8, 1, '(presents) (co-production)'), + (6, 6, 9, 1, '(presents)'), + (7, 7, 10, 1, '(co-production)'), + (8, 8, 11, 2, '(distributor)'), + (9, 9, 12, 1, '(presents) (co-production)'), + (10, 10, 13, 1, '(presents)'), + (11, 11, 14, 1, '(presents) (co-production)'), + (12, 12, 15, 1, '(presents)'), + (13, 13, 16, 1, '(co-production)'), + (14, 14, 17, 1, '(presents)'), + (15, 15, 18, 1, '(co-production)'), + (16, 16, 19, 1, '(presents)'), + (17, 17, 20, 1, '(co-production)'), + (18, 18, 21, 1, '(presents)'), + (19, 19, 22, 1, '(co-production)'), + (20, 20, 23, 1, '(presents)'), + (21, 21, 24, 1, '(presents) (co-production)'), + (22, 22, 25, 1, '(presents)'), + (23, 23, 26, 1, '(co-production)'), + (24, 24, 27, 1, '(presents)'), + (25, 25, 28, 1, '(presents) (co-production)'), + (26, 3, 35, 1, '(as Warner Bros. Pictures)'), + (27, 9, 35, 1, '(as Warner Bros. Pictures)'), + (28, 23, 14, 1, '(as Marvel Studios)'), + (29, 24, 14, 1, '(as Marvel Studios)'), + (30, 13, 14, 1, '(as Marvel Studios)'), + (31, 26, 23, 1, '(as DreamWorks Animation)'), + (32, 3, 6, 2, '(distributor)'), + (33, 2, 8, 2, '(distributor)'), + (34, 3, 6, 1, '(as Warner Bros.) (2008) (USA) (worldwide)'), + (35, 44, 36, 1, NULL), + (36, 40, 9, 1, '(production) (USA) (2016)'), + (37, 56, 18, 1, '(production)'), + (38, 2, 6, 1, NULL), + (39, 13, 14, 2, '(as Marvel Studios)'), + (40, 19, 25, 1, '(co-production)'), + (41, 23, 26, 1, '(co-production)'), + (42, 19, 27, 1, '(co-production)'), + (43, 11, 18, 1, '(theatrical) (France)'), + (44, 11, 8, 1, '(VHS) (USA) (1994)'), + (45, 11, 4, 1, '(USA)'), + (46, 9, 28, 1, '(co-production)'), + (47, 28, 5, 1, '(production)'), + (48, 29, 5, 1, '(production)'), + (49, 30, 29, 1, '(production)'), + (50, 31, 30, 1, '(production)'), + (51, 27, 22, 1, '(production)'), + (52, 32, 22, 1, '(distribution) (Blu-ray)'), + (53, 33, 31, 1, '(production)'), + (54, 33, 31, 2, '(distribution)'), + (55, 35, 32, 1, NULL), + (56, 36, 33, 1, '(production) (2008)'), + (57, 37, 34, 1, '(production) (2009) (Norway)'), + (58, 38, 35, 1, NULL), + (59, 25, 9, 1, '(production)'), + (60, 52, 19, 1, NULL), + (61, 26, 37, 1, '(voice: English version)'), + (62, 21, 3, 1, '(production) (Japan) (anime)'), + (63, 57, 2, 1, '(production) (Japan) (2007) (anime)'), + (64, 58, 1, 1, '(production) (Russia) (2010)'), + (65, 59, 35, 1, NULL), + (66, 60, 13, 2, '(distribution) (DVD) (US)'), + (67, 61, 14, 1, '(production)'), + (68, 41, 9, 1, '(production) (USA) (2018)'), + (69, 46, 16, 1, '(production) (2008) (worldwide)'), + (70, 51, 31, 1, '(production) (2008) (USA) (worldwide)'), + (71, 45, 32, 1, 'Studio (2000) Berlin'), + (72, 53, 6, 1, '(production) (2020) (USA)'), + (73, 62, 9, 1, '(production) (USA) (2010) (worldwide)'); + +# movie_info_idx table +statement ok +CREATE TABLE movie_info_idx ( + id INT NOT NULL, + movie_id INT NOT NULL, + info_type_id INT NOT NULL, + info VARCHAR NOT NULL, + note VARCHAR +); + +statement ok +INSERT INTO movie_info_idx VALUES + (1, 1, 8, '1', NULL), + (2, 2, 8, '2', NULL), + (3, 3, 8, '3', NULL), + (4, 4, 8, '4', NULL), + (5, 5, 8, '5', NULL), + (6, 6, 8, '6', NULL), + (7, 7, 8, '7', NULL), + (8, 8, 8, '8', NULL), + (9, 9, 8, '9', NULL), + (10, 10, 8, '10', NULL), + (11, 11, 8, '11', NULL), + (12, 12, 8, '12', NULL), + (13, 13, 8, '13', NULL), + (14, 14, 8, '14', NULL), + (15, 15, 8, '15', NULL), + (16, 16, 8, '16', NULL), + (17, 17, 8, '17', NULL), + (18, 18, 8, '18', NULL), + (19, 19, 8, '19', NULL), + (20, 20, 8, '20', NULL), + (21, 21, 8, '21', NULL), + (22, 22, 8, '22', NULL), + (23, 23, 8, '23', NULL), + (24, 24, 8, '24', NULL), + (25, 25, 8, '25', NULL), + (26, 40, 32, '8.6', NULL), + (27, 41, 32, '7.5', NULL), + (28, 45, 32, '6.8', NULL), + (29, 45, 22, '$10,000,000', NULL), + (30, 1, 22, '9.3', NULL), + (31, 2, 22, '9.2', NULL), + (32, 1, 27, '2,345,678', NULL), + (33, 3, 22, '9.0', NULL), + (34, 9, 22, '8.8', NULL), + (35, 23, 22, '8.5', NULL), + (36, 20, 9, '1', NULL), + (37, 25, 9, '2', NULL), + (38, 3, 9, '10', NULL), + (39, 28, 32, '8.2', NULL), + (40, 29, 32, '2.8', NULL), + (41, 30, 32, '8.5', NULL), + (42, 31, 32, '2.5', NULL), + (43, 27, 27, '45000', NULL), + (44, 32, 27, '52000', NULL), + (45, 33, 27, '120000', NULL), + (46, 35, 32, '7.2', NULL), + (47, 36, 32, '7.8', NULL), + (48, 37, 32, '7.5', NULL), + (49, 37, 27, '100000', NULL), + (50, 39, 32, '8.5', NULL), + (51, 54, 27, '1000', NULL), + (52, 3, 3002, '500', NULL), + (53, 3, 999, '9.5', NULL), + (54, 4, 999, '9.1', NULL), + (55, 13, 999, '8.9', NULL), + (56, 3, 32, '9.5', NULL), + (57, 4, 32, '9.1', NULL), + (58, 13, 32, '8.9', NULL), + (59, 4, 32, '9.3', NULL), + (60, 61, 9, '3', NULL), + (61, 35, 22, '8.4', NULL), + (62, 50, 32, '8.5', NULL), + (63, 48, 32, '7.5', NULL), + (64, 48, 27, '85000', NULL), + (65, 47, 32, '7.8', NULL), + (66, 46, 3, 'Documentary', NULL), + (67, 46, 10, 'USA: 2008-05-15', 'internet release'); + +# movie_info table +statement ok +CREATE TABLE movie_info ( + id INT NOT NULL, + movie_id INT NOT NULL, + info_type_id INT NOT NULL, + info VARCHAR NOT NULL, + note VARCHAR +); + +statement ok +INSERT INTO movie_info VALUES + (1, 1, 1, '113', NULL), + (2, 4, 7, 'Germany', NULL), + (3, 3, 7, 'Bulgaria', NULL), + (4, 2, 1, '175', NULL), + (5, 3, 1, '152', NULL), + (6, 4, 1, '202', NULL), + (7, 5, 1, '154', NULL), + (8, 6, 1, '195', NULL), + (9, 7, 1, '201', NULL), + (10, 8, 1, '139', NULL), + (11, 9, 1, '148', NULL), + (12, 10, 1, '139', NULL), + (13, 11, 1, '136', NULL), + (14, 12, 1, '146', NULL), + (15, 13, 1, '181', NULL), + (16, 14, 1, '141', NULL), + (17, 15, 1, '159', NULL), + (18, 16, 1, '150', NULL), + (19, 17, 1, '156', NULL), + (20, 18, 1, '164', NULL), + (21, 19, 1, '122', NULL), + (22, 20, 1, '140', NULL), + (23, 40, 1, '125', NULL), + (24, 21, 1, '86', NULL), + (25, 22, 1, '117', NULL), + (26, 23, 1, '126', NULL), + (27, 24, 1, '134', NULL), + (28, 25, 1, '194', NULL), + (29, 1, 10, '1994-10-14', 'internet release'), + (30, 2, 10, '1972-03-24', 'internet release'), + (31, 3, 10, '2008-07-18', 'internet release'), + (32, 9, 10, '2010-07-16', 'internet release'), + (33, 13, 10, '2019-04-26', 'internet release'), + (34, 23, 10, '2008-05-02', 'internet release'), + (35, 24, 10, '2018-02-16', 'internet release'), + (36, 1, 2, 'Color', NULL), + (37, 3, 2, 'Color', NULL), + (38, 8, 2, 'Black and White', NULL), + (39, 9, 2, 'Color', NULL), + (40, 1, 19, 'Story about hope and redemption', NULL), + (41, 3, 19, 'Batman faces his greatest challenge', NULL), + (42, 19, 19, 'Origin story of the Batman villain', NULL), + (43, 1, 3, 'Drama', NULL), + (44, 3, 3, 'Action', NULL), + (45, 3, 3, 'Crime', NULL), + (46, 3, 3, 'Drama', NULL), + (47, 9, 3, 'Action', NULL), + (48, 9, 3, 'Adventure', NULL), + (49, 9, 3, 'Sci-Fi', NULL), + (50, 23, 3, 'Action', NULL), + (51, 23, 3, 'Adventure', NULL), + (52, 23, 3, 'Sci-Fi', NULL), + (53, 24, 3, 'Action', NULL), + (54, 24, 3, 'Adventure', NULL), + (55, 9, 7, 'Germany', NULL), + (56, 19, 7, 'German', NULL), + (57, 24, 7, 'Germany', NULL), + (58, 13, 7, 'USA', NULL), + (59, 3, 7, 'USA', NULL), + (60, 3, 22, '2343110', NULL), + (61, 3, 27, '2343110', NULL), + (62, 26, 10, 'USA:2011-05-26', NULL), + (63, 19, 20, 'Batman faces his greatest challenge', NULL), + (64, 3, 3, 'Drama', NULL), + (65, 13, 3, 'Action', NULL), + (66, 13, 19, 'Epic conclusion to the Infinity Saga', NULL), + (67, 2, 8, '1972-03-24', 'Released via internet in 2001'), + (68, 13, 4, 'English', NULL), + (69, 13, 3, 'Animation', NULL), + (70, 26, 3, 'Animation', NULL), + (71, 27, 3, '$15 million', NULL), + (72, 27, 3, 'Horror', NULL), + (73, 32, 3, 'Horror', NULL), + (74, 33, 10, 'USA: 2004', NULL), + (75, 33, 3, 'Animation', NULL), + (76, 35, 7, 'Germany', NULL), + (77, 35, 10, '2005-09-15', NULL), + (78, 44, 10, 'USA: 15 May 2005', 'This movie explores internet culture and digital connections that emerged in the early 2000s.'), + (79, 40, 10, '2016-08-12', 'internet release'), + (80, 1, 31, '$25,000,000', NULL), + (81, 45, 7, 'Germany', NULL), + (82, 45, 32, 'Germany', NULL), + (83, 13, 32, '8.5', NULL), + (84, 3, 32, '9.2', NULL), + (85, 3, 102, '9.2', NULL), + (86, 3, 25, 'sequel', NULL), + (87, 3, 102, '9.2', NULL), + (88, 3, 102, '9.2', NULL), + (89, 4, 102, '9.5', NULL), + (90, 33, 102, '8.7', NULL), + (91, 4, 32, '9.5', NULL), + (92, 11, 32, '8.7', NULL), + (93, 3, 32, '9.2', NULL), + (94, 3, 102, '9.2', NULL), + (95, 3, 32, '9.0', NULL), + (96, 26, 32, '8.2', NULL), + (97, 26, 32, '8.5', NULL), + (98, 27, 27, '8231', NULL), + (99, 27, 10, '2008-10-31', NULL), + (100, 13, 1, '182', NULL), + (101, 11, 2, 'Germany', NULL), + (102, 11, 1, '120', NULL), + (103, 3, 3, 'Drama', NULL), + (104, 11, 7, 'USA', NULL), + (105, 11, 7, 'Bulgaria', NULL), + (106, 50, 3, 'Horror', NULL), + (107, 36, 7, 'Sweden', NULL), + (108, 37, 7, 'Norway', NULL), + (109, 38, 7, 'Sweden', NULL), + (110, 54, 3, 'Horror', NULL), + (111, 55, 3, 'Sci-Fi', NULL), + (112, 56, 30, 'Japan:2015-06-15', NULL), + (113, 56, 30, 'USA:2015-07-20', NULL), + (114, 26, 10, 'Japan:2011-05-29', NULL), + (115, 26, 10, 'USA:2011-05-26', NULL), + (116, 61, 31, '$500,000', NULL), + (117, 41, 10, '2018-05-25', 'USA theatrical release'), + (118, 41, 7, 'Germany', 'Filmed on location'), + (119, 48, 7, 'Sweden', 'Filmed on location'), + (120, 48, 10, '2015-06-15', 'theatrical release'), + (121, 48, 3, 'Thriller', NULL), + (122, 47, 7, 'Sweden', 'Principal filming location'), + (123, 47, 10, '2012-09-21', 'theatrical release'), + (124, 47, 3, 'Crime', NULL), + (125, 47, 3, 'Thriller', NULL), + (126, 47, 7, 'Sweden', NULL), + (127, 3, 10, 'USA: 2008-07-14', 'internet release'), + (128, 46, 10, 'USA: 2008-05-15', 'internet release'), + (129, 40, 10, 'USA:\ 2006', 'internet release'), + (130, 51, 10, 'USA: 2008-06-06', 'theatrical release'), + (131, 51, 10, 'Japan: 2007-12-20', 'preview screening'); + +# kind_type table +statement ok +CREATE TABLE kind_type ( + id INT NOT NULL, + kind VARCHAR NOT NULL +); + +statement ok +INSERT INTO kind_type VALUES + (1, 'movie'), + (2, 'tv series'), + (3, 'video movie'), + (4, 'tv movie'), + (5, 'video game'), + (6, 'episode'), + (7, 'documentary'), + (8, 'short movie'), + (9, 'tv mini series'), + (10, 'reality-tv'); + +# cast_info table +statement ok +CREATE TABLE cast_info ( + id INT NOT NULL, + person_id INT NOT NULL, + movie_id INT NOT NULL, + person_role_id INT, + note VARCHAR, + nr_order INT, + role_id INT NOT NULL +); + +statement ok +INSERT INTO cast_info VALUES + (1, 29, 53, NULL, NULL, 1, 1), + (2, 3, 1, 54, NULL, 1, 1), + (3, 3, 1, NULL, '(producer)', 1, 3), + (4, 4, 2, 2, NULL, 1, 1), + (5, 5, 3, 3, NULL, 1, 1), + (6, 6, 4, 4, NULL, 1, 1), + (7, 2, 50, NULL, '(writer)', 1, 4), + (8, 18, 51, 15, '(voice)', 1, 2), + (9, 1, 19, NULL, NULL, 1, 1), + (10, 6, 100, 1985, '(as Special Actor)', 1, 1), + (11, 15, 19, NULL, NULL, 1, 1), + (12, 8, 5, 5, NULL, 1, 1), + (13, 9, 6, 6, NULL, 1, 1), + (14, 10, 7, 7, NULL, 1, 1), + (15, 11, 8, 8, NULL, 1, 1), + (16, 12, 9, 9, NULL, 1, 1), + (17, 13, 10, 10, NULL, 1, 1), + (18, 14, 9, 55, NULL, 1, 1), + (19, 14, 14, 29, NULL, 1, 1), + (20, 27, 58, 28, '(producer)', 1, 1), + (21, 16, 3, 23, '(producer)', 2, 1), + (22, 20, 49, NULL, NULL, 1, 1), + (23, 13, 23, 14, NULL, 1, 1), + (24, 28, 13, NULL, '(costume design)', 1, 7), + (25, 25, 58, 31, '(voice) (uncredited)', 1, 1), + (26, 18, 3, 24, '(voice)', 1, 2), + (27, 29, 26, 24, '(voice)', 1, 2), + (28, 13, 13, 47, '(writer)', 1, 1), + (29, 17, 3, 25, '(producer)', 3, 8), + (30, 18, 3, 11, '(voice)', 1, 2), + (31, 18, 26, 11, '(voice)', 1, 2), + (32, 18, 26, 12, '(voice: original film)', 1, 2), + (33, 22, 27, 12, '(writer)', 4, 8), + (34, 23, 32, 12, '(writer)', 4, 8), + (35, 21, 33, 13, '(voice)', 2, 2), + (36, 21, 33, 13, '(voice: English version)', 2, 2), + (37, 21, 33, 13, '(voice) (uncredited)', 2, 2), + (38, 22, 39, 25, 'Superman', 1, 1), + (39, 22, 39, 26, 'Ironman', 1, 1), + (40, 22, 39, 27, 'Spiderman', 1, 1), + (41, 19, 52, NULL, NULL, 2, 1), + (42, 14, 19, NULL, NULL, 3, 1), + (43, 6, 2, 2, NULL, 1, 1), + (44, 16, 54, NULL, '(writer)', 1, 4), + (45, 24, 55, NULL, '(director)', 1, 8), + (46, 25, 56, 29, '(voice: English version)', 1, 2), + (47, 18, 26, 30, '(voice: English version)', 1, 2), + (48, 26, 21, 24, '(voice: English version)', 1, 2), + (49, 26, 57, 25, '(voice: English version)', 1, 2), + (50, 27, 25, NULL, NULL, 1, 4), + (51, 18, 62, 32, '(voice)', 1, 2); + +# char_name table +statement ok +CREATE TABLE char_name ( + id INT NOT NULL, + name VARCHAR NOT NULL, + imdb_index VARCHAR, + imdb_id INT, + name_pcode_nf VARCHAR, + surname_pcode VARCHAR, + md5sum VARCHAR +); + +statement ok +INSERT INTO char_name VALUES + (1, 'Andy Dufresne', NULL, NULL, NULL, NULL, NULL), + (2, 'Don Vito Corleone', NULL, NULL, NULL, NULL, NULL), + (3, 'Joker', NULL, NULL, NULL, NULL, NULL), + (4, 'Michael Corleone', NULL, NULL, NULL, NULL, NULL), + (5, 'Vincent Vega', NULL, NULL, NULL, NULL, NULL), + (6, 'Oskar Schindler', NULL, NULL, NULL, NULL, NULL), + (7, 'Gandalf', NULL, NULL, NULL, NULL, NULL), + (8, 'Juror 8', NULL, NULL, NULL, NULL, NULL), + (9, 'Cobb', NULL, NULL, NULL, NULL, NULL), + (10, 'Tyler Durden', NULL, NULL, NULL, NULL, NULL), + (11, 'Batman''s Assistant', NULL, NULL, NULL, NULL, NULL), + (12, 'Tiger', NULL, NULL, NULL, NULL, NULL), + (13, 'Queen', NULL, NULL, NULL, NULL, NULL), + (14, 'Iron Man', NULL, NULL, NULL, NULL, NULL), + (15, 'Master Tigress', NULL, NULL, NULL, NULL, NULL), + (16, 'Dom Cobb', NULL, NULL, NULL, NULL, NULL), + (17, 'Rachel Dawes', NULL, NULL, NULL, NULL, NULL), + (18, 'Arthur Fleck', NULL, NULL, NULL, NULL, NULL), + (19, 'Pepper Potts', NULL, NULL, NULL, NULL, NULL), + (20, 'T''Challa', NULL, NULL, NULL, NULL, NULL), + (21, 'Steve Rogers', NULL, NULL, NULL, NULL, NULL), + (22, 'Ellis Boyd Redding', NULL, NULL, NULL, NULL, NULL), + (23, 'Bruce Wayne', NULL, NULL, NULL, NULL, NULL), + (24, 'Tigress', NULL, NULL, NULL, NULL, NULL), + (25, 'Superman', NULL, NULL, NULL, NULL, NULL), + (26, 'Ironman', NULL, NULL, NULL, NULL, NULL), + (27, 'Spiderman', NULL, NULL, NULL, NULL, NULL), + (28, 'Director', NULL, NULL, NULL, NULL, NULL), + (29, 'Tiger Warrior', NULL, NULL, NULL, NULL, NULL), + (30, 'Tigress', NULL, NULL, NULL, NULL, NULL), + (31, 'Nikolai', NULL, NULL, NULL, NULL, NULL), + (32, 'Princess Dragon', NULL, NULL, NULL, NULL, NULL); + +# keyword table +statement ok +CREATE TABLE keyword ( + id INT NOT NULL, + keyword VARCHAR NOT NULL, + phonetic_code VARCHAR +); + +statement ok +INSERT INTO keyword VALUES + (1, 'prison', NULL), + (2, 'mafia', NULL), + (3, 'superhero', NULL), + (4, 'sequel', NULL), + (5, 'crime', NULL), + (6, 'holocaust', NULL), + (7, 'fantasy', NULL), + (8, 'jury', NULL), + (9, 'dream', NULL), + (10, 'fight', NULL), + (11, 'marvel-cinematic-universe', NULL), + (12, 'character-name-in-title', NULL), + (13, 'female-name-in-title', NULL), + (14, 'murder', NULL), + (15, 'noir', NULL), + (16, 'space', NULL), + (17, 'time-travel', NULL), + (18, 'artificial-intelligence', NULL), + (19, 'robot', NULL), + (20, 'alien', NULL), + (21, '10,000-mile-club', NULL), + (22, 'martial-arts', NULL), + (23, 'computer-animation', NULL), + (24, 'violence', NULL), + (25, 'based-on-novel', NULL), + (26, 'nerd', NULL), + (27, 'marvel-comics', NULL), + (28, 'based-on-comic', NULL), + (29, 'superhero-movie', NULL); + +# movie_keyword table +statement ok +CREATE TABLE movie_keyword ( + id INT NOT NULL, + movie_id INT NOT NULL, + keyword_id INT NOT NULL +); + +statement ok +INSERT INTO movie_keyword VALUES + (1, 1, 1), + (2, 2, 2), + (3, 3, 3), + (4, 4, 4), + (5, 5, 5), + (6, 6, 6), + (7, 7, 7), + (8, 8, 8), + (9, 9, 9), + (10, 10, 10), + (11, 3, 5), + (12, 19, 3), + (13, 19, 12), + (14, 23, 11), + (15, 13, 11), + (16, 24, 11), + (17, 11, 1), + (18, 11, 20), + (19, 11, 20), + (20, 14, 16), + (21, 9, 3), + (22, 3, 14), + (23, 25, 13), + (24, 23, 12), + (25, 2, 4), + (26, 23, 19), + (27, 19, 5), + (28, 23, 3), + (29, 23, 28), + (30, 3, 4), + (31, 3, 4), + (32, 2, 4), + (33, 4, 4), + (34, 11, 4), + (35, 3, 3), + (36, 26, 16), + (37, 13, 11), + (38, 13, 3), + (39, 13, 4), + (40, 9, 17), + (41, 9, 18), + (42, 3, 12), + (43, 13, 13), + (44, 26, 21), + (45, 24, 3), + (46, 9, 14), + (47, 2, 4), + (48, 14, 21), + (49, 27, 14), + (50, 32, 14), + (51, 33, 23), + (52, 33, 23), + (55, 35, 24), + (56, 36, 14), + (57, 36, 25), + (58, 35, 4), + (59, 37, 14), + (60, 37, 25), + (61, 45, 24), + (62, 2, 4), + (63, 14, 21), + (64, 27, 14), + (65, 32, 14), + (66, 33, 23), + (67, 33, 23), + (68, 35, 24), + (69, 38, 4), + (70, 39, 3), + (71, 39, 27), + (72, 39, 28), + (73, 39, 29), + (74, 44, 26), + (75, 52, 12), + (76, 54, 14), + (77, 55, 20), + (78, 55, 16), + (79, 56, 22), + (80, 26, 22), + (81, 3, 4), + (82, 4, 4), + (83, 13, 4), + (84, 3, 4), + (85, 40, 29), + (86, 4, 4), + (87, 13, 4), + (88, 59, 4), + (89, 60, 25), + (90, 48, 14), + (91, 47, 14), + (92, 45, 24), + (93, 46, 3), + (94, 53, 12); + +# company_name table +statement ok +CREATE TABLE company_name ( + id INT NOT NULL, + name VARCHAR NOT NULL, + country_code VARCHAR, + imdb_id INT, + name_pcode_nf VARCHAR, + name_pcode_sf VARCHAR, + md5sum VARCHAR +); + +statement ok +INSERT INTO company_name VALUES + (1, 'Mosfilm', '[ru]', NULL, NULL, NULL, NULL), + (2, 'Toei Animation', '[jp]', NULL, NULL, NULL, NULL), + (3, 'Tokyo Animation Studio', '[jp]', NULL, NULL, NULL, NULL), + (4, 'Castle Rock Entertainment', '[us]', NULL, NULL, NULL, NULL), + (5, 'Paramount Pictures', '[us]', NULL, NULL, NULL, NULL), + (6, 'Warner Bros.', '[us]', NULL, NULL, NULL, NULL), + (7, 'Metro-Goldwyn-Mayer', '[us]', NULL, NULL, NULL, NULL), + (8, 'Miramax Films', '[us]', NULL, NULL, NULL, NULL), + (9, 'Universal Pictures', '[us]', NULL, NULL, NULL, NULL), + (10, 'New Line Cinema', '[us]', NULL, NULL, NULL, NULL), + (11, 'United Artists', '[us]', NULL, NULL, NULL, NULL), + (12, 'Columbia Pictures', '[us]', NULL, NULL, NULL, NULL), + (13, 'Twentieth Century Fox', '[us]', NULL, NULL, NULL, NULL), + (14, 'Marvel Studios', '[us]', NULL, NULL, NULL, NULL), + (15, 'DC Films', '[us]', NULL, NULL, NULL, NULL), + (16, 'YouTube', '[us]', NULL, NULL, NULL, NULL), + (17, 'DreamWorks Pictures', '[us]', NULL, NULL, NULL, NULL), + (18, 'Walt Disney Pictures', '[us]', NULL, NULL, NULL, NULL), + (19, 'Netflix', '[us]', NULL, NULL, NULL, NULL), + (20, 'Amazon Studios', '[us]', NULL, NULL, NULL, NULL), + (21, 'A24', '[us]', NULL, NULL, NULL, NULL), + (22, 'Lionsgate Films', '[us]', NULL, NULL, NULL, NULL), + (23, 'DreamWorks Animation', '[us]', NULL, NULL, NULL, NULL), + (24, 'Sony Pictures', '[us]', NULL, NULL, NULL, NULL), + (25, 'Bavaria Film', '[de]', NULL, NULL, NULL, NULL), + (26, 'Dutch FilmWorks', '[nl]', NULL, NULL, NULL, NULL), + (27, 'San Marino Films', '[sm]', NULL, NULL, NULL, NULL), + (28, 'Legendary Pictures', '[us]', NULL, NULL, NULL, NULL), + (29, 'Dutch Entertainment Group', '[nl]', NULL, NULL, NULL, NULL), + (30, 'Amsterdam Studios', '[nl]', NULL, NULL, NULL, NULL), + (31, 'DreamWorks Animation', '[us]', NULL, NULL, NULL, NULL), + (32, 'Berlin Film Studio', '[de]', NULL, NULL, NULL, NULL), + (33, 'Stockholm Productions', '[se]', NULL, NULL, NULL, NULL), + (34, 'Oslo Films', '[no]', NULL, NULL, NULL, NULL), + (35, 'Warner Bros. Pictures', '[us]', NULL, NULL, NULL, NULL), + (36, 'Silicon Entertainment', '[us]', NULL, NULL, NULL, NULL), + (37, 'DreamWorks Animation', '[us]', NULL, NULL, NULL, NULL); + +# name table for actors/directors information +statement ok +CREATE TABLE name ( + id INT NOT NULL, + name VARCHAR NOT NULL, + imdb_index VARCHAR, + imdb_id INT, + gender VARCHAR, + name_pcode_cf VARCHAR, + name_pcode_nf VARCHAR, + surname_pcode VARCHAR, + md5sum VARCHAR +); + +statement ok +INSERT INTO name VALUES + (1, 'Xavier Thompson', NULL, NULL, 'm', NULL, NULL, NULL, NULL), + (2, 'Susan Hill', NULL, NULL, 'f', NULL, NULL, NULL, NULL), + (3, 'Tim Robbins', NULL, NULL, 'm', NULL, NULL, NULL, NULL), + (4, 'Marlon Brando', NULL, NULL, 'm', NULL, NULL, NULL, NULL), + (5, 'Heath Ledger', NULL, NULL, 'm', NULL, NULL, NULL, NULL), + (6, 'Al Pacino', NULL, NULL, 'm', 'A', NULL, NULL, NULL), + (7, 'Downey Pacino', NULL, NULL, 'm', 'D', NULL, NULL, NULL), + (8, 'John Travolta', NULL, NULL, 'm', NULL, NULL, NULL, NULL), + (9, 'Liam Neeson', NULL, NULL, 'm', NULL, NULL, NULL, NULL), + (10, 'Ian McKellen', NULL, NULL, 'm', NULL, NULL, NULL, NULL), + (11, 'Henry Fonda', NULL, NULL, 'm', NULL, NULL, NULL, NULL), + (12, 'Leonardo DiCaprio', NULL, NULL, 'm', NULL, NULL, NULL, NULL), + (13, 'Downey Robert Jr.', NULL, NULL, 'm', NULL, NULL, NULL, NULL), + (14, 'Zach Wilson', NULL, NULL, 'm', NULL, NULL, NULL, NULL), + (15, 'Bert Wilson', NULL, NULL, 'm', NULL, NULL, NULL, NULL), + (29, 'Alex Morgan', NULL, NULL, 'm', NULL, NULL, NULL, NULL), + (16, 'Christian Bale', NULL, NULL, 'm', NULL, NULL, NULL, NULL), + (17, 'Christopher Nolan', NULL, NULL, 'm', NULL, NULL, NULL, NULL), + (18, 'Angelina Jolie', NULL, NULL, 'f', NULL, NULL, NULL, NULL), + (19, 'Brad Wilson', NULL, NULL, 'm', NULL, NULL, NULL, NULL), + (20, 'Derek Jacobi', NULL, NULL, 'm', 'D624', NULL, NULL, NULL), + (21, 'Anne Hathaway', NULL, NULL, 'f', NULL, NULL, NULL, NULL), + (22, 'John Carpenter', NULL, NULL, 'm', NULL, NULL, NULL, NULL), + (23, 'James Wan', NULL, NULL, 'm', NULL, NULL, NULL, NULL), + (24, 'Ridley Scott', NULL, NULL, 'm', NULL, NULL, NULL, NULL), + (25, 'Angelina Jolie', NULL, NULL, 'f', NULL, NULL, NULL, NULL), + (26, 'Yoko Tanaka', NULL, NULL, 'f', NULL, NULL, NULL, NULL), + (27, 'James Cameron', NULL, NULL, 'm', NULL, NULL, NULL, NULL), + (28, 'Edith Head', NULL, NULL, 'f', NULL, NULL, NULL, NULL), + (29, 'Anne Hathaway', NULL, NULL, 'f', NULL, NULL, NULL, NULL); + +# aka_name table +statement ok +CREATE TABLE aka_name ( + id INT NOT NULL, + person_id INT NOT NULL, + name VARCHAR NOT NULL, + imdb_index VARCHAR, + name_pcode_cf VARCHAR, + name_pcode_nf VARCHAR, + surname_pcode VARCHAR, + md5sum VARCHAR +); + +statement ok +INSERT INTO aka_name VALUES + (1, 2, 'Marlon Brando Jr.', NULL, NULL, NULL, NULL, NULL), + (2, 2, 'Marlon Brando', NULL, NULL, NULL, NULL, NULL), + (3, 3, 'Heath Andrew Ledger', NULL, NULL, NULL, NULL, NULL), + (4, 6, 'Alfredo James Pacino', NULL, NULL, NULL, NULL, NULL), + (5, 5, 'John Joseph Travolta', NULL, NULL, NULL, NULL, NULL), + (6, 6, 'Liam John Neeson', NULL, NULL, NULL, NULL, NULL), + (7, 7, 'Ian Murray McKellen', NULL, NULL, NULL, NULL, NULL), + (8, 8, 'Henry Jaynes Fonda', NULL, NULL, NULL, NULL, NULL), + (9, 9, 'Leonardo Wilhelm DiCaprio', NULL, NULL, NULL, NULL, NULL), + (10, 10, 'Robert John Downey Jr.', NULL, NULL, NULL, NULL, NULL), + (11, 16, 'Christian Charles Philip Bale', NULL, NULL, NULL, NULL, NULL), + (12, 29, 'Christopher Jonathan James Nolan', NULL, NULL, NULL, NULL, NULL), + (13, 47, 'Joaquin Rafael Bottom', NULL, NULL, NULL, NULL, NULL), + (14, 26, 'Yoko Shimizu', NULL, NULL, NULL, NULL, NULL), + (15, 48, 'Chadwick Aaron Boseman', NULL, NULL, NULL, NULL, NULL), + (16, 29, 'Scarlett Ingrid Johansson', NULL, NULL, NULL, NULL, NULL), + (17, 31, 'Christopher Robert Evans', NULL, NULL, NULL, NULL, NULL), + (18, 32, 'Christopher Hemsworth', NULL, NULL, NULL, NULL, NULL), + (19, 33, 'Mark Alan Ruffalo', NULL, NULL, NULL, NULL, NULL), + (20, 20, 'Sir Derek Jacobi', NULL, NULL, NULL, NULL, NULL), + (21, 34, 'Samuel Leroy Jackson', NULL, NULL, NULL, NULL, NULL), + (22, 35, 'Gwyneth Kate Paltrow', NULL, NULL, NULL, NULL, NULL), + (23, 36, 'Thomas William Hiddleston', NULL, NULL, NULL, NULL, NULL), + (24, 37, 'Morgan Porterfield Freeman', NULL, NULL, NULL, NULL, NULL), + (25, 38, 'William Bradley Pitt', NULL, NULL, NULL, NULL, NULL), + (26, 39, 'Edward John Norton Jr.', NULL, NULL, NULL, NULL, NULL), + (27, 40, 'Marion Cotillard', NULL, NULL, NULL, NULL, NULL), + (28, 41, 'Joseph Leonard Gordon-Levitt', NULL, NULL, NULL, NULL, NULL), + (29, 42, 'Matthew David McConaughey', NULL, NULL, NULL, NULL, NULL), + (30, 43, 'Anne Jacqueline Hathaway', NULL, NULL, NULL, NULL, NULL), + (31, 44, 'Kevin Feige', NULL, NULL, NULL, NULL, NULL), + (32, 45, 'Margaret Ruth Gyllenhaal', NULL, NULL, NULL, NULL, NULL), + (33, 46, 'Kate Elizabeth Winslet', NULL, NULL, NULL, NULL, NULL), + (34, 28, 'E. Head', NULL, NULL, NULL, NULL, NULL), + (35, 29, 'Anne Jacqueline Hathaway', NULL, NULL, NULL, NULL, NULL), + (36, 29, 'Alexander Morgan', NULL, NULL, NULL, NULL, NULL), + (37, 2, 'Brando, M.', NULL, NULL, NULL, NULL, NULL), + (38, 21, 'Annie Hathaway', NULL, NULL, NULL, NULL, NULL), + (39, 21, 'Annie H', NULL, NULL, NULL, NULL, NULL), + (40, 25, 'Angie Jolie', NULL, NULL, NULL, NULL, NULL), + (41, 27, 'Jim Cameron', NULL, NULL, NULL, NULL, NULL), + (42, 18, 'Angelina Jolie', NULL, NULL, NULL, NULL, NULL); + +# role_type table +statement ok +CREATE TABLE role_type ( + id INT NOT NULL, + role VARCHAR NOT NULL +); + +statement ok +INSERT INTO role_type VALUES + (1, 'actor'), + (2, 'actress'), + (3, 'producer'), + (4, 'writer'), + (5, 'cinematographer'), + (6, 'composer'), + (7, 'costume designer'), + (8, 'director'), + (9, 'editor'), + (10, 'miscellaneous crew'); + +# link_type table +statement ok +CREATE TABLE link_type ( + id INT NOT NULL, + link VARCHAR NOT NULL +); + +statement ok +INSERT INTO link_type VALUES + (1, 'sequel'), + (2, 'follows'), + (3, 'remake of'), + (4, 'version of'), + (5, 'spin off from'), + (6, 'reference to'), + (7, 'featured in'), + (8, 'spoofed in'), + (9, 'edited into'), + (10, 'alternate language version of'), + (11, 'features'); + +# movie_link table +statement ok +CREATE TABLE movie_link ( + id INT NOT NULL, + movie_id INT NOT NULL, + linked_movie_id INT NOT NULL, + link_type_id INT NOT NULL +); + +statement ok +INSERT INTO movie_link VALUES + (1, 2, 4, 1), + (2, 3, 5, 6), + (3, 6, 7, 4), + (4, 8, 9, 8), + (5, 10, 1, 3), + (6, 28, 29, 1), + (7, 30, 31, 2), + (8, 1, 3, 6), + (9, 23, 13, 1), + (10, 13, 24, 2), + (11, 20, 3, 1), + (12, 3, 22, 1), + (13, 2, 4, 2), + (14, 19, 19, 6), + (15, 14, 16, 6), + (16, 13, 23, 2), + (17, 25, 9, 4), + (18, 17, 1, 8), + (19, 24, 23, 2), + (20, 21, 22, 1), + (21, 15, 9, 6), + (22, 11, 13, 1), + (23, 13, 11, 2), + (24, 100, 100, 7), + (25, 1, 2, 7), + (26, 23, 2, 7), + (27, 14, 25, 9), + (28, 4, 6, 4), + (29, 5, 8, 6), + (30, 7, 10, 6), + (31, 9, 2, 8), + (32, 38, 39, 2), + (33, 59, 5, 2), + (34, 60, 9, 2), + (35, 49, 49, 11), + (36, 35, 36, 2); + +# complete_cast table +statement ok +CREATE TABLE complete_cast ( + id INT NOT NULL, + movie_id INT NOT NULL, + subject_id INT NOT NULL, + status_id INT NOT NULL +); + +statement ok +INSERT INTO complete_cast VALUES + (1, 1, 1, 1), + (2, 2, 1, 1), + (3, 3, 1, 1), + (4, 4, 1, 1), + (5, 5, 1, 1), + (6, 6, 1, 1), + (7, 7, 1, 1), + (8, 8, 1, 1), + (9, 9, 1, 1), + (10, 10, 1, 1), + (11, 11, 1, 1), + (12, 12, 1, 1), + (13, 13, 1, 1), + (14, 14, 1, 1), + (15, 15, 1, 1), + (16, 16, 1, 1), + (17, 17, 1, 1), + (18, 18, 1, 1), + (19, 19, 1, 2), + (20, 20, 2, 1), + (21, 21, 1, 1), + (22, 22, 1, 1), + (23, 23, 1, 3), + (24, 24, 1, 1), + (25, 25, 1, 1), + (26, 26, 1, 1), + (27, 13, 2, 4), + (28, 44, 1, 4), + (29, 33, 1, 4), + (30, 31, 1, 1), + (31, 32, 1, 4), + (32, 33, 1, 4), + (33, 35, 2, 3), + (34, 36, 2, 3), + (35, 37, 1, 4), + (36, 37, 1, 3), + (37, 38, 1, 3), + (38, 39, 1, 3), + (39, 39, 1, 11), + (40, 40, 1, 4); + +# comp_cast_type table +statement ok +CREATE TABLE comp_cast_type ( + id INT NOT NULL, + kind VARCHAR NOT NULL +); + +statement ok +INSERT INTO comp_cast_type VALUES + (1, 'cast'), + (2, 'crew'), + (3, 'complete'), + (4, 'complete+verified'), + (5, 'pending'), + (6, 'unverified'), + (7, 'uncredited cast'), + (8, 'uncredited crew'), + (9, 'unverified cast'), + (10, 'unverified crew'), + (11, 'complete cast'); + +# person_info table +statement ok +CREATE TABLE person_info ( + id INT NOT NULL, + person_id INT NOT NULL, + info_type_id INT NOT NULL, + info VARCHAR NOT NULL, + note VARCHAR +); + +statement ok +INSERT INTO person_info VALUES + (1, 1, 3, 'actor,producer', NULL), + (2, 2, 3, 'actor,director', NULL), + (3, 3, 3, 'actor', NULL), + (4, 6, 3, 'actor,producer', NULL), + (5, 5, 3, 'actor', NULL), + (6, 6, 3, 'actor', NULL), + (7, 7, 3, 'actor', NULL), + (8, 8, 3, 'actor', NULL), + (9, 20, 30, 'Renowned Shakespearean actor and stage performer', 'Volker Boehm'), + (10, 10, 3, 'actor,producer', 'marvel-cinematic-universe'), + (11, 3, 1, 'Won Academy Award for portrayal of Joker', NULL), + (12, 10, 1, 'Played Iron Man in the Marvel Cinematic Universe', NULL), + (13, 16, 3, 'actor', NULL), + (14, 16, 1, 'Played Batman in The Dark Knight trilogy', NULL), + (15, 29, 3, 'director,producer,writer', NULL), + (16, 29, 1, 'Directed The Dark Knight trilogy', NULL), + (17, 47, 3, 'actor', NULL), + (18, 47, 1, 'Won Academy Award for portrayal of Joker', NULL), + (19, 48, 3, 'actor', NULL), + (20, 48, 1, 'Played Black Panther in the Marvel Cinematic Universe', NULL), + (21, 29, 3, 'actress', NULL), + (22, 29, 1, 'Played Black Widow in the Marvel Cinematic Universe', NULL), + (23, 31, 3, 'actor', NULL), + (24, 31, 1, 'Played Captain America in the Marvel Cinematic Universe', NULL), + (25, 32, 3, 'actor', NULL), + (26, 32, 1, 'Played Thor in the Marvel Cinematic Universe', NULL), + (27, 9, 1, 'Won Academy Award for The Revenant', NULL), + (28, 9, 7, '1974-11-11', NULL), + (29, 10, 7, '1965-04-04', NULL), + (30, 16, 7, '1974-01-30', NULL), + (31, 47, 7, '1974-10-28', NULL), + (32, 48, 7, '1976-11-29', NULL), + (33, 29, 7, '1984-11-22', NULL), + (34, 31, 7, '1981-06-13', NULL), + (35, 32, 7, '1983-08-11', NULL), + (36, 21, 14, 'Won an Oscar for Les Miserables.', 'IMDB staff'), + (37, 21, 14, 'Voiced Queen in Shrek 2.', 'IMDB staff'), + (38, 21, 28, '5 ft 8 in (1.73 m)', 'IMDB staff'), + (39, 6, 30, 'Famous for his role in The Godfather', 'Volker Boehm'); + +# aka_title table +statement ok +CREATE TABLE aka_title ( + id INT NOT NULL, + movie_id INT NOT NULL, + title VARCHAR, + imdb_index VARCHAR, + kind_id INT NOT NULL, + production_year INT, + phonetic_code VARCHAR, + episode_of_id INT, + season_nr INT, + episode_nr INT, + note VARCHAR, + md5sum VARCHAR +); + +statement ok +INSERT INTO aka_title VALUES + (1, 1, 'Shawshank', NULL, 1, 1994, NULL, NULL, NULL, NULL, NULL, NULL), + (2, 2, 'Der Pate', NULL, 1, 1972, NULL, NULL, NULL, NULL, 'German title', NULL), + (3, 3, 'The Dark Knight', NULL, 1, 2008, NULL, NULL, NULL, NULL, NULL, NULL), + (4, 4, 'Der Pate II', NULL, 1, 1974, NULL, NULL, NULL, NULL, 'German title', NULL), + (5, 5, 'Pulp Fiction', NULL, 1, 1994, NULL, NULL, NULL, NULL, NULL, NULL), + (6, 6, 'La lista di Schindler', NULL, 1, 1993, NULL, NULL, NULL, NULL, 'Italian title', NULL), + (7, 7, 'LOTR: ROTK', NULL, 1, 2003, NULL, NULL, NULL, NULL, 'Abbreviated', NULL), + (8, 8, '12 Angry Men', NULL, 1, 1957, NULL, NULL, NULL, NULL, NULL, NULL), + (9, 9, 'Dream Heist', NULL, 1, 2010, NULL, NULL, NULL, NULL, 'Working title', NULL), + (10, 10, 'Fight Club', NULL, 1, 1999, NULL, NULL, NULL, NULL, NULL, NULL), + (11, 3, 'Batman: The Dark Knight', NULL, 1, 2008, NULL, NULL, NULL, NULL, 'Full title', NULL), + (12, 13, 'Avengers 4', NULL, 1, 2019, NULL, NULL, NULL, NULL, 'Abbreviated', NULL), + (13, 19, 'The Joker', NULL, 1, 2019, NULL, NULL, NULL, NULL, 'Working title', NULL), + (14, 23, 'Iron Man: Birth of a Hero', NULL, 1, 2008, NULL, NULL, NULL, NULL, 'Extended title', NULL), + (15, 24, 'Black Panther: Wakanda Forever', NULL, 1, 2018, NULL, NULL, NULL, NULL, 'Alternate title', NULL), + (16, 11, 'Avengers 3', NULL, 1, 2018, NULL, NULL, NULL, NULL, 'Abbreviated', NULL), + (17, 3, 'Batman 2', NULL, 1, 2008, NULL, NULL, NULL, NULL, 'Sequel numbering', NULL), + (18, 20, 'Batman: Year One', NULL, 1, 2005, NULL, NULL, NULL, NULL, 'Working title', NULL), + (19, 14, 'Journey to the Stars', NULL, 1, 2014, NULL, NULL, NULL, NULL, 'Working title', NULL), + (20, 25, 'Rose and Jack', NULL, 1, 1997, NULL, NULL, NULL, NULL, 'Character-based title', NULL), + (21, 19, 'Joker: A Descent Into Madness', NULL, 1, 2019, NULL, NULL, NULL, NULL, 'Extended title', NULL), + (22, 22, 'Batman 3', NULL, 1, 2012, NULL, NULL, NULL, NULL, 'Sequel numbering', NULL), + (23, 1, 'The Shawshank Redemption', NULL, 1, 1994, NULL, NULL, NULL, NULL, 'Full title', NULL), + (24, 19, 'El Joker', NULL, 1, 2019, NULL, NULL, NULL, NULL, 'Spanish title', NULL), + (25, 13, 'Los Vengadores: Endgame', NULL, 1, 2019, NULL, NULL, NULL, NULL, 'Spanish title', NULL), + (26, 19, 'The Batman', NULL, 1, 2022, NULL, NULL, NULL, NULL, 'Working title', NULL), + (27, 41, 'Champion Boxer: The Rise of a Legend', NULL, 1, 2018, NULL, NULL, NULL, NULL, 'Extended title', NULL), + (28, 47, 'The Swedish Murder Case', NULL, 1, 2012, NULL, NULL, NULL, NULL, 'Full title', NULL), + (29, 46, 'Viral Documentary', NULL, 1, 2008, NULL, NULL, NULL, NULL, 'Alternate title', NULL), + (30, 45, 'Berlin Noir', NULL, 1, 2010, 989898, NULL, NULL, NULL, NULL, NULL), + (31, 44, 'Digital Connection', NULL, 1, 2005, NULL, NULL, NULL, NULL, NULL, NULL), + (32, 62, 'Animated Feature', NULL, 1, 2010, 123456, NULL, NULL, NULL, NULL, NULL); + +# 1a - Query with production companies and top 250 rank +query TTI +SELECT MIN(mc.note) AS production_note, MIN(t.title) AS movie_title, MIN(t.production_year) AS movie_year +FROM company_type AS ct, info_type AS it, movie_companies AS mc, movie_info_idx AS mi_idx, title AS t +WHERE ct.kind = 'production companies' + AND it.info = 'top 250 rank' + AND mc.note not like '%(as Metro-Goldwyn-Mayer Pictures)%' and (mc.note like '%(co-production)%' or mc.note like '%(presents)%') + AND ct.id = mc.company_type_id + AND t.id = mc.movie_id + AND t.id = mi_idx.movie_id + AND mc.movie_id = mi_idx.movie_id + AND it.id = mi_idx.info_type_id +---- +(co-production) Avengers: Endgame 1985 + +# 1b - Query with production companies and bottom 10 rank +query TTI +SELECT MIN(mc.note) AS production_note, MIN(t.title) AS movie_title, MIN(t.production_year) AS movie_year +FROM company_type AS ct, info_type AS it, movie_companies AS mc, movie_info_idx AS mi_idx, title AS t +WHERE ct.kind = 'production companies' + AND it.info = 'bottom 10 rank' + AND mc.note not like '%(as Metro-Goldwyn-Mayer Pictures)%' + AND t.production_year between 2005 and 2010 + AND ct.id = mc.company_type_id + AND t.id = mc.movie_id + AND t.id = mi_idx.movie_id + AND mc.movie_id = mi_idx.movie_id + AND it.id = mi_idx.info_type_id +---- +(as Warner Bros. Pictures) Bad Movie Sequel 2008 + +# 1c - Query with distributors and top 250 rank +query TTI +SELECT MIN(mc.note) AS production_note, MIN(t.title) AS movie_title, MIN(t.production_year) AS movie_year +FROM company_type AS ct, info_type AS it, movie_companies AS mc, movie_info_idx AS mi_idx, title AS t +WHERE ct.kind = 'production companies' + AND it.info = 'top 250 rank' + AND mc.note not like '%(as Metro-Goldwyn-Mayer Pictures)%' and (mc.note like '%(co-production)%') + AND t.production_year >2010 + AND ct.id = mc.company_type_id + AND t.id = mc.movie_id + AND t.id = mi_idx.movie_id + AND mc.movie_id = mi_idx.movie_id + AND it.id = mi_idx.info_type_id +---- +(co-production) Avengers: Endgame 2014 + +# 1d - Query with production companies and top 250 rank (different production year) +query TTI +SELECT MIN(mc.note) AS production_note, MIN(t.title) AS movie_title, MIN(t.production_year) AS movie_year +FROM company_type AS ct, info_type AS it, movie_companies AS mc, movie_info_idx AS mi_idx, title AS t +WHERE ct.kind = 'production companies' + AND it.info = 'bottom 10 rank' + AND mc.note not like '%(as Metro-Goldwyn-Mayer Pictures)%' + AND t.production_year >2000 + AND ct.id = mc.company_type_id + AND t.id = mc.movie_id + AND t.id = mi_idx.movie_id + AND mc.movie_id = mi_idx.movie_id + AND it.id = mi_idx.info_type_id +---- +(as Warner Bros. Pictures) Bad Movie Sequel 2008 + +# 2a - Query with German companies and character-name-in-title +query T +SELECT MIN(t.title) AS movie_title +FROM company_name AS cn, keyword AS k, movie_companies AS mc, movie_keyword AS mk, title AS t +WHERE cn.country_code ='[de]' + AND k.keyword ='character-name-in-title' + AND cn.id = mc.company_id + AND mc.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND mc.movie_id = mk.movie_id +---- +Joker + +# 2b - Query with Dutch companies and character-name-in-title +query T +SELECT MIN(t.title) AS movie_title +FROM company_name AS cn, keyword AS k, movie_companies AS mc, movie_keyword AS mk, title AS t +WHERE cn.country_code ='[nl]' + AND k.keyword ='character-name-in-title' + AND cn.id = mc.company_id + AND mc.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND mc.movie_id = mk.movie_id +---- +Iron Man + +# 2c - Query with Slovenian companies and female name in title +query T +SELECT MIN(t.title) AS movie_title +FROM company_name AS cn, keyword AS k, movie_companies AS mc, movie_keyword AS mk, title AS t +WHERE cn.country_code ='[sm]' + AND k.keyword ='character-name-in-title' + AND cn.id = mc.company_id + AND mc.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND mc.movie_id = mk.movie_id +---- +Joker + +# 2d - Query with US companies and murder movies +query T +SELECT MIN(t.title) AS movie_title +FROM company_name AS cn, keyword AS k, movie_companies AS mc, movie_keyword AS mk, title AS t +WHERE cn.country_code ='[us]' + AND k.keyword ='character-name-in-title' + AND cn.id = mc.company_id + AND mc.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND mc.movie_id = mk.movie_id +---- +Bruno + +# 3a - Query with runtimes > 100 +query T +SELECT MIN(t.title) AS movie_title +FROM keyword AS k, movie_info AS mi, movie_keyword AS mk, title AS t +WHERE k.keyword like '%sequel%' + AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Denish', 'Norwegian', 'German') + AND t.production_year > 2005 + AND t.id = mi.movie_id + AND t.id = mk.movie_id + AND mk.movie_id = mi.movie_id + AND k.id = mk.keyword_id +---- +The Godfather Part II + +# 3b - Query with Bulgarian movies +query T +SELECT MIN(t.title) AS movie_title +FROM keyword AS k, movie_info AS mi, movie_keyword AS mk, title AS t +WHERE k.keyword like '%sequel%' + AND mi.info IN ('Bulgaria') + AND t.production_year > 2000 + AND t.id = mi.movie_id + AND t.id = mk.movie_id + AND mk.movie_id = mi.movie_id + AND k.id = mk.keyword_id +---- +The Dark Knight + +# 3c - Query with biographies +query T +SELECT MIN(t.title) AS movie_title +FROM keyword AS k, movie_info AS mi, movie_keyword AS mk, title AS t +WHERE k.keyword like '%sequel%' + AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Denish', 'Norwegian', 'German', 'USA', 'American') + AND t.production_year > 1990 + AND t.id = mi.movie_id + AND t.id = mk.movie_id + AND mk.movie_id = mi.movie_id + AND k.id = mk.keyword_id +---- +Avengers: Endgame + +# 4a - Query with certain actor names +query TT +SELECT MIN(mi_idx.info) AS rating, MIN(t.title) AS movie_title +FROM info_type AS it, keyword AS k, movie_info_idx AS mi_idx, movie_keyword AS mk, title AS t +WHERE it.info ='rating' + AND k.keyword like '%sequel%' + AND mi_idx.info > '5.0' + AND t.production_year > 2005 + AND t.id = mi_idx.movie_id + AND t.id = mk.movie_id + AND mk.movie_id = mi_idx.movie_id + AND k.id = mk.keyword_id + AND it.id = mi_idx.info_type_id +---- +8.9 Avengers: Endgame + +# 4b - Query with certain actor names (revised) +query TT +SELECT MIN(mi_idx.info) AS rating, MIN(t.title) AS movie_title +FROM info_type AS it, keyword AS k, movie_info_idx AS mi_idx, movie_keyword AS mk, title AS t +WHERE it.info ='rating' + AND k.keyword like '%sequel%' + AND mi_idx.info > '9.0' + AND t.production_year > 2000 + AND t.id = mi_idx.movie_id + AND t.id = mk.movie_id + AND mk.movie_id = mi_idx.movie_id + AND k.id = mk.keyword_id + AND it.id = mi_idx.info_type_id +---- +9.1 The Dark Knight + +# 4c - Query with actors from certain period +query TT +SELECT MIN(mi_idx.info) AS rating, MIN(t.title) AS movie_title +FROM info_type AS it, keyword AS k, movie_info_idx AS mi_idx, movie_keyword AS mk, title AS t +WHERE it.info ='rating' + AND k.keyword like '%sequel%' + AND mi_idx.info > '2.0' + AND t.production_year > 1990 + AND t.id = mi_idx.movie_id + AND t.id = mk.movie_id + AND mk.movie_id = mi_idx.movie_id + AND k.id = mk.keyword_id + AND it.id = mi_idx.info_type_id +---- +7.2 Avengers: Endgame + +# 5a - Query with keyword and movie links +query T +SELECT MIN(t.title) AS typical_european_movie +FROM company_type AS ct, info_type AS it, movie_companies AS mc, movie_info AS mi, title AS t +WHERE ct.kind = 'production companies' + AND mc.note like '%(theatrical)%' and mc.note like '%(France)%' + AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Denish', 'Norwegian', 'German') + AND t.production_year > 2005 + AND t.id = mi.movie_id + AND t.id = mc.movie_id + AND mc.movie_id = mi.movie_id + AND ct.id = mc.company_type_id + AND it.id = mi.info_type_id +---- +The Matrix + +# 5b - Query with keyword and directors +query T +SELECT MIN(t.title) AS american_vhs_movie +FROM company_type AS ct, info_type AS it, movie_companies AS mc, movie_info AS mi, title AS t +WHERE ct.kind = 'production companies' + AND mc.note like '%(VHS)%' and mc.note like '%(USA)%' and mc.note like '%(1994)%' + AND mi.info IN ('USA', 'America') + AND t.production_year > 2000 + AND t.id = mi.movie_id + AND t.id = mc.movie_id + AND mc.movie_id = mi.movie_id + AND ct.id = mc.company_type_id + AND it.id = mi.info_type_id +---- +The Matrix + +# 5c - Query with female leading roles +query T +SELECT MIN(t.title) AS american_movie +FROM company_type AS ct, info_type AS it, movie_companies AS mc, movie_info AS mi, title AS t +WHERE ct.kind = 'production companies' + AND mc.note not like '%(TV)%' and mc.note like '%(USA)%' + AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Denish', 'Norwegian', 'German', 'USA', 'American') + AND t.production_year > 1990 + AND t.id = mi.movie_id + AND t.id = mc.movie_id + AND mc.movie_id = mi.movie_id + AND ct.id = mc.company_type_id + AND it.id = mi.info_type_id +---- +Champion Boxer + +# 6a - Query for Marvel movies with Robert Downey +query TTT +SELECT MIN(k.keyword) AS movie_keyword, MIN(n.name) AS actor_name, MIN(t.title) AS marvel_movie +FROM cast_info AS ci, keyword AS k, movie_keyword AS mk, name AS n, title AS t +WHERE k.keyword = 'marvel-cinematic-universe' + AND n.name LIKE '%Downey%Robert%' + AND t.production_year > 2010 + AND k.id = mk.keyword_id + AND t.id = mk.movie_id + AND t.id = ci.movie_id + AND ci.movie_id = mk.movie_id + AND n.id = ci.person_id +---- +marvel-cinematic-universe Downey Robert Jr. Avengers: Endgame + +# 6b - Query for male actors in movies after 2009 +query TTT +SELECT MIN(k.keyword) AS movie_keyword, MIN(n.name) AS actor_name, MIN(t.title) AS hero_movie +FROM cast_info AS ci, keyword AS k, movie_keyword AS mk, name AS n, title AS t +WHERE k.keyword in ('superhero', 'sequel', 'second-part', 'marvel-comics', 'based-on-comic', 'tv-special', 'fight', 'violence') + AND n.name LIKE '%Downey%Robert%' + AND t.production_year > 2014 + AND k.id = mk.keyword_id + AND t.id = mk.movie_id + AND t.id = ci.movie_id + AND ci.movie_id = mk.movie_id + AND n.id = ci.person_id +---- +sequel Downey Robert Jr. Avengers: Endgame + +# 6c - Query for superhero movies from specific year +query TTT +SELECT MIN(k.keyword) AS movie_keyword, MIN(n.name) AS actor_name, MIN(t.title) AS marvel_movie +FROM cast_info AS ci, keyword AS k, movie_keyword AS mk, name AS n, title AS t +WHERE k.keyword = 'marvel-cinematic-universe' + AND n.name LIKE '%Downey%Robert%' + AND t.production_year > 2014 + AND k.id = mk.keyword_id + AND t.id = mk.movie_id + AND t.id = ci.movie_id + AND ci.movie_id = mk.movie_id + AND n.id = ci.person_id +---- +marvel-cinematic-universe Downey Robert Jr. Avengers: Endgame + +# 6d - Query for specific director +query TTT +SELECT MIN(k.keyword) AS movie_keyword, MIN(n.name) AS actor_name, MIN(t.title) AS hero_movie +FROM cast_info AS ci, keyword AS k, movie_keyword AS mk, name AS n, title AS t +WHERE k.keyword in ('superhero', 'sequel', 'second-part', 'marvel-comics', 'based-on-comic', 'tv-special', 'fight', 'violence') + AND n.name LIKE '%Downey%Robert%' + AND t.production_year > 2000 + AND k.id = mk.keyword_id + AND t.id = mk.movie_id + AND t.id = ci.movie_id + AND ci.movie_id = mk.movie_id + AND n.id = ci.person_id +---- +based-on-comic Downey Robert Jr. Avengers: Endgame + +# 6e - Query for advanced superhero movies +query TTT +SELECT MIN(k.keyword) AS movie_keyword, MIN(n.name) AS actor_name, MIN(t.title) AS marvel_movie +FROM cast_info AS ci, keyword AS k, movie_keyword AS mk, name AS n, title AS t +WHERE k.keyword = 'marvel-cinematic-universe' + AND n.name LIKE '%Downey%Robert%' + AND t.production_year > 2000 + AND k.id = mk.keyword_id + AND t.id = mk.movie_id + AND t.id = ci.movie_id + AND ci.movie_id = mk.movie_id + AND n.id = ci.person_id +---- +marvel-cinematic-universe Downey Robert Jr. Avengers: Endgame + +# 6f - Query for complex superhero movies +query TTT +SELECT MIN(k.keyword) AS movie_keyword, MIN(n.name) AS actor_name, MIN(t.title) AS hero_movie +FROM cast_info AS ci, keyword AS k, movie_keyword AS mk, name AS n, title AS t +WHERE k.keyword in ('superhero', 'sequel', 'second-part', 'marvel-comics', 'based-on-comic', 'tv-special', 'fight', 'violence') + AND t.production_year > 2000 + AND k.id = mk.keyword_id + AND t.id = mk.movie_id + AND t.id = ci.movie_id + AND ci.movie_id = mk.movie_id + AND n.id = ci.person_id +---- +based-on-comic Al Pacino Avengers: Endgame + +# 7a - Query about character names +query TT +SELECT MIN(n.name) AS of_person, MIN(t.title) AS biography_movie +FROM aka_name AS an, cast_info AS ci, info_type AS it, link_type AS lt, movie_link AS ml, name AS n, person_info AS pi, title AS t +WHERE an.name LIKE '%a%' + AND it.info ='mini biography' + AND lt.link ='features' + AND n.name_pcode_cf BETWEEN 'A' + AND 'F' + AND (n.gender='m' OR (n.gender = 'f' + AND n.name LIKE 'B%')) + AND pi.note ='Volker Boehm' + AND t.production_year BETWEEN 1980 + AND 1995 + AND n.id = an.person_id + AND n.id = pi.person_id + AND ci.person_id = n.id + AND t.id = ci.movie_id + AND ml.linked_movie_id = t.id + AND lt.id = ml.link_type_id + AND it.id = pi.info_type_id + AND pi.person_id = an.person_id + AND pi.person_id = ci.person_id + AND an.person_id = ci.person_id + AND ci.movie_id = ml.linked_movie_id -- #Al Pacino The Godfather +---- +Derek Jacobi Derek Jacobi Story + +# 7b - Query for person with biography +query TT +SELECT MIN(n.name) AS of_person, MIN(t.title) AS biography_movie +FROM aka_name AS an, cast_info AS ci, info_type AS it, link_type AS lt, movie_link AS ml, name AS n, person_info AS pi, title AS t +WHERE an.name LIKE '%a%' + AND it.info ='mini biography' + AND lt.link ='features' + AND n.name_pcode_cf LIKE 'D%' + AND n.gender='m' + AND pi.note ='Volker Boehm' + AND t.production_year BETWEEN 1980 + AND 1984 + AND n.id = an.person_id + AND n.id = pi.person_id + AND ci.person_id = n.id + AND t.id = ci.movie_id + AND ml.linked_movie_id = t.id + AND lt.id = ml.link_type_id + AND it.id = pi.info_type_id + AND pi.person_id = an.person_id + AND pi.person_id = ci.person_id + AND an.person_id = ci.person_id + AND ci.movie_id = ml.linked_movie_id +---- +Derek Jacobi Derek Jacobi Story + +# 7c - Query for extended character names and biographies +query TT +SELECT MIN(n.name) AS cast_member_name, MIN(pi.info) AS cast_member_info +FROM aka_name AS an, cast_info AS ci, info_type AS it, link_type AS lt, movie_link AS ml, name AS n, person_info AS pi, title AS t +WHERE an.name is not NULL and (an.name LIKE '%a%' or an.name LIKE 'A%') + AND it.info ='mini biography' + AND lt.link in ('references', 'referenced in', 'features', 'featured in') + AND n.name_pcode_cf BETWEEN 'A' + AND 'F' + AND (n.gender='m' OR (n.gender = 'f' + AND n.name LIKE 'A%')) + AND pi.note is not NULL + AND t.production_year BETWEEN 1980 + AND 2010 + AND n.id = an.person_id + AND n.id = pi.person_id + AND ci.person_id = n.id + AND t.id = ci.movie_id + AND ml.linked_movie_id = t.id + AND lt.id = ml.link_type_id + AND it.id = pi.info_type_id + AND pi.person_id = an.person_id + AND pi.person_id = ci.person_id + AND an.person_id = ci.person_id + AND ci.movie_id = ml.linked_movie_id +---- +Al Pacino Famous for his role in The Godfather + +# 8a - Find movies by keyword +query TT +SELECT MIN(an1.name) AS actress_pseudonym, MIN(t.title) AS japanese_movie_dubbed +FROM aka_name AS an1, cast_info AS ci, company_name AS cn, movie_companies AS mc, name AS n1, role_type AS rt, title AS t +WHERE ci.note ='(voice: English version)' + AND cn.country_code ='[jp]' + AND mc.note like '%(Japan)%' and mc.note not like '%(USA)%' + AND n1.name like '%Yo%' and n1.name not like '%Yu%' + AND rt.role ='actress' + AND an1.person_id = n1.id + AND n1.id = ci.person_id + AND ci.movie_id = t.id + AND t.id = mc.movie_id + AND mc.company_id = cn.id + AND ci.role_id = rt.id + AND an1.person_id = ci.person_id + AND ci.movie_id = mc.movie_id +---- +Yoko Shimizu One Piece: Grand Adventure + +# 8b - Query for anime voice actors +query TT +SELECT MIN(an.name) AS acress_pseudonym, MIN(t.title) AS japanese_anime_movie +FROM aka_name AS an, cast_info AS ci, company_name AS cn, movie_companies AS mc, name AS n, role_type AS rt, title AS t +WHERE ci.note ='(voice: English version)' + AND cn.country_code ='[jp]' + AND mc.note like '%(Japan)%' and mc.note not like '%(USA)%' and (mc.note like '%(2006)%' or mc.note like '%(2007)%') + AND n.name like '%Yo%' and n.name not like '%Yu%' + AND rt.role ='actress' + AND t.production_year between 2006 and 2007 and (t.title like 'One Piece%' or t.title like 'Dragon Ball Z%') + AND an.person_id = n.id + AND n.id = ci.person_id + AND ci.movie_id = t.id + AND t.id = mc.movie_id + AND mc.company_id = cn.id + AND ci.role_id = rt.id + AND an.person_id = ci.person_id + AND ci.movie_id = mc.movie_id +---- +Yoko Shimizu One Piece: Grand Adventure + +# 8c - Query for extended movies by keyword and voice actors +query TT +SELECT MIN(a1.name) AS writer_pseudo_name, MIN(t.title) AS movie_title +FROM aka_name AS a1, cast_info AS ci, company_name AS cn, movie_companies AS mc, name AS n1, role_type AS rt, title AS t +WHERE cn.country_code ='[us]' + AND rt.role ='writer' + AND a1.person_id = n1.id + AND n1.id = ci.person_id + AND ci.movie_id = t.id + AND t.id = mc.movie_id + AND mc.company_id = cn.id + AND ci.role_id = rt.id + AND a1.person_id = ci.person_id + AND ci.movie_id = mc.movie_id +---- +Jim Cameron Titanic + +# 8d - Query for specialized movies by keyword and voice actors +query TT +SELECT MIN(an1.name) AS costume_designer_pseudo, MIN(t.title) AS movie_with_costumes +FROM aka_name AS an1, cast_info AS ci, company_name AS cn, movie_companies AS mc, name AS n1, role_type AS rt, title AS t +WHERE cn.country_code ='[us]' + AND rt.role ='costume designer' + AND an1.person_id = n1.id + AND n1.id = ci.person_id + AND ci.movie_id = t.id + AND t.id = mc.movie_id + AND mc.company_id = cn.id + AND ci.role_id = rt.id + AND an1.person_id = ci.person_id + AND ci.movie_id = mc.movie_id +---- +E. Head Avengers: Endgame + +# 9a - Query for movie sequels +query TTT +SELECT MIN(an.name) AS alternative_name, MIN(chn.name) AS character_name, MIN(t.title) AS movie +FROM aka_name AS an, char_name AS chn, cast_info AS ci, company_name AS cn, movie_companies AS mc, name AS n, role_type AS rt, title AS t +WHERE ci.note in ('(voice)', '(voice: Japanese version)', '(voice) (uncredited)', '(voice: English version)') + AND cn.country_code ='[us]' + AND mc.note is not NULL and (mc.note like '%(USA)%' or mc.note like '%(worldwide)%') + AND n.gender ='f' and n.name like '%Ang%' + AND rt.role ='actress' + AND t.production_year between 2005 and 2015 + AND ci.movie_id = t.id + AND t.id = mc.movie_id + AND ci.movie_id = mc.movie_id + AND mc.company_id = cn.id + AND ci.role_id = rt.id + AND n.id = ci.person_id + AND chn.id = ci.person_role_id + AND an.person_id = n.id + AND an.person_id = ci.person_id +---- +Angelina Jolie Batman's Assistant Kung Fu Panda + +# 9b - Query for voice actors in American movies +query TTTT +SELECT MIN(an.name) AS alternative_name, MIN(chn.name) AS voiced_character, MIN(n.name) AS voicing_actress, MIN(t.title) AS american_movie +FROM aka_name AS an, char_name AS chn, cast_info AS ci, company_name AS cn, movie_companies AS mc, name AS n, role_type AS rt, title AS t +WHERE ci.note = '(voice)' + AND cn.country_code ='[us]' + AND mc.note like '%(200%)%' and (mc.note like '%(USA)%' or mc.note like '%(worldwide)%') + AND n.gender ='f' and n.name like '%Angel%' + AND rt.role ='actress' + AND t.production_year between 2007 and 2010 + AND ci.movie_id = t.id + AND t.id = mc.movie_id + AND ci.movie_id = mc.movie_id + AND mc.company_id = cn.id + AND ci.role_id = rt.id + AND n.id = ci.person_id + AND chn.id = ci.person_role_id + AND an.person_id = n.id + AND an.person_id = ci.person_id +---- +Angelina Jolie Batman's Assistant Angelina Jolie Kung Fu Panda + +# 9c - Query for extended movie sequels and voice actors +query TTTT +SELECT MIN(an.name) AS alternative_name, MIN(chn.name) AS voiced_character_name, MIN(n.name) AS voicing_actress, MIN(t.title) AS american_movie +FROM aka_name AS an, char_name AS chn, cast_info AS ci, company_name AS cn, movie_companies AS mc, name AS n, role_type AS rt, title AS t +WHERE ci.note in ('(voice)', '(voice: Japanese version)', '(voice) (uncredited)', '(voice: English version)') + AND cn.country_code ='[us]' + AND n.gender ='f' and n.name like '%An%' + AND rt.role ='actress' + AND ci.movie_id = t.id + AND t.id = mc.movie_id + AND ci.movie_id = mc.movie_id + AND mc.company_id = cn.id + AND ci.role_id = rt.id + AND n.id = ci.person_id + AND chn.id = ci.person_role_id + AND an.person_id = n.id + AND an.person_id = ci.person_id +---- +Alexander Morgan Batman's Assistant Angelina Jolie Dragon Warriors + +# 9d - Query for specialized movie sequels and voice actors +query TTTT +SELECT MIN(an.name) AS alternative_name, MIN(chn.name) AS voiced_char_name, MIN(n.name) AS voicing_actress, MIN(t.title) AS american_movie +FROM aka_name AS an, char_name AS chn, cast_info AS ci, company_name AS cn, movie_companies AS mc, name AS n, role_type AS rt, title AS t +WHERE ci.note in ('(voice)', '(voice: Japanese version)', '(voice) (uncredited)', '(voice: English version)') + AND cn.country_code ='[us]' + AND n.gender ='f' + AND rt.role ='actress' + AND ci.movie_id = t.id + AND t.id = mc.movie_id + AND ci.movie_id = mc.movie_id + AND mc.company_id = cn.id + AND ci.role_id = rt.id + AND n.id = ci.person_id + AND chn.id = ci.person_role_id + AND an.person_id = n.id + AND an.person_id = ci.person_id +---- +Alexander Morgan Batman's Assistant Angelina Jolie Dragon Warriors + +# 10a - Query for cast combinations +query TT +SELECT MIN(chn.name) AS uncredited_voiced_character, MIN(t.title) AS russian_movie +FROM char_name AS chn, cast_info AS ci, company_name AS cn, company_type AS ct, movie_companies AS mc, role_type AS rt, title AS t +WHERE ci.note like '%(voice)%' and ci.note like '%(uncredited)%' + AND cn.country_code = '[ru]' + AND rt.role = 'actor' + AND t.production_year > 2005 + AND t.id = mc.movie_id + AND t.id = ci.movie_id + AND ci.movie_id = mc.movie_id + AND chn.id = ci.person_role_id + AND rt.id = ci.role_id + AND cn.id = mc.company_id + AND ct.id = mc.company_type_id +---- +Nikolai Moscow Nights + +# 10b - Query for Russian movie producers who are also actors +query TT +SELECT MIN(chn.name) AS character, MIN(t.title) AS russian_mov_with_actor_producer +FROM char_name AS chn, cast_info AS ci, company_name AS cn, company_type AS ct, movie_companies AS mc, role_type AS rt, title AS t +WHERE ci.note like '%(producer)%' + AND cn.country_code = '[ru]' + AND rt.role = 'actor' + AND t.production_year > 2000 + AND t.id = mc.movie_id + AND t.id = ci.movie_id + AND ci.movie_id = mc.movie_id + AND chn.id = ci.person_role_id + AND rt.id = ci.role_id + AND cn.id = mc.company_id + AND ct.id = mc.company_type_id +---- +Director Moscow Nights + +# 10c - Query for American producers in movies +query TT +SELECT MIN(chn.name) AS character, MIN(t.title) AS movie_with_american_producer +FROM char_name AS chn, cast_info AS ci, company_name AS cn, company_type AS ct, movie_companies AS mc, role_type AS rt, title AS t +WHERE ci.note like '%(producer)%' + AND cn.country_code = '[us]' + AND t.production_year > 1990 + AND t.id = mc.movie_id + AND t.id = ci.movie_id + AND ci.movie_id = mc.movie_id + AND chn.id = ci.person_role_id + AND rt.id = ci.role_id + AND cn.id = mc.company_id + AND ct.id = mc.company_type_id +---- +Bruce Wayne The Dark Knight + +# 11a - Query for non-Polish companies with sequels +query TTT +SELECT MIN(cn.name) AS from_company, MIN(lt.link) AS movie_link_type, MIN(t.title) AS non_polish_sequel_movie +FROM company_name AS cn, company_type AS ct, keyword AS k, link_type AS lt, movie_companies AS mc, movie_keyword AS mk, movie_link AS ml, title AS t +WHERE cn.country_code !='[pl]' + AND (cn.name LIKE '%Film%' OR cn.name LIKE '%Warner%') + AND ct.kind ='production companies' + AND k.keyword ='sequel' + AND lt.link LIKE '%follow%' + AND mc.note IS NULL + AND t.production_year BETWEEN 1950 + AND 2000 + AND lt.id = ml.link_type_id + AND ml.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND t.id = mc.movie_id + AND mc.company_type_id = ct.id + AND mc.company_id = cn.id + AND ml.movie_id = mk.movie_id + AND ml.movie_id = mc.movie_id + AND mk.movie_id = mc.movie_id +---- +Warner Bros. follows Money Talks + +# 11b - Query for non-Polish companies with Money sequels from 1998 +query TTT +SELECT MIN(cn.name) AS from_company, MIN(lt.link) AS movie_link_type, MIN(t.title) AS sequel_movie +FROM company_name AS cn, company_type AS ct, keyword AS k, link_type AS lt, movie_companies AS mc, movie_keyword AS mk, movie_link AS ml, title AS t +WHERE cn.country_code !='[pl]' + AND (cn.name LIKE '%Film%' OR cn.name LIKE '%Warner%') + AND ct.kind ='production companies' + AND k.keyword ='sequel' + AND lt.link LIKE '%follows%' + AND mc.note IS NULL + AND t.production_year = 1998 and t.title like '%Money%' + AND lt.id = ml.link_type_id + AND ml.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND t.id = mc.movie_id + AND mc.company_type_id = ct.id + AND mc.company_id = cn.id + AND ml.movie_id = mk.movie_id + AND ml.movie_id = mc.movie_id + AND mk.movie_id = mc.movie_id +---- +Warner Bros. Pictures follows Money Talks + +# 11c - Query for Fox movies based on novels +query TTT +SELECT MIN(cn.name) AS from_company, MIN(mc.note) AS production_note, MIN(t.title) AS movie_based_on_book +FROM company_name AS cn, company_type AS ct, keyword AS k, link_type AS lt, movie_companies AS mc, movie_keyword AS mk, movie_link AS ml, title AS t +WHERE cn.country_code !='[pl]' and (cn.name like '20th Century Fox%' or cn.name like 'Twentieth Century Fox%') + AND ct.kind != 'production companies' and ct.kind is not NULL + AND k.keyword in ('sequel', 'revenge', 'based-on-novel') + AND mc.note is not NULL + AND t.production_year > 1950 + AND lt.id = ml.link_type_id + AND ml.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND t.id = mc.movie_id + AND mc.company_type_id = ct.id + AND mc.company_id = cn.id + AND ml.movie_id = mk.movie_id + AND ml.movie_id = mc.movie_id + AND mk.movie_id = mc.movie_id +---- +Twentieth Century Fox (distribution) (DVD) (US) Fox Novel Movie + +# 11d - Query for movies based on novels from non-Polish companies +query TTT +SELECT MIN(cn.name) AS from_company, MIN(mc.note) AS production_note, MIN(t.title) AS movie_based_on_book +FROM company_name AS cn, company_type AS ct, keyword AS k, link_type AS lt, movie_companies AS mc, movie_keyword AS mk, movie_link AS ml, title AS t +WHERE cn.country_code !='[pl]' + AND ct.kind != 'production companies' and ct.kind is not NULL + AND k.keyword in ('sequel', 'revenge', 'based-on-novel') + AND mc.note is not NULL + AND t.production_year > 1950 + AND lt.id = ml.link_type_id + AND ml.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND t.id = mc.movie_id + AND mc.company_type_id = ct.id + AND mc.company_id = cn.id + AND ml.movie_id = mk.movie_id + AND ml.movie_id = mc.movie_id + AND mk.movie_id = mc.movie_id +---- +Marvel Studios (as Marvel Studios) Avengers: Endgame + +# 12a - Query for cast in movies with specific genres +query TTT +SELECT MIN(cn.name) AS movie_company, MIN(mi_idx.info) AS rating, MIN(t.title) AS drama_horror_movie +FROM company_name AS cn, company_type AS ct, info_type AS it1, info_type AS it2, movie_companies AS mc, movie_info AS mi, movie_info_idx AS mi_idx, title AS t +WHERE cn.country_code = '[us]' + AND ct.kind = 'production companies' + AND it1.info = 'genres' + AND it2.info = 'rating' + AND mi.info in ('Drama', 'Horror') + AND mi_idx.info > '8.0' + AND t.production_year between 2005 and 2008 + AND t.id = mi.movie_id + AND t.id = mi_idx.movie_id + AND mi.info_type_id = it1.id + AND mi_idx.info_type_id = it2.id + AND t.id = mc.movie_id + AND ct.id = mc.company_type_id + AND cn.id = mc.company_id + AND mc.movie_id = mi.movie_id + AND mc.movie_id = mi_idx.movie_id + AND mi.movie_id = mi_idx.movie_id +---- +Warner Bros. 9.5 The Dark Knight + +# 12b - Query for unsuccessful movies with specific budget criteria +query TT +SELECT MIN(mi.info) AS budget, MIN(t.title) AS unsuccsessful_movie +FROM company_name AS cn, company_type AS ct, info_type AS it1, info_type AS it2, movie_companies AS mc, movie_info AS mi, movie_info_idx AS mi_idx, title AS t +WHERE cn.country_code ='[us]' + AND ct.kind is not NULL and (ct.kind ='production companies' or ct.kind = 'distributors') + AND it1.info ='budget' + AND it2.info ='bottom 10 rank' + AND t.production_year >2000 + AND (t.title LIKE 'Birdemic%' OR t.title LIKE '%Movie%') + AND t.id = mi.movie_id + AND t.id = mi_idx.movie_id + AND mi.info_type_id = it1.id + AND mi_idx.info_type_id = it2.id + AND t.id = mc.movie_id + AND ct.id = mc.company_type_id + AND cn.id = mc.company_id + AND mc.movie_id = mi.movie_id + AND mc.movie_id = mi_idx.movie_id + AND mi.movie_id = mi_idx.movie_id +---- +$500,000 Bad Movie Sequel + +# 12c - Query for highly rated mainstream movies +query TTT +SELECT MIN(cn.name) AS movie_company, MIN(mi_idx.info) AS rating, MIN(t.title) AS mainstream_movie +FROM company_name AS cn, company_type AS ct, info_type AS it1, info_type AS it2, movie_companies AS mc, movie_info AS mi, movie_info_idx AS mi_idx, title AS t +WHERE cn.country_code = '[us]' + AND ct.kind = 'production companies' + AND it1.info = 'genres' + AND it2.info = 'rating' + AND mi.info in ('Drama', 'Horror', 'Western', 'Family') + AND mi_idx.info > '7.0' + AND t.production_year between 2000 and 2010 + AND t.id = mi.movie_id + AND t.id = mi_idx.movie_id + AND mi.info_type_id = it1.id + AND mi_idx.info_type_id = it2.id + AND t.id = mc.movie_id + AND ct.id = mc.company_type_id + AND cn.id = mc.company_id + AND mc.movie_id = mi.movie_id + AND mc.movie_id = mi_idx.movie_id + AND mi.movie_id = mi_idx.movie_id +---- +Warner Bros. 9.5 The Dark Knight + +# 13a - Query for movies with specific genre combinations +query TTT +SELECT MIN(mi.info) AS release_date, MIN(miidx.info) AS rating, MIN(t.title) AS german_movie +FROM company_name AS cn, company_type AS ct, info_type AS it, info_type AS it2, kind_type AS kt, movie_companies AS mc, movie_info AS mi, movie_info_idx AS miidx, title AS t +WHERE cn.country_code ='[de]' + AND ct.kind ='production companies' + AND it.info ='rating' + AND it2.info ='release dates' + AND kt.kind ='movie' + AND mi.movie_id = t.id + AND it2.id = mi.info_type_id + AND kt.id = t.kind_id + AND mc.movie_id = t.id + AND cn.id = mc.company_id + AND ct.id = mc.company_type_id + AND miidx.movie_id = t.id + AND it.id = miidx.info_type_id + AND mi.movie_id = miidx.movie_id + AND mi.movie_id = mc.movie_id + AND miidx.movie_id = mc.movie_id +---- +2005-09-15 7.2 Dark Blood + +# 13b - Query for movies about winning with specific criteria +query TTT +SELECT MIN(cn.name) AS producing_company, MIN(miidx.info) AS rating, MIN(t.title) AS movie_about_winning +FROM company_name AS cn, company_type AS ct, info_type AS it, info_type AS it2, kind_type AS kt, movie_companies AS mc, movie_info AS mi, movie_info_idx AS miidx, title AS t +WHERE cn.country_code ='[us]' + AND ct.kind ='production companies' + AND it.info ='rating' + AND it2.info ='release dates' + AND kt.kind ='movie' + AND t.title != '' + AND (t.title LIKE '%Champion%' OR t.title LIKE '%Loser%') + AND mi.movie_id = t.id + AND it2.id = mi.info_type_id + AND kt.id = t.kind_id + AND mc.movie_id = t.id + AND cn.id = mc.company_id + AND ct.id = mc.company_type_id + AND miidx.movie_id = t.id + AND it.id = miidx.info_type_id + AND mi.movie_id = miidx.movie_id + AND mi.movie_id = mc.movie_id + AND miidx.movie_id = mc.movie_id +---- +Universal Pictures 7.5 Champion Boxer + +# 13c - Query for movies with Champion in the title +query TTT +SELECT MIN(cn.name) AS producing_company, MIN(miidx.info) AS rating, MIN(t.title) AS movie_about_winning +FROM company_name AS cn, company_type AS ct, info_type AS it, info_type AS it2, kind_type AS kt, movie_companies AS mc, movie_info AS mi, movie_info_idx AS miidx, title AS t +WHERE cn.country_code ='[us]' + AND ct.kind ='production companies' + AND it.info ='rating' + AND it2.info ='release dates' + AND kt.kind ='movie' + AND t.title != '' + AND (t.title LIKE 'Champion%' OR t.title LIKE 'Loser%') + AND mi.movie_id = t.id + AND it2.id = mi.info_type_id + AND kt.id = t.kind_id + AND mc.movie_id = t.id + AND cn.id = mc.company_id + AND ct.id = mc.company_type_id + AND miidx.movie_id = t.id + AND it.id = miidx.info_type_id + AND mi.movie_id = miidx.movie_id + AND mi.movie_id = mc.movie_id + AND miidx.movie_id = mc.movie_id +---- +Universal Pictures 7.5 Champion Boxer + +# 13d - Query for all US movies +query TTT +SELECT MIN(cn.name) AS producing_company, MIN(miidx.info) AS rating, MIN(t.title) AS movie +FROM company_name AS cn, company_type AS ct, info_type AS it, info_type AS it2, kind_type AS kt, movie_companies AS mc, movie_info AS mi, movie_info_idx AS miidx, title AS t +WHERE cn.country_code ='[us]' + AND ct.kind ='production companies' + AND it.info ='rating' + AND it2.info ='release dates' + AND kt.kind ='movie' + AND mi.movie_id = t.id + AND it2.id = mi.info_type_id + AND kt.id = t.kind_id + AND mc.movie_id = t.id + AND cn.id = mc.company_id + AND ct.id = mc.company_type_id + AND miidx.movie_id = t.id + AND it.id = miidx.info_type_id + AND mi.movie_id = miidx.movie_id + AND mi.movie_id = mc.movie_id + AND miidx.movie_id = mc.movie_id +---- +Marvel Studios 7.5 Avengers: Endgame + +# 14a - Query for actors in specific movie types +query TT +SELECT MIN(mi_idx.info) AS rating, MIN(t.title) AS northern_dark_movie +FROM info_type AS it1, info_type AS it2, keyword AS k, kind_type AS kt, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, title AS t +WHERE it1.info = 'countries' + AND it2.info = 'rating' + AND k.keyword in ('murder', 'murder-in-title', 'blood', 'violence') + AND kt.kind = 'movie' + AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Denish', 'Norwegian', 'German', 'USA', 'American') + AND mi_idx.info < '8.5' + AND t.production_year > 2010 + AND kt.id = t.kind_id + AND t.id = mi.movie_id + AND t.id = mk.movie_id + AND t.id = mi_idx.movie_id + AND mk.movie_id = mi.movie_id + AND mk.movie_id = mi_idx.movie_id + AND mi.movie_id = mi_idx.movie_id + AND k.id = mk.keyword_id + AND it1.id = mi.info_type_id + AND it2.id = mi_idx.info_type_id +---- +7.5 Nordic Noir + +# 14b - Query for dark western productions with specific criteria +query TT +SELECT MIN(mi_idx.info) AS rating, MIN(t.title) AS western_dark_production +FROM info_type AS it1, info_type AS it2, keyword AS k, kind_type AS kt, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, title AS t +WHERE it1.info = 'countries' + AND it2.info = 'rating' + AND k.keyword in ('murder', 'murder-in-title') + AND kt.kind = 'movie' + AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Denish', 'Norwegian', 'German', 'USA', 'American') + AND mi_idx.info > '6.0' + AND t.production_year > 2010 and (t.title like '%murder%' or t.title like '%Murder%' or t.title like '%Mord%') + AND kt.id = t.kind_id + AND t.id = mi.movie_id + AND t.id = mk.movie_id + AND t.id = mi_idx.movie_id + AND mk.movie_id = mi.movie_id + AND mk.movie_id = mi_idx.movie_id + AND mi.movie_id = mi_idx.movie_id + AND k.id = mk.keyword_id + AND it1.id = mi.info_type_id + AND it2.id = mi_idx.info_type_id +---- +7.8 The Swedish Murder Case + +# 14c - Query for extended movie types and dark themes +query TT +SELECT MIN(mi_idx.info) AS rating, MIN(t.title) AS north_european_dark_production +FROM info_type AS it1, info_type AS it2, keyword AS k, kind_type AS kt, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, title AS t +WHERE it1.info = 'countries' + AND it2.info = 'rating' + AND k.keyword is not null and k.keyword in ('murder', 'murder-in-title', 'blood', 'violence') + AND kt.kind in ('movie', 'episode') + AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Danish', 'Norwegian', 'German', 'USA', 'American') + AND mi_idx.info < '8.5' + AND t.production_year > 2005 + AND kt.id = t.kind_id + AND t.id = mi.movie_id + AND t.id = mk.movie_id + AND t.id = mi_idx.movie_id + AND mk.movie_id = mi.movie_id + AND mk.movie_id = mi_idx.movie_id + AND mi.movie_id = mi_idx.movie_id + AND k.id = mk.keyword_id + AND it1.id = mi.info_type_id + AND it2.id = mi_idx.info_type_id +---- +6.8 Berlin Noir + +# 15a - Query for US movies with internet releases +query TT +SELECT MIN(mi.info) AS release_date, MIN(t.title) AS internet_movie +FROM aka_title AS at, company_name AS cn, company_type AS ct, info_type AS it1, keyword AS k, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, title AS t +WHERE cn.country_code = '[us]' + AND it1.info = 'release dates' + AND mc.note like '%(200%)%' and mc.note like '%(worldwide)%' + AND mi.note like '%internet%' + AND mi.info like 'USA:% 200%' + AND t.production_year > 2000 + AND t.id = at.movie_id + AND t.id = mi.movie_id + AND t.id = mk.movie_id + AND t.id = mc.movie_id + AND mk.movie_id = mi.movie_id + AND mk.movie_id = mc.movie_id + AND mk.movie_id = at.movie_id + AND mi.movie_id = mc.movie_id + AND mi.movie_id = at.movie_id + AND mc.movie_id = at.movie_id + AND k.id = mk.keyword_id + AND it1.id = mi.info_type_id + AND cn.id = mc.company_id + AND ct.id = mc.company_type_id +---- +USA: 2008-05-15 The Dark Knight + +# 15b - Query for YouTube movies with specific release criteria +query TT +SELECT MIN(mi.info) AS release_date, MIN(t.title) AS youtube_movie +FROM aka_title AS at, company_name AS cn, company_type AS ct, info_type AS it1, keyword AS k, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, title AS t +WHERE cn.country_code = '[us]' and cn.name = 'YouTube' + AND it1.info = 'release dates' + AND mc.note like '%(200%)%' and mc.note like '%(worldwide)%' + AND mi.note like '%internet%' + AND mi.info like 'USA:% 200%' + AND t.production_year between 2005 and 2010 + AND t.id = at.movie_id + AND t.id = mi.movie_id + AND t.id = mk.movie_id + AND t.id = mc.movie_id + AND mk.movie_id = mi.movie_id + AND mk.movie_id = mc.movie_id + AND mk.movie_id = at.movie_id + AND mi.movie_id = mc.movie_id + AND mi.movie_id = at.movie_id + AND mc.movie_id = at.movie_id + AND k.id = mk.keyword_id + AND it1.id = mi.info_type_id + AND cn.id = mc.company_id + AND ct.id = mc.company_type_id +---- +USA: 2008-05-15 YouTube Documentary + +# 15c - Query for extended internet releases +query TT +SELECT MIN(mi.info) AS release_date, MIN(t.title) AS modern_american_internet_movie +FROM aka_title AS at, company_name AS cn, company_type AS ct, info_type AS it1, keyword AS k, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, title AS t +WHERE cn.country_code = '[us]' + AND it1.info = 'release dates' + AND mi.note like '%internet%' + AND mi.info is not NULL and (mi.info like 'USA:% 199%' or mi.info like 'USA:% 200%') + AND t.production_year > 1990 + AND t.id = at.movie_id + AND t.id = mi.movie_id + AND t.id = mk.movie_id + AND t.id = mc.movie_id + AND mk.movie_id = mi.movie_id + AND mk.movie_id = mc.movie_id + AND mk.movie_id = at.movie_id + AND mi.movie_id = mc.movie_id + AND mi.movie_id = at.movie_id + AND mc.movie_id = at.movie_id + AND k.id = mk.keyword_id + AND it1.id = mi.info_type_id + AND cn.id = mc.company_id + AND ct.id = mc.company_type_id +---- +USA: 15 May 2005 Digital Connection + +# 15d - Query for specialized internet releases +query TT +SELECT MIN(at.title) AS aka_title, MIN(t.title) AS internet_movie_title +FROM aka_title AS at, company_name AS cn, company_type AS ct, info_type AS it1, keyword AS k, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, title AS t +WHERE cn.country_code = '[us]' + AND it1.info = 'release dates' + AND mi.note like '%internet%' + AND t.production_year > 1990 + AND t.id = at.movie_id + AND t.id = mi.movie_id + AND t.id = mk.movie_id + AND t.id = mc.movie_id + AND mk.movie_id = mi.movie_id + AND mk.movie_id = mc.movie_id + AND mk.movie_id = at.movie_id + AND mi.movie_id = mc.movie_id + AND mi.movie_id = at.movie_id + AND mc.movie_id = at.movie_id + AND k.id = mk.keyword_id + AND it1.id = mi.info_type_id + AND cn.id = mc.company_id + AND ct.id = mc.company_type_id +---- +Avengers 4 Avengers: Endgame + +# 16a - Query for movies in specific languages +query TT +SELECT MIN(an.name) AS cool_actor_pseudonym, MIN(t.title) AS series_named_after_char +FROM aka_name AS an, cast_info AS ci, company_name AS cn, keyword AS k, movie_companies AS mc, movie_keyword AS mk, name AS n, title AS t +WHERE cn.country_code ='[us]' + AND k.keyword ='character-name-in-title' + AND t.episode_nr >= 50 + AND t.episode_nr < 100 + AND an.person_id = n.id + AND n.id = ci.person_id + AND ci.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND t.id = mc.movie_id + AND mc.company_id = cn.id + AND an.person_id = ci.person_id + AND ci.movie_id = mc.movie_id + AND ci.movie_id = mk.movie_id + AND mc.movie_id = mk.movie_id +---- +Alexander Morgan Character Series + +# 16b - Query for series named after characters +query TT +SELECT MIN(an.name) AS cool_actor_pseudonym, MIN(t.title) AS series_named_after_char +FROM aka_name AS an, cast_info AS ci, company_name AS cn, keyword AS k, movie_companies AS mc, movie_keyword AS mk, name AS n, title AS t +WHERE cn.country_code ='[us]' + AND k.keyword ='character-name-in-title' + AND an.person_id = n.id + AND n.id = ci.person_id + AND ci.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND t.id = mc.movie_id + AND mc.company_id = cn.id + AND an.person_id = ci.person_id + AND ci.movie_id = mc.movie_id + AND ci.movie_id = mk.movie_id + AND mc.movie_id = mk.movie_id +---- +Alexander Morgan Character Series + +# 16c - Query for extended languages and character-named series +query TT +SELECT MIN(an.name) AS cool_actor_pseudonym, MIN(t.title) AS series_named_after_char +FROM aka_name AS an, cast_info AS ci, company_name AS cn, keyword AS k, movie_companies AS mc, movie_keyword AS mk, name AS n, title AS t +WHERE cn.country_code ='[us]' + AND k.keyword ='character-name-in-title' + AND t.episode_nr < 100 + AND an.person_id = n.id + AND n.id = ci.person_id + AND ci.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND t.id = mc.movie_id + AND mc.company_id = cn.id + AND an.person_id = ci.person_id + AND ci.movie_id = mc.movie_id + AND ci.movie_id = mk.movie_id + AND mc.movie_id = mk.movie_id +---- +Alexander Morgan Character Series + +# 16d - Query for specialized languages and character-named series +query TT +SELECT MIN(an.name) AS cool_actor_pseudonym, MIN(t.title) AS series_named_after_char +FROM aka_name AS an, cast_info AS ci, company_name AS cn, keyword AS k, movie_companies AS mc, movie_keyword AS mk, name AS n, title AS t +WHERE cn.country_code ='[us]' + AND k.keyword ='character-name-in-title' + AND t.episode_nr >= 5 + AND t.episode_nr < 100 + AND an.person_id = n.id + AND n.id = ci.person_id + AND ci.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND t.id = mc.movie_id + AND mc.company_id = cn.id + AND an.person_id = ci.person_id + AND ci.movie_id = mc.movie_id + AND ci.movie_id = mk.movie_id + AND mc.movie_id = mk.movie_id +---- +Alexander Morgan Character Series + +# 17a - Query for actor/actress combinations +query TT +SELECT MIN(n.name) AS member_in_charnamed_american_movie, MIN(n.name) AS a1 +FROM cast_info AS ci, company_name AS cn, keyword AS k, movie_companies AS mc, movie_keyword AS mk, name AS n, title AS t +WHERE cn.country_code ='[us]' + AND k.keyword ='character-name-in-title' + AND n.name LIKE 'B%' + AND n.id = ci.person_id + AND ci.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND t.id = mc.movie_id + AND mc.company_id = cn.id + AND ci.movie_id = mc.movie_id + AND ci.movie_id = mk.movie_id + AND mc.movie_id = mk.movie_id +---- +Bert Wilson Bert Wilson + +# 17b - Query for actors with names starting with Z in character-named movies +query TT +SELECT MIN(n.name) AS member_in_charnamed_movie, MIN(n.name) AS a1 +FROM cast_info AS ci, company_name AS cn, keyword AS k, movie_companies AS mc, movie_keyword AS mk, name AS n, title AS t +WHERE k.keyword ='character-name-in-title' + AND n.name LIKE 'Z%' + AND n.id = ci.person_id + AND ci.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND t.id = mc.movie_id + AND mc.company_id = cn.id + AND ci.movie_id = mc.movie_id + AND ci.movie_id = mk.movie_id + AND mc.movie_id = mk.movie_id +---- +Zach Wilson Zach Wilson + +# 17c - Query for extended actor/actress combinations +query TT +SELECT MIN(n.name) AS member_in_charnamed_movie, MIN(n.name) AS a1 +FROM cast_info AS ci, company_name AS cn, keyword AS k, movie_companies AS mc, movie_keyword AS mk, name AS n, title AS t +WHERE k.keyword ='character-name-in-title' + AND n.name LIKE 'X%' + AND n.id = ci.person_id + AND ci.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND t.id = mc.movie_id + AND mc.company_id = cn.id + AND ci.movie_id = mc.movie_id + AND ci.movie_id = mk.movie_id + AND mc.movie_id = mk.movie_id +---- +Xavier Thompson Xavier Thompson + +# 17d - Query for specialized actor/actress combinations +query T +SELECT MIN(n.name) AS member_in_charnamed_movie +FROM cast_info AS ci, company_name AS cn, keyword AS k, movie_companies AS mc, movie_keyword AS mk, name AS n, title AS t +WHERE k.keyword ='character-name-in-title' + AND n.name LIKE '%Bert%' + AND n.id = ci.person_id + AND ci.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND t.id = mc.movie_id + AND mc.company_id = cn.id + AND ci.movie_id = mc.movie_id + AND ci.movie_id = mk.movie_id + AND mc.movie_id = mk.movie_id +---- +Bert Wilson + +# 17e - Query for advanced actor/actress combinations +query T +SELECT MIN(n.name) AS member_in_charnamed_movie +FROM cast_info AS ci, company_name AS cn, keyword AS k, movie_companies AS mc, movie_keyword AS mk, name AS n, title AS t +WHERE cn.country_code ='[us]' + AND k.keyword ='character-name-in-title' + AND n.id = ci.person_id + AND ci.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND t.id = mc.movie_id + AND mc.company_id = cn.id + AND ci.movie_id = mc.movie_id + AND ci.movie_id = mk.movie_id + AND mc.movie_id = mk.movie_id +---- +Alex Morgan + +# 17f - Query for complex actor/actress combinations +query T +SELECT MIN(n.name) AS member_in_charnamed_movie +FROM cast_info AS ci, company_name AS cn, keyword AS k, movie_companies AS mc, movie_keyword AS mk, name AS n, title AS t +WHERE k.keyword ='character-name-in-title' + AND n.name LIKE '%B%' + AND n.id = ci.person_id + AND ci.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND t.id = mc.movie_id + AND mc.company_id = cn.id + AND ci.movie_id = mc.movie_id + AND ci.movie_id = mk.movie_id + AND mc.movie_id = mk.movie_id +---- +Bert Wilson + +# 18a - Query with complex genre filtering +query TTT +SELECT MIN(mi.info) AS movie_budget, MIN(mi_idx.info) AS movie_votes, MIN(t.title) AS movie_title +FROM cast_info AS ci, info_type AS it1, info_type AS it2, movie_info AS mi, movie_info_idx AS mi_idx, name AS n, title AS t +WHERE ci.note in ('(producer)', '(executive producer)') + AND it1.info = 'budget' + AND it2.info = 'votes' + AND n.gender = 'm' and n.name like '%Tim%' + AND t.id = mi.movie_id + AND t.id = mi_idx.movie_id + AND t.id = ci.movie_id + AND ci.movie_id = mi.movie_id + AND ci.movie_id = mi_idx.movie_id + AND mi.movie_id = mi_idx.movie_id + AND n.id = ci.person_id + AND it1.id = mi.info_type_id + AND it2.id = mi_idx.info_type_id +---- +$25,000,000 2,345,678 The Shawshank Redemption + +# 18b - Query for horror movies by female writers with high ratings +query TTT +SELECT MIN(mi.info) AS movie_budget, MIN(mi_idx.info) AS movie_votes, MIN(t.title) AS movie_title +FROM cast_info AS ci, info_type AS it1, info_type AS it2, movie_info AS mi, movie_info_idx AS mi_idx, name AS n, title AS t +WHERE ci.note in ('(writer)', '(head writer)', '(written by)', '(story)', '(story editor)') + AND it1.info = 'genres' + AND it2.info = 'rating' + AND mi.info in ('Horror', 'Thriller') and mi.note is NULL + AND mi_idx.info > '8.0' + AND n.gender is not null and n.gender = 'f' + AND t.production_year between 2008 and 2014 + AND t.id = mi.movie_id + AND t.id = mi_idx.movie_id + AND t.id = ci.movie_id + AND ci.movie_id = mi.movie_id + AND ci.movie_id = mi_idx.movie_id + AND mi.movie_id = mi_idx.movie_id + AND n.id = ci.person_id + AND it1.id = mi.info_type_id + AND it2.id = mi_idx.info_type_id +---- +Horror 8.5 Woman in Black + +# 18c - Query for extended genre filtering +query TTT +SELECT MIN(mi.info) AS movie_budget, MIN(mi_idx.info) AS movie_votes, MIN(t.title) AS movie_title +FROM cast_info AS ci, info_type AS it1, info_type AS it2, movie_info AS mi, movie_info_idx AS mi_idx, name AS n, title AS t +WHERE ci.note in ('(writer)', '(head writer)', '(written by)', '(story)', '(story editor)') + AND it1.info = 'genres' + AND it2.info = 'votes' + AND mi.info in ('Horror', 'Action', 'Sci-Fi', 'Thriller', 'Crime', 'War') + AND n.gender = 'm' + AND t.id = mi.movie_id + AND t.id = mi_idx.movie_id + AND t.id = ci.movie_id + AND ci.movie_id = mi.movie_id + AND ci.movie_id = mi_idx.movie_id + AND mi.movie_id = mi_idx.movie_id + AND n.id = ci.person_id + AND it1.id = mi.info_type_id + AND it2.id = mi_idx.info_type_id +---- +Horror 1000 Halloween + +# 19a - Query for character name patterns +query TT +SELECT MIN(n.name) AS voicing_actress, MIN(t.title) AS voiced_movie +FROM aka_name AS an, char_name AS chn, cast_info AS ci, company_name AS cn, info_type AS it, movie_companies AS mc, movie_info AS mi, name AS n, role_type AS rt, title AS t +WHERE ci.note in ('(voice)', '(voice: Japanese version)', '(voice) (uncredited)', '(voice: English version)') + AND cn.country_code ='[us]' + AND it.info = 'release dates' + AND mc.note is not NULL and (mc.note like '%(USA)%' or mc.note like '%(worldwide)%') + AND mi.info is not null and (mi.info like 'Japan:%200%' or mi.info like 'USA:%200%') + AND n.gender ='f' and n.name like '%Ang%' + AND rt.role ='actress' + AND t.production_year between 2005 and 2009 + AND t.id = mi.movie_id + AND t.id = mc.movie_id + AND t.id = ci.movie_id + AND mc.movie_id = ci.movie_id + AND mc.movie_id = mi.movie_id + AND mi.movie_id = ci.movie_id + AND cn.id = mc.company_id + AND it.id = mi.info_type_id + AND n.id = ci.person_id + AND rt.id = ci.role_id + AND n.id = an.person_id + AND ci.person_id = an.person_id + AND chn.id = ci.person_role_id +---- +Angelina Jolie Kung Fu Panda + +# 19b - Query for Angelina Jolie as voice actress in Kung Fu Panda series +query TT +SELECT MIN(n.name) AS voicing_actress, MIN(t.title) AS kung_fu_panda +FROM aka_name AS an, char_name AS chn, cast_info AS ci, company_name AS cn, info_type AS it, movie_companies AS mc, movie_info AS mi, name AS n, role_type AS rt, title AS t +WHERE ci.note = '(voice)' + AND cn.country_code ='[us]' + AND it.info = 'release dates' + AND mc.note like '%(200%)%' and (mc.note like '%(USA)%' or mc.note like '%(worldwide)%') + AND mi.info is not null and (mi.info like 'Japan:%2007%' or mi.info like 'USA:%2008%') + AND n.gender ='f' and n.name like '%Angel%' + AND rt.role ='actress' + AND t.production_year between 2007 and 2008 and t.title like '%Kung%Fu%Panda%' + AND t.id = mi.movie_id + AND t.id = mc.movie_id + AND t.id = ci.movie_id + AND mc.movie_id = ci.movie_id + AND mc.movie_id = mi.movie_id + AND mi.movie_id = ci.movie_id + AND cn.id = mc.company_id + AND it.id = mi.info_type_id + AND n.id = ci.person_id + AND rt.id = ci.role_id + AND n.id = an.person_id + AND ci.person_id = an.person_id + AND chn.id = ci.person_role_id +---- +Angelina Jolie Kung Fu Panda + +# 19c - Query for extended character patterns +query TT +SELECT MIN(n.name) AS voicing_actress, MIN(t.title) AS jap_engl_voiced_movie +FROM aka_name AS an, char_name AS chn, cast_info AS ci, company_name AS cn, info_type AS it, movie_companies AS mc, movie_info AS mi, name AS n, role_type AS rt, title AS t +WHERE ci.note in ('(voice)', '(voice: Japanese version)', '(voice) (uncredited)', '(voice: English version)') + AND cn.country_code ='[us]' + AND it.info = 'release dates' + AND mi.info is not null and (mi.info like 'Japan:%200%' or mi.info like 'USA:%200%') + AND n.gender ='f' and n.name like '%An%' + AND rt.role ='actress' + AND t.production_year > 2000 + AND t.id = mi.movie_id + AND t.id = mc.movie_id + AND t.id = ci.movie_id + AND mc.movie_id = ci.movie_id + AND mc.movie_id = mi.movie_id + AND mi.movie_id = ci.movie_id + AND cn.id = mc.company_id + AND it.id = mi.info_type_id + AND n.id = ci.person_id + AND rt.id = ci.role_id + AND n.id = an.person_id + AND ci.person_id = an.person_id + AND chn.id = ci.person_role_id +---- +Angelina Jolie Kung Fu Panda + +# 19d - Query for specialized character patterns +query TT +SELECT MIN(n.name) AS voicing_actress, MIN(t.title) AS jap_engl_voiced_movie +FROM aka_name AS an, char_name AS chn, cast_info AS ci, company_name AS cn, info_type AS it, movie_companies AS mc, movie_info AS mi, name AS n, role_type AS rt, title AS t +WHERE ci.note in ('(voice)', '(voice: Japanese version)', '(voice) (uncredited)', '(voice: English version)') + AND cn.country_code ='[us]' + AND it.info = 'release dates' + AND n.gender ='f' + AND rt.role ='actress' + AND t.production_year > 2000 + AND t.id = mi.movie_id + AND t.id = mc.movie_id + AND t.id = ci.movie_id + AND mc.movie_id = ci.movie_id + AND mc.movie_id = mi.movie_id + AND mi.movie_id = ci.movie_id + AND cn.id = mc.company_id + AND it.id = mi.info_type_id + AND n.id = ci.person_id + AND rt.id = ci.role_id + AND n.id = an.person_id + AND ci.person_id = an.person_id + AND chn.id = ci.person_role_id +---- +Angelina Jolie Kung Fu Panda + +# 20a - Query for movies with specific actor roles +query T +SELECT MIN(t.title) AS complete_downey_ironman_movie +FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, char_name AS chn, cast_info AS ci, keyword AS k, kind_type AS kt, movie_keyword AS mk, name AS n, title AS t +WHERE cct1.kind = 'cast' + AND cct2.kind like '%complete%' + AND chn.name not like '%Sherlock%' and (chn.name like '%Tony%Stark%' or chn.name like '%Iron%Man%') + AND k.keyword in ('superhero', 'sequel', 'second-part', 'marvel-comics', 'based-on-comic', 'tv-special', 'fight', 'violence') + AND kt.kind = 'movie' + AND t.production_year > 1950 + AND kt.id = t.kind_id + AND t.id = mk.movie_id + AND t.id = ci.movie_id + AND t.id = cc.movie_id + AND mk.movie_id = ci.movie_id + AND mk.movie_id = cc.movie_id + AND ci.movie_id = cc.movie_id + AND chn.id = ci.person_role_id + AND n.id = ci.person_id + AND k.id = mk.keyword_id + AND cct1.id = cc.subject_id + AND cct2.id = cc.status_id +---- +Iron Man + +# 20b - Query for complete Downey Iron Man movies +query T +SELECT MIN(t.title) AS complete_downey_ironman_movie +FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, char_name AS chn, cast_info AS ci, keyword AS k, kind_type AS kt, movie_keyword AS mk, name AS n, title AS t +WHERE cct1.kind = 'cast' + AND cct2.kind like '%complete%' + AND chn.name not like '%Sherlock%' and (chn.name like '%Tony%Stark%' or chn.name like '%Iron%Man%') + AND k.keyword in ('superhero', 'sequel', 'second-part', 'marvel-comics', 'based-on-comic', 'tv-special', 'fight', 'violence') + AND kt.kind = 'movie' + AND n.name LIKE '%Downey%Robert%' + AND t.production_year > 2000 + AND kt.id = t.kind_id + AND t.id = mk.movie_id + AND t.id = ci.movie_id + AND t.id = cc.movie_id + AND mk.movie_id = ci.movie_id + AND mk.movie_id = cc.movie_id + AND ci.movie_id = cc.movie_id + AND chn.id = ci.person_role_id + AND n.id = ci.person_id + AND k.id = mk.keyword_id + AND cct1.id = cc.subject_id + AND cct2.id = cc.status_id +---- +Iron Man + +# 20c - Query for extended specific actor roles +query TT +SELECT MIN(n.name) AS cast_member, MIN(t.title) AS complete_dynamic_hero_movie +FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, char_name AS chn, cast_info AS ci, keyword AS k, kind_type AS kt, movie_keyword AS mk, name AS n, title AS t +WHERE cct1.kind = 'cast' + AND cct2.kind like '%complete%' + AND chn.name is not NULL and (chn.name like '%man%' or chn.name like '%Man%') + AND k.keyword in ('superhero', 'marvel-comics', 'based-on-comic', 'tv-special', 'fight', 'violence', 'magnet', 'web', 'claw', 'laser') + AND kt.kind = 'movie' + AND t.production_year > 2000 + AND kt.id = t.kind_id + AND t.id = mk.movie_id + AND t.id = ci.movie_id + AND t.id = cc.movie_id + AND mk.movie_id = ci.movie_id + AND mk.movie_id = cc.movie_id + AND ci.movie_id = cc.movie_id + AND chn.id = ci.person_role_id + AND n.id = ci.person_id + AND k.id = mk.keyword_id + AND cct1.id = cc.subject_id + AND cct2.id = cc.status_id +---- +Downey Robert Jr. Iron Man + +# 21a - Query for movies with specific production years +query TTT +SELECT MIN(cn.name) AS company_name, MIN(lt.link) AS link_type, MIN(t.title) AS western_follow_up +FROM company_name AS cn, company_type AS ct, keyword AS k, link_type AS lt, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, movie_link AS ml, title AS t +WHERE cn.country_code !='[pl]' + AND (cn.name LIKE '%Film%' OR cn.name LIKE '%Warner%') + AND ct.kind ='production companies' + AND k.keyword ='sequel' + AND lt.link LIKE '%follow%' + AND mc.note IS NULL + AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Denish', 'Norwegian', 'German') + AND t.production_year BETWEEN 1950 + AND 2000 + AND lt.id = ml.link_type_id + AND ml.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND t.id = mc.movie_id + AND mc.company_type_id = ct.id + AND mc.company_id = cn.id + AND mi.movie_id = t.id + AND ml.movie_id = mk.movie_id + AND ml.movie_id = mc.movie_id + AND mk.movie_id = mc.movie_id + AND ml.movie_id = mi.movie_id + AND mk.movie_id = mi.movie_id + AND mc.movie_id = mi.movie_id +---- +Warner Bros. Pictures follows The Western Sequel + +# 21b - Query for German follow-up movies +query TTT +SELECT MIN(cn.name) AS company_name, MIN(lt.link) AS link_type, MIN(t.title) AS german_follow_up +FROM company_name AS cn, company_type AS ct, keyword AS k, link_type AS lt, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, movie_link AS ml, title AS t +WHERE cn.country_code !='[pl]' + AND (cn.name LIKE '%Film%' OR cn.name LIKE '%Warner%') + AND ct.kind ='production companies' + AND k.keyword ='sequel' + AND lt.link LIKE '%follow%' + AND mc.note IS NULL + AND mi.info IN ('Germany', 'German') + AND t.production_year BETWEEN 2000 + AND 2010 + AND lt.id = ml.link_type_id + AND ml.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND t.id = mc.movie_id + AND mc.company_type_id = ct.id + AND mc.company_id = cn.id + AND mi.movie_id = t.id + AND ml.movie_id = mk.movie_id + AND ml.movie_id = mc.movie_id + AND mk.movie_id = mc.movie_id + AND ml.movie_id = mi.movie_id + AND mk.movie_id = mi.movie_id + AND mc.movie_id = mi.movie_id +---- +Berlin Film Studio follows Dark Blood + +# 21c - Query for extended specific production years +query TTT +SELECT MIN(cn.name) AS company_name, MIN(lt.link) AS link_type, MIN(t.title) AS western_follow_up +FROM company_name AS cn, company_type AS ct, keyword AS k, link_type AS lt, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, movie_link AS ml, title AS t +WHERE cn.country_code !='[pl]' + AND (cn.name LIKE '%Film%' OR cn.name LIKE '%Warner%') + AND ct.kind ='production companies' + AND k.keyword ='sequel' + AND lt.link LIKE '%follow%' + AND mc.note IS NULL + AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Denish', 'Norwegian', 'German', 'English') + AND t.production_year BETWEEN 1950 + AND 2010 + AND lt.id = ml.link_type_id + AND ml.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND t.id = mc.movie_id + AND mc.company_type_id = ct.id + AND mc.company_id = cn.id + AND mi.movie_id = t.id + AND ml.movie_id = mk.movie_id + AND ml.movie_id = mc.movie_id + AND mk.movie_id = mc.movie_id + AND ml.movie_id = mi.movie_id + AND mk.movie_id = mi.movie_id + AND mc.movie_id = mi.movie_id +---- +Berlin Film Studio follows Dark Blood + +# 22a - Query for movies with specific actor roles +query TTT +SELECT MIN(cn.name) AS movie_company, MIN(mi_idx.info) AS rating, MIN(t.title) AS western_violent_movie +FROM company_name AS cn, company_type AS ct, info_type AS it1, info_type AS it2, keyword AS k, kind_type AS kt, movie_companies AS mc, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, title AS t +WHERE cn.country_code != '[us]' + AND it1.info = 'countries' + AND it2.info = 'rating' + AND k.keyword in ('murder', 'murder-in-title', 'blood', 'violence') + AND kt.kind in ('movie', 'episode') + AND mc.note not like '%(USA)%' and mc.note like '%(200%)%' + AND mi.info IN ('Germany', 'German', 'USA', 'American') + AND mi_idx.info < '7.0' + AND t.production_year > 2008 + AND kt.id = t.kind_id + AND t.id = mi.movie_id + AND t.id = mk.movie_id + AND t.id = mi_idx.movie_id + AND t.id = mc.movie_id + AND mk.movie_id = mi.movie_id + AND mk.movie_id = mi_idx.movie_id + AND mk.movie_id = mc.movie_id + AND mi.movie_id = mi_idx.movie_id + AND mi.movie_id = mc.movie_id + AND mc.movie_id = mi_idx.movie_id + AND k.id = mk.keyword_id + AND it1.id = mi.info_type_id + AND it2.id = mi_idx.info_type_id + AND ct.id = mc.company_type_id + AND cn.id = mc.company_id +---- +Berlin Film Studio 6.8 Berlin Noir + +# 22b - Query for western violent movies by non-US companies +query TTT +SELECT MIN(cn.name) AS movie_company, MIN(mi_idx.info) AS rating, MIN(t.title) AS western_violent_movie +FROM company_name AS cn, company_type AS ct, info_type AS it1, info_type AS it2, keyword AS k, kind_type AS kt, movie_companies AS mc, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, title AS t +WHERE cn.country_code != '[us]' + AND it1.info = 'countries' + AND it2.info = 'rating' + AND k.keyword in ('murder', 'murder-in-title', 'blood', 'violence') + AND kt.kind in ('movie', 'episode') + AND mc.note not like '%(USA)%' and mc.note like '%(200%)%' + AND mi.info IN ('Germany', 'German', 'USA', 'American') + AND mi_idx.info < '7.0' + AND t.production_year > 2009 + AND kt.id = t.kind_id + AND t.id = mi.movie_id + AND t.id = mk.movie_id + AND t.id = mi_idx.movie_id + AND t.id = mc.movie_id + AND mk.movie_id = mi.movie_id + AND mk.movie_id = mi_idx.movie_id + AND mk.movie_id = mc.movie_id + AND mi.movie_id = mi_idx.movie_id + AND mi.movie_id = mc.movie_id + AND mc.movie_id = mi_idx.movie_id + AND k.id = mk.keyword_id + AND it1.id = mi.info_type_id + AND it2.id = mi_idx.info_type_id + AND ct.id = mc.company_type_id + AND cn.id = mc.company_id +---- +Berlin Film Studio 6.8 Berlin Noir + +# 22c - Query for extended actor roles +query TTT +SELECT MIN(cn.name) AS movie_company, MIN(mi_idx.info) AS rating, MIN(t.title) AS western_violent_movie +FROM company_name AS cn, company_type AS ct, info_type AS it1, info_type AS it2, keyword AS k, kind_type AS kt, movie_companies AS mc, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, title AS t +WHERE cn.country_code != '[us]' + AND it1.info = 'countries' + AND it2.info = 'rating' + AND k.keyword in ('murder', 'murder-in-title', 'blood', 'violence') + AND kt.kind in ('movie', 'episode') + AND mc.note not like '%(USA)%' and mc.note like '%(200%)%' + AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Danish', 'Norwegian', 'German', 'USA', 'American') + AND mi_idx.info < '8.5' + AND t.production_year > 2005 + AND kt.id = t.kind_id + AND t.id = mi.movie_id + AND t.id = mk.movie_id + AND t.id = mi_idx.movie_id + AND t.id = mc.movie_id + AND mk.movie_id = mi.movie_id + AND mk.movie_id = mi_idx.movie_id + AND mk.movie_id = mc.movie_id + AND mi.movie_id = mi_idx.movie_id + AND mi.movie_id = mc.movie_id + AND mc.movie_id = mi_idx.movie_id + AND k.id = mk.keyword_id + AND it1.id = mi.info_type_id + AND it2.id = mi_idx.info_type_id + AND ct.id = mc.company_type_id + AND cn.id = mc.company_id +---- +Berlin Film Studio 6.8 Berlin Noir + +# 22d - Query for specialized actor roles +query TTT +SELECT MIN(cn.name) AS movie_company, MIN(mi_idx.info) AS rating, MIN(t.title) AS western_violent_movie +FROM company_name AS cn, company_type AS ct, info_type AS it1, info_type AS it2, keyword AS k, kind_type AS kt, movie_companies AS mc, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, title AS t +WHERE cn.country_code != '[us]' + AND it1.info = 'countries' + AND it2.info = 'rating' + AND k.keyword in ('murder', 'murder-in-title', 'blood', 'violence') + AND kt.kind in ('movie', 'episode') + AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Danish', 'Norwegian', 'German', 'USA', 'American') + AND mi_idx.info < '8.5' + AND t.production_year > 2005 + AND kt.id = t.kind_id + AND t.id = mi.movie_id + AND t.id = mk.movie_id + AND t.id = mi_idx.movie_id + AND t.id = mc.movie_id + AND mk.movie_id = mi.movie_id + AND mk.movie_id = mi_idx.movie_id + AND mk.movie_id = mc.movie_id + AND mi.movie_id = mi_idx.movie_id + AND mi.movie_id = mc.movie_id + AND mc.movie_id = mi_idx.movie_id + AND k.id = mk.keyword_id + AND it1.id = mi.info_type_id + AND it2.id = mi_idx.info_type_id + AND ct.id = mc.company_type_id + AND cn.id = mc.company_id +---- +Berlin Film Studio 6.8 Berlin Noir + +# 23a - Query for sequels with specific character names +query TT +SELECT MIN(kt.kind) AS movie_kind, MIN(t.title) AS complete_us_internet_movie +FROM complete_cast AS cc, comp_cast_type AS cct1, company_name AS cn, company_type AS ct, info_type AS it1, keyword AS k, kind_type AS kt, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, title AS t +WHERE cct1.kind = 'complete+verified' + AND cn.country_code = '[us]' + AND it1.info = 'release dates' + AND kt.kind in ('movie') + AND mi.note like '%internet%' + AND mi.info is not NULL and (mi.info like 'USA:% 199%' or mi.info like 'USA:% 200%') + AND t.production_year > 2000 + AND kt.id = t.kind_id + AND t.id = mi.movie_id + AND t.id = mk.movie_id + AND t.id = mc.movie_id + AND t.id = cc.movie_id + AND mk.movie_id = mi.movie_id + AND mk.movie_id = mc.movie_id + AND mk.movie_id = cc.movie_id + AND mi.movie_id = mc.movie_id + AND mi.movie_id = cc.movie_id + AND mc.movie_id = cc.movie_id + AND k.id = mk.keyword_id + AND it1.id = mi.info_type_id + AND cn.id = mc.company_id + AND ct.id = mc.company_type_id + AND cct1.id = cc.status_id +---- +movie Digital Connection + +# 23b - Query for complete nerdy internet movies +query TT +SELECT MIN(kt.kind) AS movie_kind, MIN(t.title) AS complete_nerdy_internet_movie +FROM complete_cast AS cc, comp_cast_type AS cct1, company_name AS cn, company_type AS ct, info_type AS it1, keyword AS k, kind_type AS kt, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, title AS t +WHERE cct1.kind = 'complete+verified' + AND cn.country_code = '[us]' + AND it1.info = 'release dates' + AND k.keyword in ('nerd', 'loner', 'alienation', 'dignity') + AND kt.kind in ('movie') + AND mi.note like '%internet%' + AND mi.info like 'USA:% 200%' + AND t.production_year > 2000 + AND kt.id = t.kind_id + AND t.id = mi.movie_id + AND t.id = mk.movie_id + AND t.id = mc.movie_id + AND t.id = cc.movie_id + AND mk.movie_id = mi.movie_id + AND mk.movie_id = mc.movie_id + AND mk.movie_id = cc.movie_id + AND mi.movie_id = mc.movie_id + AND mi.movie_id = cc.movie_id + AND mc.movie_id = cc.movie_id + AND k.id = mk.keyword_id + AND it1.id = mi.info_type_id + AND cn.id = mc.company_id + AND ct.id = mc.company_type_id + AND cct1.id = cc.status_id +---- +movie Digital Connection + +# 23c - Query for extended sequels with specific attributes +query TT +SELECT MIN(kt.kind) AS movie_kind, MIN(t.title) AS complete_us_internet_movie +FROM complete_cast AS cc, comp_cast_type AS cct1, company_name AS cn, company_type AS ct, info_type AS it1, keyword AS k, kind_type AS kt, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, title AS t +WHERE cct1.kind = 'complete+verified' + AND cn.country_code = '[us]' + AND it1.info = 'release dates' + AND kt.kind in ('movie', 'tv movie', 'video movie', 'video game') + AND mi.note like '%internet%' + AND mi.info is not NULL and (mi.info like 'USA:% 199%' or mi.info like 'USA:% 200%') + AND t.production_year > 1990 + AND kt.id = t.kind_id + AND t.id = mi.movie_id + AND t.id = mk.movie_id + AND t.id = mc.movie_id + AND t.id = cc.movie_id + AND mk.movie_id = mi.movie_id + AND mk.movie_id = mc.movie_id + AND mk.movie_id = cc.movie_id + AND mi.movie_id = mc.movie_id + AND mi.movie_id = cc.movie_id + AND mc.movie_id = cc.movie_id + AND k.id = mk.keyword_id + AND it1.id = mi.info_type_id + AND cn.id = mc.company_id + AND ct.id = mc.company_type_id + AND cct1.id = cc.status_id +---- +movie Digital Connection + +# 24a - Query for movies with specific budgets +query TTT +SELECT MIN(chn.name) AS voiced_char_name, MIN(n.name) AS voicing_actress_name, MIN(t.title) AS voiced_action_movie_jap_eng +FROM aka_name AS an, char_name AS chn, cast_info AS ci, company_name AS cn, info_type AS it, keyword AS k, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, name AS n, role_type AS rt, title AS t +WHERE ci.note in ('(voice)', '(voice: Japanese version)', '(voice) (uncredited)', '(voice: English version)') + AND cn.country_code ='[us]' + AND it.info = 'release dates' + AND k.keyword in ('hero', 'martial-arts', 'hand-to-hand-combat') + AND mi.info is not null and (mi.info like 'Japan:%201%' or mi.info like 'USA:%201%') + AND n.gender ='f' and n.name like '%An%' + AND rt.role ='actress' + AND t.production_year > 2010 + AND t.id = mi.movie_id + AND t.id = mc.movie_id + AND t.id = ci.movie_id + AND t.id = mk.movie_id + AND mc.movie_id = ci.movie_id + AND mc.movie_id = mi.movie_id + AND mc.movie_id = mk.movie_id + AND mi.movie_id = ci.movie_id + AND mi.movie_id = mk.movie_id + AND ci.movie_id = mk.movie_id + AND cn.id = mc.company_id + AND it.id = mi.info_type_id + AND n.id = ci.person_id + AND rt.id = ci.role_id + AND n.id = an.person_id + AND ci.person_id = an.person_id + AND chn.id = ci.person_role_id + AND k.id = mk.keyword_id +---- +Batman's Assistant Angelina Jolie Kung Fu Panda 2 + +# 24b - Query for voiced characters in Kung Fu Panda +query TTT +SELECT MIN(chn.name) AS voiced_char_name, MIN(n.name) AS voicing_actress_name, MIN(t.title) AS kung_fu_panda +FROM aka_name AS an, char_name AS chn, cast_info AS ci, company_name AS cn, info_type AS it, keyword AS k, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, name AS n, role_type AS rt, title AS t +WHERE ci.note in ('(voice)', '(voice: Japanese version)', '(voice) (uncredited)', '(voice: English version)') + AND cn.country_code ='[us]' + AND cn.name = 'DreamWorks Animation' + AND it.info = 'release dates' + AND k.keyword in ('hero', 'martial-arts', 'hand-to-hand-combat', 'computer-animated-movie') + AND mi.info is not null and (mi.info like 'Japan:%201%' or mi.info like 'USA:%201%') + AND n.gender ='f' and n.name like '%An%' + AND rt.role ='actress' + AND t.production_year > 2010 + AND t.title like 'Kung Fu Panda%' + AND t.id = mi.movie_id + AND t.id = mc.movie_id + AND t.id = ci.movie_id + AND t.id = mk.movie_id + AND mc.movie_id = ci.movie_id + AND mc.movie_id = mi.movie_id + AND mc.movie_id = mk.movie_id + AND mi.movie_id = ci.movie_id + AND mi.movie_id = mk.movie_id + AND ci.movie_id = mk.movie_id + AND cn.id = mc.company_id + AND it.id = mi.info_type_id + AND n.id = ci.person_id + AND rt.id = ci.role_id + AND n.id = an.person_id + AND ci.person_id = an.person_id + AND chn.id = ci.person_role_id + AND k.id = mk.keyword_id +---- +Batman's Assistant Angelina Jolie Kung Fu Panda 2 + +# 25a - Query for cast combinations in specific movies +query TTTT +SELECT MIN(mi.info) AS movie_budget, MIN(mi_idx.info) AS movie_votes, MIN(n.name) AS male_writer, MIN(t.title) AS violent_movie_title +FROM cast_info AS ci, info_type AS it1, info_type AS it2, keyword AS k, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, name AS n, title AS t +WHERE ci.note in ('(writer)', '(head writer)', '(written by)', '(story)', '(story editor)') + AND it1.info = 'genres' + AND it2.info = 'votes' + AND k.keyword in ('murder', 'blood', 'gore', 'death', 'female-nudity') + AND mi.info = 'Horror' + AND n.gender = 'm' + AND t.id = mi.movie_id + AND t.id = mi_idx.movie_id + AND t.id = ci.movie_id + AND t.id = mk.movie_id + AND ci.movie_id = mi.movie_id + AND ci.movie_id = mi_idx.movie_id + AND ci.movie_id = mk.movie_id + AND mi.movie_id = mi_idx.movie_id + AND mi.movie_id = mk.movie_id + AND mi_idx.movie_id = mk.movie_id + AND n.id = ci.person_id + AND it1.id = mi.info_type_id + AND it2.id = mi_idx.info_type_id + AND k.id = mk.keyword_id +---- +Horror 1000 Christian Bale Halloween + +# 25b - Query for violent horror films with male writers +query TTTT +SELECT MIN(mi.info) AS movie_budget, MIN(mi_idx.info) AS movie_votes, MIN(n.name) AS male_writer, MIN(t.title) AS violent_movie_title +FROM cast_info AS ci, info_type AS it1, info_type AS it2, keyword AS k, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, name AS n, title AS t +WHERE ci.note in ('(writer)', '(head writer)', '(written by)', '(story)', '(story editor)') + AND it1.info = 'genres' + AND it2.info = 'votes' + AND k.keyword in ('murder', 'blood', 'gore', 'death', 'female-nudity') + AND mi.info = 'Horror' + AND n.gender = 'm' + AND t.production_year > 2010 + AND t.title like 'Vampire%' + AND t.id = mi.movie_id + AND t.id = mi_idx.movie_id + AND t.id = ci.movie_id + AND t.id = mk.movie_id + AND ci.movie_id = mi.movie_id + AND ci.movie_id = mi_idx.movie_id + AND ci.movie_id = mk.movie_id + AND mi.movie_id = mi_idx.movie_id + AND mi.movie_id = mk.movie_id + AND mi_idx.movie_id = mk.movie_id + AND n.id = ci.person_id + AND it1.id = mi.info_type_id + AND it2.id = mi_idx.info_type_id + AND k.id = mk.keyword_id +---- +Horror 1000 Christian Bale Vampire Chronicles + +# 25c - Query for extended cast combinations +query TTTT +SELECT MIN(mi.info) AS movie_budget, MIN(mi_idx.info) AS movie_votes, MIN(n.name) AS male_writer, MIN(t.title) AS violent_movie_title +FROM cast_info AS ci, info_type AS it1, info_type AS it2, keyword AS k, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, name AS n, title AS t +WHERE ci.note in ('(writer)', '(head writer)', '(written by)', '(story)', '(story editor)') + AND it1.info = 'genres' + AND it2.info = 'votes' + AND k.keyword in ('murder', 'violence', 'blood', 'gore', 'death', 'female-nudity', 'hospital') + AND mi.info in ('Horror', 'Action', 'Sci-Fi', 'Thriller', 'Crime', 'War') + AND n.gender = 'm' + AND t.id = mi.movie_id + AND t.id = mi_idx.movie_id + AND t.id = ci.movie_id + AND t.id = mk.movie_id + AND ci.movie_id = mi.movie_id + AND ci.movie_id = mi_idx.movie_id + AND ci.movie_id = mk.movie_id + AND mi.movie_id = mi_idx.movie_id + AND mi.movie_id = mk.movie_id + AND mi_idx.movie_id = mk.movie_id + AND n.id = ci.person_id + AND it1.id = mi.info_type_id + AND it2.id = mi_idx.info_type_id + AND k.id = mk.keyword_id +---- +Horror 1000 Christian Bale Halloween + +# 26a - Query for specific movie genres with ratings +query TTTT +SELECT MIN(chn.name) AS character_name, MIN(mi_idx.info) AS rating, MIN(n.name) AS playing_actor, MIN(t.title) AS complete_hero_movie +FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, char_name AS chn, cast_info AS ci, info_type AS it2, keyword AS k, kind_type AS kt, movie_info_idx AS mi_idx, movie_keyword AS mk, name AS n, title AS t +WHERE cct1.kind = 'cast' + AND cct2.kind like '%complete%' + AND chn.name is not NULL and (chn.name like '%man%' or chn.name like '%Man%') + AND it2.info = 'rating' + AND k.keyword in ('superhero', 'marvel-comics', 'based-on-comic', 'tv-special', 'fight', 'violence', 'magnet', 'web', 'claw', 'laser') + AND kt.kind = 'movie' + AND mi_idx.info > '7.0' + AND t.production_year > 2000 + AND kt.id = t.kind_id + AND t.id = mk.movie_id + AND t.id = ci.movie_id + AND t.id = cc.movie_id + AND t.id = mi_idx.movie_id + AND mk.movie_id = ci.movie_id + AND mk.movie_id = cc.movie_id + AND mk.movie_id = mi_idx.movie_id + AND ci.movie_id = cc.movie_id + AND ci.movie_id = mi_idx.movie_id + AND cc.movie_id = mi_idx.movie_id + AND chn.id = ci.person_role_id + AND n.id = ci.person_id + AND k.id = mk.keyword_id + AND cct1.id = cc.subject_id + AND cct2.id = cc.status_id + AND it2.id = mi_idx.info_type_id +---- +Ironman 8.5 John Carpenter Marvel Superhero Epic + +# 26b - Query for complete hero movies with Man in character name +query TTT +SELECT MIN(chn.name) AS character_name, MIN(mi_idx.info) AS rating, MIN(t.title) AS complete_hero_movie +FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, char_name AS chn, cast_info AS ci, info_type AS it2, keyword AS k, kind_type AS kt, movie_info_idx AS mi_idx, movie_keyword AS mk, name AS n, title AS t +WHERE cct1.kind = 'cast' + AND cct2.kind like '%complete%' + AND chn.name is not NULL and (chn.name like '%man%' or chn.name like '%Man%') + AND it2.info = 'rating' + AND k.keyword in ('superhero', 'marvel-comics', 'based-on-comic', 'fight') + AND kt.kind = 'movie' + AND mi_idx.info > '8.0' + AND t.production_year > 2005 + AND kt.id = t.kind_id + AND t.id = mk.movie_id + AND t.id = ci.movie_id + AND t.id = cc.movie_id + AND t.id = mi_idx.movie_id + AND mk.movie_id = ci.movie_id + AND mk.movie_id = cc.movie_id + AND mk.movie_id = mi_idx.movie_id + AND ci.movie_id = cc.movie_id + AND ci.movie_id = mi_idx.movie_id + AND cc.movie_id = mi_idx.movie_id + AND chn.id = ci.person_role_id + AND n.id = ci.person_id + AND k.id = mk.keyword_id + AND cct1.id = cc.subject_id + AND cct2.id = cc.status_id + AND it2.id = mi_idx.info_type_id +---- +Ironman 8.5 Marvel Superhero Epic + +# 26c - Query for extended movie genres and ratings +query TTT +SELECT MIN(chn.name) AS character_name, MIN(mi_idx.info) AS rating, MIN(t.title) AS complete_hero_movie +FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, char_name AS chn, cast_info AS ci, info_type AS it2, keyword AS k, kind_type AS kt, movie_info_idx AS mi_idx, movie_keyword AS mk, name AS n, title AS t +WHERE cct1.kind = 'cast' + AND cct2.kind like '%complete%' + AND chn.name is not NULL and (chn.name like '%man%' or chn.name like '%Man%') + AND it2.info = 'rating' + AND k.keyword in ('superhero', 'marvel-comics', 'based-on-comic', 'tv-special', 'fight', 'violence', 'magnet', 'web', 'claw', 'laser') + AND kt.kind = 'movie' + AND t.production_year > 2000 + AND kt.id = t.kind_id + AND t.id = mk.movie_id + AND t.id = ci.movie_id + AND t.id = cc.movie_id + AND t.id = mi_idx.movie_id + AND mk.movie_id = ci.movie_id + AND mk.movie_id = cc.movie_id + AND mk.movie_id = mi_idx.movie_id + AND ci.movie_id = cc.movie_id + AND ci.movie_id = mi_idx.movie_id + AND cc.movie_id = mi_idx.movie_id + AND chn.id = ci.person_role_id + AND n.id = ci.person_id + AND k.id = mk.keyword_id + AND cct1.id = cc.subject_id + AND cct2.id = cc.status_id + AND it2.id = mi_idx.info_type_id +---- +Ironman 8.5 Marvel Superhero Epic + +# 27a - Query for movies with specific person roles +query TTT +SELECT MIN(cn.name) AS producing_company, MIN(lt.link) AS link_type, MIN(t.title) AS complete_western_sequel +FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, company_name AS cn, company_type AS ct, keyword AS k, link_type AS lt, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, movie_link AS ml, title AS t +WHERE cct1.kind in ('cast', 'crew') + AND cct2.kind = 'complete' + AND cn.country_code !='[pl]' + AND (cn.name LIKE '%Film%' OR cn.name LIKE '%Warner%') + AND ct.kind ='production companies' + AND k.keyword ='sequel' + AND lt.link LIKE '%follow%' + AND mc.note IS NULL + AND mi.info IN ('Sweden', 'Germany','Swedish', 'German') + AND t.production_year BETWEEN 1950 + AND 2000 + AND lt.id = ml.link_type_id + AND ml.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND t.id = mc.movie_id + AND mc.company_type_id = ct.id + AND mc.company_id = cn.id + AND mi.movie_id = t.id + AND t.id = cc.movie_id + AND cct1.id = cc.subject_id + AND cct2.id = cc.status_id + AND ml.movie_id = mk.movie_id + AND ml.movie_id = mc.movie_id + AND mk.movie_id = mc.movie_id + AND ml.movie_id = mi.movie_id + AND mk.movie_id = mi.movie_id + AND mc.movie_id = mi.movie_id + AND ml.movie_id = cc.movie_id + AND mk.movie_id = cc.movie_id + AND mc.movie_id = cc.movie_id + AND mi.movie_id = cc.movie_id +---- +Warner Bros. Pictures follows The Western Sequel + +# 27b - Query for complete western sequel films by non-Polish companies +query TTT +SELECT MIN(cn.name) AS producing_company, MIN(lt.link) AS link_type, MIN(t.title) AS complete_western_sequel +FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, company_name AS cn, company_type AS ct, keyword AS k, link_type AS lt, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, movie_link AS ml, title AS t +WHERE cct1.kind in ('cast', 'crew') + AND cct2.kind = 'complete' + AND cn.country_code !='[pl]' + AND (cn.name LIKE '%Film%' OR cn.name LIKE '%Warner%') + AND ct.kind ='production companies' + AND k.keyword ='sequel' + AND lt.link LIKE '%follow%' + AND mc.note IS NULL + AND mi.info IN ('Sweden', 'Germany','Swedish', 'German') + AND t.production_year = 1998 + AND lt.id = ml.link_type_id + AND ml.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND t.id = mc.movie_id + AND mc.company_type_id = ct.id + AND mc.company_id = cn.id + AND mi.movie_id = t.id + AND t.id = cc.movie_id + AND cct1.id = cc.subject_id + AND cct2.id = cc.status_id + AND ml.movie_id = mk.movie_id + AND ml.movie_id = mc.movie_id + AND mk.movie_id = mc.movie_id + AND ml.movie_id = mi.movie_id + AND mk.movie_id = mi.movie_id + AND mc.movie_id = mi.movie_id + AND ml.movie_id = cc.movie_id + AND mk.movie_id = cc.movie_id + AND mc.movie_id = cc.movie_id + AND mi.movie_id = cc.movie_id +---- +Warner Bros. Pictures follows The Western Sequel + +# 27c - Query for extended person roles +query TTT +SELECT MIN(cn.name) AS producing_company, MIN(lt.link) AS link_type, MIN(t.title) AS complete_western_sequel +FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, company_name AS cn, company_type AS ct, keyword AS k, link_type AS lt, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, movie_link AS ml, title AS t +WHERE cct1.kind = 'cast' + AND cct2.kind like 'complete%' + AND cn.country_code !='[pl]' + AND (cn.name LIKE '%Film%' OR cn.name LIKE '%Warner%') + AND ct.kind ='production companies' + AND k.keyword ='sequel' + AND lt.link LIKE '%follow%' + AND mc.note IS NULL + AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Denish', 'Norwegian', 'German', 'English') + AND t.production_year BETWEEN 1950 + AND 2010 + AND lt.id = ml.link_type_id + AND ml.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND t.id = mc.movie_id + AND mc.company_type_id = ct.id + AND mc.company_id = cn.id + AND mi.movie_id = t.id + AND t.id = cc.movie_id + AND cct1.id = cc.subject_id + AND cct2.id = cc.status_id + AND ml.movie_id = mk.movie_id + AND ml.movie_id = mc.movie_id + AND mk.movie_id = mc.movie_id + AND ml.movie_id = mi.movie_id + AND mk.movie_id = mi.movie_id + AND mc.movie_id = mi.movie_id + AND ml.movie_id = cc.movie_id + AND mk.movie_id = cc.movie_id + AND mc.movie_id = cc.movie_id + AND mi.movie_id = cc.movie_id +---- +Warner Bros. Pictures follows The Western Sequel + +# 28a - Query for movies with specific production years +query TTT +SELECT MIN(cn.name) AS movie_company, MIN(mi_idx.info) AS rating, MIN(t.title) AS complete_euro_dark_movie +FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, company_name AS cn, company_type AS ct, info_type AS it1, info_type AS it2, keyword AS k, kind_type AS kt, movie_companies AS mc, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, title AS t +WHERE cct1.kind = 'crew' + AND cct2.kind != 'complete+verified' + AND cn.country_code != '[us]' + AND it1.info = 'countries' + AND it2.info = 'rating' + AND k.keyword in ('murder', 'murder-in-title', 'blood', 'violence') + AND kt.kind in ('movie', 'episode') + AND mc.note not like '%(USA)%' and mc.note like '%(200%)%' + AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Danish', 'Norwegian', 'German', 'USA', 'American') + AND mi_idx.info < '8.5' + AND t.production_year > 2000 + AND kt.id = t.kind_id + AND t.id = mi.movie_id + AND t.id = mk.movie_id + AND t.id = mi_idx.movie_id + AND t.id = mc.movie_id + AND t.id = cc.movie_id + AND mk.movie_id = mi.movie_id + AND mk.movie_id = mi_idx.movie_id + AND mk.movie_id = mc.movie_id + AND mk.movie_id = cc.movie_id + AND mi.movie_id = mi_idx.movie_id + AND mi.movie_id = mc.movie_id + AND mi.movie_id = cc.movie_id + AND mc.movie_id = mi_idx.movie_id + AND mc.movie_id = cc.movie_id + AND mi_idx.movie_id = cc.movie_id + AND k.id = mk.keyword_id + AND it1.id = mi.info_type_id + AND it2.id = mi_idx.info_type_id + AND ct.id = mc.company_type_id + AND cn.id = mc.company_id + AND cct1.id = cc.subject_id + AND cct2.id = cc.status_id +---- +Stockholm Productions 7.8 The Nordic Murders + +# 28b - Query for Euro dark movies with complete crew +query TTT +SELECT MIN(cn.name) AS movie_company, MIN(mi_idx.info) AS rating, MIN(t.title) AS complete_euro_dark_movie +FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, company_name AS cn, company_type AS ct, info_type AS it1, info_type AS it2, keyword AS k, kind_type AS kt, movie_companies AS mc, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, title AS t +WHERE cct1.kind = 'crew' + AND cct2.kind != 'complete+verified' + AND cn.country_code != '[us]' + AND it1.info = 'countries' + AND it2.info = 'rating' + AND k.keyword in ('murder', 'murder-in-title', 'blood', 'violence') + AND kt.kind in ('movie', 'episode') + AND mc.note not like '%(USA)%' and mc.note like '%(200%)%' + AND mi.info IN ('Sweden', 'Germany', 'Swedish', 'German') + AND mi_idx.info > '6.5' + AND t.production_year > 2005 + AND kt.id = t.kind_id + AND t.id = mi.movie_id + AND t.id = mk.movie_id + AND t.id = mi_idx.movie_id + AND t.id = mc.movie_id + AND t.id = cc.movie_id + AND mk.movie_id = mi.movie_id + AND mk.movie_id = mi_idx.movie_id + AND mk.movie_id = mc.movie_id + AND mk.movie_id = cc.movie_id + AND mi.movie_id = mi_idx.movie_id + AND mi.movie_id = mc.movie_id + AND mi.movie_id = cc.movie_id + AND mc.movie_id = mi_idx.movie_id + AND mc.movie_id = cc.movie_id + AND mi_idx.movie_id = cc.movie_id + AND k.id = mk.keyword_id + AND it1.id = mi.info_type_id + AND it2.id = mi_idx.info_type_id + AND ct.id = mc.company_type_id + AND cn.id = mc.company_id + AND cct1.id = cc.subject_id + AND cct2.id = cc.status_id +---- +Stockholm Productions 7.8 The Nordic Murders + +# 28c - Query for extended movies with specific criteria +query TTT +SELECT MIN(cn.name) AS movie_company, MIN(mi_idx.info) AS rating, MIN(t.title) AS complete_euro_dark_movie +FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, company_name AS cn, company_type AS ct, info_type AS it1, info_type AS it2, keyword AS k, kind_type AS kt, movie_companies AS mc, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, title AS t +WHERE cct1.kind = 'cast' + AND cct2.kind = 'complete' + AND cn.country_code != '[us]' + AND it1.info = 'countries' + AND it2.info = 'rating' + AND k.keyword in ('murder', 'murder-in-title', 'blood', 'violence') + AND kt.kind in ('movie', 'episode') + AND mc.note not like '%(USA)%' and mc.note like '%(200%)%' + AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Danish', 'Norwegian', 'German', 'USA', 'American') + AND mi_idx.info < '8.5' + AND t.production_year > 2005 + AND kt.id = t.kind_id + AND t.id = mi.movie_id + AND t.id = mk.movie_id + AND t.id = mi_idx.movie_id + AND t.id = mc.movie_id + AND t.id = cc.movie_id + AND mk.movie_id = mi.movie_id + AND mk.movie_id = mi_idx.movie_id + AND mk.movie_id = mc.movie_id + AND mk.movie_id = cc.movie_id + AND mi.movie_id = mi_idx.movie_id + AND mi.movie_id = mc.movie_id + AND mi.movie_id = cc.movie_id + AND mc.movie_id = mi_idx.movie_id + AND mc.movie_id = cc.movie_id + AND mi_idx.movie_id = cc.movie_id + AND k.id = mk.keyword_id + AND it1.id = mi.info_type_id + AND it2.id = mi_idx.info_type_id + AND ct.id = mc.company_type_id + AND cn.id = mc.company_id + AND cct1.id = cc.subject_id + AND cct2.id = cc.status_id +---- +Oslo Films 7.5 Scandinavian Crime + +# 29a - Query for movies with specific combinations +query TTT +SELECT MIN(chn.name) AS voiced_char, MIN(n.name) AS voicing_actress, MIN(t.title) AS voiced_animation +FROM aka_name AS an, complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, char_name AS chn, cast_info AS ci, company_name AS cn, info_type AS it, info_type AS it3, keyword AS k, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, name AS n, person_info AS pi, role_type AS rt, title AS t +WHERE cct1.kind ='cast' + AND cct2.kind ='complete+verified' + AND chn.name = 'Queen' + AND ci.note in ('(voice)', '(voice) (uncredited)', '(voice: English version)') + AND cn.country_code ='[us]' + AND it.info = 'release dates' + AND it3.info = 'trivia' + AND k.keyword = 'computer-animation' + AND mi.info is not null and (mi.info like 'Japan:%200%' or mi.info like 'USA:%200%') + AND n.gender ='f' and n.name like '%An%' + AND rt.role ='actress' + AND t.title = 'Shrek 2' + AND t.production_year between 2000 and 2010 + AND t.id = mi.movie_id + AND t.id = mc.movie_id + AND t.id = ci.movie_id + AND t.id = mk.movie_id + AND t.id = cc.movie_id + AND mc.movie_id = ci.movie_id + AND mc.movie_id = mi.movie_id + AND mc.movie_id = mk.movie_id + AND mc.movie_id = cc.movie_id + AND mi.movie_id = ci.movie_id + AND mi.movie_id = mk.movie_id + AND mi.movie_id = cc.movie_id + AND ci.movie_id = mk.movie_id + AND ci.movie_id = cc.movie_id + AND mk.movie_id = cc.movie_id + AND cn.id = mc.company_id + AND it.id = mi.info_type_id + AND n.id = ci.person_id + AND rt.id = ci.role_id + AND n.id = an.person_id + AND ci.person_id = an.person_id + AND chn.id = ci.person_role_id + AND n.id = pi.person_id + AND ci.person_id = pi.person_id + AND it3.id = pi.info_type_id + AND k.id = mk.keyword_id + AND cct1.id = cc.subject_id + AND cct2.id = cc.status_id +---- +Queen Anne Hathaway Shrek 2 + +# 29b - Query for specific Queen character voice actress +query TTT +SELECT MIN(chn.name) AS voiced_char, MIN(n.name) AS voicing_actress, MIN(t.title) AS voiced_animation +FROM aka_name AS an, complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, char_name AS chn, cast_info AS ci, company_name AS cn, info_type AS it, info_type AS it3, keyword AS k, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, name AS n, person_info AS pi, role_type AS rt, title AS t +WHERE cct1.kind ='cast' + AND cct2.kind ='complete+verified' + AND chn.name = 'Queen' + AND ci.note in ('(voice)', '(voice) (uncredited)', '(voice: English version)') + AND cn.country_code ='[us]' + AND it.info = 'release dates' + AND it3.info = 'height' + AND k.keyword = 'computer-animation' + AND mi.info like 'USA:%200%' + AND n.gender ='f' and n.name like '%An%' + AND rt.role ='actress' + AND t.title = 'Shrek 2' + AND t.production_year between 2000 and 2005 + AND t.id = mi.movie_id + AND t.id = mc.movie_id + AND t.id = ci.movie_id + AND t.id = mk.movie_id + AND t.id = cc.movie_id + AND mc.movie_id = ci.movie_id + AND mc.movie_id = mi.movie_id + AND mc.movie_id = mk.movie_id + AND mc.movie_id = cc.movie_id + AND mi.movie_id = ci.movie_id + AND mi.movie_id = mk.movie_id + AND mi.movie_id = cc.movie_id + AND ci.movie_id = mk.movie_id + AND ci.movie_id = cc.movie_id + AND mk.movie_id = cc.movie_id + AND cn.id = mc.company_id + AND it.id = mi.info_type_id + AND n.id = ci.person_id + AND rt.id = ci.role_id + AND n.id = an.person_id + AND ci.person_id = an.person_id + AND chn.id = ci.person_role_id + AND n.id = pi.person_id + AND ci.person_id = pi.person_id + AND it3.id = pi.info_type_id + AND k.id = mk.keyword_id + AND cct1.id = cc.subject_id + AND cct2.id = cc.status_id +---- +Queen Anne Hathaway Shrek 2 + +# 29c - Query for extended specific combinations +query TTT +SELECT MIN(chn.name) AS voiced_char, MIN(n.name) AS voicing_actress, MIN(t.title) AS voiced_animation +FROM aka_name AS an, complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, char_name AS chn, cast_info AS ci, company_name AS cn, info_type AS it, info_type AS it3, keyword AS k, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, name AS n, person_info AS pi, role_type AS rt, title AS t +WHERE cct1.kind ='cast' + AND cct2.kind ='complete+verified' + AND ci.note in ('(voice)', '(voice: Japanese version)', '(voice) (uncredited)', '(voice: English version)') + AND cn.country_code ='[us]' + AND it.info = 'release dates' + AND it3.info = 'trivia' + AND k.keyword = 'computer-animation' + AND mi.info is not null and (mi.info like 'Japan:%200%' or mi.info like 'USA:%200%') + AND n.gender ='f' and n.name like '%An%' + AND rt.role ='actress' + AND t.production_year between 2000 and 2010 + AND t.id = mi.movie_id + AND t.id = mc.movie_id + AND t.id = ci.movie_id + AND t.id = mk.movie_id + AND t.id = cc.movie_id + AND mc.movie_id = ci.movie_id + AND mc.movie_id = mi.movie_id + AND mc.movie_id = mk.movie_id + AND mc.movie_id = cc.movie_id + AND mi.movie_id = ci.movie_id + AND mi.movie_id = mk.movie_id + AND mi.movie_id = cc.movie_id + AND ci.movie_id = mk.movie_id + AND ci.movie_id = cc.movie_id + AND mk.movie_id = cc.movie_id + AND cn.id = mc.company_id + AND it.id = mi.info_type_id + AND n.id = ci.person_id + AND rt.id = ci.role_id + AND n.id = an.person_id + AND ci.person_id = an.person_id + AND chn.id = ci.person_role_id + AND n.id = pi.person_id + AND ci.person_id = pi.person_id + AND it3.id = pi.info_type_id + AND k.id = mk.keyword_id + AND cct1.id = cc.subject_id + AND cct2.id = cc.status_id +---- +Queen Anne Hathaway Shrek 2 + +# 30a - Query for top-rated action movies +query TTTT +SELECT MIN(mi.info) AS movie_budget, MIN(mi_idx.info) AS movie_votes, MIN(n.name) AS writer, MIN(t.title) AS complete_violent_movie +FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, cast_info AS ci, info_type AS it1, info_type AS it2, keyword AS k, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, name AS n, title AS t +WHERE cct1.kind in ('cast', 'crew') + AND cct2.kind ='complete+verified' + AND ci.note in ('(writer)', '(head writer)', '(written by)', '(story)', '(story editor)') + AND it1.info = 'genres' + AND it2.info = 'votes' + AND k.keyword in ('murder', 'violence', 'blood', 'gore', 'death', 'female-nudity', 'hospital') + AND mi.info in ('Horror', 'Thriller') + AND n.gender = 'm' + AND t.production_year > 2000 + AND t.id = mi.movie_id + AND t.id = mi_idx.movie_id + AND t.id = ci.movie_id + AND t.id = mk.movie_id + AND t.id = cc.movie_id + AND ci.movie_id = mi.movie_id + AND ci.movie_id = mi_idx.movie_id + AND ci.movie_id = mk.movie_id + AND ci.movie_id = cc.movie_id + AND mi.movie_id = mi_idx.movie_id + AND mi.movie_id = mk.movie_id + AND mi.movie_id = cc.movie_id + AND mi_idx.movie_id = mk.movie_id + AND mi_idx.movie_id = cc.movie_id + AND mk.movie_id = cc.movie_id + AND n.id = ci.person_id + AND it1.id = mi.info_type_id + AND it2.id = mi_idx.info_type_id + AND k.id = mk.keyword_id + AND cct1.id = cc.subject_id + AND cct2.id = cc.status_id +---- +Horror 52000 James Wan Saw IV + +# 30b - Query for ratings of female-cast-only movies +query TTTT +SELECT MIN(mi.info) AS movie_budget, MIN(mi_idx.info) AS movie_votes, MIN(n.name) AS writer, MIN(t.title) AS complete_gore_movie +FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, cast_info AS ci, info_type AS it1, info_type AS it2, keyword AS k, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, name AS n, title AS t +WHERE cct1.kind in ('cast', 'crew') + AND cct2.kind ='complete+verified' + AND ci.note in ('(writer)', '(head writer)', '(written by)', '(story)', '(story editor)') + AND it1.info = 'genres' + AND it2.info = 'votes' + AND k.keyword in ('murder', 'violence', 'blood', 'gore', 'death', 'female-nudity', 'hospital') + AND mi.info in ('Horror', 'Thriller') + AND n.gender = 'm' + AND t.production_year > 2000 and (t.title like '%Freddy%' or t.title like '%Jason%' or t.title like 'Saw%') + AND t.id = mi.movie_id + AND t.id = mi_idx.movie_id + AND t.id = ci.movie_id + AND t.id = mk.movie_id + AND t.id = cc.movie_id + AND ci.movie_id = mi.movie_id + AND ci.movie_id = mi_idx.movie_id + AND ci.movie_id = mk.movie_id + AND ci.movie_id = cc.movie_id + AND mi.movie_id = mi_idx.movie_id + AND mi.movie_id = mk.movie_id + AND mi.movie_id = cc.movie_id + AND mi_idx.movie_id = mk.movie_id + AND mi_idx.movie_id = cc.movie_id + AND mk.movie_id = cc.movie_id + AND n.id = ci.person_id + AND it1.id = mi.info_type_id + AND it2.id = mi_idx.info_type_id + AND k.id = mk.keyword_id + AND cct1.id = cc.subject_id + AND cct2.id = cc.status_id +---- +Horror 52000 James Wan Saw IV + +# 30c - Query for extended action movies +query TTTT +SELECT MIN(mi.info) AS movie_budget, MIN(mi_idx.info) AS movie_votes, MIN(n.name) AS writer, MIN(t.title) AS complete_violent_movie +FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, cast_info AS ci, info_type AS it1, info_type AS it2, keyword AS k, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, name AS n, title AS t +WHERE cct1.kind = 'cast' + AND cct2.kind ='complete+verified' + AND ci.note in ('(writer)', '(head writer)', '(written by)', '(story)', '(story editor)') + AND it1.info = 'genres' + AND it2.info = 'votes' + AND k.keyword in ('murder', 'violence', 'blood', 'gore', 'death', 'female-nudity', 'hospital') + AND mi.info in ('Horror', 'Action', 'Sci-Fi', 'Thriller', 'Crime', 'War') + AND t.id = mi.movie_id + AND t.id = mi_idx.movie_id + AND t.id = ci.movie_id + AND t.id = mk.movie_id + AND t.id = cc.movie_id + AND ci.movie_id = mi.movie_id + AND ci.movie_id = mi_idx.movie_id + AND ci.movie_id = mk.movie_id + AND ci.movie_id = cc.movie_id + AND mi.movie_id = mi_idx.movie_id + AND mi.movie_id = mk.movie_id + AND mi.movie_id = cc.movie_id + AND mi_idx.movie_id = mk.movie_id + AND mi_idx.movie_id = cc.movie_id + AND mk.movie_id = cc.movie_id + AND n.id = ci.person_id + AND it1.id = mi.info_type_id + AND it2.id = mi_idx.info_type_id + AND k.id = mk.keyword_id + AND cct1.id = cc.subject_id + AND cct2.id = cc.status_id +---- +Horror 52000 James Wan Saw IV + +# 31a - Query for movies with specific language and production values +query TTTT +SELECT MIN(mi.info) AS movie_budget, MIN(mi_idx.info) AS movie_votes, MIN(n.name) AS writer, MIN(t.title) AS violent_liongate_movie +FROM cast_info AS ci, company_name AS cn, info_type AS it1, info_type AS it2, keyword AS k, movie_companies AS mc, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, name AS n, title AS t +WHERE ci.note in ('(writer)', '(head writer)', '(written by)', '(story)', '(story editor)') + AND cn.name like 'Lionsgate%' + AND it1.info = 'genres' + AND it2.info = 'votes' + AND k.keyword in ('murder', 'violence', 'blood', 'gore', 'death', 'female-nudity', 'hospital') + AND mi.info in ('Horror', 'Thriller') + AND n.gender = 'm' + AND t.id = mi.movie_id + AND t.id = mi_idx.movie_id + AND t.id = ci.movie_id + AND t.id = mk.movie_id + AND t.id = mc.movie_id + AND ci.movie_id = mi.movie_id + AND ci.movie_id = mi_idx.movie_id + AND ci.movie_id = mk.movie_id + AND ci.movie_id = mc.movie_id + AND mi.movie_id = mi_idx.movie_id + AND mi.movie_id = mk.movie_id + AND mi.movie_id = mc.movie_id + AND mi_idx.movie_id = mk.movie_id + AND mi_idx.movie_id = mc.movie_id + AND mk.movie_id = mc.movie_id + AND n.id = ci.person_id + AND it1.id = mi.info_type_id + AND it2.id = mi_idx.info_type_id + AND k.id = mk.keyword_id + AND cn.id = mc.company_id +---- +Horror 45000 James Wan Halloween + +# 31b - Query for sci-fi female-focused movies +query TTTT +SELECT MIN(mi.info) AS movie_budget, MIN(mi_idx.info) AS movie_votes, MIN(n.name) AS writer, MIN(t.title) AS violent_liongate_movie +FROM cast_info AS ci, company_name AS cn, info_type AS it1, info_type AS it2, keyword AS k, movie_companies AS mc, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, name AS n, title AS t +WHERE ci.note in ('(writer)', '(head writer)', '(written by)', '(story)', '(story editor)') + AND cn.name like 'Lionsgate%' + AND it1.info = 'genres' + AND it2.info = 'votes' + AND k.keyword in ('murder', 'violence', 'blood', 'gore', 'death', 'female-nudity', 'hospital') + AND mc.note like '%(Blu-ray)%' + AND mi.info in ('Horror', 'Thriller') + AND n.gender = 'm' + AND t.production_year > 2000 and (t.title like '%Freddy%' or t.title like '%Jason%' or t.title like 'Saw%') + AND t.id = mi.movie_id + AND t.id = mi_idx.movie_id + AND t.id = ci.movie_id + AND t.id = mk.movie_id + AND t.id = mc.movie_id + AND ci.movie_id = mi.movie_id + AND ci.movie_id = mi_idx.movie_id + AND ci.movie_id = mk.movie_id + AND ci.movie_id = mc.movie_id + AND mi.movie_id = mi_idx.movie_id + AND mi.movie_id = mk.movie_id + AND mi.movie_id = mc.movie_id + AND mi_idx.movie_id = mk.movie_id + AND mi_idx.movie_id = mc.movie_id + AND mk.movie_id = mc.movie_id + AND n.id = ci.person_id + AND it1.id = mi.info_type_id + AND it2.id = mi_idx.info_type_id + AND k.id = mk.keyword_id + AND cn.id = mc.company_id +---- +Horror 52000 James Wan Saw IV + +# 31c - Query for extended language and production values +query TTTT +SELECT MIN(mi.info) AS movie_budget, MIN(mi_idx.info) AS movie_votes, MIN(n.name) AS writer, MIN(t.title) AS violent_liongate_movie +FROM cast_info AS ci, company_name AS cn, info_type AS it1, info_type AS it2, keyword AS k, movie_companies AS mc, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, name AS n, title AS t +WHERE ci.note in ('(writer)', '(head writer)', '(written by)', '(story)', '(story editor)') + AND cn.name like 'Lionsgate%' + AND it1.info = 'genres' + AND it2.info = 'votes' + AND k.keyword in ('murder', 'violence', 'blood', 'gore', 'death', 'female-nudity', 'hospital') + AND mi.info in ('Horror', 'Action', 'Sci-Fi', 'Thriller', 'Crime', 'War') + AND t.id = mi.movie_id + AND t.id = mi_idx.movie_id + AND t.id = ci.movie_id + AND t.id = mk.movie_id + AND t.id = mc.movie_id + AND ci.movie_id = mi.movie_id + AND ci.movie_id = mi_idx.movie_id + AND ci.movie_id = mk.movie_id + AND ci.movie_id = mc.movie_id + AND mi.movie_id = mi_idx.movie_id + AND mi.movie_id = mk.movie_id + AND mi.movie_id = mc.movie_id + AND mi_idx.movie_id = mk.movie_id + AND mi_idx.movie_id = mc.movie_id + AND mk.movie_id = mc.movie_id + AND n.id = ci.person_id + AND it1.id = mi.info_type_id + AND it2.id = mi_idx.info_type_id + AND k.id = mk.keyword_id + AND cn.id = mc.company_id +---- +Horror 45000 James Wan Halloween + +# 32a - Query for action movies with specific actor roles +query TTT +SELECT MIN(lt.link) AS link_type, MIN(t1.title) AS first_movie, MIN(t2.title) AS second_movie +FROM keyword AS k, link_type AS lt, movie_keyword AS mk, movie_link AS ml, title AS t1, title AS t2 +WHERE k.keyword ='10,000-mile-club' + AND mk.keyword_id = k.id + AND t1.id = mk.movie_id + AND ml.movie_id = t1.id + AND ml.linked_movie_id = t2.id + AND lt.id = ml.link_type_id + AND mk.movie_id = t1.id +---- +edited into Interstellar Saving Private Ryan + +# 32b - Query for character-name-in-title movies and their connections +query TTT +SELECT MIN(lt.link) AS link_type, MIN(t1.title) AS first_movie, MIN(t2.title) AS second_movie +FROM keyword AS k, link_type AS lt, movie_keyword AS mk, movie_link AS ml, title AS t1, title AS t2 +WHERE k.keyword ='character-name-in-title' + AND mk.keyword_id = k.id + AND t1.id = mk.movie_id + AND ml.movie_id = t1.id + AND ml.linked_movie_id = t2.id + AND lt.id = ml.link_type_id + AND mk.movie_id = t1.id +---- +featured in Iron Man Avengers: Endgame + +# 33a - Query for directors of sequels with specific ratings +query TTTTTT +SELECT MIN(cn1.name) AS first_company, MIN(cn2.name) AS second_company, MIN(mi_idx1.info) AS first_rating, MIN(mi_idx2.info) AS second_rating, MIN(t1.title) AS first_movie, MIN(t2.title) AS second_movie +FROM company_name AS cn1, company_name AS cn2, info_type AS it1, info_type AS it2, kind_type AS kt1, kind_type AS kt2, link_type AS lt, movie_companies AS mc1, movie_companies AS mc2, movie_info_idx AS mi_idx1, movie_info_idx AS mi_idx2, movie_link AS ml, title AS t1, title AS t2 +WHERE cn1.country_code = '[us]' + AND it1.info = 'rating' + AND it2.info = 'rating' + AND kt1.kind in ('tv series') + AND kt2.kind in ('tv series') + AND lt.link in ('sequel', 'follows', 'followed by') + AND mi_idx2.info < '3.0' + AND t2.production_year between 2005 and 2008 + AND lt.id = ml.link_type_id + AND t1.id = ml.movie_id + AND t2.id = ml.linked_movie_id + AND it1.id = mi_idx1.info_type_id + AND t1.id = mi_idx1.movie_id + AND kt1.id = t1.kind_id + AND cn1.id = mc1.company_id + AND t1.id = mc1.movie_id + AND ml.movie_id = mi_idx1.movie_id + AND ml.movie_id = mc1.movie_id + AND mi_idx1.movie_id = mc1.movie_id + AND it2.id = mi_idx2.info_type_id + AND t2.id = mi_idx2.movie_id + AND kt2.id = t2.kind_id + AND cn2.id = mc2.company_id + AND t2.id = mc2.movie_id + AND ml.linked_movie_id = mi_idx2.movie_id + AND ml.linked_movie_id = mc2.movie_id + AND mi_idx2.movie_id = mc2.movie_id +---- +Paramount Pictures Paramount Pictures 8.2 2.8 Breaking Bad Breaking Bad: The Final Season + +# 33b - Query for linked TV series by country code +query TTTTTT +SELECT MIN(cn1.name) AS first_company, MIN(cn2.name) AS second_company, MIN(mi_idx1.info) AS first_rating, MIN(mi_idx2.info) AS second_rating, MIN(t1.title) AS first_movie, MIN(t2.title) AS second_movie +FROM company_name AS cn1, company_name AS cn2, info_type AS it1, info_type AS it2, kind_type AS kt1, kind_type AS kt2, link_type AS lt, movie_companies AS mc1, movie_companies AS mc2, movie_info_idx AS mi_idx1, movie_info_idx AS mi_idx2, movie_link AS ml, title AS t1, title AS t2 +WHERE cn1.country_code = '[nl]' + AND it1.info = 'rating' + AND it2.info = 'rating' + AND kt1.kind in ('tv series') + AND kt2.kind in ('tv series') + AND lt.link LIKE '%follow%' + AND mi_idx2.info < '3.0' + AND t2.production_year = 2007 + AND lt.id = ml.link_type_id + AND t1.id = ml.movie_id + AND t2.id = ml.linked_movie_id + AND it1.id = mi_idx1.info_type_id + AND t1.id = mi_idx1.movie_id + AND kt1.id = t1.kind_id + AND cn1.id = mc1.company_id + AND t1.id = mc1.movie_id + AND ml.movie_id = mi_idx1.movie_id + AND ml.movie_id = mc1.movie_id + AND mi_idx1.movie_id = mc1.movie_id + AND it2.id = mi_idx2.info_type_id + AND t2.id = mi_idx2.movie_id + AND kt2.id = t2.kind_id + AND cn2.id = mc2.company_id + AND t2.id = mc2.movie_id + AND ml.linked_movie_id = mi_idx2.movie_id + AND ml.linked_movie_id = mc2.movie_id + AND mi_idx2.movie_id = mc2.movie_id +---- +Dutch Entertainment Group Amsterdam Studios 8.5 2.5 Amsterdam Detective Amsterdam Detective: Cold Case + +# 33c - Query for linked TV series and episodes with specific ratings +query TTTTTT +SELECT MIN(cn1.name) AS first_company, MIN(cn2.name) AS second_company, MIN(mi_idx1.info) AS first_rating, MIN(mi_idx2.info) AS second_rating, MIN(t1.title) AS first_movie, MIN(t2.title) AS second_movie +FROM company_name AS cn1, company_name AS cn2, info_type AS it1, info_type AS it2, kind_type AS kt1, kind_type AS kt2, link_type AS lt, movie_companies AS mc1, movie_companies AS mc2, movie_info_idx AS mi_idx1, movie_info_idx AS mi_idx2, movie_link AS ml, title AS t1, title AS t2 +WHERE cn1.country_code != '[us]' + AND it1.info = 'rating' + AND it2.info = 'rating' + AND kt1.kind in ('tv series', 'episode') + AND kt2.kind in ('tv series', 'episode') + AND lt.link in ('sequel', 'follows', 'followed by') + AND mi_idx2.info < '3.5' + AND t2.production_year between 2000 and 2010 + AND lt.id = ml.link_type_id + AND t1.id = ml.movie_id + AND t2.id = ml.linked_movie_id + AND it1.id = mi_idx1.info_type_id + AND t1.id = mi_idx1.movie_id + AND kt1.id = t1.kind_id + AND cn1.id = mc1.company_id + AND t1.id = mc1.movie_id + AND ml.movie_id = mi_idx1.movie_id + AND ml.movie_id = mc1.movie_id + AND mi_idx1.movie_id = mc1.movie_id + AND it2.id = mi_idx2.info_type_id + AND t2.id = mi_idx2.movie_id + AND kt2.id = t2.kind_id + AND cn2.id = mc2.company_id + AND t2.id = mc2.movie_id + AND ml.linked_movie_id = mi_idx2.movie_id + AND ml.linked_movie_id = mc2.movie_id + AND mi_idx2.movie_id = mc2.movie_id +---- +Dutch Entertainment Group Amsterdam Studios 8.5 2.5 Amsterdam Detective Amsterdam Detective: Cold Case + +# Clean up all tables +statement ok +DROP TABLE company_type; + +statement ok +DROP TABLE info_type; + +statement ok +DROP TABLE title; + +statement ok +DROP TABLE movie_companies; + +statement ok +DROP TABLE movie_info_idx; + +statement ok +DROP TABLE movie_info; + +statement ok +DROP TABLE kind_type; + +statement ok +DROP TABLE cast_info; + +statement ok +DROP TABLE char_name; + +statement ok +DROP TABLE keyword; + +statement ok +DROP TABLE movie_keyword; + +statement ok +DROP TABLE company_name; + +statement ok +DROP TABLE name; + +statement ok +DROP TABLE role_type; + +statement ok +DROP TABLE link_type; + +statement ok +DROP TABLE movie_link; + +statement ok +DROP TABLE complete_cast; + +statement ok +DROP TABLE comp_cast_type; + +statement ok +DROP TABLE person_info; + +statement ok +DROP TABLE aka_title; + +statement ok +DROP TABLE aka_name; diff --git a/datafusion/sqllogictest/test_files/information_schema.slt b/datafusion/sqllogictest/test_files/information_schema.slt index 4964bcbc735c..2ce64ffc6836 100644 --- a/datafusion/sqllogictest/test_files/information_schema.slt +++ b/datafusion/sqllogictest/test_files/information_schema.slt @@ -149,6 +149,39 @@ drop table t statement ok drop table t2 + +############ +## 0 to represent the default value (target_partitions and planning_concurrency) +########### + +statement ok +SET datafusion.execution.target_partitions = 3; + +statement ok +SET datafusion.execution.planning_concurrency = 3; + +# when setting target_partitions and planning_concurrency to 3, their values will be 3 +query TB rowsort +SELECT name, value = 3 FROM information_schema.df_settings WHERE name IN ('datafusion.execution.target_partitions', 'datafusion.execution.planning_concurrency'); +---- +datafusion.execution.planning_concurrency true +datafusion.execution.target_partitions true + +statement ok +SET datafusion.execution.target_partitions = 0; + +statement ok +SET datafusion.execution.planning_concurrency = 0; + +# when setting target_partitions and planning_concurrency to 0, their values will be equal to the +# default values, which are different from 0 (which is invalid) +query TB rowsort +SELECT name, value = 0 FROM information_schema.df_settings WHERE name IN ('datafusion.execution.target_partitions', 'datafusion.execution.planning_concurrency'); +---- +datafusion.execution.planning_concurrency false +datafusion.execution.target_partitions false + + ############ ## SHOW VARIABLES should work ########### @@ -183,7 +216,7 @@ datafusion.catalog.location NULL datafusion.catalog.newlines_in_values false datafusion.execution.batch_size 8192 datafusion.execution.coalesce_batches true -datafusion.execution.collect_statistics false +datafusion.execution.collect_statistics true datafusion.execution.enable_recursive_ctes true datafusion.execution.enforce_batch_size_in_joins false datafusion.execution.keep_partition_by_columns false @@ -191,6 +224,7 @@ datafusion.execution.listing_table_ignore_subdirectory true datafusion.execution.max_buffered_batches_per_output_file 2 datafusion.execution.meta_fetch_concurrency 32 datafusion.execution.minimum_parallel_output_files 4 +datafusion.execution.objectstore_writer_buffer_size 10485760 datafusion.execution.parquet.allow_single_file_parallelism true datafusion.execution.parquet.binary_as_string false datafusion.execution.parquet.bloom_filter_fpp NULL @@ -239,6 +273,15 @@ datafusion.explain.physical_plan_only false datafusion.explain.show_schema false datafusion.explain.show_sizes true datafusion.explain.show_statistics false +datafusion.format.date_format %Y-%m-%d +datafusion.format.datetime_format %Y-%m-%dT%H:%M:%S%.f +datafusion.format.duration_format pretty +datafusion.format.null (empty) +datafusion.format.safe true +datafusion.format.time_format %H:%M:%S%.f +datafusion.format.timestamp_format %Y-%m-%dT%H:%M:%S%.f +datafusion.format.timestamp_tz_format NULL +datafusion.format.types_info false datafusion.optimizer.allow_symmetric_joins_without_pruning true datafusion.optimizer.default_filter_selectivity 20 datafusion.optimizer.enable_distinct_aggregation_soft_limit true @@ -264,7 +307,7 @@ datafusion.sql_parser.collect_spans false datafusion.sql_parser.dialect generic datafusion.sql_parser.enable_ident_normalization true datafusion.sql_parser.enable_options_value_normalization false -datafusion.sql_parser.map_varchar_to_utf8view false +datafusion.sql_parser.map_varchar_to_utf8view true datafusion.sql_parser.parse_float_as_decimal false datafusion.sql_parser.recursion_limit 50 datafusion.sql_parser.support_varchar_with_length true @@ -283,7 +326,7 @@ datafusion.catalog.location NULL Location scanned to load tables for `default` s datafusion.catalog.newlines_in_values false Specifies whether newlines in (quoted) CSV values are supported. This is the default value for `format.newlines_in_values` for `CREATE EXTERNAL TABLE` if not specified explicitly in the statement. Parsing newlines in quoted values may be affected by execution behaviour such as parallel file scanning. Setting this to `true` ensures that newlines in values are parsed successfully, which may reduce performance. datafusion.execution.batch_size 8192 Default batch size while creating new batches, it's especially useful for buffer-in-memory batches since creating tiny batches would result in too much metadata memory consumption datafusion.execution.coalesce_batches true When set to true, record batches will be examined between each operator and small batches will be coalesced into larger batches. This is helpful when there are highly selective filters or joins that could produce tiny output batches. The target batch size is determined by the configuration setting -datafusion.execution.collect_statistics false Should DataFusion collect statistics after listing files +datafusion.execution.collect_statistics true Should DataFusion collect statistics when first creating a table. Has no effect after the table is created. Applies to the default `ListingTableProvider` in DataFusion. Defaults to true. datafusion.execution.enable_recursive_ctes true Should DataFusion support recursive CTEs datafusion.execution.enforce_batch_size_in_joins false Should DataFusion enforce batch size in joins or not. By default, DataFusion will not enforce batch size in joins. Enforcing batch size in joins can reduce memory usage when joining large tables with a highly-selective join filter, but is also slightly slower. datafusion.execution.keep_partition_by_columns false Should DataFusion keep the columns used for partition_by in the output RecordBatches @@ -291,11 +334,12 @@ datafusion.execution.listing_table_ignore_subdirectory true Should sub directori datafusion.execution.max_buffered_batches_per_output_file 2 This is the maximum number of RecordBatches buffered for each output file being worked. Higher values can potentially give faster write performance at the cost of higher peak memory consumption datafusion.execution.meta_fetch_concurrency 32 Number of files to read in parallel when inferring schema and statistics datafusion.execution.minimum_parallel_output_files 4 Guarantees a minimum level of output files running in parallel. RecordBatches will be distributed in round robin fashion to each parallel writer. Each writer is closed and a new file opened once soft_max_rows_per_output_file is reached. +datafusion.execution.objectstore_writer_buffer_size 10485760 Size (bytes) of data buffer DataFusion uses when writing output files. This affects the size of the data chunks that are uploaded to remote object stores (e.g. AWS S3). If very large (>= 100 GiB) output files are being written, it may be necessary to increase this size to avoid errors from the remote end point. datafusion.execution.parquet.allow_single_file_parallelism true (writing) Controls whether DataFusion will attempt to speed up writing parquet files by serializing them in parallel. Each column in each row group in each output file are serialized in parallel leveraging a maximum possible core count of n_files*n_row_groups*n_columns. datafusion.execution.parquet.binary_as_string false (reading) If true, parquet reader will read columns of `Binary/LargeBinary` with `Utf8`, and `BinaryView` with `Utf8View`. Parquet files generated by some legacy writers do not correctly set the UTF8 flag for strings, causing string columns to be loaded as BLOB instead. datafusion.execution.parquet.bloom_filter_fpp NULL (writing) Sets bloom filter false positive probability. If NULL, uses default parquet writer setting datafusion.execution.parquet.bloom_filter_ndv NULL (writing) Sets bloom filter number of distinct values. If NULL, uses default parquet writer setting -datafusion.execution.parquet.bloom_filter_on_read true (writing) Use any available bloom filters when reading parquet files +datafusion.execution.parquet.bloom_filter_on_read true (reading) Use any available bloom filters when reading parquet files datafusion.execution.parquet.bloom_filter_on_write false (writing) Write bloom filters for all columns when creating parquet files datafusion.execution.parquet.coerce_int96 NULL (reading) If true, parquet reader will read columns of physical type int96 as originating from a different resolution than nanosecond. This is useful for reading data from systems like Spark which stores microsecond resolution timestamps in an int96 allowing it to write values with a larger date range than 64-bit timestamps with nanosecond resolution. datafusion.execution.parquet.column_index_truncate_length 64 (writing) Sets column index truncate length @@ -339,6 +383,15 @@ datafusion.explain.physical_plan_only false When set to true, the explain statem datafusion.explain.show_schema false When set to true, the explain statement will print schema information datafusion.explain.show_sizes true When set to true, the explain statement will print the partition sizes datafusion.explain.show_statistics false When set to true, the explain statement will print operator statistics for physical plans +datafusion.format.date_format %Y-%m-%d Date format for date arrays +datafusion.format.datetime_format %Y-%m-%dT%H:%M:%S%.f Format for DateTime arrays +datafusion.format.duration_format pretty Duration format. Can be either `"pretty"` or `"ISO8601"` +datafusion.format.null (empty) Format string for nulls +datafusion.format.safe true If set to `true` any formatting errors will be written to the output instead of being converted into a [`std::fmt::Error`] +datafusion.format.time_format %H:%M:%S%.f Time format for time arrays +datafusion.format.timestamp_format %Y-%m-%dT%H:%M:%S%.f Timestamp format for timestamp arrays +datafusion.format.timestamp_tz_format NULL Timestamp format for timestamp with timezone arrays. When `None`, ISO 8601 format is used. +datafusion.format.types_info false Show types in visual representation batches datafusion.optimizer.allow_symmetric_joins_without_pruning true Should DataFusion allow symmetric hash joins for unbounded data sources even when its inputs do not have any ordering or filtering If the flag is not enabled, the SymmetricHashJoin operator will be unable to prune its internal buffers, resulting in certain join types - such as Full, Left, LeftAnti, LeftSemi, Right, RightAnti, and RightSemi - being produced only at the end of the execution. This is not typical in stream processing. Additionally, without proper design for long runner execution, all types of joins may encounter out-of-memory errors. datafusion.optimizer.default_filter_selectivity 20 The default filter selectivity used by Filter Statistics when an exact selectivity cannot be determined. Valid values are between 0 (no selectivity) and 100 (all rows are selected). datafusion.optimizer.enable_distinct_aggregation_soft_limit true When set to true, the optimizer will push a limit operation into grouped aggregations which have no aggregate expressions, as a soft limit, emitting groups once the limit is reached, before all rows in the group are read. @@ -354,7 +407,7 @@ datafusion.optimizer.prefer_existing_union false When set to true, the optimizer datafusion.optimizer.prefer_hash_join true When set to true, the physical plan optimizer will prefer HashJoin over SortMergeJoin. HashJoin can work more efficiently than SortMergeJoin but consumes more memory datafusion.optimizer.repartition_aggregations true Should DataFusion repartition data using the aggregate keys to execute aggregates in parallel using the provided `target_partitions` level datafusion.optimizer.repartition_file_min_size 10485760 Minimum total files size in bytes to perform file scan repartitioning. -datafusion.optimizer.repartition_file_scans true When set to `true`, file groups will be repartitioned to achieve maximum parallelism. Currently Parquet and CSV formats are supported. If set to `true`, all files will be repartitioned evenly (i.e., a single large file might be partitioned into smaller chunks) for parallel scanning. If set to `false`, different files will be read in parallel, but repartitioning won't happen within a single file. +datafusion.optimizer.repartition_file_scans true When set to `true`, datasource partitions will be repartitioned to achieve maximum parallelism. This applies to both in-memory partitions and FileSource's file groups (1 group is 1 partition). For FileSources, only Parquet and CSV formats are currently supported. If set to `true` for a FileSource, all files will be repartitioned evenly (i.e., a single large file might be partitioned into smaller chunks) for parallel scanning. If set to `false` for a FileSource, different files will be read in parallel, but repartitioning won't happen within a single file. If set to `true` for an in-memory source, all memtable's partitions will have their batches repartitioned evenly to the desired number of `target_partitions`. Repartitioning can change the total number of partitions and batches per partition, but does not slice the initial record tables provided to the MemTable on creation. datafusion.optimizer.repartition_joins true Should DataFusion repartition data using the join keys to execute joins in parallel using the provided `target_partitions` level datafusion.optimizer.repartition_sorts true Should DataFusion execute sorts in a per-partition fashion and merge afterwards instead of coalescing first and sorting globally. With this flag is enabled, plans in the form below ```text "SortExec: [a@0 ASC]", " CoalescePartitionsExec", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", ``` would turn into the plan below which performs better in multithreaded environments ```text "SortPreservingMergeExec: [a@0 ASC]", " SortExec: [a@0 ASC]", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", ``` datafusion.optimizer.repartition_windows true Should DataFusion repartition data using the partitions keys to execute window functions in parallel using the provided `target_partitions` level @@ -364,7 +417,7 @@ datafusion.sql_parser.collect_spans false When set to true, the source locations datafusion.sql_parser.dialect generic Configure the SQL dialect used by DataFusion's parser; supported values include: Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, MsSQL, ClickHouse, BigQuery, Ansi, DuckDB and Databricks. datafusion.sql_parser.enable_ident_normalization true When set to true, SQL parser will normalize ident (convert ident to lowercase when not quoted) datafusion.sql_parser.enable_options_value_normalization false When set to true, SQL parser will normalize options value (convert value to lowercase). Note that this option is ignored and will be removed in the future. All case-insensitive values are normalized automatically. -datafusion.sql_parser.map_varchar_to_utf8view false If true, `VARCHAR` is mapped to `Utf8View` during SQL planning. If false, `VARCHAR` is mapped to `Utf8` during SQL planning. Default is false. +datafusion.sql_parser.map_varchar_to_utf8view true If true, `VARCHAR` is mapped to `Utf8View` during SQL planning. If false, `VARCHAR` is mapped to `Utf8` during SQL planning. Default is false. datafusion.sql_parser.parse_float_as_decimal false When set to true, SQL parser will parse float as decimal type datafusion.sql_parser.recursion_limit 50 Specifies the recursion depth limit when parsing complex SQL Queries datafusion.sql_parser.support_varchar_with_length true If true, permit lengths for `VARCHAR` such as `VARCHAR(20)`, but ignore the length. If false, error if a `VARCHAR` with a length is specified. The Arrow type system does not have a notion of maximum string length and thus DataFusion can not enforce such limits. @@ -636,7 +689,7 @@ datafusion public abc CREATE EXTERNAL TABLE abc STORED AS CSV LOCATION ../../tes query TTT select routine_name, data_type, function_type from information_schema.routines where routine_name = 'string_agg'; ---- -string_agg LargeUtf8 AGGREGATE +string_agg String AGGREGATE # test every function type are included in the result query TTTTTTTBTTTT rowsort @@ -651,7 +704,7 @@ datafusion public date_trunc datafusion public date_trunc FUNCTION true Timestam datafusion public date_trunc datafusion public date_trunc FUNCTION true Timestamp(Second, None) SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) datafusion public date_trunc datafusion public date_trunc FUNCTION true Timestamp(Second, Some("+TZ")) SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) datafusion public rank datafusion public rank FUNCTION true NULL WINDOW Returns the rank of the current row within its partition, allowing gaps between ranks. This function provides a ranking similar to `row_number`, but skips ranks for identical values. rank() -datafusion public string_agg datafusion public string_agg FUNCTION true LargeUtf8 AGGREGATE Concatenates the values of string expressions and places separator values between them. If ordering is required, strings are concatenated in the specified order. This aggregation function can only mix DISTINCT and ORDER BY if the ordering expression is exactly the same as the first argument expression. string_agg([DISTINCT] expression, delimiter [ORDER BY expression]) +datafusion public string_agg datafusion public string_agg FUNCTION true String AGGREGATE Concatenates the values of string expressions and places separator values between them. If ordering is required, strings are concatenated in the specified order. This aggregation function can only mix DISTINCT and ORDER BY if the ordering expression is exactly the same as the first argument expression. string_agg([DISTINCT] expression, delimiter [ORDER BY expression]) query B select is_deterministic from information_schema.routines where routine_name = 'now'; @@ -660,119 +713,65 @@ false # test every function type are included in the result query TTTITTTTBI -select * from information_schema.parameters where specific_name = 'date_trunc' OR specific_name = 'string_agg' OR specific_name = 'rank' ORDER BY specific_name, rid; ----- -datafusion public date_trunc 1 IN precision Utf8 NULL false 0 -datafusion public date_trunc 2 IN expression Timestamp(Nanosecond, None) NULL false 0 -datafusion public date_trunc 1 OUT NULL Timestamp(Nanosecond, None) NULL false 0 -datafusion public date_trunc 1 IN precision Utf8View NULL false 1 -datafusion public date_trunc 2 IN expression Timestamp(Nanosecond, None) NULL false 1 -datafusion public date_trunc 1 OUT NULL Timestamp(Nanosecond, None) NULL false 1 -datafusion public date_trunc 1 IN precision Utf8 NULL false 2 -datafusion public date_trunc 2 IN expression Timestamp(Nanosecond, Some("+TZ")) NULL false 2 -datafusion public date_trunc 1 OUT NULL Timestamp(Nanosecond, Some("+TZ")) NULL false 2 -datafusion public date_trunc 1 IN precision Utf8View NULL false 3 -datafusion public date_trunc 2 IN expression Timestamp(Nanosecond, Some("+TZ")) NULL false 3 -datafusion public date_trunc 1 OUT NULL Timestamp(Nanosecond, Some("+TZ")) NULL false 3 -datafusion public date_trunc 1 IN precision Utf8 NULL false 4 -datafusion public date_trunc 2 IN expression Timestamp(Microsecond, None) NULL false 4 -datafusion public date_trunc 1 OUT NULL Timestamp(Microsecond, None) NULL false 4 -datafusion public date_trunc 1 IN precision Utf8View NULL false 5 -datafusion public date_trunc 2 IN expression Timestamp(Microsecond, None) NULL false 5 -datafusion public date_trunc 1 OUT NULL Timestamp(Microsecond, None) NULL false 5 -datafusion public date_trunc 1 IN precision Utf8 NULL false 6 -datafusion public date_trunc 2 IN expression Timestamp(Microsecond, Some("+TZ")) NULL false 6 -datafusion public date_trunc 1 OUT NULL Timestamp(Microsecond, Some("+TZ")) NULL false 6 -datafusion public date_trunc 1 IN precision Utf8View NULL false 7 -datafusion public date_trunc 2 IN expression Timestamp(Microsecond, Some("+TZ")) NULL false 7 -datafusion public date_trunc 1 OUT NULL Timestamp(Microsecond, Some("+TZ")) NULL false 7 -datafusion public date_trunc 1 IN precision Utf8 NULL false 8 -datafusion public date_trunc 2 IN expression Timestamp(Millisecond, None) NULL false 8 -datafusion public date_trunc 1 OUT NULL Timestamp(Millisecond, None) NULL false 8 -datafusion public date_trunc 1 IN precision Utf8View NULL false 9 -datafusion public date_trunc 2 IN expression Timestamp(Millisecond, None) NULL false 9 -datafusion public date_trunc 1 OUT NULL Timestamp(Millisecond, None) NULL false 9 -datafusion public date_trunc 1 IN precision Utf8 NULL false 10 -datafusion public date_trunc 2 IN expression Timestamp(Millisecond, Some("+TZ")) NULL false 10 -datafusion public date_trunc 1 OUT NULL Timestamp(Millisecond, Some("+TZ")) NULL false 10 -datafusion public date_trunc 1 IN precision Utf8View NULL false 11 -datafusion public date_trunc 2 IN expression Timestamp(Millisecond, Some("+TZ")) NULL false 11 -datafusion public date_trunc 1 OUT NULL Timestamp(Millisecond, Some("+TZ")) NULL false 11 -datafusion public date_trunc 1 IN precision Utf8 NULL false 12 -datafusion public date_trunc 2 IN expression Timestamp(Second, None) NULL false 12 -datafusion public date_trunc 1 OUT NULL Timestamp(Second, None) NULL false 12 -datafusion public date_trunc 1 IN precision Utf8View NULL false 13 -datafusion public date_trunc 2 IN expression Timestamp(Second, None) NULL false 13 -datafusion public date_trunc 1 OUT NULL Timestamp(Second, None) NULL false 13 -datafusion public date_trunc 1 IN precision Utf8 NULL false 14 -datafusion public date_trunc 2 IN expression Timestamp(Second, Some("+TZ")) NULL false 14 -datafusion public date_trunc 1 OUT NULL Timestamp(Second, Some("+TZ")) NULL false 14 -datafusion public date_trunc 1 IN precision Utf8View NULL false 15 -datafusion public date_trunc 2 IN expression Timestamp(Second, Some("+TZ")) NULL false 15 -datafusion public date_trunc 1 OUT NULL Timestamp(Second, Some("+TZ")) NULL false 15 -datafusion public string_agg 1 IN expression LargeUtf8 NULL false 0 -datafusion public string_agg 2 IN delimiter Utf8 NULL false 0 -datafusion public string_agg 1 OUT NULL LargeUtf8 NULL false 0 -datafusion public string_agg 1 IN expression LargeUtf8 NULL false 1 -datafusion public string_agg 2 IN delimiter LargeUtf8 NULL false 1 -datafusion public string_agg 1 OUT NULL LargeUtf8 NULL false 1 -datafusion public string_agg 1 IN expression LargeUtf8 NULL false 2 -datafusion public string_agg 2 IN delimiter Null NULL false 2 -datafusion public string_agg 1 OUT NULL LargeUtf8 NULL false 2 -datafusion public string_agg 1 IN expression Utf8 NULL false 3 -datafusion public string_agg 2 IN delimiter Utf8 NULL false 3 -datafusion public string_agg 1 OUT NULL LargeUtf8 NULL false 3 -datafusion public string_agg 1 IN expression Utf8 NULL false 4 -datafusion public string_agg 2 IN delimiter LargeUtf8 NULL false 4 -datafusion public string_agg 1 OUT NULL LargeUtf8 NULL false 4 -datafusion public string_agg 1 IN expression Utf8 NULL false 5 -datafusion public string_agg 2 IN delimiter Null NULL false 5 -datafusion public string_agg 1 OUT NULL LargeUtf8 NULL false 5 +select * from information_schema.parameters where specific_name = 'date_trunc' OR specific_name = 'string_agg' OR specific_name = 'rank' ORDER BY specific_name, rid, data_type; +---- +datafusion public date_trunc 1 IN precision String NULL false 0 +datafusion public date_trunc 2 IN expression Timestamp(Microsecond, None) NULL false 0 +datafusion public date_trunc 1 OUT NULL Timestamp(Microsecond, None) NULL false 0 +datafusion public date_trunc 1 IN precision String NULL false 1 +datafusion public date_trunc 2 IN expression Timestamp(Microsecond, Some("+TZ")) NULL false 1 +datafusion public date_trunc 1 OUT NULL Timestamp(Microsecond, Some("+TZ")) NULL false 1 +datafusion public date_trunc 1 IN precision String NULL false 2 +datafusion public date_trunc 2 IN expression Timestamp(Millisecond, None) NULL false 2 +datafusion public date_trunc 1 OUT NULL Timestamp(Millisecond, None) NULL false 2 +datafusion public date_trunc 1 IN precision String NULL false 3 +datafusion public date_trunc 2 IN expression Timestamp(Millisecond, Some("+TZ")) NULL false 3 +datafusion public date_trunc 1 OUT NULL Timestamp(Millisecond, Some("+TZ")) NULL false 3 +datafusion public date_trunc 1 IN precision String NULL false 4 +datafusion public date_trunc 2 IN expression Timestamp(Nanosecond, None) NULL false 4 +datafusion public date_trunc 1 OUT NULL Timestamp(Nanosecond, None) NULL false 4 +datafusion public date_trunc 1 IN precision String NULL false 5 +datafusion public date_trunc 2 IN expression Timestamp(Nanosecond, Some("+TZ")) NULL false 5 +datafusion public date_trunc 1 OUT NULL Timestamp(Nanosecond, Some("+TZ")) NULL false 5 +datafusion public date_trunc 1 IN precision String NULL false 6 +datafusion public date_trunc 2 IN expression Timestamp(Second, None) NULL false 6 +datafusion public date_trunc 1 OUT NULL Timestamp(Second, None) NULL false 6 +datafusion public date_trunc 1 IN precision String NULL false 7 +datafusion public date_trunc 2 IN expression Timestamp(Second, Some("+TZ")) NULL false 7 +datafusion public date_trunc 1 OUT NULL Timestamp(Second, Some("+TZ")) NULL false 7 +datafusion public string_agg 2 IN delimiter Null NULL false 0 +datafusion public string_agg 1 IN expression String NULL false 0 +datafusion public string_agg 1 OUT NULL String NULL false 0 +datafusion public string_agg 1 IN expression String NULL false 1 +datafusion public string_agg 2 IN delimiter String NULL false 1 +datafusion public string_agg 1 OUT NULL String NULL false 1 # test variable length arguments query TTTBI rowsort select specific_name, data_type, parameter_mode, is_variadic, rid from information_schema.parameters where specific_name = 'concat'; ---- -concat LargeUtf8 IN true 2 -concat LargeUtf8 OUT false 2 -concat Utf8 IN true 1 -concat Utf8 OUT false 1 -concat Utf8View IN true 0 -concat Utf8View OUT false 0 +concat String IN true 0 +concat String OUT false 0 # test ceorcion signature query TTITI rowsort select specific_name, data_type, ordinal_position, parameter_mode, rid from information_schema.parameters where specific_name = 'repeat'; ---- repeat Int64 2 IN 0 -repeat Int64 2 IN 1 -repeat Int64 2 IN 2 -repeat LargeUtf8 1 IN 1 -repeat LargeUtf8 1 OUT 1 -repeat Utf8 1 IN 0 -repeat Utf8 1 OUT 0 -repeat Utf8 1 OUT 2 -repeat Utf8View 1 IN 2 +repeat String 1 IN 0 +repeat String 1 OUT 0 query TT??TTT rowsort show functions like 'date_trunc'; ---- -date_trunc Timestamp(Microsecond, None) [precision, expression] [Utf8, Timestamp(Microsecond, None)] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) -date_trunc Timestamp(Microsecond, None) [precision, expression] [Utf8View, Timestamp(Microsecond, None)] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) -date_trunc Timestamp(Microsecond, Some("+TZ")) [precision, expression] [Utf8, Timestamp(Microsecond, Some("+TZ"))] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) -date_trunc Timestamp(Microsecond, Some("+TZ")) [precision, expression] [Utf8View, Timestamp(Microsecond, Some("+TZ"))] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) -date_trunc Timestamp(Millisecond, None) [precision, expression] [Utf8, Timestamp(Millisecond, None)] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) -date_trunc Timestamp(Millisecond, None) [precision, expression] [Utf8View, Timestamp(Millisecond, None)] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) -date_trunc Timestamp(Millisecond, Some("+TZ")) [precision, expression] [Utf8, Timestamp(Millisecond, Some("+TZ"))] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) -date_trunc Timestamp(Millisecond, Some("+TZ")) [precision, expression] [Utf8View, Timestamp(Millisecond, Some("+TZ"))] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) -date_trunc Timestamp(Nanosecond, None) [precision, expression] [Utf8, Timestamp(Nanosecond, None)] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) -date_trunc Timestamp(Nanosecond, None) [precision, expression] [Utf8View, Timestamp(Nanosecond, None)] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) -date_trunc Timestamp(Nanosecond, Some("+TZ")) [precision, expression] [Utf8, Timestamp(Nanosecond, Some("+TZ"))] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) -date_trunc Timestamp(Nanosecond, Some("+TZ")) [precision, expression] [Utf8View, Timestamp(Nanosecond, Some("+TZ"))] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) -date_trunc Timestamp(Second, None) [precision, expression] [Utf8, Timestamp(Second, None)] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) -date_trunc Timestamp(Second, None) [precision, expression] [Utf8View, Timestamp(Second, None)] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) -date_trunc Timestamp(Second, Some("+TZ")) [precision, expression] [Utf8, Timestamp(Second, Some("+TZ"))] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) -date_trunc Timestamp(Second, Some("+TZ")) [precision, expression] [Utf8View, Timestamp(Second, Some("+TZ"))] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) +date_trunc Timestamp(Microsecond, None) [precision, expression] [String, Timestamp(Microsecond, None)] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) +date_trunc Timestamp(Microsecond, Some("+TZ")) [precision, expression] [String, Timestamp(Microsecond, Some("+TZ"))] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) +date_trunc Timestamp(Millisecond, None) [precision, expression] [String, Timestamp(Millisecond, None)] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) +date_trunc Timestamp(Millisecond, Some("+TZ")) [precision, expression] [String, Timestamp(Millisecond, Some("+TZ"))] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) +date_trunc Timestamp(Nanosecond, None) [precision, expression] [String, Timestamp(Nanosecond, None)] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) +date_trunc Timestamp(Nanosecond, Some("+TZ")) [precision, expression] [String, Timestamp(Nanosecond, Some("+TZ"))] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) +date_trunc Timestamp(Second, None) [precision, expression] [String, Timestamp(Second, None)] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) +date_trunc Timestamp(Second, Some("+TZ")) [precision, expression] [String, Timestamp(Second, Some("+TZ"))] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) statement ok show functions diff --git a/datafusion/sqllogictest/test_files/join.slt.part b/datafusion/sqllogictest/test_files/join.slt.part index 972dd2265343..19763ab0083f 100644 --- a/datafusion/sqllogictest/test_files/join.slt.part +++ b/datafusion/sqllogictest/test_files/join.slt.part @@ -842,7 +842,7 @@ LEFT JOIN department AS d ON (e.name = 'Alice' OR e.name = 'Bob'); ---- logical_plan -01)Left Join: Filter: e.name = Utf8("Alice") OR e.name = Utf8("Bob") +01)Left Join: Filter: e.name = Utf8View("Alice") OR e.name = Utf8View("Bob") 02)--SubqueryAlias: e 03)----TableScan: employees projection=[emp_id, name] 04)--SubqueryAlias: d @@ -929,7 +929,7 @@ ON (e.name = 'Alice' OR e.name = 'Bob'); logical_plan 01)Cross Join: 02)--SubqueryAlias: e -03)----Filter: employees.name = Utf8("Alice") OR employees.name = Utf8("Bob") +03)----Filter: employees.name = Utf8View("Alice") OR employees.name = Utf8View("Bob") 04)------TableScan: employees projection=[emp_id, name] 05)--SubqueryAlias: d 06)----TableScan: department projection=[dept_name] @@ -974,11 +974,11 @@ ON e.emp_id = d.emp_id WHERE ((dept_name != 'Engineering' AND e.name = 'Alice') OR (name != 'Alice' AND e.name = 'Carol')); ---- logical_plan -01)Filter: d.dept_name != Utf8("Engineering") AND e.name = Utf8("Alice") OR e.name != Utf8("Alice") AND e.name = Utf8("Carol") +01)Filter: d.dept_name != Utf8View("Engineering") AND e.name = Utf8View("Alice") OR e.name != Utf8View("Alice") AND e.name = Utf8View("Carol") 02)--Projection: e.emp_id, e.name, d.dept_name 03)----Left Join: e.emp_id = d.emp_id 04)------SubqueryAlias: e -05)--------Filter: employees.name = Utf8("Alice") OR employees.name != Utf8("Alice") AND employees.name = Utf8("Carol") +05)--------Filter: employees.name = Utf8View("Alice") OR employees.name != Utf8View("Alice") AND employees.name = Utf8View("Carol") 06)----------TableScan: employees projection=[emp_id, name] 07)------SubqueryAlias: d 08)--------TableScan: department projection=[emp_id, dept_name] @@ -1404,3 +1404,102 @@ set datafusion.execution.target_partitions = 4; statement ok set datafusion.optimizer.repartition_joins = false; + +statement ok +CREATE TABLE t1(v0 BIGINT, v1 BIGINT); + +statement ok +CREATE TABLE t0(v0 BIGINT, v1 BIGINT); + +statement ok +INSERT INTO t0(v0, v1) VALUES (1, 1), (1, 2), (3, 3), (4, 4); + +statement ok +INSERT INTO t1(v0, v1) VALUES (1, 1), (3, 2), (3, 5); + +query TT +explain SELECT * +FROM t0, +LATERAL (SELECT sum(v1) FROM t1 WHERE t0.v0 = t1.v0); +---- +logical_plan +01)Projection: t0.v0, t0.v1, sum(t1.v1) +02)--Left Join: t0.v0 = t1.v0 +03)----TableScan: t0 projection=[v0, v1] +04)----Projection: sum(t1.v1), t1.v0 +05)------Aggregate: groupBy=[[t1.v0]], aggr=[[sum(t1.v1)]] +06)--------TableScan: t1 projection=[v0, v1] +physical_plan +01)ProjectionExec: expr=[v0@1 as v0, v1@2 as v1, sum(t1.v1)@0 as sum(t1.v1)] +02)--CoalesceBatchesExec: target_batch_size=8192 +03)----HashJoinExec: mode=CollectLeft, join_type=Right, on=[(v0@1, v0@0)], projection=[sum(t1.v1)@0, v0@2, v1@3] +04)------CoalescePartitionsExec +05)--------ProjectionExec: expr=[sum(t1.v1)@1 as sum(t1.v1), v0@0 as v0] +06)----------AggregateExec: mode=FinalPartitioned, gby=[v0@0 as v0], aggr=[sum(t1.v1)] +07)------------CoalesceBatchesExec: target_batch_size=8192 +08)--------------RepartitionExec: partitioning=Hash([v0@0], 4), input_partitions=4 +09)----------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +10)------------------AggregateExec: mode=Partial, gby=[v0@0 as v0], aggr=[sum(t1.v1)] +11)--------------------DataSourceExec: partitions=1, partition_sizes=[1] +12)------DataSourceExec: partitions=1, partition_sizes=[1] + +query III +SELECT * +FROM t0, +LATERAL (SELECT sum(v1) FROM t1 WHERE t0.v0 = t1.v0); +---- +1 1 1 +1 2 1 +3 3 7 +4 4 NULL + +query TT +explain SELECT * FROM t0, LATERAL (SELECT * FROM t1 WHERE t0.v0 = t1.v0); +---- +logical_plan +01)Inner Join: t0.v0 = t1.v0 +02)--TableScan: t0 projection=[v0, v1] +03)--TableScan: t1 projection=[v0, v1] +physical_plan +01)CoalesceBatchesExec: target_batch_size=8192 +02)--HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(v0@0, v0@0)] +03)----DataSourceExec: partitions=1, partition_sizes=[1] +04)----DataSourceExec: partitions=1, partition_sizes=[1] + +query IIII +SELECT * FROM t0, LATERAL (SELECT * FROM t1 WHERE t0.v0 = t1.v0); +---- +1 1 1 1 +1 2 1 1 +3 3 3 2 +3 3 3 5 + +query III +SELECT * FROM t0, LATERAL (SELECT 1); +---- +1 1 1 +1 2 1 +3 3 1 +4 4 1 + +query IIII +SELECT * FROM t0, LATERAL (SELECT * FROM t1 WHERE t1.v0 = 1); +---- +1 1 1 1 +1 2 1 1 +3 3 1 1 +4 4 1 1 + +query IIII +SELECT * FROM t0 JOIN LATERAL (SELECT * FROM t1 WHERE t1.v0 = 1) on true; +---- +1 1 1 1 +1 2 1 1 +3 3 1 1 +4 4 1 1 + +statement ok +drop table t1; + +statement ok +drop table t0; diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index ddf701ba04ef..ccecb9494331 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -1067,9 +1067,9 @@ LEFT JOIN join_t2 on join_t1.t1_id = join_t2.t2_id WHERE join_t2.t2_int < 10 or (join_t1.t1_int > 2 and join_t2.t2_name != 'w') ---- logical_plan -01)Inner Join: join_t1.t1_id = join_t2.t2_id Filter: join_t2.t2_int < UInt32(10) OR join_t1.t1_int > UInt32(2) AND join_t2.t2_name != Utf8("w") +01)Inner Join: join_t1.t1_id = join_t2.t2_id Filter: join_t2.t2_int < UInt32(10) OR join_t1.t1_int > UInt32(2) AND join_t2.t2_name != Utf8View("w") 02)--TableScan: join_t1 projection=[t1_id, t1_name, t1_int] -03)--Filter: join_t2.t2_int < UInt32(10) OR join_t2.t2_name != Utf8("w") +03)--Filter: join_t2.t2_int < UInt32(10) OR join_t2.t2_name != Utf8View("w") 04)----TableScan: join_t2 projection=[t2_id, t2_name, t2_int] # Reduce left join 3 (to inner join) @@ -1153,7 +1153,7 @@ WHERE join_t1.t1_name != 'b' ---- logical_plan 01)Left Join: join_t1.t1_id = join_t2.t2_id -02)--Filter: join_t1.t1_name != Utf8("b") +02)--Filter: join_t1.t1_name != Utf8View("b") 03)----TableScan: join_t1 projection=[t1_id, t1_name, t1_int] 04)--TableScan: join_t2 projection=[t2_id, t2_name, t2_int] @@ -1168,9 +1168,9 @@ WHERE join_t1.t1_name != 'b' and join_t2.t2_name = 'x' ---- logical_plan 01)Inner Join: join_t1.t1_id = join_t2.t2_id -02)--Filter: join_t1.t1_name != Utf8("b") +02)--Filter: join_t1.t1_name != Utf8View("b") 03)----TableScan: join_t1 projection=[t1_id, t1_name, t1_int] -04)--Filter: join_t2.t2_name = Utf8("x") +04)--Filter: join_t2.t2_name = Utf8View("x") 05)----TableScan: join_t2 projection=[t2_id, t2_name, t2_int] ### @@ -4087,7 +4087,7 @@ logical_plan 07)------------TableScan: sales_global projection=[ts, sn, amount, currency] 08)----------SubqueryAlias: e 09)------------Projection: exchange_rates.ts, exchange_rates.currency_from, exchange_rates.rate -10)--------------Filter: exchange_rates.currency_to = Utf8("USD") +10)--------------Filter: exchange_rates.currency_to = Utf8View("USD") 11)----------------TableScan: exchange_rates projection=[ts, currency_from, currency_to, rate] physical_plan 01)SortExec: expr=[sn@1 ASC NULLS LAST], preserve_partitioning=[false] @@ -4385,7 +4385,7 @@ JOIN my_catalog.my_schema.table_with_many_types AS r ON l.binary_col = r.binary_ logical_plan 01)Projection: count(Int64(1)) AS count(*) 02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] -03)----Projection: +03)----Projection: 04)------Inner Join: l.binary_col = r.binary_col 05)--------SubqueryAlias: l 06)----------TableScan: my_catalog.my_schema.table_with_many_types projection=[binary_col] @@ -4644,7 +4644,7 @@ logical_plan 08)----Subquery: 09)------Filter: j3.j3_string = outer_ref(j2.j2_string) 10)--------TableScan: j3 projection=[j3_string, j3_id] -physical_plan_error This feature is not implemented: Physical plan does not support logical expression OuterReferenceColumn(Utf8, Column { relation: Some(Bare { table: "j2" }), name: "j2_string" }) +physical_plan_error This feature is not implemented: Physical plan does not support logical expression OuterReferenceColumn(Utf8View, Column { relation: Some(Bare { table: "j2" }), name: "j2_string" }) query TT explain SELECT * FROM j1, LATERAL (SELECT * FROM j1, LATERAL (SELECT * FROM j2 WHERE j1_id = j2_id) as j2) as j2; diff --git a/datafusion/sqllogictest/test_files/limit.slt b/datafusion/sqllogictest/test_files/limit.slt index 93ffa313b8f7..1af14a52e2bc 100644 --- a/datafusion/sqllogictest/test_files/limit.slt +++ b/datafusion/sqllogictest/test_files/limit.slt @@ -365,7 +365,7 @@ EXPLAIN SELECT COUNT(*) FROM (SELECT a FROM t1 WHERE a > 3 LIMIT 3 OFFSET 6); logical_plan 01)Projection: count(Int64(1)) AS count(*) 02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] -03)----Projection: +03)----Projection: 04)------Limit: skip=6, fetch=3 05)--------Filter: t1.a > Int32(3) 06)----------TableScan: t1 projection=[a] @@ -854,7 +854,7 @@ physical_plan 02)--SortExec: TopK(fetch=1000), expr=[part_key@1 ASC NULLS LAST], preserve_partitioning=[false] 03)----ProjectionExec: expr=[1 as foo, part_key@0 as part_key] 04)------CoalescePartitionsExec: fetch=1 -05)--------DataSourceExec: file_groups={3 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/test_limit_with_partitions/part-0.parquet:0..794], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/test_limit_with_partitions/part-1.parquet:0..794], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/test_limit_with_partitions/part-2.parquet:0..794]]}, projection=[part_key], limit=1, file_type=parquet +05)--------DataSourceExec: file_groups={3 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/test_limit_with_partitions/part-2.parquet:0..265], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/test_limit_with_partitions/part-2.parquet:265..530], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/test_limit_with_partitions/part-2.parquet:530..794]]}, projection=[part_key], limit=1, file_type=parquet query I with selection as ( diff --git a/datafusion/sqllogictest/test_files/listing_table_statistics.slt b/datafusion/sqllogictest/test_files/listing_table_statistics.slt new file mode 100644 index 000000000000..890d1f2e9250 --- /dev/null +++ b/datafusion/sqllogictest/test_files/listing_table_statistics.slt @@ -0,0 +1,47 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Test file with different schema order but generating correct statistics for table +statement ok +COPY (SELECT * FROM values (1, 'a'), (2, 'b') t(int_col, str_col)) to 'test_files/scratch/table/1.parquet'; + +statement ok +COPY (SELECT * FROM values ('c', 3), ('d', -1) t(str_col, int_col)) to 'test_files/scratch/table/2.parquet'; + +statement ok +set datafusion.execution.collect_statistics = true; + +statement ok +set datafusion.explain.show_statistics = true; + +statement ok +create external table t stored as parquet location 'test_files/scratch/table'; + +query TT +explain format indent select * from t; +---- +logical_plan TableScan: t projection=[int_col, str_col] +physical_plan DataSourceExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/table/1.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/table/2.parquet]]}, projection=[int_col, str_col], file_type=parquet, statistics=[Rows=Exact(4), Bytes=Exact(288), [(Col[0]: Min=Exact(Int64(-1)) Max=Exact(Int64(3)) Null=Exact(0)),(Col[1]: Min=Exact(Utf8View("a")) Max=Exact(Utf8View("d")) Null=Exact(0))]] + +statement ok +drop table t; + +statement ok +set datafusion.execution.collect_statistics = false; + +statement ok +set datafusion.explain.show_statistics = false; diff --git a/datafusion/sqllogictest/test_files/min_max/fixed_size_list.slt b/datafusion/sqllogictest/test_files/min_max/fixed_size_list.slt new file mode 100644 index 000000000000..aa623b63cdc7 --- /dev/null +++ b/datafusion/sqllogictest/test_files/min_max/fixed_size_list.slt @@ -0,0 +1,133 @@ +# Min/Max with FixedSizeList over integers +query ?? +SELECT MIN(column1), MAX(column1) FROM VALUES +(arrow_cast(make_array(1, 2, 3, 4), 'FixedSizeList(4, Int64)')), +(arrow_cast(make_array(1, 2), 'FixedSizeList(2, Int64)')); +---- +[1, 2] [1, 2, 3, 4] + +# Min/Max with FixedSizeList over strings +query ?? +SELECT MIN(column1), MAX(column1) FROM VALUES +(arrow_cast(make_array('a', 'b', 'c'), 'FixedSizeList(3, Utf8)')), +(arrow_cast(make_array('a', 'b'), 'LargeList(Utf8)')); +---- +[a, b] [a, b, c] + +# Min/Max with FixedSizeList over booleans +query ?? +SELECT MIN(column1), MAX(column1) FROM VALUES +(arrow_cast(make_array(true, false, true), 'FixedSizeList(3, Boolean)')), +(arrow_cast(make_array(true, false), 'FixedSizeList(2, Boolean)')); +---- +[true, false] [true, false, true] + +# Min/Max with FixedSizeList over nullable integers +query ?? +SELECT MIN(column1), MAX(column1) FROM VALUES +(arrow_cast(make_array(NULL, 1, 2), 'FixedSizeList(3, Int64)')), +(arrow_cast(make_array(1, 2), 'FixedSizeList(2, Int64)')); +---- +[1, 2] [NULL, 1, 2] + +# Min/Max FixedSizeList with different lengths and nulls +query ?? +SELECT MIN(column1), MAX(column1) FROM VALUES +(arrow_cast(make_array(1, 2, 3, 4), 'FixedSizeList(4, Int64)')), +(arrow_cast(make_array(1, 2), 'FixedSizeList(2, Int64)')), +(arrow_cast(make_array(1, NULL, 3), 'FixedSizeList(3, Int64)')); +---- +[1, 2] [1, NULL, 3] + +# Min/Max FixedSizeList with only NULLs +query ?? +SELECT MIN(column1), MAX(column1) FROM VALUES +(arrow_cast(make_array(NULL, NULL), 'FixedSizeList(2, Int64)')), +(arrow_cast(make_array(NULL), 'FixedSizeList(1, Int64)')); +---- +[NULL] [NULL, NULL] + + +# Min/Max FixedSizeList of varying types (integers and NULLs) +query ?? +SELECT MIN(column1), MAX(column1) FROM VALUES +(arrow_cast(make_array(1, 2, 3), 'FixedSizeList(3, Int64)')), +(arrow_cast(make_array(NULL, 2, 3), 'FixedSizeList(3, Int64)')), +(arrow_cast(make_array(1, 2, NULL), 'FixedSizeList(3, Int64)')); +---- +[1, 2, 3] [NULL, 2, 3] + +# Min/Max FixedSizeList grouped by key with NULLs and differing lengths +query I?? rowsort +SELECT column1, MIN(column2), MAX(column2) FROM VALUES +(0, arrow_cast(make_array(1, NULL, 3), 'FixedSizeList(3, Int64)')), +(0, arrow_cast(make_array(1, 2, 3, 4), 'FixedSizeList(4, Int64)')), +(1, arrow_cast(make_array(1, 2), 'FixedSizeList(2, Int64)')), +(1, arrow_cast(make_array(NULL, 5), 'FixedSizeList(2, Int64)')) +GROUP BY column1; +---- +0 [1, 2, 3, 4] [1, NULL, 3] +1 [1, 2] [NULL, 5] + +# Min/Max FixedSizeList grouped by key with NULLs and differing lengths +query I?? rowsort +SELECT column1, MIN(column2), MAX(column2) FROM VALUES +(0, arrow_cast(make_array(NULL), 'FixedSizeList(1, Int64)')), +(0, arrow_cast(make_array(NULL, NULL), 'FixedSizeList(2, Int64)')), +(1, arrow_cast(make_array(NULL), 'FixedSizeList(1, Int64)')) +GROUP BY column1; +---- +0 [NULL] [NULL, NULL] +1 [NULL] [NULL] + +# Min/Max grouped FixedSizeList with empty and non-empty +query I?? rowsort +SELECT column1, MIN(column2), MAX(column2) FROM VALUES +(0, arrow_cast(make_array(1), 'FixedSizeList(1, Int64)')), +(1, arrow_cast(make_array(5, 6), 'FixedSizeList(2, Int64)')) +GROUP BY column1; +---- +0 [1] [1] +1 [5, 6] [5, 6] + +# Min/Max over FixedSizeList with a window function +query ? +SELECT min(column1) OVER (ORDER BY column1) FROM VALUES +(arrow_cast(make_array(1, 2, 3), 'FixedSizeList(3, Int64)')), +(arrow_cast(make_array(1, 2, 3), 'FixedSizeList(3, Int64)')), +(arrow_cast(make_array(2, 3), 'FixedSizeList(2, Int64)')) +---- +[1, 2, 3] +[1, 2, 3] +[1, 2, 3] + +# Min/Max over FixedSizeList with a window function and nulls +query ? +SELECT min(column1) OVER (ORDER BY column1) FROM VALUES +(arrow_cast(make_array(NULL), 'FixedSizeList(1, Int64)')), +(arrow_cast(make_array(4, 5), 'FixedSizeList(2, Int64)')), +(arrow_cast(make_array(2, 3), 'FixedSizeList(2, Int64)')) +---- +[2, 3] +[2, 3] +[2, 3] + +# Min/Max over FixedSizeList with a window function, nulls and ROWS BETWEEN statement +query ? +SELECT min(column1) OVER (ORDER BY column1 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) FROM VALUES +(arrow_cast(make_array(NULL), 'FixedSizeList(1, Int64)')), +(arrow_cast(make_array(4, 5), 'FixedSizeList(2, Int64)')), +(arrow_cast(make_array(2, 3), 'FixedSizeList(2, Int64)')) +---- +[2, 3] +[2, 3] +[4, 5] + +# Min/Max over FixedSizeList with a window function using a different column +query ? +SELECT max(column2) OVER (ORDER BY column1 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) FROM VALUES +(arrow_cast(make_array(1, 2, 3), 'FixedSizeList(3, Int64)'), arrow_cast(make_array(4, 5), 'FixedSizeList(2, Int64)')), +(arrow_cast(make_array(2, 3), 'FixedSizeList(2, Int64)'), arrow_cast(make_array(2, 3), 'FixedSizeList(2, Int64)')) +---- +[4, 5] +[4, 5] diff --git a/datafusion/sqllogictest/test_files/min_max/init_data.slt.part b/datafusion/sqllogictest/test_files/min_max/init_data.slt.part new file mode 100644 index 000000000000..57e14f6993d4 --- /dev/null +++ b/datafusion/sqllogictest/test_files/min_max/init_data.slt.part @@ -0,0 +1,155 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# -------------------------------------- +# 1. Min/Max over integers +# -------------------------------------- +statement ok +create table min_max_base_int as values + (make_array(1, 2, 3, 4)), + (make_array(1, 2)) +; + +# -------------------------------------- +# 2. Min/Max over strings +# -------------------------------------- +statement ok +create table min_max_base_string as values + (make_array('a', 'b', 'c')), + (make_array('a', 'b')) +; + +# -------------------------------------- +# 3. Min/Max over booleans +# -------------------------------------- +statement ok +create table min_max_base_bool as values + (make_array(true, false, true)), + (make_array(true, false)) +; + +# -------------------------------------- +# 4. Min/Max over nullable integers +# -------------------------------------- +statement ok +create table min_max_base_nullable_int as values + (make_array(NULL, 1, 2)), + (make_array(1, 2)) +; + +# -------------------------------------- +# 5. Min/Max with mixed lengths and nulls +# -------------------------------------- +statement ok +create table min_max_base_mixed_lengths_nulls as values + (make_array(1, 2, 3, 4)), + (make_array(1, 2)), + (make_array(1, NULL, 3)) +; + +# -------------------------------------- +# 6. Min/Max with only NULLs +# -------------------------------------- +statement ok +create table min_max_base_all_nulls as values + (make_array(NULL, NULL)), + (make_array(NULL)) +; + +# -------------------------------------- +# 7. Min/Max with partial NULLs +# -------------------------------------- +statement ok +create table min_max_base_null_variants as values + (make_array(1, 2, 3)), + (make_array(NULL, 2, 3)), + (make_array(1, 2, NULL)) +; + +# -------------------------------------- +# 8. Min/Max grouped by key with NULLs and differing lengths +# -------------------------------------- +statement ok +create table min_max_base_grouped_nulls as values + (0, make_array(1, NULL, 3)), + (0, make_array(1, 2, 3, 4)), + (1, make_array(1, 2)), + (1, make_array(NULL, 5)), + (1, make_array()) +; + +# -------------------------------------- +# 9. Min/Max grouped by key with only NULLs +# -------------------------------------- +statement ok +create table min_max_base_grouped_all_null as values + (0, make_array(NULL)), + (0, make_array(NULL, NULL)), + (1, make_array(NULL)) +; + +# -------------------------------------- +# 10. Min/Max grouped with empty and non-empty lists +# -------------------------------------- +statement ok +create table min_max_base_grouped_simple as values + (0, make_array()), + (0, make_array(1)), + (0, make_array()), + (1, make_array()), + (1, make_array(5, 6)) +; + +# -------------------------------------- +# 11. Min over with window function +# -------------------------------------- +statement ok +create table min_base_window_simple as values + (make_array(1, 2, 3)), + (make_array(1, 2, 3)), + (make_array(2, 3)) +; + +# -------------------------------------- +# 12. Min over with window + NULLs +# -------------------------------------- +statement ok +create table min_base_window_with_null as values + (make_array(NULL)), + (make_array(4, 5)), + (make_array(2, 3)) +; + +# -------------------------------------- +# 13. Min over with ROWS BETWEEN clause +# -------------------------------------- +statement ok +create table min_base_window_rows_between as values + (make_array(NULL)), + (make_array(4, 5)), + (make_array(2, 3)) +; + +# -------------------------------------- +# 14. Max over using different order column +# -------------------------------------- +statement ok +create table max_base_window_different_column as values + (make_array(1, 2, 3), make_array(4, 5)), + (make_array(2, 3), make_array(2, 3)), + (make_array(2, 3), NULL) +; diff --git a/datafusion/sqllogictest/test_files/min_max/large_list.slt b/datafusion/sqllogictest/test_files/min_max/large_list.slt new file mode 100644 index 000000000000..44789e9dd786 --- /dev/null +++ b/datafusion/sqllogictest/test_files/min_max/large_list.slt @@ -0,0 +1,143 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +include ./init_data.slt.part + +## -------------------------------------- +## 1. Min/Max over integers +## -------------------------------------- +statement ok +create table min_max_int as ( + select + arrow_cast(column1, 'LargeList(Int64)') as column1 + from min_max_base_int + ); + +## -------------------------------------- +## 2. Min/Max over strings +## -------------------------------------- +statement ok +create table min_max_string as ( + select + arrow_cast(column1, 'LargeList(Utf8)') as column1 +from min_max_base_string); + +## -------------------------------------- +## 3. Min/Max over booleans +## -------------------------------------- +statement ok +create table min_max_bool as +( + select + arrow_cast(column1, 'LargeList(Boolean)') as column1 +from min_max_base_bool); + +## -------------------------------------- +## 4. Min/Max over nullable integers +## -------------------------------------- +statement ok +create table min_max_nullable_int as ( + select + arrow_cast(column1, 'LargeList(Int64)') as column1 + from min_max_base_nullable_int +); + +## -------------------------------------- +## 5. Min/Max with mixed lengths and nulls +## -------------------------------------- +statement ok +create table min_max_mixed_lengths_nulls as (select + arrow_cast(column1, 'LargeList(Int64)') as column1 +from min_max_base_mixed_lengths_nulls); + +## -------------------------------------- +## 6. Min/Max with only NULLs +## -------------------------------------- +statement ok +create table min_max_all_nulls as (select + arrow_cast(column1, 'LargeList(Int64)') as column1 +from min_max_base_all_nulls); + +## -------------------------------------- +## 7. Min/Max with partial NULLs +## -------------------------------------- +statement ok +create table min_max_null_variants as (select + arrow_cast(column1, 'LargeList(Int64)') as column1 +from min_max_base_null_variants); + +## -------------------------------------- +## 8. Min/Max grouped by key with NULLs and differing lengths +## -------------------------------------- +statement ok +create table min_max_grouped_nulls as (select + column1, + arrow_cast(column2, 'LargeList(Int64)') as column2 +from min_max_base_grouped_nulls); + +## -------------------------------------- +## 9. Min/Max grouped by key with only NULLs +## -------------------------------------- +statement ok +create table min_max_grouped_all_null as (select + column1, + arrow_cast(column2, 'LargeList(Int64)') as column2 +from min_max_base_grouped_all_null); + +## -------------------------------------- +## 10. Min/Max grouped with simple sizes +## -------------------------------------- +statement ok +create table min_max_grouped_simple as (select + column1, + arrow_cast(column2, 'LargeList(Int64)') as column2 +from min_max_base_grouped_simple); + +## -------------------------------------- +## 11. Min over with window function +## -------------------------------------- +statement ok +create table min_window_simple as (select + arrow_cast(column1, 'LargeList(Int64)') as column1 +from min_base_window_simple); + +## -------------------------------------- +## 12. Min over with window + NULLs +## -------------------------------------- +statement ok +create table min_window_with_null as (select + arrow_cast(column1, 'LargeList(Int64)') as column1 +from min_base_window_with_null); + +## -------------------------------------- +## 13. Min over with ROWS BETWEEN clause +## -------------------------------------- +statement ok +create table min_window_rows_between as (select + arrow_cast(column1, 'LargeList(Int64)') as column1 +from min_base_window_rows_between); + +## -------------------------------------- +## 14. Max over using different order column +## -------------------------------------- +statement ok +create table max_window_different_column as (select + arrow_cast(column1, 'LargeList(Int64)') as column1, + arrow_cast(column2, 'LargeList(Int64)') as column2 +from max_base_window_different_column); + +include ./queries.slt.part diff --git a/datafusion/sqllogictest/test_files/min_max/list.slt b/datafusion/sqllogictest/test_files/min_max/list.slt new file mode 100644 index 000000000000..e63e8303c7d5 --- /dev/null +++ b/datafusion/sqllogictest/test_files/min_max/list.slt @@ -0,0 +1,132 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +include ./init_data.slt.part + +# -------------------------------------- +# 1. Min/Max over integers +# -------------------------------------- +statement ok +create table min_max_int as ( + select * from min_max_base_int ) +; + +# -------------------------------------- +# 2. Min/Max over strings +# -------------------------------------- +statement ok +create table min_max_string as ( + select * from min_max_base_string ) +; + +# -------------------------------------- +# 3. Min/Max over booleans +# -------------------------------------- +statement ok +create table min_max_bool as ( + select * from min_max_base_bool ) +; + +# -------------------------------------- +# 4. Min/Max over nullable integers +# -------------------------------------- +statement ok +create table min_max_nullable_int as ( + select * from min_max_base_nullable_int ) +; + +# -------------------------------------- +# 5. Min/Max with mixed lengths and nulls +# -------------------------------------- +statement ok +create table min_max_mixed_lengths_nulls as ( + select * from min_max_base_mixed_lengths_nulls ) +; + +# -------------------------------------- +# 6. Min/Max with only NULLs +# -------------------------------------- +statement ok +create table min_max_all_nulls as ( + select * from min_max_base_all_nulls ) +; + +# -------------------------------------- +# 7. Min/Max with partial NULLs +# -------------------------------------- +statement ok +create table min_max_null_variants as ( + select * from min_max_base_null_variants ) +; + +# -------------------------------------- +# 8. Min/Max grouped by key with NULLs and differing lengths +# -------------------------------------- +statement ok +create table min_max_grouped_nulls as ( + select * from min_max_base_grouped_nulls ) +; + +# -------------------------------------- +# 9. Min/Max grouped by key with only NULLs +# -------------------------------------- +statement ok +create table min_max_grouped_all_null as ( + select * from min_max_base_grouped_all_null ) +; + +# -------------------------------------- +# 10. Min/Max grouped with simple sizes +# -------------------------------------- +statement ok +create table min_max_grouped_simple as ( + select * from min_max_base_grouped_simple ) +; + +# -------------------------------------- +# 11. Min over with window function +# -------------------------------------- +statement ok +create table min_window_simple as ( + select * from min_base_window_simple ) +; + +# -------------------------------------- +# 12. Min over with window + NULLs +# -------------------------------------- +statement ok +create table min_window_with_null as ( + select * from min_base_window_with_null ) +; + +# -------------------------------------- +# 13. Min over with ROWS BETWEEN clause +# -------------------------------------- +statement ok +create table min_window_rows_between as ( + select * from min_base_window_rows_between ) +; + +# -------------------------------------- +# 14. Max over using different order column +# -------------------------------------- +statement ok +create table max_window_different_column as ( + select * from max_base_window_different_column ) +; + +include ./queries.slt.part diff --git a/datafusion/sqllogictest/test_files/min_max/queries.slt.part b/datafusion/sqllogictest/test_files/min_max/queries.slt.part new file mode 100644 index 000000000000..bc7fb840bf97 --- /dev/null +++ b/datafusion/sqllogictest/test_files/min_max/queries.slt.part @@ -0,0 +1,112 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +## 1. Min/Max List over integers +query ?? +SELECT MIN(column1), MAX(column1) FROM min_max_int; +---- +[1, 2] [1, 2, 3, 4] + +## 2. Min/Max List over strings +query ?? +SELECT MIN(column1), MAX(column1) FROM min_max_string; +---- +[a, b] [a, b, c] + +## 3. Min/Max List over booleans +query ?? +SELECT MIN(column1), MAX(column1) FROM min_max_bool; +---- +[true, false] [true, false, true] + +## 4. Min/Max List over nullable integers +query ?? +SELECT MIN(column1), MAX(column1) FROM min_max_nullable_int; +---- +[1, 2] [NULL, 1, 2] + +## 5. Min/Max List with mixed lengths and nulls +query ?? +SELECT MIN(column1), MAX(column1) FROM min_max_mixed_lengths_nulls; +---- +[1, 2] [1, NULL, 3] + +## 6. Min/Max List with only NULLs +query ?? +SELECT MIN(column1), MAX(column1) FROM min_max_all_nulls; +---- +[NULL] [NULL, NULL] + +## 7. Min/Max List with partial NULLs +query ?? +SELECT MIN(column1), MAX(column1) FROM min_max_null_variants; +---- +[1, 2, 3] [NULL, 2, 3] + +## 8. Min/Max List grouped by key with NULLs and differing lengths +query I?? +SELECT column1, MIN(column2), MAX(column2) FROM min_max_grouped_nulls GROUP BY column1 ORDER BY column1; +---- +0 [1, 2, 3, 4] [1, NULL, 3] +1 [] [NULL, 5] + +## 9. Min/Max List grouped by key with only NULLs +query I?? +SELECT column1, MIN(column2), MAX(column2) FROM min_max_grouped_all_null GROUP BY column1 ORDER BY column1; +---- +0 [NULL] [NULL, NULL] +1 [NULL] [NULL] + +## 10. Min/Max grouped List with simple sizes +query I?? +SELECT column1, MIN(column2), MAX(column2) FROM min_max_grouped_simple GROUP BY column1 ORDER BY column1; +---- +0 [] [1] +1 [] [5, 6] + +## 11. Min over List with window function +query ? +SELECT MIN(column1) OVER (ORDER BY column1) FROM min_window_simple; +---- +[1, 2, 3] +[1, 2, 3] +[1, 2, 3] + +## 12. Min over List with window + NULLs +query ? +SELECT MIN(column1) OVER (ORDER BY column1) FROM min_window_with_null; +---- +[2, 3] +[2, 3] +[2, 3] + +## 13. Min over List with ROWS BETWEEN clause +query ? +SELECT MIN(column1) OVER (ORDER BY column1 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) FROM min_window_rows_between; +---- +[2, 3] +[2, 3] +[4, 5] + +## 14. Max over List using different order column +query ? +SELECT MAX(column2) OVER (ORDER BY column1 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) FROM max_window_different_column; +---- +[4, 5] +[4, 5] +[2, 3] diff --git a/datafusion/sqllogictest/test_files/order.slt b/datafusion/sqllogictest/test_files/order.slt index 4e8be56f3377..3fc90a6459f2 100644 --- a/datafusion/sqllogictest/test_files/order.slt +++ b/datafusion/sqllogictest/test_files/order.slt @@ -1040,12 +1040,12 @@ limit 5; ---- logical_plan 01)Sort: c_str ASC NULLS LAST, fetch=5 -02)--Projection: CAST(ordered_table.c AS Utf8) AS c_str +02)--Projection: CAST(ordered_table.c AS Utf8View) AS c_str 03)----TableScan: ordered_table projection=[c] physical_plan 01)SortPreservingMergeExec: [c_str@0 ASC NULLS LAST], fetch=5 02)--SortExec: TopK(fetch=5), expr=[c_str@0 ASC NULLS LAST], preserve_partitioning=[true] -03)----ProjectionExec: expr=[CAST(c@0 AS Utf8) as c_str] +03)----ProjectionExec: expr=[CAST(c@0 AS Utf8View) as c_str] 04)------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 05)--------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[c], output_ordering=[c@0 ASC NULLS LAST], file_type=csv, has_header=true @@ -1380,3 +1380,42 @@ physical_plan statement ok drop table table_with_ordered_not_null; + +# ORDER BY ALL +statement ok +set datafusion.sql_parser.dialect = 'DuckDB'; + +statement ok +CREATE OR REPLACE TABLE addresses AS + SELECT '123 Quack Blvd' AS address, 'DuckTown' AS city, '11111' AS zip + UNION ALL + SELECT '111 Duck Duck Goose Ln', 'DuckTown', '11111' + UNION ALL + SELECT '111 Duck Duck Goose Ln', 'Duck Town', '11111' + UNION ALL + SELECT '111 Duck Duck Goose Ln', 'Duck Town', '11111-0001'; + + +query TTT +SELECT * FROM addresses ORDER BY ALL; +---- +111 Duck Duck Goose Ln Duck Town 11111 +111 Duck Duck Goose Ln Duck Town 11111-0001 +111 Duck Duck Goose Ln DuckTown 11111 +123 Quack Blvd DuckTown 11111 + +query TTT +SELECT * FROM addresses ORDER BY ALL DESC; +---- +123 Quack Blvd DuckTown 11111 +111 Duck Duck Goose Ln DuckTown 11111 +111 Duck Duck Goose Ln Duck Town 11111-0001 +111 Duck Duck Goose Ln Duck Town 11111 + +query TT +SELECT address, zip FROM addresses ORDER BY ALL; +---- +111 Duck Duck Goose Ln 11111 +111 Duck Duck Goose Ln 11111 +111 Duck Duck Goose Ln 11111-0001 +123 Quack Blvd 11111 diff --git a/datafusion/sqllogictest/test_files/parquet.slt b/datafusion/sqllogictest/test_files/parquet.slt index 2970b2effb3e..abc6fdab3c8a 100644 --- a/datafusion/sqllogictest/test_files/parquet.slt +++ b/datafusion/sqllogictest/test_files/parquet.slt @@ -629,3 +629,78 @@ physical_plan statement ok drop table foo + + +# Tests for int96 timestamps written by spark +# See https://github.com/apache/datafusion/issues/9981 + +statement ok +CREATE EXTERNAL TABLE int96_from_spark +STORED AS PARQUET +LOCATION '../../parquet-testing/data/int96_from_spark.parquet'; + +# by default the value is read as nanosecond precision +query TTT +describe int96_from_spark +---- +a Timestamp(Nanosecond, None) YES + +# Note that the values are read as nanosecond precision +query P +select * from int96_from_spark +---- +2024-01-01T20:34:56.123456 +2024-01-01T01:00:00 +1816-03-29T08:56:08.066277376 +2024-12-30T23:00:00 +NULL +1815-11-08T16:01:01.191053312 + +statement ok +drop table int96_from_spark; + +# Enable coercion of int96 to microseconds +statement ok +set datafusion.execution.parquet.coerce_int96 = ms; + +statement ok +CREATE EXTERNAL TABLE int96_from_spark +STORED AS PARQUET +LOCATION '../../parquet-testing/data/int96_from_spark.parquet'; + +# Print schema +query TTT +describe int96_from_spark; +---- +a Timestamp(Millisecond, None) YES + +# Per https://github.com/apache/parquet-testing/blob/6e851ddd768d6af741c7b15dc594874399fc3cff/data/int96_from_spark.md?plain=1#L37 +# these values should be +# +# Some("2024-01-01T12:34:56.123456"), +# Some("2024-01-01T01:00:00Z"), +# Some("9999-12-31T01:00:00-02:00"), +# Some("2024-12-31T01:00:00+02:00"), +# None, +# Some("290000-12-31T01:00:00+02:00")) +# +# However, printing the large dates (9999-12-31 and 290000-12-31) is not supported by +# arrow yet +# +# See https://github.com/apache/arrow-rs/issues/7287 +query P +select * from int96_from_spark +---- +2024-01-01T20:34:56.123 +2024-01-01T01:00:00 +9999-12-31T03:00:00 +2024-12-30T23:00:00 +NULL +ERROR: Cast error: Failed to convert -9357363680509551 to datetime for Timestamp(Millisecond, None) + +# Cleanup / reset default setting +statement ok +drop table int96_from_spark; + +statement ok +set datafusion.execution.parquet.coerce_int96 = ns; diff --git a/datafusion/sqllogictest/test_files/parquet_filter_pushdown.slt b/datafusion/sqllogictest/test_files/parquet_filter_pushdown.slt index 758113b70835..f4fb0e87c43b 100644 --- a/datafusion/sqllogictest/test_files/parquet_filter_pushdown.slt +++ b/datafusion/sqllogictest/test_files/parquet_filter_pushdown.slt @@ -54,7 +54,6 @@ LOCATION 'test_files/scratch/parquet_filter_pushdown/parquet_table/'; statement ok set datafusion.execution.parquet.pushdown_filters = true; -## Create table without pushdown statement ok CREATE EXTERNAL TABLE t_pushdown(a varchar, b int, c float) STORED AS PARQUET LOCATION 'test_files/scratch/parquet_filter_pushdown/parquet_table/'; @@ -76,17 +75,142 @@ NULL NULL NULL +query T +select a from t_pushdown where b > 2 ORDER BY a; +---- +baz +foo +NULL +NULL +NULL + +query TT +EXPLAIN select a from t where b > 2 ORDER BY a; +---- +logical_plan +01)Sort: t.a ASC NULLS LAST +02)--Projection: t.a +03)----Filter: t.b > Int32(2) +04)------TableScan: t projection=[a, b], partial_filters=[t.b > Int32(2)] +physical_plan +01)SortPreservingMergeExec: [a@0 ASC NULLS LAST] +02)--SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] +03)----CoalesceBatchesExec: target_batch_size=8192 +04)------FilterExec: b@1 > 2, projection=[a@0] +05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=2 +06)----------DataSourceExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_table/1.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_table/2.parquet]]}, projection=[a, b], file_type=parquet, predicate=b@1 > 2, pruning_predicate=b_null_count@1 != row_count@2 AND b_max@0 > 2, required_guarantees=[] + +query TT +EXPLAIN select a from t_pushdown where b > 2 ORDER BY a; +---- +logical_plan +01)Sort: t_pushdown.a ASC NULLS LAST +02)--Projection: t_pushdown.a +03)----Filter: t_pushdown.b > Int32(2) +04)------TableScan: t_pushdown projection=[a, b], partial_filters=[t_pushdown.b > Int32(2)] +physical_plan +01)SortPreservingMergeExec: [a@0 ASC NULLS LAST] +02)--SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] +03)----DataSourceExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_table/1.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_table/2.parquet]]}, projection=[a], file_type=parquet, predicate=b@1 > 2, pruning_predicate=b_null_count@1 != row_count@2 AND b_max@0 > 2, required_guarantees=[] + +# If we set the setting to `true` it override's the table's setting +statement ok +set datafusion.execution.parquet.pushdown_filters = true; + +query T +select a from t where b > 2 ORDER BY a; +---- +baz +foo +NULL +NULL +NULL + +query T +select a from t_pushdown where b > 2 ORDER BY a; +---- +baz +foo +NULL +NULL +NULL + +query TT +EXPLAIN select a from t where b > 2 ORDER BY a; +---- +logical_plan +01)Sort: t.a ASC NULLS LAST +02)--Projection: t.a +03)----Filter: t.b > Int32(2) +04)------TableScan: t projection=[a, b], partial_filters=[t.b > Int32(2)] +physical_plan +01)SortPreservingMergeExec: [a@0 ASC NULLS LAST] +02)--SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] +03)----DataSourceExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_table/1.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_table/2.parquet]]}, projection=[a], file_type=parquet, predicate=b@1 > 2, pruning_predicate=b_null_count@1 != row_count@2 AND b_max@0 > 2, required_guarantees=[] + query TT EXPLAIN select a from t_pushdown where b > 2 ORDER BY a; ---- logical_plan 01)Sort: t_pushdown.a ASC NULLS LAST -02)--TableScan: t_pushdown projection=[a], full_filters=[t_pushdown.b > Int32(2)] +02)--Projection: t_pushdown.a +03)----Filter: t_pushdown.b > Int32(2) +04)------TableScan: t_pushdown projection=[a, b], partial_filters=[t_pushdown.b > Int32(2)] physical_plan 01)SortPreservingMergeExec: [a@0 ASC NULLS LAST] 02)--SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] 03)----DataSourceExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_table/1.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_table/2.parquet]]}, projection=[a], file_type=parquet, predicate=b@1 > 2, pruning_predicate=b_null_count@1 != row_count@2 AND b_max@0 > 2, required_guarantees=[] +# If we reset the default the table created without pushdown goes back to disabling it +statement ok +set datafusion.execution.parquet.pushdown_filters = false; + +query T +select a from t where b > 2 ORDER BY a; +---- +baz +foo +NULL +NULL +NULL + +query T +select a from t_pushdown where b > 2 ORDER BY a; +---- +baz +foo +NULL +NULL +NULL + +query TT +EXPLAIN select a from t where b > 2 ORDER BY a; +---- +logical_plan +01)Sort: t.a ASC NULLS LAST +02)--Projection: t.a +03)----Filter: t.b > Int32(2) +04)------TableScan: t projection=[a, b], partial_filters=[t.b > Int32(2)] +physical_plan +01)SortPreservingMergeExec: [a@0 ASC NULLS LAST] +02)--SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] +03)----CoalesceBatchesExec: target_batch_size=8192 +04)------FilterExec: b@1 > 2, projection=[a@0] +05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=2 +06)----------DataSourceExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_table/1.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_table/2.parquet]]}, projection=[a, b], file_type=parquet, predicate=b@1 > 2, pruning_predicate=b_null_count@1 != row_count@2 AND b_max@0 > 2, required_guarantees=[] + +query TT +EXPLAIN select a from t_pushdown where b > 2 ORDER BY a; +---- +logical_plan +01)Sort: t_pushdown.a ASC NULLS LAST +02)--Projection: t_pushdown.a +03)----Filter: t_pushdown.b > Int32(2) +04)------TableScan: t_pushdown projection=[a, b], partial_filters=[t_pushdown.b > Int32(2)] +physical_plan +01)SortPreservingMergeExec: [a@0 ASC NULLS LAST] +02)--SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] +03)----DataSourceExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_table/1.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_table/2.parquet]]}, projection=[a], file_type=parquet, predicate=b@1 > 2, pruning_predicate=b_null_count@1 != row_count@2 AND b_max@0 > 2, required_guarantees=[] # When filter pushdown *is* enabled, ParquetExec can filter exactly, # not just metadata, so we expect to see no FilterExec @@ -127,7 +251,9 @@ EXPLAIN select a from t_pushdown where b > 2 AND a IS NOT NULL order by a; ---- logical_plan 01)Sort: t_pushdown.a ASC NULLS LAST -02)--TableScan: t_pushdown projection=[a], full_filters=[t_pushdown.b > Int32(2), t_pushdown.a IS NOT NULL] +02)--Projection: t_pushdown.a +03)----Filter: t_pushdown.b > Int32(2) AND t_pushdown.a IS NOT NULL +04)------TableScan: t_pushdown projection=[a, b], partial_filters=[t_pushdown.b > Int32(2), t_pushdown.a IS NOT NULL] physical_plan 01)SortPreservingMergeExec: [a@0 ASC NULLS LAST] 02)--SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] @@ -144,7 +270,9 @@ EXPLAIN select b from t_pushdown where a = 'bar' order by b; ---- logical_plan 01)Sort: t_pushdown.b ASC NULLS LAST -02)--TableScan: t_pushdown projection=[b], full_filters=[t_pushdown.a = Utf8("bar")] +02)--Projection: t_pushdown.b +03)----Filter: t_pushdown.a = Utf8View("bar") +04)------TableScan: t_pushdown projection=[a, b], partial_filters=[t_pushdown.a = Utf8View("bar")] physical_plan 01)SortPreservingMergeExec: [b@0 ASC NULLS LAST] 02)--SortExec: expr=[b@0 ASC NULLS LAST], preserve_partitioning=[true] @@ -156,3 +284,87 @@ DROP TABLE t; statement ok DROP TABLE t_pushdown; + +## Test filter pushdown with a predicate that references both a partition column and a file column +statement ok +set datafusion.execution.parquet.pushdown_filters = true; + +## Create table +statement ok +CREATE EXTERNAL TABLE t_pushdown(part text, val text) +STORED AS PARQUET +PARTITIONED BY (part) +LOCATION 'test_files/scratch/parquet_filter_pushdown/parquet_part_test/'; + +statement ok +COPY ( + SELECT arrow_cast('a', 'Utf8') AS val +) TO 'test_files/scratch/parquet_filter_pushdown/parquet_part_test/part=a/file.parquet' +STORED AS PARQUET; + +statement ok +COPY ( + SELECT arrow_cast('b', 'Utf8') AS val +) TO 'test_files/scratch/parquet_filter_pushdown/parquet_part_test/part=b/file.parquet' +STORED AS PARQUET; + +statement ok +COPY ( + SELECT arrow_cast('xyz', 'Utf8') AS val +) TO 'test_files/scratch/parquet_filter_pushdown/parquet_part_test/part=c/file.parquet' +STORED AS PARQUET; + +query TT +select * from t_pushdown where part == val order by part, val; +---- +a a +b b + +query TT +select * from t_pushdown where part != val order by part, val; +---- +xyz c + +# If we reference both a file and partition column the predicate cannot be pushed down +query TT +EXPLAIN select * from t_pushdown where part != val +---- +logical_plan +01)Filter: t_pushdown.val != t_pushdown.part +02)--TableScan: t_pushdown projection=[val, part], partial_filters=[t_pushdown.val != t_pushdown.part] +physical_plan +01)CoalesceBatchesExec: target_batch_size=8192 +02)--FilterExec: val@0 != part@1 +03)----DataSourceExec: file_groups={3 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_part_test/part=a/file.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_part_test/part=b/file.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_part_test/part=c/file.parquet]]}, projection=[val, part], file_type=parquet + +# If we reference only a partition column it gets evaluted during the listing phase +query TT +EXPLAIN select * from t_pushdown where part != 'a'; +---- +logical_plan TableScan: t_pushdown projection=[val, part], full_filters=[t_pushdown.part != Utf8("a")] +physical_plan DataSourceExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_part_test/part=b/file.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_part_test/part=c/file.parquet]]}, projection=[val, part], file_type=parquet + +# And if we reference only a file column it gets pushed down +query TT +EXPLAIN select * from t_pushdown where val != 'c'; +---- +logical_plan +01)Filter: t_pushdown.val != Utf8("c") +02)--TableScan: t_pushdown projection=[val, part], partial_filters=[t_pushdown.val != Utf8("c")] +physical_plan DataSourceExec: file_groups={3 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_part_test/part=a/file.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_part_test/part=b/file.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_part_test/part=c/file.parquet]]}, projection=[val, part], file_type=parquet, predicate=val@0 != c, pruning_predicate=val_null_count@2 != row_count@3 AND (val_min@0 != c OR c != val_max@1), required_guarantees=[val not in (c)] + +# If we have a mix of filters: +# - The partition filters get evaluated during planning +# - The mixed filters end up in a FilterExec +# - The file filters get pushed down into the scan +query TT +EXPLAIN select * from t_pushdown where val != 'd' AND val != 'c' AND part = 'a' AND part != val; +---- +logical_plan +01)Filter: t_pushdown.val != Utf8("d") AND t_pushdown.val != Utf8("c") AND t_pushdown.val != t_pushdown.part +02)--TableScan: t_pushdown projection=[val, part], full_filters=[t_pushdown.part = Utf8("a")], partial_filters=[t_pushdown.val != Utf8("d"), t_pushdown.val != Utf8("c"), t_pushdown.val != t_pushdown.part] +physical_plan +01)CoalesceBatchesExec: target_batch_size=8192 +02)--FilterExec: val@0 != part@1 +03)----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +04)------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_part_test/part=a/file.parquet]]}, projection=[val, part], file_type=parquet, predicate=val@0 != d AND val@0 != c, pruning_predicate=val_null_count@2 != row_count@3 AND (val_min@0 != d OR d != val_max@1) AND val_null_count@2 != row_count@3 AND (val_min@0 != c OR c != val_max@1), required_guarantees=[val not in (c, d)] diff --git a/datafusion/sqllogictest/test_files/parquet_statistics.slt b/datafusion/sqllogictest/test_files/parquet_statistics.slt new file mode 100644 index 000000000000..efbe69bd856c --- /dev/null +++ b/datafusion/sqllogictest/test_files/parquet_statistics.slt @@ -0,0 +1,125 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Tests for statistics in parquet files. +# Writes data into two files: +# * test_table/0.parquet +# * test_table/1.parquet +# +# And verifies statistics are correctly calculated for the table +# +# NOTE that statistics are ONLY gathered when the table is first created +# so the table must be recreated to see the effects of the setting + +query I +COPY (values (1), (2), (3)) +TO 'test_files/scratch/parquet_statistics/test_table/0.parquet' +STORED AS PARQUET; +---- +3 + +query I +COPY (values (3), (4)) +TO 'test_files/scratch/parquet_statistics/test_table/1.parquet' +STORED AS PARQUET; +---- +2 + +statement ok +set datafusion.explain.physical_plan_only = true; + +statement ok +set datafusion.explain.show_statistics = true; + +###### +# By default, the statistics are gathered +###### + +# Recreate the table to pick up the current setting +statement ok +CREATE EXTERNAL TABLE test_table +STORED AS PARQUET +LOCATION 'test_files/scratch/parquet_statistics/test_table'; + +query TT +EXPLAIN SELECT * FROM test_table WHERE column1 = 1; +---- +physical_plan +01)CoalesceBatchesExec: target_batch_size=8192, statistics=[Rows=Inexact(2), Bytes=Inexact(44), [(Col[0]: Min=Exact(Int64(1)) Max=Exact(Int64(1)) Null=Inexact(0))]] +02)--FilterExec: column1@0 = 1, statistics=[Rows=Inexact(2), Bytes=Inexact(44), [(Col[0]: Min=Exact(Int64(1)) Max=Exact(Int64(1)) Null=Inexact(0))]] +03)----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=2, statistics=[Rows=Inexact(5), Bytes=Inexact(173), [(Col[0]: Min=Inexact(Int64(1)) Max=Inexact(Int64(4)) Null=Inexact(0))]] +04)------DataSourceExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_statistics/test_table/0.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_statistics/test_table/1.parquet]]}, projection=[column1], file_type=parquet, predicate=column1@0 = 1, pruning_predicate=column1_null_count@2 != row_count@3 AND column1_min@0 <= 1 AND 1 <= column1_max@1, required_guarantees=[column1 in (1)] +05), statistics=[Rows=Inexact(5), Bytes=Inexact(173), [(Col[0]: Min=Inexact(Int64(1)) Max=Inexact(Int64(4)) Null=Inexact(0))]] + +# cleanup +statement ok +DROP TABLE test_table; + +###### +# When the setting is true, statistics are gathered +###### + +statement ok +set datafusion.execution.collect_statistics = true; + +# Recreate the table to pick up the current setting +statement ok +CREATE EXTERNAL TABLE test_table +STORED AS PARQUET +LOCATION 'test_files/scratch/parquet_statistics/test_table'; + +query TT +EXPLAIN SELECT * FROM test_table WHERE column1 = 1; +---- +physical_plan +01)CoalesceBatchesExec: target_batch_size=8192, statistics=[Rows=Inexact(2), Bytes=Inexact(44), [(Col[0]: Min=Exact(Int64(1)) Max=Exact(Int64(1)) Null=Inexact(0))]] +02)--FilterExec: column1@0 = 1, statistics=[Rows=Inexact(2), Bytes=Inexact(44), [(Col[0]: Min=Exact(Int64(1)) Max=Exact(Int64(1)) Null=Inexact(0))]] +03)----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=2, statistics=[Rows=Inexact(5), Bytes=Inexact(173), [(Col[0]: Min=Inexact(Int64(1)) Max=Inexact(Int64(4)) Null=Inexact(0))]] +04)------DataSourceExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_statistics/test_table/0.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_statistics/test_table/1.parquet]]}, projection=[column1], file_type=parquet, predicate=column1@0 = 1, pruning_predicate=column1_null_count@2 != row_count@3 AND column1_min@0 <= 1 AND 1 <= column1_max@1, required_guarantees=[column1 in (1)] +05), statistics=[Rows=Inexact(5), Bytes=Inexact(173), [(Col[0]: Min=Inexact(Int64(1)) Max=Inexact(Int64(4)) Null=Inexact(0))]] + +# cleanup +statement ok +DROP TABLE test_table; + + +###### +# When the setting is false, the statistics are NOT gathered +###### + +statement ok +set datafusion.execution.collect_statistics = false; + +# Recreate the table to pick up the current setting +statement ok +CREATE EXTERNAL TABLE test_table +STORED AS PARQUET +LOCATION 'test_files/scratch/parquet_statistics/test_table'; + +query TT +EXPLAIN SELECT * FROM test_table WHERE column1 = 1; +---- +physical_plan +01)CoalesceBatchesExec: target_batch_size=8192, statistics=[Rows=Absent, Bytes=Absent, [(Col[0]:)]] +02)--FilterExec: column1@0 = 1, statistics=[Rows=Absent, Bytes=Absent, [(Col[0]: Min=Exact(Int64(1)) Max=Exact(Int64(1)))]] +03)----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=2, statistics=[Rows=Absent, Bytes=Absent, [(Col[0]:)]] +04)------DataSourceExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_statistics/test_table/0.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_statistics/test_table/1.parquet]]}, projection=[column1], file_type=parquet, predicate=column1@0 = 1, pruning_predicate=column1_null_count@2 != row_count@3 AND column1_min@0 <= 1 AND 1 <= column1_max@1, required_guarantees=[column1 in (1)] +05), statistics=[Rows=Absent, Bytes=Absent, [(Col[0]:)]] + +# cleanup +statement ok +DROP TABLE test_table; diff --git a/datafusion/sqllogictest/test_files/predicates.slt b/datafusion/sqllogictest/test_files/predicates.slt index b263e39f3b11..b4b31fa78a69 100644 --- a/datafusion/sqllogictest/test_files/predicates.slt +++ b/datafusion/sqllogictest/test_files/predicates.slt @@ -662,11 +662,11 @@ OR ---- logical_plan 01)Projection: lineitem.l_partkey -02)--Inner Join: lineitem.l_partkey = part.p_partkey Filter: part.p_brand = Utf8("Brand#12") AND lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) AND part.p_size <= Int32(5) OR part.p_brand = Utf8("Brand#23") AND lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) AND part.p_size <= Int32(10) OR part.p_brand = Utf8("Brand#34") AND lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2) AND part.p_size <= Int32(15) +02)--Inner Join: lineitem.l_partkey = part.p_partkey Filter: part.p_brand = Utf8View("Brand#12") AND lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) AND part.p_size <= Int32(5) OR part.p_brand = Utf8View("Brand#23") AND lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) AND part.p_size <= Int32(10) OR part.p_brand = Utf8View("Brand#34") AND lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2) AND part.p_size <= Int32(15) 03)----Filter: lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2) 04)------TableScan: lineitem projection=[l_partkey, l_quantity], partial_filters=[lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)] -05)----Filter: (part.p_brand = Utf8("Brand#12") AND part.p_size <= Int32(5) OR part.p_brand = Utf8("Brand#23") AND part.p_size <= Int32(10) OR part.p_brand = Utf8("Brand#34") AND part.p_size <= Int32(15)) AND part.p_size >= Int32(1) -06)------TableScan: part projection=[p_partkey, p_brand, p_size], partial_filters=[part.p_size >= Int32(1), part.p_brand = Utf8("Brand#12") AND part.p_size <= Int32(5) OR part.p_brand = Utf8("Brand#23") AND part.p_size <= Int32(10) OR part.p_brand = Utf8("Brand#34") AND part.p_size <= Int32(15)] +05)----Filter: (part.p_brand = Utf8View("Brand#12") AND part.p_size <= Int32(5) OR part.p_brand = Utf8View("Brand#23") AND part.p_size <= Int32(10) OR part.p_brand = Utf8View("Brand#34") AND part.p_size <= Int32(15)) AND part.p_size >= Int32(1) +06)------TableScan: part projection=[p_partkey, p_brand, p_size], partial_filters=[part.p_size >= Int32(1), part.p_brand = Utf8View("Brand#12") AND part.p_size <= Int32(5) OR part.p_brand = Utf8View("Brand#23") AND part.p_size <= Int32(10) OR part.p_brand = Utf8View("Brand#34") AND part.p_size <= Int32(15)] physical_plan 01)CoalesceBatchesExec: target_batch_size=8192 02)--HashJoinExec: mode=Partitioned, join_type=Inner, on=[(l_partkey@0, p_partkey@0)], filter=p_brand@1 = Brand#12 AND l_quantity@0 >= Some(100),15,2 AND l_quantity@0 <= Some(1100),15,2 AND p_size@2 <= 5 OR p_brand@1 = Brand#23 AND l_quantity@0 >= Some(1000),15,2 AND l_quantity@0 <= Some(2000),15,2 AND p_size@2 <= 10 OR p_brand@1 = Brand#34 AND l_quantity@0 >= Some(2000),15,2 AND l_quantity@0 <= Some(3000),15,2 AND p_size@2 <= 15, projection=[l_partkey@0] @@ -755,8 +755,8 @@ logical_plan 05)--------Inner Join: lineitem.l_partkey = part.p_partkey 06)----------TableScan: lineitem projection=[l_partkey, l_extendedprice, l_discount] 07)----------Projection: part.p_partkey -08)------------Filter: part.p_brand = Utf8("Brand#12") OR part.p_brand = Utf8("Brand#23") -09)--------------TableScan: part projection=[p_partkey, p_brand], partial_filters=[part.p_brand = Utf8("Brand#12") OR part.p_brand = Utf8("Brand#23")] +08)------------Filter: part.p_brand = Utf8View("Brand#12") OR part.p_brand = Utf8View("Brand#23") +09)--------------TableScan: part projection=[p_partkey, p_brand], partial_filters=[part.p_brand = Utf8View("Brand#12") OR part.p_brand = Utf8View("Brand#23")] 10)------TableScan: partsupp projection=[ps_partkey, ps_suppkey] physical_plan 01)AggregateExec: mode=SinglePartitioned, gby=[p_partkey@2 as p_partkey], aggr=[sum(lineitem.l_extendedprice), avg(lineitem.l_discount), count(DISTINCT partsupp.ps_suppkey)] diff --git a/datafusion/sqllogictest/test_files/prepare.slt b/datafusion/sqllogictest/test_files/prepare.slt index 33df0d26f361..d61603ae6558 100644 --- a/datafusion/sqllogictest/test_files/prepare.slt +++ b/datafusion/sqllogictest/test_files/prepare.slt @@ -92,7 +92,7 @@ DEALLOCATE my_plan statement ok PREPARE my_plan AS SELECT * FROM person WHERE id < $1; -statement error No value found for placeholder with id \$1 +statement error Prepared statement 'my_plan' expects 1 parameters, but 0 provided EXECUTE my_plan statement ok diff --git a/datafusion/sqllogictest/test_files/push_down_filter.slt b/datafusion/sqllogictest/test_files/push_down_filter.slt index 67965146e76b..ed948dd11439 100644 --- a/datafusion/sqllogictest/test_files/push_down_filter.slt +++ b/datafusion/sqllogictest/test_files/push_down_filter.slt @@ -18,7 +18,7 @@ # Test push down filter statement ok -set datafusion.explain.logical_plan_only = true; +set datafusion.explain.physical_plan_only = true; statement ok CREATE TABLE IF NOT EXISTS v AS VALUES(1,[1,2,3]),(2,[3,4,5]); @@ -35,12 +35,14 @@ select uc2 from (select unnest(column2) as uc2, column1 from v) where column1 = query TT explain select uc2 from (select unnest(column2) as uc2, column1 from v) where column1 = 2; ---- -logical_plan -01)Projection: __unnest_placeholder(v.column2,depth=1) AS uc2 -02)--Unnest: lists[__unnest_placeholder(v.column2)|depth=1] structs[] -03)----Projection: v.column2 AS __unnest_placeholder(v.column2), v.column1 -04)------Filter: v.column1 = Int64(2) -05)--------TableScan: v projection=[column1, column2] +physical_plan +01)ProjectionExec: expr=[__unnest_placeholder(v.column2,depth=1)@0 as uc2] +02)--UnnestExec +03)----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +04)------ProjectionExec: expr=[column2@1 as __unnest_placeholder(v.column2), column1@0 as column1] +05)--------CoalesceBatchesExec: target_batch_size=8192 +06)----------FilterExec: column1@0 = 2 +07)------------DataSourceExec: partitions=1, partition_sizes=[1] query I select uc2 from (select unnest(column2) as uc2, column1 from v) where uc2 > 3; @@ -52,13 +54,15 @@ select uc2 from (select unnest(column2) as uc2, column1 from v) where uc2 > 3; query TT explain select uc2 from (select unnest(column2) as uc2, column1 from v) where uc2 > 3; ---- -logical_plan -01)Projection: __unnest_placeholder(v.column2,depth=1) AS uc2 -02)--Filter: __unnest_placeholder(v.column2,depth=1) > Int64(3) -03)----Projection: __unnest_placeholder(v.column2,depth=1) -04)------Unnest: lists[__unnest_placeholder(v.column2)|depth=1] structs[] -05)--------Projection: v.column2 AS __unnest_placeholder(v.column2), v.column1 -06)----------TableScan: v projection=[column1, column2] +physical_plan +01)ProjectionExec: expr=[__unnest_placeholder(v.column2,depth=1)@0 as uc2] +02)--CoalesceBatchesExec: target_batch_size=8192 +03)----FilterExec: __unnest_placeholder(v.column2,depth=1)@0 > 3 +04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +05)--------ProjectionExec: expr=[__unnest_placeholder(v.column2,depth=1)@0 as __unnest_placeholder(v.column2,depth=1)] +06)----------UnnestExec +07)------------ProjectionExec: expr=[column2@1 as __unnest_placeholder(v.column2), column1@0 as column1] +08)--------------DataSourceExec: partitions=1, partition_sizes=[1] query II select uc2, column1 from (select unnest(column2) as uc2, column1 from v) where uc2 > 3 AND column1 = 2; @@ -70,13 +74,16 @@ select uc2, column1 from (select unnest(column2) as uc2, column1 from v) where query TT explain select uc2, column1 from (select unnest(column2) as uc2, column1 from v) where uc2 > 3 AND column1 = 2; ---- -logical_plan -01)Projection: __unnest_placeholder(v.column2,depth=1) AS uc2, v.column1 -02)--Filter: __unnest_placeholder(v.column2,depth=1) > Int64(3) -03)----Unnest: lists[__unnest_placeholder(v.column2)|depth=1] structs[] -04)------Projection: v.column2 AS __unnest_placeholder(v.column2), v.column1 -05)--------Filter: v.column1 = Int64(2) -06)----------TableScan: v projection=[column1, column2] +physical_plan +01)ProjectionExec: expr=[__unnest_placeholder(v.column2,depth=1)@0 as uc2, column1@1 as column1] +02)--CoalesceBatchesExec: target_batch_size=8192 +03)----FilterExec: __unnest_placeholder(v.column2,depth=1)@0 > 3 +04)------UnnestExec +05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +06)----------ProjectionExec: expr=[column2@1 as __unnest_placeholder(v.column2), column1@0 as column1] +07)------------CoalesceBatchesExec: target_batch_size=8192 +08)--------------FilterExec: column1@0 = 2 +09)----------------DataSourceExec: partitions=1, partition_sizes=[1] query II select uc2, column1 from (select unnest(column2) as uc2, column1 from v) where uc2 > 3 OR column1 = 2; @@ -89,12 +96,14 @@ select uc2, column1 from (select unnest(column2) as uc2, column1 from v) where query TT explain select uc2, column1 from (select unnest(column2) as uc2, column1 from v) where uc2 > 3 OR column1 = 2; ---- -logical_plan -01)Projection: __unnest_placeholder(v.column2,depth=1) AS uc2, v.column1 -02)--Filter: __unnest_placeholder(v.column2,depth=1) > Int64(3) OR v.column1 = Int64(2) -03)----Unnest: lists[__unnest_placeholder(v.column2)|depth=1] structs[] -04)------Projection: v.column2 AS __unnest_placeholder(v.column2), v.column1 -05)--------TableScan: v projection=[column1, column2] +physical_plan +01)ProjectionExec: expr=[__unnest_placeholder(v.column2,depth=1)@0 as uc2, column1@1 as column1] +02)--CoalesceBatchesExec: target_batch_size=8192 +03)----FilterExec: __unnest_placeholder(v.column2,depth=1)@0 > 3 OR column1@1 = 2 +04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +05)--------UnnestExec +06)----------ProjectionExec: expr=[column2@1 as __unnest_placeholder(v.column2), column1@0 as column1] +07)------------DataSourceExec: partitions=1, partition_sizes=[1] statement ok drop table v; @@ -111,12 +120,14 @@ select * from (select column1, unnest(column2) as o from d) where o['a'] = 1; query TT explain select * from (select column1, unnest(column2) as o from d) where o['a'] = 1; ---- -logical_plan -01)Projection: d.column1, __unnest_placeholder(d.column2,depth=1) AS o -02)--Filter: get_field(__unnest_placeholder(d.column2,depth=1), Utf8("a")) = Int64(1) -03)----Unnest: lists[__unnest_placeholder(d.column2)|depth=1] structs[] -04)------Projection: d.column1, d.column2 AS __unnest_placeholder(d.column2) -05)--------TableScan: d projection=[column1, column2] +physical_plan +01)ProjectionExec: expr=[column1@0 as column1, __unnest_placeholder(d.column2,depth=1)@1 as o] +02)--CoalesceBatchesExec: target_batch_size=8192 +03)----FilterExec: get_field(__unnest_placeholder(d.column2,depth=1)@1, a) = 1 +04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +05)--------UnnestExec +06)----------ProjectionExec: expr=[column1@0 as column1, column2@1 as __unnest_placeholder(d.column2)] +07)------------DataSourceExec: partitions=1, partition_sizes=[1] @@ -179,9 +190,9 @@ LOCATION 'test_files/scratch/parquet/test_filter_with_limit/'; query TT explain select * from test_filter_with_limit where value = 2 limit 1; ---- -logical_plan -01)Limit: skip=0, fetch=1 -02)--TableScan: test_filter_with_limit projection=[part_key, value], full_filters=[test_filter_with_limit.value = Int32(2)], fetch=1 +physical_plan +01)CoalescePartitionsExec: fetch=1 +02)--DataSourceExec: file_groups={3 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/test_filter_with_limit/part-0.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/test_filter_with_limit/part-1.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/test_filter_with_limit/part-2.parquet]]}, projection=[part_key, value], limit=1, file_type=parquet, predicate=value@1 = 2, pruning_predicate=value_null_count@2 != row_count@3 AND value_min@0 <= 2 AND 2 <= value_max@1, required_guarantees=[value in (2)] query II select * from test_filter_with_limit where value = 2 limit 1; @@ -218,43 +229,43 @@ LOCATION 'test_files/scratch/push_down_filter/t.parquet'; query TT explain select a from t where a = '100'; ---- -logical_plan TableScan: t projection=[a], full_filters=[t.a = Int32(100)] +physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/push_down_filter/t.parquet]]}, projection=[a], file_type=parquet, predicate=a@0 = 100, pruning_predicate=a_null_count@2 != row_count@3 AND a_min@0 <= 100 AND 100 <= a_max@1, required_guarantees=[a in (100)] # The predicate should not have a column cast when the value is a valid i32 query TT explain select a from t where a != '100'; ---- -logical_plan TableScan: t projection=[a], full_filters=[t.a != Int32(100)] +physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/push_down_filter/t.parquet]]}, projection=[a], file_type=parquet, predicate=a@0 != 100, pruning_predicate=a_null_count@2 != row_count@3 AND (a_min@0 != 100 OR 100 != a_max@1), required_guarantees=[a not in (100)] # The predicate should still have the column cast when the value is a NOT valid i32 query TT explain select a from t where a = '99999999999'; ---- -logical_plan TableScan: t projection=[a], full_filters=[CAST(t.a AS Utf8) = Utf8("99999999999")] +physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/push_down_filter/t.parquet]]}, projection=[a], file_type=parquet, predicate=CAST(a@0 AS Utf8) = 99999999999 # The predicate should still have the column cast when the value is a NOT valid i32 query TT explain select a from t where a = '99.99'; ---- -logical_plan TableScan: t projection=[a], full_filters=[CAST(t.a AS Utf8) = Utf8("99.99")] +physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/push_down_filter/t.parquet]]}, projection=[a], file_type=parquet, predicate=CAST(a@0 AS Utf8) = 99.99 # The predicate should still have the column cast when the value is a NOT valid i32 query TT explain select a from t where a = ''; ---- -logical_plan TableScan: t projection=[a], full_filters=[CAST(t.a AS Utf8) = Utf8("")] +physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/push_down_filter/t.parquet]]}, projection=[a], file_type=parquet, predicate=CAST(a@0 AS Utf8) = # The predicate should not have a column cast when the operator is = or != and the literal can be round-trip casted without losing information. query TT explain select a from t where cast(a as string) = '100'; ---- -logical_plan TableScan: t projection=[a], full_filters=[t.a = Int32(100)] +physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/push_down_filter/t.parquet]]}, projection=[a], file_type=parquet, predicate=a@0 = 100, pruning_predicate=a_null_count@2 != row_count@3 AND a_min@0 <= 100 AND 100 <= a_max@1, required_guarantees=[a in (100)] # The predicate should still have the column cast when the literal alters its string representation after round-trip casting (leading zero lost). query TT explain select a from t where CAST(a AS string) = '0123'; ---- -logical_plan TableScan: t projection=[a], full_filters=[CAST(t.a AS Utf8) = Utf8("0123")] +physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/push_down_filter/t.parquet]]}, projection=[a], file_type=parquet, predicate=CAST(a@0 AS Utf8) = 0123 statement ok diff --git a/datafusion/sqllogictest/test_files/regexp.slt b/datafusion/sqllogictest/test_files/regexp.slt deleted file mode 100644 index 44ba61e877d9..000000000000 --- a/datafusion/sqllogictest/test_files/regexp.slt +++ /dev/null @@ -1,898 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -statement ok -CREATE TABLE t (str varchar, pattern varchar, start int, flags varchar) AS VALUES - ('abc', '^(a)', 1, 'i'), - ('ABC', '^(A).*', 1, 'i'), - ('aBc', '(b|d)', 1, 'i'), - ('AbC', '(B|D)', 2, null), - ('aBC', '^(b|c)', 3, null), - ('4000', '\b4([1-9]\d\d|\d[1-9]\d|\d\d[1-9])\b', 1, null), - ('4010', '\b4([1-9]\d\d|\d[1-9]\d|\d\d[1-9])\b', 2, null), - ('Düsseldorf','[\p{Letter}-]+', 3, null), - ('Москва', '[\p{L}-]+', 4, null), - ('Köln', '[a-zA-Z]ö[a-zA-Z]{2}', 1, null), - ('إسرائيل', '^\p{Arabic}+$', 2, null); - -# -# regexp_like tests -# - -query B -SELECT regexp_like(str, pattern, flags) FROM t; ----- -true -true -true -false -false -false -true -true -true -true -true - -query B -SELECT str ~ NULL FROM t; ----- -NULL -NULL -NULL -NULL -NULL -NULL -NULL -NULL -NULL -NULL -NULL - -query B -select str ~ right('foo', NULL) FROM t; ----- -NULL -NULL -NULL -NULL -NULL -NULL -NULL -NULL -NULL -NULL -NULL - -query B -select right('foo', NULL) !~ str FROM t; ----- -NULL -NULL -NULL -NULL -NULL -NULL -NULL -NULL -NULL -NULL -NULL - -query B -SELECT regexp_like('foobarbequebaz', ''); ----- -true - -query B -SELECT regexp_like('', ''); ----- -true - -query B -SELECT regexp_like('foobarbequebaz', '(bar)(beque)'); ----- -true - -query B -SELECT regexp_like('fooBarb -eQuebaz', '(bar).*(que)', 'is'); ----- -true - -query B -SELECT regexp_like('foobarbequebaz', '(ba3r)(bequ34e)'); ----- -false - -query B -SELECT regexp_like('foobarbequebaz', '^.*(barbequ[0-9]*e).*$', 'm'); ----- -true - -query B -SELECT regexp_like('aaa-0', '.*-(\d)'); ----- -true - -query B -SELECT regexp_like('bb-1', '.*-(\d)'); ----- -true - -query B -SELECT regexp_like('aa', '.*-(\d)'); ----- -false - -query B -SELECT regexp_like(NULL, '.*-(\d)'); ----- -NULL - -query B -SELECT regexp_like('aaa-0', NULL); ----- -NULL - -query B -SELECT regexp_like(null, '.*-(\d)'); ----- -NULL - -query error Error during planning: regexp_like\(\) does not support the "global" option -SELECT regexp_like('bb-1', '.*-(\d)', 'g'); - -query error Error during planning: regexp_like\(\) does not support the "global" option -SELECT regexp_like('bb-1', '.*-(\d)', 'g'); - -query error Arrow error: Compute error: Regular expression did not compile: CompiledTooBig\(10485760\) -SELECT regexp_like('aaaaa', 'a{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}'); - -# look-around is not supported and will just return false -query B -SELECT regexp_like('(?<=[A-Z]\w )Smith', 'John Smith', 'i'); ----- -false - -query B -select regexp_like('aaa-555', '.*-(\d*)'); ----- -true - -# -# regexp_match tests -# - -query ? -SELECT regexp_match(str, pattern, flags) FROM t; ----- -[a] -[A] -[B] -NULL -NULL -NULL -[010] -[Düsseldorf] -[Москва] -[Köln] -[إسرائيل] - -# test string view -statement ok -CREATE TABLE t_stringview AS -SELECT arrow_cast(str, 'Utf8View') as str, arrow_cast(pattern, 'Utf8View') as pattern, arrow_cast(flags, 'Utf8View') as flags FROM t; - -query ? -SELECT regexp_match(str, pattern, flags) FROM t_stringview; ----- -[a] -[A] -[B] -NULL -NULL -NULL -[010] -[Düsseldorf] -[Москва] -[Köln] -[إسرائيل] - -statement ok -DROP TABLE t_stringview; - -query ? -SELECT regexp_match('foobarbequebaz', ''); ----- -[] - -query ? -SELECT regexp_match('', ''); ----- -[] - -query ? -SELECT regexp_match('foobarbequebaz', '(bar)(beque)'); ----- -[bar, beque] - -query ? -SELECT regexp_match('fooBarb -eQuebaz', '(bar).*(que)', 'is'); ----- -[Bar, Que] - -query ? -SELECT regexp_match('foobarbequebaz', '(ba3r)(bequ34e)'); ----- -NULL - -query ? -SELECT regexp_match('foobarbequebaz', '^.*(barbequ[0-9]*e).*$', 'm'); ----- -[barbeque] - -query ? -SELECT regexp_match('aaa-0', '.*-(\d)'); ----- -[0] - -query ? -SELECT regexp_match('bb-1', '.*-(\d)'); ----- -[1] - -query ? -SELECT regexp_match('aa', '.*-(\d)'); ----- -NULL - -query ? -SELECT regexp_match(NULL, '.*-(\d)'); ----- -NULL - -query ? -SELECT regexp_match('aaa-0', NULL); ----- -NULL - -query ? -SELECT regexp_match(null, '.*-(\d)'); ----- -NULL - -query error Error during planning: regexp_match\(\) does not support the "global" option -SELECT regexp_match('bb-1', '.*-(\d)', 'g'); - -query error Error during planning: regexp_match\(\) does not support the "global" option -SELECT regexp_match('bb-1', '.*-(\d)', 'g'); - -query error Arrow error: Compute error: Regular expression did not compile: CompiledTooBig\(10485760\) -SELECT regexp_match('aaaaa', 'a{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}'); - -# look-around is not supported and will just return null -query ? -SELECT regexp_match('(?<=[A-Z]\w )Smith', 'John Smith', 'i'); ----- -NULL - -# ported test -query ? -SELECT regexp_match('aaa-555', '.*-(\d*)'); ----- -[555] - -query B -select 'abc' ~ null; ----- -NULL - -query B -select null ~ null; ----- -NULL - -query B -select null ~ 'abc'; ----- -NULL - -query B -select 'abc' ~* null; ----- -NULL - -query B -select null ~* null; ----- -NULL - -query B -select null ~* 'abc'; ----- -NULL - -query B -select 'abc' !~ null; ----- -NULL - -query B -select null !~ null; ----- -NULL - -query B -select null !~ 'abc'; ----- -NULL - -query B -select 'abc' !~* null; ----- -NULL - -query B -select null !~* null; ----- -NULL - -query B -select null !~* 'abc'; ----- -NULL - -# -# regexp_replace tests -# - -query T -SELECT regexp_replace(str, pattern, 'X', concat('g', flags)) FROM t; ----- -Xbc -X -aXc -AbC -aBC -4000 -X -X -X -X -X - -# test string view -statement ok -CREATE TABLE t_stringview AS -SELECT arrow_cast(str, 'Utf8View') as str, arrow_cast(pattern, 'Utf8View') as pattern, arrow_cast(flags, 'Utf8View') as flags FROM t; - -query T -SELECT regexp_replace(str, pattern, 'X', concat('g', flags)) FROM t_stringview; ----- -Xbc -X -aXc -AbC -aBC -4000 -X -X -X -X -X - -statement ok -DROP TABLE t_stringview; - -query T -SELECT regexp_replace('ABCabcABC', '(abc)', 'X', 'gi'); ----- -XXX - -query T -SELECT regexp_replace('ABCabcABC', '(abc)', 'X', 'i'); ----- -XabcABC - -query T -SELECT regexp_replace('foobarbaz', 'b..', 'X', 'g'); ----- -fooXX - -query T -SELECT regexp_replace('foobarbaz', 'b..', 'X'); ----- -fooXbaz - -query T -SELECT regexp_replace('foobarbaz', 'b(..)', 'X\\1Y', 'g'); ----- -fooXarYXazY - -query T -SELECT regexp_replace('foobarbaz', 'b(..)', 'X\\1Y', NULL); ----- -NULL - -query T -SELECT regexp_replace('foobarbaz', 'b(..)', NULL, 'g'); ----- -NULL - -query T -SELECT regexp_replace('foobarbaz', NULL, 'X\\1Y', 'g'); ----- -NULL - -query T -SELECT regexp_replace('Thomas', '.[mN]a.', 'M'); ----- -ThM - -query T -SELECT regexp_replace(NULL, 'b(..)', 'X\\1Y', 'g'); ----- -NULL - -query T -SELECT regexp_replace('foobar', 'bar', 'xx', 'gi') ----- -fooxx - -query T -SELECT regexp_replace(arrow_cast('foobar', 'Dictionary(Int32, Utf8)'), 'bar', 'xx', 'gi') ----- -fooxx - -query TTT -select - regexp_replace(col, NULL, 'c'), - regexp_replace(col, 'a', NULL), - regexp_replace(col, 'a', 'c', NULL) -from (values ('a'), ('b')) as tbl(col); ----- -NULL NULL NULL -NULL NULL NULL - -# multiline string -query B -SELECT 'foo\nbar\nbaz' ~ 'bar'; ----- -true - -statement error -Error during planning: Cannot infer common argument type for regex operation List(Field { name: "item", data_type: Int64, nullable: true, dict_is_ordered: false, metadata -: {} }) ~ List(Field { name: "item", data_type: Int64, nullable: true, dict_is_ordered: false, metadata: {} }) -select [1,2] ~ [3]; - -query B -SELECT 'foo\nbar\nbaz' LIKE '%bar%'; ----- -true - -query B -SELECT NULL LIKE NULL; ----- -NULL - -query B -SELECT NULL iLIKE NULL; ----- -NULL - -query B -SELECT NULL not LIKE NULL; ----- -NULL - -query B -SELECT NULL not iLIKE NULL; ----- -NULL - -# regexp_count tests - -# regexp_count tests from postgresql -# https://github.com/postgres/postgres/blob/56d23855c864b7384970724f3ad93fb0fc319e51/src/test/regress/sql/strings.sql#L226-L235 - -query I -SELECT regexp_count('123123123123123', '(12)3'); ----- -5 - -query I -SELECT regexp_count('123123123123', '123', 1); ----- -4 - -query I -SELECT regexp_count('123123123123', '123', 3); ----- -3 - -query I -SELECT regexp_count('123123123123', '123', 33); ----- -0 - -query I -SELECT regexp_count('ABCABCABCABC', 'Abc', 1, ''); ----- -0 - -query I -SELECT regexp_count('ABCABCABCABC', 'Abc', 1, 'i'); ----- -4 - -statement error -External error: query failed: DataFusion error: Arrow error: Compute error: regexp_count() requires start to be 1 based -SELECT regexp_count('123123123123', '123', 0); - -statement error -External error: query failed: DataFusion error: Arrow error: Compute error: regexp_count() requires start to be 1 based -SELECT regexp_count('123123123123', '123', -3); - -statement error -External error: statement failed: DataFusion error: Arrow error: Compute error: regexp_count() does not support global flag -SELECT regexp_count('123123123123', '123', 1, 'g'); - -query I -SELECT regexp_count(str, '\w') from t; ----- -3 -3 -3 -3 -3 -4 -4 -10 -6 -4 -7 - -query I -SELECT regexp_count(str, '\w{2}', start) from t; ----- -1 -1 -1 -1 -0 -2 -1 -4 -1 -2 -3 - -query I -SELECT regexp_count(str, 'ab', 1, 'i') from t; ----- -1 -1 -1 -1 -1 -0 -0 -0 -0 -0 -0 - - -query I -SELECT regexp_count(str, pattern) from t; ----- -1 -1 -0 -0 -0 -0 -1 -1 -1 -1 -1 - -query I -SELECT regexp_count(str, pattern, start) from t; ----- -1 -1 -0 -0 -0 -0 -0 -1 -1 -1 -1 - -query I -SELECT regexp_count(str, pattern, start, flags) from t; ----- -1 -1 -1 -0 -0 -0 -0 -1 -1 -1 -1 - -# test type coercion -query I -SELECT regexp_count(arrow_cast(str, 'Utf8'), arrow_cast(pattern, 'LargeUtf8'), arrow_cast(start, 'Int32'), flags) from t; ----- -1 -1 -1 -0 -0 -0 -0 -1 -1 -1 -1 - -# test string views - -statement ok -CREATE TABLE t_stringview AS -SELECT arrow_cast(str, 'Utf8View') as str, arrow_cast(pattern, 'Utf8View') as pattern, arrow_cast(start, 'Int64') as start, arrow_cast(flags, 'Utf8View') as flags FROM t; - -query I -SELECT regexp_count(str, '\w') from t_stringview; ----- -3 -3 -3 -3 -3 -4 -4 -10 -6 -4 -7 - -query I -SELECT regexp_count(str, '\w{2}', start) from t_stringview; ----- -1 -1 -1 -1 -0 -2 -1 -4 -1 -2 -3 - -query I -SELECT regexp_count(str, 'ab', 1, 'i') from t_stringview; ----- -1 -1 -1 -1 -1 -0 -0 -0 -0 -0 -0 - - -query I -SELECT regexp_count(str, pattern) from t_stringview; ----- -1 -1 -0 -0 -0 -0 -1 -1 -1 -1 -1 - -query I -SELECT regexp_count(str, pattern, start) from t_stringview; ----- -1 -1 -0 -0 -0 -0 -0 -1 -1 -1 -1 - -query I -SELECT regexp_count(str, pattern, start, flags) from t_stringview; ----- -1 -1 -1 -0 -0 -0 -0 -1 -1 -1 -1 - -# test type coercion -query I -SELECT regexp_count(arrow_cast(str, 'Utf8'), arrow_cast(pattern, 'LargeUtf8'), arrow_cast(start, 'Int32'), flags) from t_stringview; ----- -1 -1 -1 -0 -0 -0 -0 -1 -1 -1 -1 - -# NULL tests - -query I -SELECT regexp_count(NULL, NULL); ----- -0 - -query I -SELECT regexp_count(NULL, 'a'); ----- -0 - -query I -SELECT regexp_count('a', NULL); ----- -0 - -query I -SELECT regexp_count(NULL, NULL, NULL, NULL); ----- -0 - -statement ok -CREATE TABLE empty_table (str varchar, pattern varchar, start int, flags varchar); - -query I -SELECT regexp_count(str, pattern, start, flags) from empty_table; ----- - -statement ok -INSERT INTO empty_table VALUES ('a', NULL, 1, 'i'), (NULL, 'a', 1, 'i'), (NULL, NULL, 1, 'i'), (NULL, NULL, NULL, 'i'); - -query I -SELECT regexp_count(str, pattern, start, flags) from empty_table; ----- -0 -0 -0 -0 - -statement ok -drop table t; - -statement ok -create or replace table strings as values - ('FooBar'), - ('Foo'), - ('Foo'), - ('Bar'), - ('FooBar'), - ('Bar'), - ('Baz'); - -statement ok -create or replace table dict_table as -select arrow_cast(column1, 'Dictionary(Int32, Utf8)') as column1 -from strings; - -query T -select column1 from dict_table where column1 LIKE '%oo%'; ----- -FooBar -Foo -Foo -FooBar - -query T -select column1 from dict_table where column1 NOT LIKE '%oo%'; ----- -Bar -Bar -Baz - -query T -select column1 from dict_table where column1 ILIKE '%oO%'; ----- -FooBar -Foo -Foo -FooBar - -query T -select column1 from dict_table where column1 NOT ILIKE '%oO%'; ----- -Bar -Bar -Baz - - -# plan should not cast the column, instead it should use the dictionary directly -query TT -explain select column1 from dict_table where column1 LIKE '%oo%'; ----- -logical_plan -01)Filter: dict_table.column1 LIKE Utf8("%oo%") -02)--TableScan: dict_table projection=[column1] -physical_plan -01)CoalesceBatchesExec: target_batch_size=8192 -02)--FilterExec: column1@0 LIKE %oo% -03)----DataSourceExec: partitions=1, partition_sizes=[1] - -# Ensure casting / coercion works for all operators -# (there should be no casts to Utf8) -query TT -explain select - column1 LIKE '%oo%', - column1 NOT LIKE '%oo%', - column1 ILIKE '%oo%', - column1 NOT ILIKE '%oo%' -from dict_table; ----- -logical_plan -01)Projection: dict_table.column1 LIKE Utf8("%oo%"), dict_table.column1 NOT LIKE Utf8("%oo%"), dict_table.column1 ILIKE Utf8("%oo%"), dict_table.column1 NOT ILIKE Utf8("%oo%") -02)--TableScan: dict_table projection=[column1] -physical_plan -01)ProjectionExec: expr=[column1@0 LIKE %oo% as dict_table.column1 LIKE Utf8("%oo%"), column1@0 NOT LIKE %oo% as dict_table.column1 NOT LIKE Utf8("%oo%"), column1@0 ILIKE %oo% as dict_table.column1 ILIKE Utf8("%oo%"), column1@0 NOT ILIKE %oo% as dict_table.column1 NOT ILIKE Utf8("%oo%")] -02)--DataSourceExec: partitions=1, partition_sizes=[1] - -statement ok -drop table strings - -statement ok -drop table dict_table diff --git a/datafusion/sqllogictest/test_files/regexp/README.md b/datafusion/sqllogictest/test_files/regexp/README.md new file mode 100644 index 000000000000..7e5efc5b5ddf --- /dev/null +++ b/datafusion/sqllogictest/test_files/regexp/README.md @@ -0,0 +1,59 @@ + + +# Regexp Test Files + +This directory contains test files for regular expression (regexp) functions in DataFusion. + +## Directory Structure + +``` +regexp/ + - init_data.slt.part // Shared test data for regexp functions + - regexp_like.slt // Tests for regexp_like function + - regexp_count.slt // Tests for regexp_count function + - regexp_match.slt // Tests for regexp_match function + - regexp_replace.slt // Tests for regexp_replace function +``` + +## Tested Functions + +1. `regexp_like`: Check if a string matches a regular expression +2. `regexp_count`: Count occurrences of a pattern in a string +3. `regexp_match`: Extract matching substrings +4. `regexp_replace`: Replace matched substrings + +## Test Data + +Test data is centralized in the `init_data.slt.part` file and imported into each test file using the `include` directive. This approach ensures: + +Consistent test data across different regexp function tests +Easy maintenance of test data +Reduced duplication + +## Test Coverage + +Each test file covers: + +Basic functionality +Case-insensitive matching +Null handling +Start position tests +Capture group handling +Different string types (UTF-8, Unicode) diff --git a/datafusion/sqllogictest/test_files/regexp/init_data.slt.part b/datafusion/sqllogictest/test_files/regexp/init_data.slt.part new file mode 100644 index 000000000000..ed6fb0e872df --- /dev/null +++ b/datafusion/sqllogictest/test_files/regexp/init_data.slt.part @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +statement ok +create table regexp_test_data (str varchar, pattern varchar, start int, flags varchar) as values + (NULL, '^(a)', 1, 'i'), + ('abc', '^(a)', 1, 'i'), + ('ABC', '^(A).*', 1, 'i'), + ('aBc', '(b|d)', 1, 'i'), + ('AbC', '(B|D)', 2, null), + ('aBC', '^(b|c)', 3, null), + ('4000', '\b4([1-9]\d\d|\d[1-9]\d|\d\d[1-9])\b', 1, null), + ('4010', '\b4([1-9]\d\d|\d[1-9]\d|\d\d[1-9])\b', 2, null), + ('Düsseldorf','[\p{Letter}-]+', 3, null), + ('Москва', '[\p{L}-]+', 4, null), + ('Köln', '[a-zA-Z]ö[a-zA-Z]{2}', 1, null), + ('إسرائيل', '^\p{Arabic}+$', 2, null); diff --git a/datafusion/sqllogictest/test_files/regexp/regexp_count.slt b/datafusion/sqllogictest/test_files/regexp/regexp_count.slt new file mode 100644 index 000000000000..d842a1ee81df --- /dev/null +++ b/datafusion/sqllogictest/test_files/regexp/regexp_count.slt @@ -0,0 +1,344 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Import common test data +include ./init_data.slt.part + +# regexp_count tests from postgresql +# https://github.com/postgres/postgres/blob/56d23855c864b7384970724f3ad93fb0fc319e51/src/test/regress/sql/strings.sql#L226-L235 + +query I +SELECT regexp_count('123123123123123', '(12)3'); +---- +5 + +query I +SELECT regexp_count('123123123123', '123', 1); +---- +4 + +query I +SELECT regexp_count('123123123123', '123', 3); +---- +3 + +query I +SELECT regexp_count('123123123123', '123', 33); +---- +0 + +query I +SELECT regexp_count('ABCABCABCABC', 'Abc', 1, ''); +---- +0 + +query I +SELECT regexp_count('ABCABCABCABC', 'Abc', 1, 'i'); +---- +4 + +statement error +External error: query failed: DataFusion error: Arrow error: Compute error: regexp_count() requires start to be 1 based +SELECT regexp_count('123123123123', '123', 0); + +statement error +External error: query failed: DataFusion error: Arrow error: Compute error: regexp_count() requires start to be 1 based +SELECT regexp_count('123123123123', '123', -3); + +statement error +External error: statement failed: DataFusion error: Arrow error: Compute error: regexp_count() does not support global flag +SELECT regexp_count('123123123123', '123', 1, 'g'); + +query I +SELECT regexp_count(str, '\w') from regexp_test_data; +---- +0 +3 +3 +3 +3 +3 +4 +4 +10 +6 +4 +7 + +query I +SELECT regexp_count(str, '\w{2}', start) from regexp_test_data; +---- +0 +1 +1 +1 +1 +0 +2 +1 +4 +1 +2 +3 + +query I +SELECT regexp_count(str, 'ab', 1, 'i') from regexp_test_data; +---- +0 +1 +1 +1 +1 +1 +0 +0 +0 +0 +0 +0 + + +query I +SELECT regexp_count(str, pattern) from regexp_test_data; +---- +0 +1 +1 +0 +0 +0 +0 +1 +1 +1 +1 +1 + +query I +SELECT regexp_count(str, pattern, start) from regexp_test_data; +---- +0 +1 +1 +0 +0 +0 +0 +0 +1 +1 +1 +1 + +query I +SELECT regexp_count(str, pattern, start, flags) from regexp_test_data; +---- +0 +1 +1 +1 +0 +0 +0 +0 +1 +1 +1 +1 + +# test type coercion +query I +SELECT regexp_count(arrow_cast(str, 'Utf8'), arrow_cast(pattern, 'LargeUtf8'), arrow_cast(start, 'Int32'), flags) from regexp_test_data; +---- +0 +1 +1 +1 +0 +0 +0 +0 +1 +1 +1 +1 + +# test string views + +statement ok +CREATE TABLE t_stringview AS +SELECT arrow_cast(str, 'Utf8View') as str, arrow_cast(pattern, 'Utf8View') as pattern, arrow_cast(start, 'Int64') as start, arrow_cast(flags, 'Utf8View') as flags FROM regexp_test_data; + +query I +SELECT regexp_count(str, '\w') from t_stringview; +---- +0 +3 +3 +3 +3 +3 +4 +4 +10 +6 +4 +7 + +query I +SELECT regexp_count(str, '\w{2}', start) from t_stringview; +---- +0 +1 +1 +1 +1 +0 +2 +1 +4 +1 +2 +3 + +query I +SELECT regexp_count(str, 'ab', 1, 'i') from t_stringview; +---- +0 +1 +1 +1 +1 +1 +0 +0 +0 +0 +0 +0 + + +query I +SELECT regexp_count(str, pattern) from t_stringview; +---- +0 +1 +1 +0 +0 +0 +0 +1 +1 +1 +1 +1 + +query I +SELECT regexp_count(str, pattern, start) from t_stringview; +---- +0 +1 +1 +0 +0 +0 +0 +0 +1 +1 +1 +1 + +query I +SELECT regexp_count(str, pattern, start, flags) from t_stringview; +---- +0 +1 +1 +1 +0 +0 +0 +0 +1 +1 +1 +1 + +# test type coercion +query I +SELECT regexp_count(arrow_cast(str, 'Utf8'), arrow_cast(pattern, 'LargeUtf8'), arrow_cast(start, 'Int32'), flags) from t_stringview; +---- +0 +1 +1 +1 +0 +0 +0 +0 +1 +1 +1 +1 + +# NULL tests + +query I +SELECT regexp_count(NULL, NULL); +---- +0 + +query I +SELECT regexp_count(NULL, 'a'); +---- +0 + +query I +SELECT regexp_count('a', NULL); +---- +0 + +query I +SELECT regexp_count(NULL, NULL, NULL, NULL); +---- +0 + +statement ok +CREATE TABLE empty_table (str varchar, pattern varchar, start int, flags varchar); + +query I +SELECT regexp_count(str, pattern, start, flags) from empty_table; +---- + +statement ok +INSERT INTO empty_table VALUES ('a', NULL, 1, 'i'), (NULL, 'a', 1, 'i'), (NULL, NULL, 1, 'i'), (NULL, NULL, NULL, 'i'); + +query I +SELECT regexp_count(str, pattern, start, flags) from empty_table; +---- +0 +0 +0 +0 + +statement ok +drop table t_stringview; + +statement ok +drop table empty_table; diff --git a/datafusion/sqllogictest/test_files/regexp/regexp_like.slt b/datafusion/sqllogictest/test_files/regexp/regexp_like.slt new file mode 100644 index 000000000000..223ef22b9861 --- /dev/null +++ b/datafusion/sqllogictest/test_files/regexp/regexp_like.slt @@ -0,0 +1,279 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Import common test data +include ./init_data.slt.part + +query B +SELECT regexp_like(str, pattern, flags) FROM regexp_test_data; +---- +NULL +true +true +true +false +false +false +true +true +true +true +true + +query B +SELECT str ~ NULL FROM regexp_test_data; +---- +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL + +query B +select str ~ right('foo', NULL) FROM regexp_test_data; +---- +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL + +query B +select right('foo', NULL) !~ str FROM regexp_test_data; +---- +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL + +query B +SELECT regexp_like('foobarbequebaz', ''); +---- +true + +query B +SELECT regexp_like('', ''); +---- +true + +query B +SELECT regexp_like('foobarbequebaz', '(bar)(beque)'); +---- +true + +query B +SELECT regexp_like('fooBarbeQuebaz', '(bar).*(que)', 'is'); +---- +true + +query B +SELECT regexp_like('foobarbequebaz', '(ba3r)(bequ34e)'); +---- +false + +query B +SELECT regexp_like('foobarbequebaz', '^.*(barbequ[0-9]*e).*$', 'm'); +---- +true + +query B +SELECT regexp_like('aaa-0', '.*-(\d)'); +---- +true + +query B +SELECT regexp_like('bb-1', '.*-(\d)'); +---- +true + +query B +SELECT regexp_like('aa', '.*-(\d)'); +---- +false + +query B +SELECT regexp_like(NULL, '.*-(\d)'); +---- +NULL + +query B +SELECT regexp_like('aaa-0', NULL); +---- +NULL + +query B +SELECT regexp_like(null, '.*-(\d)'); +---- +NULL + +query error Error during planning: regexp_like\(\) does not support the "global" option +SELECT regexp_like('bb-1', '.*-(\d)', 'g'); + +query error Error during planning: regexp_like\(\) does not support the "global" option +SELECT regexp_like('bb-1', '.*-(\d)', 'g'); + +query error Arrow error: Compute error: Regular expression did not compile: CompiledTooBig\(10485760\) +SELECT regexp_like('aaaaa', 'a{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}'); + +# look-around is not supported and will just return false +query B +SELECT regexp_like('(?<=[A-Z]\w )Smith', 'John Smith', 'i'); +---- +false + +query B +select regexp_like('aaa-555', '.*-(\d*)'); +---- +true + +# multiline string +query B +SELECT 'foo\nbar\nbaz' ~ 'bar'; +---- +true + +statement error +Error during planning: Cannot infer common argument type for regex operation List(Field { name: "item", data_type: Int64, nullable: true, metadata: {} }) ~ List(Field { name: "item", data_type: Int64, nullable: true, metadata: {} }) +select [1,2] ~ [3]; + +query B +SELECT 'foo\nbar\nbaz' LIKE '%bar%'; +---- +true + +query B +SELECT NULL LIKE NULL; +---- +NULL + +query B +SELECT NULL iLIKE NULL; +---- +NULL + +query B +SELECT NULL not LIKE NULL; +---- +NULL + +query B +SELECT NULL not iLIKE NULL; +---- +NULL + +statement ok +create or replace table strings as values + ('FooBar'), + ('Foo'), + ('Foo'), + ('Bar'), + ('FooBar'), + ('Bar'), + ('Baz'); + +statement ok +create or replace table dict_table as +select arrow_cast(column1, 'Dictionary(Int32, Utf8)') as column1 +from strings; + +query T +select column1 from dict_table where column1 LIKE '%oo%'; +---- +FooBar +Foo +Foo +FooBar + +query T +select column1 from dict_table where column1 NOT LIKE '%oo%'; +---- +Bar +Bar +Baz + +query T +select column1 from dict_table where column1 ILIKE '%oO%'; +---- +FooBar +Foo +Foo +FooBar + +query T +select column1 from dict_table where column1 NOT ILIKE '%oO%'; +---- +Bar +Bar +Baz + + +# plan should not cast the column, instead it should use the dictionary directly +query TT +explain select column1 from dict_table where column1 LIKE '%oo%'; +---- +logical_plan +01)Filter: dict_table.column1 LIKE Utf8("%oo%") +02)--TableScan: dict_table projection=[column1] +physical_plan +01)CoalesceBatchesExec: target_batch_size=8192 +02)--FilterExec: column1@0 LIKE %oo% +03)----DataSourceExec: partitions=1, partition_sizes=[1] + +# Ensure casting / coercion works for all operators +# (there should be no casts to Utf8) +query TT +explain select + column1 LIKE '%oo%', + column1 NOT LIKE '%oo%', + column1 ILIKE '%oo%', + column1 NOT ILIKE '%oo%' +from dict_table; +---- +logical_plan +01)Projection: dict_table.column1 LIKE Utf8("%oo%"), dict_table.column1 NOT LIKE Utf8("%oo%"), dict_table.column1 ILIKE Utf8("%oo%"), dict_table.column1 NOT ILIKE Utf8("%oo%") +02)--TableScan: dict_table projection=[column1] +physical_plan +01)ProjectionExec: expr=[column1@0 LIKE %oo% as dict_table.column1 LIKE Utf8("%oo%"), column1@0 NOT LIKE %oo% as dict_table.column1 NOT LIKE Utf8("%oo%"), column1@0 ILIKE %oo% as dict_table.column1 ILIKE Utf8("%oo%"), column1@0 NOT ILIKE %oo% as dict_table.column1 NOT ILIKE Utf8("%oo%")] +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +statement ok +drop table strings + +statement ok +drop table dict_table diff --git a/datafusion/sqllogictest/test_files/regexp/regexp_match.slt b/datafusion/sqllogictest/test_files/regexp/regexp_match.slt new file mode 100644 index 000000000000..e79af4774aa2 --- /dev/null +++ b/datafusion/sqllogictest/test_files/regexp/regexp_match.slt @@ -0,0 +1,201 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Import common test data +include ./init_data.slt.part + +query ? +SELECT regexp_match(str, pattern, flags) FROM regexp_test_data; +---- +NULL +[a] +[A] +[B] +NULL +NULL +NULL +[010] +[Düsseldorf] +[Москва] +[Köln] +[إسرائيل] + +# test string view +statement ok +CREATE TABLE t_stringview AS +SELECT arrow_cast(str, 'Utf8View') as str, arrow_cast(pattern, 'Utf8View') as pattern, arrow_cast(flags, 'Utf8View') as flags FROM regexp_test_data; + +query ? +SELECT regexp_match(str, pattern, flags) FROM t_stringview; +---- +NULL +[a] +[A] +[B] +NULL +NULL +NULL +[010] +[Düsseldorf] +[Москва] +[Köln] +[إسرائيل] + +statement ok +DROP TABLE t_stringview; + +query ? +SELECT regexp_match('foobarbequebaz', ''); +---- +[] + +query ? +SELECT regexp_match('', ''); +---- +[] + +query ? +SELECT regexp_match('foobarbequebaz', '(bar)(beque)'); +---- +[bar, beque] + +query ? +SELECT regexp_match('fooBarb +eQuebaz', '(bar).*(que)', 'is'); +---- +[Bar, Que] + +query ? +SELECT regexp_match('foobarbequebaz', '(ba3r)(bequ34e)'); +---- +NULL + +query ? +SELECT regexp_match('foobarbequebaz', '^.*(barbequ[0-9]*e).*$', 'm'); +---- +[barbeque] + +query ? +SELECT regexp_match('aaa-0', '.*-(\d)'); +---- +[0] + +query ? +SELECT regexp_match('bb-1', '.*-(\d)'); +---- +[1] + +query ? +SELECT regexp_match('aa', '.*-(\d)'); +---- +NULL + +query ? +SELECT regexp_match(NULL, '.*-(\d)'); +---- +NULL + +query ? +SELECT regexp_match('aaa-0', NULL); +---- +NULL + +query ? +SELECT regexp_match(null, '.*-(\d)'); +---- +NULL + +query error Error during planning: regexp_match\(\) does not support the "global" option +SELECT regexp_match('bb-1', '.*-(\d)', 'g'); + +query error Error during planning: regexp_match\(\) does not support the "global" option +SELECT regexp_match('bb-1', '.*-(\d)', 'g'); + +query error Arrow error: Compute error: Regular expression did not compile: CompiledTooBig\(10485760\) +SELECT regexp_match('aaaaa', 'a{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}'); + +# look-around is not supported and will just return null +query ? +SELECT regexp_match('(?<=[A-Z]\w )Smith', 'John Smith', 'i'); +---- +NULL + +# ported test +query ? +SELECT regexp_match('aaa-555', '.*-(\d*)'); +---- +[555] + +query B +select 'abc' ~ null; +---- +NULL + +query B +select null ~ null; +---- +NULL + +query B +select null ~ 'abc'; +---- +NULL + +query B +select 'abc' ~* null; +---- +NULL + +query B +select null ~* null; +---- +NULL + +query B +select null ~* 'abc'; +---- +NULL + +query B +select 'abc' !~ null; +---- +NULL + +query B +select null !~ null; +---- +NULL + +query B +select null !~ 'abc'; +---- +NULL + +query B +select 'abc' !~* null; +---- +NULL + +query B +select null !~* null; +---- +NULL + +query B +select null !~* 'abc'; +---- +NULL diff --git a/datafusion/sqllogictest/test_files/regexp/regexp_replace.slt b/datafusion/sqllogictest/test_files/regexp/regexp_replace.slt new file mode 100644 index 000000000000..a16801adcef7 --- /dev/null +++ b/datafusion/sqllogictest/test_files/regexp/regexp_replace.slt @@ -0,0 +1,129 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Import common test data +include ./init_data.slt.part + +query T +SELECT regexp_replace(str, pattern, 'X', concat('g', flags)) FROM regexp_test_data; +---- +NULL +Xbc +X +aXc +AbC +aBC +4000 +X +X +X +X +X + +# test string view +statement ok +CREATE TABLE t_stringview AS +SELECT arrow_cast(str, 'Utf8View') as str, arrow_cast(pattern, 'Utf8View') as pattern, arrow_cast(flags, 'Utf8View') as flags FROM regexp_test_data; + +query T +SELECT regexp_replace(str, pattern, 'X', concat('g', flags)) FROM t_stringview; +---- +NULL +Xbc +X +aXc +AbC +aBC +4000 +X +X +X +X +X + +statement ok +DROP TABLE t_stringview; + +query T +SELECT regexp_replace('ABCabcABC', '(abc)', 'X', 'gi'); +---- +XXX + +query T +SELECT regexp_replace('ABCabcABC', '(abc)', 'X', 'i'); +---- +XabcABC + +query T +SELECT regexp_replace('foobarbaz', 'b..', 'X', 'g'); +---- +fooXX + +query T +SELECT regexp_replace('foobarbaz', 'b..', 'X'); +---- +fooXbaz + +query T +SELECT regexp_replace('foobarbaz', 'b(..)', 'X\\1Y', 'g'); +---- +fooXarYXazY + +query T +SELECT regexp_replace('foobarbaz', 'b(..)', 'X\\1Y', NULL); +---- +NULL + +query T +SELECT regexp_replace('foobarbaz', 'b(..)', NULL, 'g'); +---- +NULL + +query T +SELECT regexp_replace('foobarbaz', NULL, 'X\\1Y', 'g'); +---- +NULL + +query T +SELECT regexp_replace('Thomas', '.[mN]a.', 'M'); +---- +ThM + +query T +SELECT regexp_replace(NULL, 'b(..)', 'X\\1Y', 'g'); +---- +NULL + +query T +SELECT regexp_replace('foobar', 'bar', 'xx', 'gi') +---- +fooxx + +query T +SELECT regexp_replace(arrow_cast('foobar', 'Dictionary(Int32, Utf8)'), 'bar', 'xx', 'gi') +---- +fooxx + +query TTT +select + regexp_replace(col, NULL, 'c'), + regexp_replace(col, 'a', NULL), + regexp_replace(col, 'a', 'c', NULL) +from (values ('a'), ('b')) as tbl(col); +---- +NULL NULL NULL +NULL NULL NULL diff --git a/datafusion/sqllogictest/test_files/repartition.slt b/datafusion/sqllogictest/test_files/repartition.slt index 70666346e2ca..29d20d10b671 100644 --- a/datafusion/sqllogictest/test_files/repartition.slt +++ b/datafusion/sqllogictest/test_files/repartition.slt @@ -46,8 +46,8 @@ physical_plan 01)AggregateExec: mode=FinalPartitioned, gby=[column1@0 as column1], aggr=[sum(parquet_table.column2)] 02)--CoalesceBatchesExec: target_batch_size=8192 03)----RepartitionExec: partitioning=Hash([column1@0], 4), input_partitions=4 -04)------AggregateExec: mode=Partial, gby=[column1@0 as column1], aggr=[sum(parquet_table.column2)] -05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +05)--------AggregateExec: mode=Partial, gby=[column1@0 as column1], aggr=[sum(parquet_table.column2)] 06)----------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition/parquet_table/2.parquet]]}, projection=[column1, column2], file_type=parquet # disable round robin repartitioning diff --git a/datafusion/sqllogictest/test_files/simplify_expr.slt b/datafusion/sqllogictest/test_files/simplify_expr.slt index 9985ab49c2da..075ccafcfd2e 100644 --- a/datafusion/sqllogictest/test_files/simplify_expr.slt +++ b/datafusion/sqllogictest/test_files/simplify_expr.slt @@ -107,4 +107,3 @@ query B SELECT a / NULL::DECIMAL(4,3) > 1.2::decimal(2,1) FROM VALUES (1) AS t(a); ---- NULL - diff --git a/datafusion/sqllogictest/test_files/sort_merge_join.slt b/datafusion/sqllogictest/test_files/sort_merge_join.slt index 162c9a17b61f..c17fe8dfc7e6 100644 --- a/datafusion/sqllogictest/test_files/sort_merge_join.slt +++ b/datafusion/sqllogictest/test_files/sort_merge_join.slt @@ -695,6 +695,144 @@ select t2.* from t1 right anti join t2 on t1.a = t2.a and t1.b = t2.b ---- 51 54 +# RIGHTSEMI join tests + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b union all + select 11 a, 13 b), +t2 as ( + select 11 a, 12 b union all + select 11 a, 13 b + ) + select t2.* from t1 right semi join t2 on t1.a = t2.a and t1.b = t2.b +) order by 1, 2; +---- +11 12 +11 13 + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b union all + select 11 a, 13 b), +t2 as ( + select 11 a, 12 b union all + select 11 a, 13 b + ) + select t2.* from t1 right semi join t2 on t1.a = t2.a and t1.b != t2.b +) order by 1, 2; +---- +11 12 +11 13 + +query II +select * from ( +with +t1 as ( + select null a, 12 b union all + select 11 a, 13 b), +t2 as ( + select 11 a, 12 b union all + select 11 a, 13 b + ) + select t2.* from t1 right semi join t2 on t1.a = t2.a and t1.b != t2.b +) order by 1, 2; +---- +11 12 + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b union all + select 11 a, 13 b) + select t2.* from t1 right semi join t1 t2 on t1.a = t2.a and t1.b = t2.b +) order by 1, 2; +---- +11 12 +11 13 + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b union all + select 11 a, 13 b) + select t2.* from t1 right semi join t1 t2 on t1.a = t2.a and t1.b != t2.b +) order by 1, 2; +---- +11 12 +11 13 + +query II +select * from ( +with +t1 as ( + select null a, 12 b union all + select 11 a, 13 b) + select t2.* from t1 right semi join t1 t2 on t1.a = t2.a and t1.b != t2.b +) order by 1, 2; +---- + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b union all + select 11 a, 13 b), +t2 as ( + select 11 a, 12 b union all + select 11 a, 14 b + ) + select t2.* from t1 right semi join t2 on t1.a = t2.a and t1.b != t2.b +) order by 1, 2; +---- +11 12 +11 14 + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b union all + select 11 a, 12 b union all + select 11 a, 14 b + ), +t2 as ( + select 11 a, 12 b union all + select 11 a, 13 b + ) + select t2.* from t1 right semi join t2 on t1.a = t2.a and t1.b != t2.b +) order by 1, 2; +---- +11 12 +11 13 + +# Test RIGHTSEMI with cross batch data distribution + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b union all + select 11 a, 13 b union all + select 12 a, 14 b), +t2 as ( + select 11 a, 12 b union all + select 11 a, 14 b union all + select 12 a, 15 b + ) + select t2.* from t1 right semi join t2 on t1.a = t2.a and t1.b != t2.b +) order by 1, 2; +---- +11 12 +11 14 +12 15 + # return sql params back to default values statement ok set datafusion.optimizer.prefer_hash_join = true; diff --git a/datafusion/sqllogictest/test_files/spark/README.md b/datafusion/sqllogictest/test_files/spark/README.md new file mode 100644 index 000000000000..0a7bb92371b5 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/README.md @@ -0,0 +1,57 @@ + + +# Spark Test Files + +This directory contains test files for the `spark` test suite. + +## Testing Guide + +When testing Spark functions: + +- Functions must be tested on both `Scalar` and `Array` inputs +- Test cases should only contain `SELECT` statements with the function being tested +- Add explicit casts to input values to ensure the correct data type is used (e.g., `0::INT`) + - Explicit casting is necessary because DataFusion and Spark do not infer data types in the same way + +### Finding Test Cases + +To verify and compare function behavior at a minimum, you can refer to the following documentation sources: + +1. Databricks SQL Function Reference: + https://docs.databricks.com/aws/en/sql/language-manual/functions/NAME +2. Apache Spark SQL Function Reference: + https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.NAME.html +3. PySpark SQL Function Reference: + https://spark.apache.org/docs/latest/api/sql/#NAME + +**Note:** Replace `NAME` in each URL with the actual function name (e.g., for the `ASCII` function, use `ascii` instead +of `NAME`). + +### Scalar Example: + +```sql +SELECT expm1(0::INT); +``` + +### Array Example: + +```sql +SELECT expm1(a) FROM (VALUES (0::INT), (1::INT)) AS t(a); +``` diff --git a/datafusion/sqllogictest/test_files/spark/hash/sha2.slt b/datafusion/sqllogictest/test_files/spark/hash/sha2.slt new file mode 100644 index 000000000000..e2341df164ba --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/hash/sha2.slt @@ -0,0 +1,65 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +query T +SELECT sha2('Spark', 0::INT); +---- +529BC3B07127ECB7E53A4DCF1991D9152C24537D919178022B2C42657F79A26B + +query T +SELECT sha2('Spark', 256::INT); +---- +529BC3B07127ECB7E53A4DCF1991D9152C24537D919178022B2C42657F79A26B + +query T +SELECT sha2('Spark', 224::INT); +---- +DBEAB94971678D36AF2195851C0F7485775A2A7C60073D62FC04549C + +query T +SELECT sha2('Spark', 384::INT); +---- +1E40B8D06C248A1CC32428C22582B6219D072283078FA140D9AD297ECADF2CABEFC341B857AD36226AA8D6D79F2AB67D + +query T +SELECT sha2('Spark', 512::INT); +---- +44844A586C54C9A212DA1DBFE05C5F1705DE1AF5FDA1F0D36297623249B279FD8F0CCEC03F888F4FB13BF7CD83FDAD58591C797F81121A23CFDD5E0897795238 + +query T +SELECT sha2(expr, 256::INT) FROM VALUES ('foo'), ('bar') AS t(expr); +---- +2C26B46B68FFC68FF99B453C1D30413413422D706483BFA0F98A5E886266E7AE +FCDE2B2EDBA56BF408601FB721FE9B5C338D10EE429EA04FAE5511B68FBF8FB9 + +query T +SELECT sha2('foo', bit_length) FROM VALUES (0::INT), (256::INT), (224::INT), (384::INT), (512::INT) AS t(bit_length); +---- +2C26B46B68FFC68FF99B453C1D30413413422D706483BFA0F98A5E886266E7AE +2C26B46B68FFC68FF99B453C1D30413413422D706483BFA0F98A5E886266E7AE +0808F64E60D58979FCB676C96EC938270DEA42445AEEFCD3A4E6F8DB +98C11FFDFDD540676B1A137CB1A22B2A70350C9A44171D6B1180C6BE5CBB2EE3F79D532C8A1DD9EF2E8E08E752A3BABB +F7FBBA6E0636F890E56FBBF3283E524C6FA3204AE298382D624741D0DC6638326E282C41BE5E4254D8820772C5518A2C5A8C0C7F7EDA19594A7EB539453E1ED7 + + +query T +SELECT sha2(expr, bit_length) FROM VALUES ('foo',0::INT), ('bar',224::INT), ('baz',384::INT), ('qux',512::INT) AS t(expr, bit_length); +---- +2C26B46B68FFC68FF99B453C1D30413413422D706483BFA0F98A5E886266E7AE +07DAF010DE7F7F0D8D76A76EB8D1EB40182C8D1E7A3877A6686C9BF0 +967004D25DE4ABC1BD6A7C9A216254A5AC0733E8AD96DC9F1EA0FAD9619DA7C32D654EC8AD8BA2F9B5728FED6633BD91 +8C6BE9ED448A34883A13A13F4EAD4AEFA036B67DCDA59020C01E57EA075EA8A4792D428F2C6FD0C09D1C49994D6C22789336E062188DF29572ED07E7F9779C52 diff --git a/datafusion/sqllogictest/test_files/spark/math/expm1.slt b/datafusion/sqllogictest/test_files/spark/math/expm1.slt new file mode 100644 index 000000000000..96d4abb0414b --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/expm1.slt @@ -0,0 +1,32 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +query R +SELECT expm1(0::INT); +---- +0 + +query R +SELECT expm1(1::INT); +---- +1.718281828459045 + +query R +SELECT expm1(a) FROM (VALUES (0::INT), (1::INT)) AS t(a); +---- +0 +1.718281828459045 diff --git a/datafusion/sqllogictest/test_files/spark/math/hex.slt b/datafusion/sqllogictest/test_files/spark/math/hex.slt new file mode 100644 index 000000000000..24db1a318358 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/hex.slt @@ -0,0 +1,40 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +query T +SELECT hex('Spark SQL'); +---- +537061726B2053514C + +query T +SELECT hex(1234::INT); +---- +4D2 + +query T +SELECT hex(a) from VALUES (1234::INT), (NULL), (456::INT) AS t(a); +---- +4D2 +NULL +1C8 + +query T +SELECT hex(a) from VALUES ('foo'), (NULL), ('foobarbaz') AS t(a); +---- +666F6F +NULL +666F6F62617262617A diff --git a/datafusion/sqllogictest/test_files/spark/string/ascii.slt b/datafusion/sqllogictest/test_files/spark/string/ascii.slt new file mode 100644 index 000000000000..623154ffaa7b --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/ascii.slt @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +query I +SELECT ascii('234'); +---- +50 + +query I +SELECT ascii(''); +---- +0 + +query I +SELECT ascii('222'); +---- +50 + +query I +SELECT ascii('😀'); +---- +128512 + +query I +SELECT ascii(2::INT); +---- +50 + +query I +SELECT ascii(a) FROM (VALUES ('Spark'), ('PySpark'), ('Pandas API')) AS t(a); +---- +83 +80 +80 diff --git a/datafusion/sqllogictest/test_files/spark/string/char.slt b/datafusion/sqllogictest/test_files/spark/string/char.slt new file mode 100644 index 000000000000..d8fc11f6d512 Binary files /dev/null and b/datafusion/sqllogictest/test_files/spark/string/char.slt differ diff --git a/datafusion/sqllogictest/test_files/struct.slt b/datafusion/sqllogictest/test_files/struct.slt index bdba73876103..95eeffc31903 100644 --- a/datafusion/sqllogictest/test_files/struct.slt +++ b/datafusion/sqllogictest/test_files/struct.slt @@ -53,9 +53,9 @@ select * from struct_values; query TT select arrow_typeof(s1), arrow_typeof(s2) from struct_values; ---- -Struct([Field { name: "c0", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) Struct([Field { name: "a", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) -Struct([Field { name: "c0", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) Struct([Field { name: "a", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) -Struct([Field { name: "c0", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) Struct([Field { name: "a", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) +Struct(c0 Int32) Struct(a Int32, b Utf8View) +Struct(c0 Int32) Struct(a Int32, b Utf8View) +Struct(c0 Int32) Struct(a Int32, b Utf8View) # struct[i] @@ -229,12 +229,12 @@ select named_struct('field_a', 1, 'field_b', 2); query T select arrow_typeof(named_struct('first', 1, 'second', 2, 'third', 3)); ---- -Struct([Field { name: "first", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "second", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "third", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) +Struct(first Int64, second Int64, third Int64) query T select arrow_typeof({'first': 1, 'second': 2, 'third': 3}); ---- -Struct([Field { name: "first", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "second", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "third", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) +Struct(first Int64, second Int64, third Int64) # test nested struct literal query ? @@ -271,12 +271,33 @@ select a from values where (a, c) = (1, 'a'); ---- 1 +query I +select a from values as v where (v.a, v.c) = (1, 'a'); +---- +1 + +query I +select a from values as v where (v.a, v.c) != (1, 'a'); +---- +2 +3 + +query I +select a from values as v where (v.a, v.c) = (1, 'b'); +---- + query I select a from values where (a, c) IN ((1, 'a'), (2, 'b')); ---- 1 2 +query I +select a from values as v where (v.a, v.c) IN ((1, 'a'), (2, 'b')); +---- +1 +2 + statement ok drop table values; @@ -392,7 +413,7 @@ create table t(a struct, b struct) as valu query T select arrow_typeof([a, b]) from t; ---- -List(Field { name: "item", data_type: Struct([Field { name: "r", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) +List(Field { name: "item", data_type: Struct([Field { name: "r", data_type: Utf8View, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) query ? select [a, b] from t; @@ -443,12 +464,12 @@ select * from t; query T select arrow_typeof(c1) from t; ---- -Struct([Field { name: "r", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) +Struct(r Utf8View, b Int32) query T select arrow_typeof(c2) from t; ---- -Struct([Field { name: "r", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) +Struct(r Utf8View, b Float32) statement ok drop table t; @@ -465,8 +486,8 @@ select * from t; query T select arrow_typeof(column1) from t; ---- -Struct([Field { name: "r", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) -Struct([Field { name: "r", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) +Struct(r Utf8, c Float64) +Struct(r Utf8, c Float64) statement ok drop table t; @@ -498,9 +519,9 @@ select coalesce(s1) from t; query T select arrow_typeof(coalesce(s1, s2)) from t; ---- -Struct([Field { name: "a", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) -Struct([Field { name: "a", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) -Struct([Field { name: "a", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) +Struct(a Float32, b Utf8View) +Struct(a Float32, b Utf8View) +Struct(a Float32, b Utf8View) statement ok drop table t; @@ -525,9 +546,9 @@ select coalesce(s1, s2) from t; query T select arrow_typeof(coalesce(s1, s2)) from t; ---- -Struct([Field { name: "a", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) -Struct([Field { name: "a", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) -Struct([Field { name: "a", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) +Struct(a Float32, b Utf8View) +Struct(a Float32, b Utf8View) +Struct(a Float32, b Utf8View) statement ok drop table t; @@ -562,7 +583,7 @@ create table t(a struct(r varchar, c int), b struct(r varchar, c float)) as valu query T select arrow_typeof([a, b]) from t; ---- -List(Field { name: "item", data_type: Struct([Field { name: "r", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) +List(Field { name: "item", data_type: Struct([Field { name: "r", data_type: Utf8View, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) statement ok drop table t; @@ -585,13 +606,13 @@ create table t(a struct(r varchar, c int, g float), b struct(r varchar, c float, query T select arrow_typeof(a) from t; ---- -Struct([Field { name: "r", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "g", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) +Struct(r Utf8View, c Int32, g Float32) # type of each column should not coerced but perserve as it is query T select arrow_typeof(b) from t; ---- -Struct([Field { name: "r", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "g", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) +Struct(r Utf8View, c Float32, g Int32) statement ok drop table t; diff --git a/datafusion/sqllogictest/test_files/subquery.slt b/datafusion/sqllogictest/test_files/subquery.slt index a0ac15b740d7..796570633f67 100644 --- a/datafusion/sqllogictest/test_files/subquery.slt +++ b/datafusion/sqllogictest/test_files/subquery.slt @@ -400,7 +400,7 @@ logical_plan 01)LeftSemi Join: 02)--TableScan: t1 projection=[t1_id, t1_name, t1_int] 03)--SubqueryAlias: __correlated_sq_1 -04)----Projection: +04)----Projection: 05)------Filter: t1.t1_int < t1.t1_id 06)--------TableScan: t1 projection=[t1_id, t1_int] @@ -1453,7 +1453,7 @@ logical_plan 01)LeftSemi Join: 02)--TableScan: t1 projection=[a] 03)--SubqueryAlias: __correlated_sq_1 -04)----Projection: +04)----Projection: 05)------Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] 06)--------TableScan: t2 projection=[] diff --git a/datafusion/sqllogictest/test_files/subquery_sort.slt b/datafusion/sqllogictest/test_files/subquery_sort.slt index 5d22bf92e7e6..d993515f4de9 100644 --- a/datafusion/sqllogictest/test_files/subquery_sort.slt +++ b/datafusion/sqllogictest/test_files/subquery_sort.slt @@ -100,7 +100,7 @@ physical_plan 01)ProjectionExec: expr=[c1@0 as c1, r@1 as r] 02)--SortExec: TopK(fetch=2), expr=[c1@0 ASC NULLS LAST, c3@2 ASC NULLS LAST, c9@3 ASC NULLS LAST], preserve_partitioning=[false] 03)----ProjectionExec: expr=[c1@0 as c1, rank() ORDER BY [sink_table.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as r, c3@1 as c3, c9@2 as c9] -04)------BoundedWindowAggExec: wdw=[rank() ORDER BY [sink_table.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "rank() ORDER BY [sink_table.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Utf8(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +04)------BoundedWindowAggExec: wdw=[rank() ORDER BY [sink_table.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "rank() ORDER BY [sink_table.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Utf8View(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] 05)--------SortExec: expr=[c1@0 DESC], preserve_partitioning=[false] 06)----------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c3, c9], file_type=csv, has_header=true @@ -127,9 +127,8 @@ physical_plan 02)--SortExec: TopK(fetch=2), expr=[c1@0 ASC NULLS LAST, c3@2 ASC NULLS LAST, c9@3 ASC NULLS LAST], preserve_partitioning=[false] 03)----ProjectionExec: expr=[c1@0 as c1, rank() ORDER BY [sink_table_with_utf8view.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as r, c3@1 as c3, c9@2 as c9] 04)------BoundedWindowAggExec: wdw=[rank() ORDER BY [sink_table_with_utf8view.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "rank() ORDER BY [sink_table_with_utf8view.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Utf8View(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] -05)--------SortPreservingMergeExec: [c1@0 DESC] -06)----------SortExec: expr=[c1@0 DESC], preserve_partitioning=[true] -07)------------DataSourceExec: partitions=4, partition_sizes=[1, 0, 0, 0] +05)--------SortExec: expr=[c1@0 DESC], preserve_partitioning=[false] +06)----------DataSourceExec: partitions=1, partition_sizes=[1] statement ok DROP TABLE sink_table_with_utf8view; diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q10.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q10.slt.part index fee496f92055..04de9153a047 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q10.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q10.slt.part @@ -65,8 +65,8 @@ logical_plan 12)--------------------Filter: orders.o_orderdate >= Date32("1993-10-01") AND orders.o_orderdate < Date32("1994-01-01") 13)----------------------TableScan: orders projection=[o_orderkey, o_custkey, o_orderdate], partial_filters=[orders.o_orderdate >= Date32("1993-10-01"), orders.o_orderdate < Date32("1994-01-01")] 14)--------------Projection: lineitem.l_orderkey, lineitem.l_extendedprice, lineitem.l_discount -15)----------------Filter: lineitem.l_returnflag = Utf8("R") -16)------------------TableScan: lineitem projection=[l_orderkey, l_extendedprice, l_discount, l_returnflag], partial_filters=[lineitem.l_returnflag = Utf8("R")] +15)----------------Filter: lineitem.l_returnflag = Utf8View("R") +16)------------------TableScan: lineitem projection=[l_orderkey, l_extendedprice, l_discount, l_returnflag], partial_filters=[lineitem.l_returnflag = Utf8View("R")] 17)----------TableScan: nation projection=[n_nationkey, n_name] physical_plan 01)SortPreservingMergeExec: [revenue@2 DESC], fetch=10 diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q11.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q11.slt.part index 1dba8c053720..a6225daae436 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q11.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q11.slt.part @@ -58,8 +58,8 @@ logical_plan 09)----------------TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost], partial_filters=[Boolean(true)] 10)----------------TableScan: supplier projection=[s_suppkey, s_nationkey] 11)------------Projection: nation.n_nationkey -12)--------------Filter: nation.n_name = Utf8("GERMANY") -13)----------------TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8("GERMANY")] +12)--------------Filter: nation.n_name = Utf8View("GERMANY") +13)----------------TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8View("GERMANY")] 14)------SubqueryAlias: __scalar_sq_1 15)--------Projection: CAST(CAST(sum(partsupp.ps_supplycost * partsupp.ps_availqty) AS Float64) * Float64(0.0001) AS Decimal128(38, 15)) 16)----------Aggregate: groupBy=[[]], aggr=[[sum(partsupp.ps_supplycost * CAST(partsupp.ps_availqty AS Decimal128(10, 0)))]] @@ -70,8 +70,8 @@ logical_plan 21)--------------------TableScan: partsupp projection=[ps_suppkey, ps_availqty, ps_supplycost] 22)--------------------TableScan: supplier projection=[s_suppkey, s_nationkey] 23)----------------Projection: nation.n_nationkey -24)------------------Filter: nation.n_name = Utf8("GERMANY") -25)--------------------TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8("GERMANY")] +24)------------------Filter: nation.n_name = Utf8View("GERMANY") +25)--------------------TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8View("GERMANY")] physical_plan 01)SortExec: TopK(fetch=10), expr=[value@1 DESC], preserve_partitioning=[false] 02)--ProjectionExec: expr=[ps_partkey@0 as ps_partkey, sum(partsupp.ps_supplycost * partsupp.ps_availqty)@1 as value] diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q12.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q12.slt.part index 3757fc48dba0..f7344daed8c7 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q12.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q12.slt.part @@ -51,12 +51,12 @@ order by logical_plan 01)Sort: lineitem.l_shipmode ASC NULLS LAST 02)--Projection: lineitem.l_shipmode, sum(CASE WHEN orders.o_orderpriority = Utf8("1-URGENT") OR orders.o_orderpriority = Utf8("2-HIGH") THEN Int64(1) ELSE Int64(0) END) AS high_line_count, sum(CASE WHEN orders.o_orderpriority != Utf8("1-URGENT") AND orders.o_orderpriority != Utf8("2-HIGH") THEN Int64(1) ELSE Int64(0) END) AS low_line_count -03)----Aggregate: groupBy=[[lineitem.l_shipmode]], aggr=[[sum(CASE WHEN orders.o_orderpriority = Utf8("1-URGENT") OR orders.o_orderpriority = Utf8("2-HIGH") THEN Int64(1) ELSE Int64(0) END), sum(CASE WHEN orders.o_orderpriority != Utf8("1-URGENT") AND orders.o_orderpriority != Utf8("2-HIGH") THEN Int64(1) ELSE Int64(0) END)]] +03)----Aggregate: groupBy=[[lineitem.l_shipmode]], aggr=[[sum(CASE WHEN orders.o_orderpriority = Utf8View("1-URGENT") OR orders.o_orderpriority = Utf8View("2-HIGH") THEN Int64(1) ELSE Int64(0) END) AS sum(CASE WHEN orders.o_orderpriority = Utf8("1-URGENT") OR orders.o_orderpriority = Utf8("2-HIGH") THEN Int64(1) ELSE Int64(0) END), sum(CASE WHEN orders.o_orderpriority != Utf8View("1-URGENT") AND orders.o_orderpriority != Utf8View("2-HIGH") THEN Int64(1) ELSE Int64(0) END) AS sum(CASE WHEN orders.o_orderpriority != Utf8("1-URGENT") AND orders.o_orderpriority != Utf8("2-HIGH") THEN Int64(1) ELSE Int64(0) END)]] 04)------Projection: lineitem.l_shipmode, orders.o_orderpriority 05)--------Inner Join: lineitem.l_orderkey = orders.o_orderkey 06)----------Projection: lineitem.l_orderkey, lineitem.l_shipmode -07)------------Filter: (lineitem.l_shipmode = Utf8("MAIL") OR lineitem.l_shipmode = Utf8("SHIP")) AND lineitem.l_receiptdate > lineitem.l_commitdate AND lineitem.l_shipdate < lineitem.l_commitdate AND lineitem.l_receiptdate >= Date32("1994-01-01") AND lineitem.l_receiptdate < Date32("1995-01-01") -08)--------------TableScan: lineitem projection=[l_orderkey, l_shipdate, l_commitdate, l_receiptdate, l_shipmode], partial_filters=[lineitem.l_shipmode = Utf8("MAIL") OR lineitem.l_shipmode = Utf8("SHIP"), lineitem.l_receiptdate > lineitem.l_commitdate, lineitem.l_shipdate < lineitem.l_commitdate, lineitem.l_receiptdate >= Date32("1994-01-01"), lineitem.l_receiptdate < Date32("1995-01-01")] +07)------------Filter: (lineitem.l_shipmode = Utf8View("MAIL") OR lineitem.l_shipmode = Utf8View("SHIP")) AND lineitem.l_receiptdate > lineitem.l_commitdate AND lineitem.l_shipdate < lineitem.l_commitdate AND lineitem.l_receiptdate >= Date32("1994-01-01") AND lineitem.l_receiptdate < Date32("1995-01-01") +08)--------------TableScan: lineitem projection=[l_orderkey, l_shipdate, l_commitdate, l_receiptdate, l_shipmode], partial_filters=[lineitem.l_shipmode = Utf8View("MAIL") OR lineitem.l_shipmode = Utf8View("SHIP"), lineitem.l_receiptdate > lineitem.l_commitdate, lineitem.l_shipdate < lineitem.l_commitdate, lineitem.l_receiptdate >= Date32("1994-01-01"), lineitem.l_receiptdate < Date32("1995-01-01")] 09)----------TableScan: orders projection=[o_orderkey, o_orderpriority] physical_plan 01)SortPreservingMergeExec: [l_shipmode@0 ASC NULLS LAST] diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q13.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q13.slt.part index e9d9cf141d10..96f3bd6edf32 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q13.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q13.slt.part @@ -50,8 +50,8 @@ logical_plan 08)--------------Left Join: customer.c_custkey = orders.o_custkey 09)----------------TableScan: customer projection=[c_custkey] 10)----------------Projection: orders.o_orderkey, orders.o_custkey -11)------------------Filter: orders.o_comment NOT LIKE Utf8("%special%requests%") -12)--------------------TableScan: orders projection=[o_orderkey, o_custkey, o_comment], partial_filters=[orders.o_comment NOT LIKE Utf8("%special%requests%")] +11)------------------Filter: orders.o_comment NOT LIKE Utf8View("%special%requests%") +12)--------------------TableScan: orders projection=[o_orderkey, o_custkey, o_comment], partial_filters=[orders.o_comment NOT LIKE Utf8View("%special%requests%")] physical_plan 01)SortPreservingMergeExec: [custdist@1 DESC, c_count@0 DESC], fetch=10 02)--SortExec: TopK(fetch=10), expr=[custdist@1 DESC, c_count@0 DESC], preserve_partitioning=[true] diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q14.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q14.slt.part index 1104af2bdc64..8d8dd68c3d7b 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q14.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q14.slt.part @@ -33,7 +33,7 @@ where ---- logical_plan 01)Projection: Float64(100) * CAST(sum(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) END) AS Float64) / CAST(sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS Float64) AS promo_revenue -02)--Aggregate: groupBy=[[]], aggr=[[sum(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN __common_expr_1 ELSE Decimal128(Some(0),38,4) END) AS sum(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) END), sum(__common_expr_1) AS sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]] +02)--Aggregate: groupBy=[[]], aggr=[[sum(CASE WHEN part.p_type LIKE Utf8View("PROMO%") THEN __common_expr_1 ELSE Decimal128(Some(0),38,4) END) AS sum(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) END), sum(__common_expr_1) AS sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]] 03)----Projection: lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount) AS __common_expr_1, part.p_type 04)------Inner Join: lineitem.l_partkey = part.p_partkey 05)--------Projection: lineitem.l_partkey, lineitem.l_extendedprice, lineitem.l_discount diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q16.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q16.slt.part index c648f164c809..cd2f407387ed 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q16.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q16.slt.part @@ -58,12 +58,12 @@ logical_plan 06)----------Projection: partsupp.ps_suppkey, part.p_brand, part.p_type, part.p_size 07)------------Inner Join: partsupp.ps_partkey = part.p_partkey 08)--------------TableScan: partsupp projection=[ps_partkey, ps_suppkey] -09)--------------Filter: part.p_brand != Utf8("Brand#45") AND part.p_type NOT LIKE Utf8("MEDIUM POLISHED%") AND part.p_size IN ([Int32(49), Int32(14), Int32(23), Int32(45), Int32(19), Int32(3), Int32(36), Int32(9)]) -10)----------------TableScan: part projection=[p_partkey, p_brand, p_type, p_size], partial_filters=[part.p_brand != Utf8("Brand#45"), part.p_type NOT LIKE Utf8("MEDIUM POLISHED%"), part.p_size IN ([Int32(49), Int32(14), Int32(23), Int32(45), Int32(19), Int32(3), Int32(36), Int32(9)])] +09)--------------Filter: part.p_brand != Utf8View("Brand#45") AND part.p_type NOT LIKE Utf8View("MEDIUM POLISHED%") AND part.p_size IN ([Int32(49), Int32(14), Int32(23), Int32(45), Int32(19), Int32(3), Int32(36), Int32(9)]) +10)----------------TableScan: part projection=[p_partkey, p_brand, p_type, p_size], partial_filters=[part.p_brand != Utf8View("Brand#45"), part.p_type NOT LIKE Utf8View("MEDIUM POLISHED%"), part.p_size IN ([Int32(49), Int32(14), Int32(23), Int32(45), Int32(19), Int32(3), Int32(36), Int32(9)])] 11)----------SubqueryAlias: __correlated_sq_1 12)------------Projection: supplier.s_suppkey -13)--------------Filter: supplier.s_comment LIKE Utf8("%Customer%Complaints%") -14)----------------TableScan: supplier projection=[s_suppkey, s_comment], partial_filters=[supplier.s_comment LIKE Utf8("%Customer%Complaints%")] +13)--------------Filter: supplier.s_comment LIKE Utf8View("%Customer%Complaints%") +14)----------------TableScan: supplier projection=[s_suppkey, s_comment], partial_filters=[supplier.s_comment LIKE Utf8View("%Customer%Complaints%")] physical_plan 01)SortPreservingMergeExec: [supplier_cnt@3 DESC, p_brand@0 ASC NULLS LAST, p_type@1 ASC NULLS LAST, p_size@2 ASC NULLS LAST], fetch=10 02)--SortExec: TopK(fetch=10), expr=[supplier_cnt@3 DESC, p_brand@0 ASC NULLS LAST, p_type@1 ASC NULLS LAST, p_size@2 ASC NULLS LAST], preserve_partitioning=[true] @@ -88,7 +88,7 @@ physical_plan 21)----------------------------------CoalesceBatchesExec: target_batch_size=8192 22)------------------------------------RepartitionExec: partitioning=Hash([p_partkey@0], 4), input_partitions=4 23)--------------------------------------CoalesceBatchesExec: target_batch_size=8192 -24)----------------------------------------FilterExec: p_brand@1 != Brand#45 AND p_type@2 NOT LIKE MEDIUM POLISHED% AND Use p_size@3 IN (SET) ([Literal { value: Int32(49) }, Literal { value: Int32(14) }, Literal { value: Int32(23) }, Literal { value: Int32(45) }, Literal { value: Int32(19) }, Literal { value: Int32(3) }, Literal { value: Int32(36) }, Literal { value: Int32(9) }]) +24)----------------------------------------FilterExec: p_brand@1 != Brand#45 AND p_type@2 NOT LIKE MEDIUM POLISHED% AND Use p_size@3 IN (SET) ([Literal { value: Int32(49), field: Field { name: "49", data_type: Int32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Int32(14), field: Field { name: "14", data_type: Int32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Int32(23), field: Field { name: "23", data_type: Int32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Int32(45), field: Field { name: "45", data_type: Int32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Int32(19), field: Field { name: "19", data_type: Int32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Int32(3), field: Field { name: "3", data_type: Int32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Int32(36), field: Field { name: "36", data_type: Int32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Int32(9), field: Field { name: "9", data_type: Int32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }]) 25)------------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 26)--------------------------------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/part.tbl]]}, projection=[p_partkey, p_brand, p_type, p_size], file_type=csv, has_header=false 27)--------------------------CoalesceBatchesExec: target_batch_size=8192 diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q17.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q17.slt.part index 02553890bcf5..51a0d096428c 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q17.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q17.slt.part @@ -44,8 +44,8 @@ logical_plan 06)----------Inner Join: lineitem.l_partkey = part.p_partkey 07)------------TableScan: lineitem projection=[l_partkey, l_quantity, l_extendedprice] 08)------------Projection: part.p_partkey -09)--------------Filter: part.p_brand = Utf8("Brand#23") AND part.p_container = Utf8("MED BOX") -10)----------------TableScan: part projection=[p_partkey, p_brand, p_container], partial_filters=[part.p_brand = Utf8("Brand#23"), part.p_container = Utf8("MED BOX")] +09)--------------Filter: part.p_brand = Utf8View("Brand#23") AND part.p_container = Utf8View("MED BOX") +10)----------------TableScan: part projection=[p_partkey, p_brand, p_container], partial_filters=[part.p_brand = Utf8View("Brand#23"), part.p_container = Utf8View("MED BOX")] 11)--------SubqueryAlias: __scalar_sq_1 12)----------Projection: CAST(Float64(0.2) * CAST(avg(lineitem.l_quantity) AS Float64) AS Decimal128(30, 15)), lineitem.l_partkey 13)------------Aggregate: groupBy=[[lineitem.l_partkey]], aggr=[[avg(lineitem.l_quantity)]] diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q19.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q19.slt.part index b0e5b2e904d0..ace2081eb18f 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q19.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q19.slt.part @@ -57,19 +57,19 @@ logical_plan 01)Projection: sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS revenue 02)--Aggregate: groupBy=[[]], aggr=[[sum(lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)) AS sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]] 03)----Projection: lineitem.l_extendedprice, lineitem.l_discount -04)------Inner Join: lineitem.l_partkey = part.p_partkey Filter: part.p_brand = Utf8("Brand#12") AND part.p_container IN ([Utf8("SM CASE"), Utf8("SM BOX"), Utf8("SM PACK"), Utf8("SM PKG")]) AND lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) AND part.p_size <= Int32(5) OR part.p_brand = Utf8("Brand#23") AND part.p_container IN ([Utf8("MED BAG"), Utf8("MED BOX"), Utf8("MED PKG"), Utf8("MED PACK")]) AND lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) AND part.p_size <= Int32(10) OR part.p_brand = Utf8("Brand#34") AND part.p_container IN ([Utf8("LG CASE"), Utf8("LG BOX"), Utf8("LG PACK"), Utf8("LG PKG")]) AND lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2) AND part.p_size <= Int32(15) +04)------Inner Join: lineitem.l_partkey = part.p_partkey Filter: part.p_brand = Utf8View("Brand#12") AND part.p_container IN ([Utf8View("SM CASE"), Utf8View("SM BOX"), Utf8View("SM PACK"), Utf8View("SM PKG")]) AND lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) AND part.p_size <= Int32(5) OR part.p_brand = Utf8View("Brand#23") AND part.p_container IN ([Utf8View("MED BAG"), Utf8View("MED BOX"), Utf8View("MED PKG"), Utf8View("MED PACK")]) AND lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) AND part.p_size <= Int32(10) OR part.p_brand = Utf8View("Brand#34") AND part.p_container IN ([Utf8View("LG CASE"), Utf8View("LG BOX"), Utf8View("LG PACK"), Utf8View("LG PKG")]) AND lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2) AND part.p_size <= Int32(15) 05)--------Projection: lineitem.l_partkey, lineitem.l_quantity, lineitem.l_extendedprice, lineitem.l_discount -06)----------Filter: (lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)) AND (lineitem.l_shipmode = Utf8("AIR") OR lineitem.l_shipmode = Utf8("AIR REG")) AND lineitem.l_shipinstruct = Utf8("DELIVER IN PERSON") -07)------------TableScan: lineitem projection=[l_partkey, l_quantity, l_extendedprice, l_discount, l_shipinstruct, l_shipmode], partial_filters=[lineitem.l_shipmode = Utf8("AIR") OR lineitem.l_shipmode = Utf8("AIR REG"), lineitem.l_shipinstruct = Utf8("DELIVER IN PERSON"), lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)] -08)--------Filter: (part.p_brand = Utf8("Brand#12") AND part.p_container IN ([Utf8("SM CASE"), Utf8("SM BOX"), Utf8("SM PACK"), Utf8("SM PKG")]) AND part.p_size <= Int32(5) OR part.p_brand = Utf8("Brand#23") AND part.p_container IN ([Utf8("MED BAG"), Utf8("MED BOX"), Utf8("MED PKG"), Utf8("MED PACK")]) AND part.p_size <= Int32(10) OR part.p_brand = Utf8("Brand#34") AND part.p_container IN ([Utf8("LG CASE"), Utf8("LG BOX"), Utf8("LG PACK"), Utf8("LG PKG")]) AND part.p_size <= Int32(15)) AND part.p_size >= Int32(1) -09)----------TableScan: part projection=[p_partkey, p_brand, p_size, p_container], partial_filters=[part.p_size >= Int32(1), part.p_brand = Utf8("Brand#12") AND part.p_container IN ([Utf8("SM CASE"), Utf8("SM BOX"), Utf8("SM PACK"), Utf8("SM PKG")]) AND part.p_size <= Int32(5) OR part.p_brand = Utf8("Brand#23") AND part.p_container IN ([Utf8("MED BAG"), Utf8("MED BOX"), Utf8("MED PKG"), Utf8("MED PACK")]) AND part.p_size <= Int32(10) OR part.p_brand = Utf8("Brand#34") AND part.p_container IN ([Utf8("LG CASE"), Utf8("LG BOX"), Utf8("LG PACK"), Utf8("LG PKG")]) AND part.p_size <= Int32(15)] +06)----------Filter: (lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)) AND (lineitem.l_shipmode = Utf8View("AIR") OR lineitem.l_shipmode = Utf8View("AIR REG")) AND lineitem.l_shipinstruct = Utf8View("DELIVER IN PERSON") +07)------------TableScan: lineitem projection=[l_partkey, l_quantity, l_extendedprice, l_discount, l_shipinstruct, l_shipmode], partial_filters=[lineitem.l_shipmode = Utf8View("AIR") OR lineitem.l_shipmode = Utf8View("AIR REG"), lineitem.l_shipinstruct = Utf8View("DELIVER IN PERSON"), lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)] +08)--------Filter: (part.p_brand = Utf8View("Brand#12") AND part.p_container IN ([Utf8View("SM CASE"), Utf8View("SM BOX"), Utf8View("SM PACK"), Utf8View("SM PKG")]) AND part.p_size <= Int32(5) OR part.p_brand = Utf8View("Brand#23") AND part.p_container IN ([Utf8View("MED BAG"), Utf8View("MED BOX"), Utf8View("MED PKG"), Utf8View("MED PACK")]) AND part.p_size <= Int32(10) OR part.p_brand = Utf8View("Brand#34") AND part.p_container IN ([Utf8View("LG CASE"), Utf8View("LG BOX"), Utf8View("LG PACK"), Utf8View("LG PKG")]) AND part.p_size <= Int32(15)) AND part.p_size >= Int32(1) +09)----------TableScan: part projection=[p_partkey, p_brand, p_size, p_container], partial_filters=[part.p_size >= Int32(1), part.p_brand = Utf8View("Brand#12") AND part.p_container IN ([Utf8View("SM CASE"), Utf8View("SM BOX"), Utf8View("SM PACK"), Utf8View("SM PKG")]) AND part.p_size <= Int32(5) OR part.p_brand = Utf8View("Brand#23") AND part.p_container IN ([Utf8View("MED BAG"), Utf8View("MED BOX"), Utf8View("MED PKG"), Utf8View("MED PACK")]) AND part.p_size <= Int32(10) OR part.p_brand = Utf8View("Brand#34") AND part.p_container IN ([Utf8View("LG CASE"), Utf8View("LG BOX"), Utf8View("LG PACK"), Utf8View("LG PKG")]) AND part.p_size <= Int32(15)] physical_plan 01)ProjectionExec: expr=[sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)@0 as revenue] 02)--AggregateExec: mode=Final, gby=[], aggr=[sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] 03)----CoalescePartitionsExec 04)------AggregateExec: mode=Partial, gby=[], aggr=[sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] 05)--------CoalesceBatchesExec: target_batch_size=8192 -06)----------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(l_partkey@0, p_partkey@0)], filter=p_brand@1 = Brand#12 AND Use p_container@3 IN (SET) ([Literal { value: Utf8("SM CASE") }, Literal { value: Utf8("SM BOX") }, Literal { value: Utf8("SM PACK") }, Literal { value: Utf8("SM PKG") }]) AND l_quantity@0 >= Some(100),15,2 AND l_quantity@0 <= Some(1100),15,2 AND p_size@2 <= 5 OR p_brand@1 = Brand#23 AND Use p_container@3 IN (SET) ([Literal { value: Utf8("MED BAG") }, Literal { value: Utf8("MED BOX") }, Literal { value: Utf8("MED PKG") }, Literal { value: Utf8("MED PACK") }]) AND l_quantity@0 >= Some(1000),15,2 AND l_quantity@0 <= Some(2000),15,2 AND p_size@2 <= 10 OR p_brand@1 = Brand#34 AND Use p_container@3 IN (SET) ([Literal { value: Utf8("LG CASE") }, Literal { value: Utf8("LG BOX") }, Literal { value: Utf8("LG PACK") }, Literal { value: Utf8("LG PKG") }]) AND l_quantity@0 >= Some(2000),15,2 AND l_quantity@0 <= Some(3000),15,2 AND p_size@2 <= 15, projection=[l_extendedprice@2, l_discount@3] +06)----------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(l_partkey@0, p_partkey@0)], filter=p_brand@1 = Brand#12 AND p_container@3 IN ([Literal { value: Utf8View("SM CASE"), field: Field { name: "SM CASE", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("SM BOX"), field: Field { name: "SM BOX", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("SM PACK"), field: Field { name: "SM PACK", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("SM PKG"), field: Field { name: "SM PKG", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }]) AND l_quantity@0 >= Some(100),15,2 AND l_quantity@0 <= Some(1100),15,2 AND p_size@2 <= 5 OR p_brand@1 = Brand#23 AND p_container@3 IN ([Literal { value: Utf8View("MED BAG"), field: Field { name: "MED BAG", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("MED BOX"), field: Field { name: "MED BOX", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("MED PKG"), field: Field { name: "MED PKG", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("MED PACK"), field: Field { name: "MED PACK", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }]) AND l_quantity@0 >= Some(1000),15,2 AND l_quantity@0 <= Some(2000),15,2 AND p_size@2 <= 10 OR p_brand@1 = Brand#34 AND p_container@3 IN ([Literal { value: Utf8View("LG CASE"), field: Field { name: "LG CASE", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("LG BOX"), field: Field { name: "LG BOX", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("LG PACK"), field: Field { name: "LG PACK", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("LG PKG"), field: Field { name: "LG PKG", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }]) AND l_quantity@0 >= Some(2000),15,2 AND l_quantity@0 <= Some(3000),15,2 AND p_size@2 <= 15, projection=[l_extendedprice@2, l_discount@3] 07)------------CoalesceBatchesExec: target_batch_size=8192 08)--------------RepartitionExec: partitioning=Hash([l_partkey@0], 4), input_partitions=4 09)----------------CoalesceBatchesExec: target_batch_size=8192 @@ -78,6 +78,6 @@ physical_plan 12)------------CoalesceBatchesExec: target_batch_size=8192 13)--------------RepartitionExec: partitioning=Hash([p_partkey@0], 4), input_partitions=4 14)----------------CoalesceBatchesExec: target_batch_size=8192 -15)------------------FilterExec: (p_brand@1 = Brand#12 AND Use p_container@3 IN (SET) ([Literal { value: Utf8("SM CASE") }, Literal { value: Utf8("SM BOX") }, Literal { value: Utf8("SM PACK") }, Literal { value: Utf8("SM PKG") }]) AND p_size@2 <= 5 OR p_brand@1 = Brand#23 AND Use p_container@3 IN (SET) ([Literal { value: Utf8("MED BAG") }, Literal { value: Utf8("MED BOX") }, Literal { value: Utf8("MED PKG") }, Literal { value: Utf8("MED PACK") }]) AND p_size@2 <= 10 OR p_brand@1 = Brand#34 AND Use p_container@3 IN (SET) ([Literal { value: Utf8("LG CASE") }, Literal { value: Utf8("LG BOX") }, Literal { value: Utf8("LG PACK") }, Literal { value: Utf8("LG PKG") }]) AND p_size@2 <= 15) AND p_size@2 >= 1 +15)------------------FilterExec: (p_brand@1 = Brand#12 AND p_container@3 IN ([Literal { value: Utf8View("SM CASE"), field: Field { name: "SM CASE", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("SM BOX"), field: Field { name: "SM BOX", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("SM PACK"), field: Field { name: "SM PACK", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("SM PKG"), field: Field { name: "SM PKG", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }]) AND p_size@2 <= 5 OR p_brand@1 = Brand#23 AND p_container@3 IN ([Literal { value: Utf8View("MED BAG"), field: Field { name: "MED BAG", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("MED BOX"), field: Field { name: "MED BOX", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("MED PKG"), field: Field { name: "MED PKG", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("MED PACK"), field: Field { name: "MED PACK", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }]) AND p_size@2 <= 10 OR p_brand@1 = Brand#34 AND p_container@3 IN ([Literal { value: Utf8View("LG CASE"), field: Field { name: "LG CASE", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("LG BOX"), field: Field { name: "LG BOX", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("LG PACK"), field: Field { name: "LG PACK", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("LG PKG"), field: Field { name: "LG PKG", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }]) AND p_size@2 <= 15) AND p_size@2 >= 1 16)--------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 17)----------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/part.tbl]]}, projection=[p_partkey, p_brand, p_size, p_container], file_type=csv, has_header=false diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q2.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q2.slt.part index 2a8ee9f229b7..b2e0fb0cd1cc 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q2.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q2.slt.part @@ -75,14 +75,14 @@ logical_plan 10)------------------Projection: part.p_partkey, part.p_mfgr, partsupp.ps_suppkey, partsupp.ps_supplycost 11)--------------------Inner Join: part.p_partkey = partsupp.ps_partkey 12)----------------------Projection: part.p_partkey, part.p_mfgr -13)------------------------Filter: part.p_size = Int32(15) AND part.p_type LIKE Utf8("%BRASS") -14)--------------------------TableScan: part projection=[p_partkey, p_mfgr, p_type, p_size], partial_filters=[part.p_size = Int32(15), part.p_type LIKE Utf8("%BRASS")] +13)------------------------Filter: part.p_size = Int32(15) AND part.p_type LIKE Utf8View("%BRASS") +14)--------------------------TableScan: part projection=[p_partkey, p_mfgr, p_type, p_size], partial_filters=[part.p_size = Int32(15), part.p_type LIKE Utf8View("%BRASS")] 15)----------------------TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_supplycost] 16)------------------TableScan: supplier projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment] 17)--------------TableScan: nation projection=[n_nationkey, n_name, n_regionkey] 18)----------Projection: region.r_regionkey -19)------------Filter: region.r_name = Utf8("EUROPE") -20)--------------TableScan: region projection=[r_regionkey, r_name], partial_filters=[region.r_name = Utf8("EUROPE")] +19)------------Filter: region.r_name = Utf8View("EUROPE") +20)--------------TableScan: region projection=[r_regionkey, r_name], partial_filters=[region.r_name = Utf8View("EUROPE")] 21)------SubqueryAlias: __scalar_sq_1 22)--------Projection: min(partsupp.ps_supplycost), partsupp.ps_partkey 23)----------Aggregate: groupBy=[[partsupp.ps_partkey]], aggr=[[min(partsupp.ps_supplycost)]] @@ -96,8 +96,8 @@ logical_plan 31)------------------------TableScan: supplier projection=[s_suppkey, s_nationkey] 32)--------------------TableScan: nation projection=[n_nationkey, n_regionkey] 33)----------------Projection: region.r_regionkey -34)------------------Filter: region.r_name = Utf8("EUROPE") -35)--------------------TableScan: region projection=[r_regionkey, r_name], partial_filters=[region.r_name = Utf8("EUROPE")] +34)------------------Filter: region.r_name = Utf8View("EUROPE") +35)--------------------TableScan: region projection=[r_regionkey, r_name], partial_filters=[region.r_name = Utf8View("EUROPE")] physical_plan 01)SortPreservingMergeExec: [s_acctbal@0 DESC, n_name@2 ASC NULLS LAST, s_name@1 ASC NULLS LAST, p_partkey@3 ASC NULLS LAST], fetch=10 02)--SortExec: TopK(fetch=10), expr=[s_acctbal@0 DESC, n_name@2 ASC NULLS LAST, s_name@1 ASC NULLS LAST, p_partkey@3 ASC NULLS LAST], preserve_partitioning=[true] diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q20.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q20.slt.part index 4844d5fae60b..0b994de411ea 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q20.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q20.slt.part @@ -63,8 +63,8 @@ logical_plan 05)--------Inner Join: supplier.s_nationkey = nation.n_nationkey 06)----------TableScan: supplier projection=[s_suppkey, s_name, s_address, s_nationkey] 07)----------Projection: nation.n_nationkey -08)------------Filter: nation.n_name = Utf8("CANADA") -09)--------------TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8("CANADA")] +08)------------Filter: nation.n_name = Utf8View("CANADA") +09)--------------TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8View("CANADA")] 10)------SubqueryAlias: __correlated_sq_2 11)--------Projection: partsupp.ps_suppkey 12)----------Inner Join: partsupp.ps_partkey = __scalar_sq_3.l_partkey, partsupp.ps_suppkey = __scalar_sq_3.l_suppkey Filter: CAST(partsupp.ps_availqty AS Float64) > __scalar_sq_3.Float64(0.5) * sum(lineitem.l_quantity) @@ -72,8 +72,8 @@ logical_plan 14)--------------TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty] 15)--------------SubqueryAlias: __correlated_sq_1 16)----------------Projection: part.p_partkey -17)------------------Filter: part.p_name LIKE Utf8("forest%") -18)--------------------TableScan: part projection=[p_partkey, p_name], partial_filters=[part.p_name LIKE Utf8("forest%")] +17)------------------Filter: part.p_name LIKE Utf8View("forest%") +18)--------------------TableScan: part projection=[p_partkey, p_name], partial_filters=[part.p_name LIKE Utf8View("forest%")] 19)------------SubqueryAlias: __scalar_sq_3 20)--------------Projection: Float64(0.5) * CAST(sum(lineitem.l_quantity) AS Float64), lineitem.l_partkey, lineitem.l_suppkey 21)----------------Aggregate: groupBy=[[lineitem.l_partkey, lineitem.l_suppkey]], aggr=[[sum(lineitem.l_quantity)]] diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q21.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q21.slt.part index bb3e884e27be..e52171524007 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q21.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q21.slt.part @@ -76,11 +76,11 @@ logical_plan 16)----------------------------Filter: lineitem.l_receiptdate > lineitem.l_commitdate 17)------------------------------TableScan: lineitem projection=[l_orderkey, l_suppkey, l_commitdate, l_receiptdate], partial_filters=[lineitem.l_receiptdate > lineitem.l_commitdate] 18)--------------------Projection: orders.o_orderkey -19)----------------------Filter: orders.o_orderstatus = Utf8("F") -20)------------------------TableScan: orders projection=[o_orderkey, o_orderstatus], partial_filters=[orders.o_orderstatus = Utf8("F")] +19)----------------------Filter: orders.o_orderstatus = Utf8View("F") +20)------------------------TableScan: orders projection=[o_orderkey, o_orderstatus], partial_filters=[orders.o_orderstatus = Utf8View("F")] 21)----------------Projection: nation.n_nationkey -22)------------------Filter: nation.n_name = Utf8("SAUDI ARABIA") -23)--------------------TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8("SAUDI ARABIA")] +22)------------------Filter: nation.n_name = Utf8View("SAUDI ARABIA") +23)--------------------TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8View("SAUDI ARABIA")] 24)------------SubqueryAlias: __correlated_sq_1 25)--------------SubqueryAlias: l2 26)----------------TableScan: lineitem projection=[l_orderkey, l_suppkey] diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q22.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q22.slt.part index 828bf967d8f4..6af91b4aaa42 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q22.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q22.slt.part @@ -90,7 +90,7 @@ physical_plan 14)--------------------------CoalesceBatchesExec: target_batch_size=8192 15)----------------------------RepartitionExec: partitioning=Hash([c_custkey@0], 4), input_partitions=4 16)------------------------------CoalesceBatchesExec: target_batch_size=8192 -17)--------------------------------FilterExec: substr(c_phone@1, 1, 2) IN ([Literal { value: Utf8View("13") }, Literal { value: Utf8View("31") }, Literal { value: Utf8View("23") }, Literal { value: Utf8View("29") }, Literal { value: Utf8View("30") }, Literal { value: Utf8View("18") }, Literal { value: Utf8View("17") }]) +17)--------------------------------FilterExec: substr(c_phone@1, 1, 2) IN ([Literal { value: Utf8View("13"), field: Field { name: "13", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("31"), field: Field { name: "31", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("23"), field: Field { name: "23", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("29"), field: Field { name: "29", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("30"), field: Field { name: "30", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("18"), field: Field { name: "18", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("17"), field: Field { name: "17", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }]) 18)----------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 19)------------------------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/customer.tbl]]}, projection=[c_custkey, c_phone, c_acctbal], file_type=csv, has_header=false 20)--------------------------CoalesceBatchesExec: target_batch_size=8192 @@ -100,6 +100,6 @@ physical_plan 24)----------------------CoalescePartitionsExec 25)------------------------AggregateExec: mode=Partial, gby=[], aggr=[avg(customer.c_acctbal)] 26)--------------------------CoalesceBatchesExec: target_batch_size=8192 -27)----------------------------FilterExec: c_acctbal@1 > Some(0),15,2 AND substr(c_phone@0, 1, 2) IN ([Literal { value: Utf8View("13") }, Literal { value: Utf8View("31") }, Literal { value: Utf8View("23") }, Literal { value: Utf8View("29") }, Literal { value: Utf8View("30") }, Literal { value: Utf8View("18") }, Literal { value: Utf8View("17") }]), projection=[c_acctbal@1] +27)----------------------------FilterExec: c_acctbal@1 > Some(0),15,2 AND substr(c_phone@0, 1, 2) IN ([Literal { value: Utf8View("13"), field: Field { name: "13", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("31"), field: Field { name: "31", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("23"), field: Field { name: "23", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("29"), field: Field { name: "29", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("30"), field: Field { name: "30", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("18"), field: Field { name: "18", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("17"), field: Field { name: "17", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }]), projection=[c_acctbal@1] 28)------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 29)--------------------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/customer.tbl]]}, projection=[c_phone, c_acctbal], file_type=csv, has_header=false diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q3.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q3.slt.part index 2ad496ef26fd..d982ec32e954 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q3.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q3.slt.part @@ -50,8 +50,8 @@ logical_plan 06)----------Projection: orders.o_orderkey, orders.o_orderdate, orders.o_shippriority 07)------------Inner Join: customer.c_custkey = orders.o_custkey 08)--------------Projection: customer.c_custkey -09)----------------Filter: customer.c_mktsegment = Utf8("BUILDING") -10)------------------TableScan: customer projection=[c_custkey, c_mktsegment], partial_filters=[customer.c_mktsegment = Utf8("BUILDING")] +09)----------------Filter: customer.c_mktsegment = Utf8View("BUILDING") +10)------------------TableScan: customer projection=[c_custkey, c_mktsegment], partial_filters=[customer.c_mktsegment = Utf8View("BUILDING")] 11)--------------Filter: orders.o_orderdate < Date32("1995-03-15") 12)----------------TableScan: orders projection=[o_orderkey, o_custkey, o_orderdate, o_shippriority], partial_filters=[orders.o_orderdate < Date32("1995-03-15")] 13)----------Projection: lineitem.l_orderkey, lineitem.l_extendedprice, lineitem.l_discount diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q5.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q5.slt.part index f192f987b3ef..15636056b871 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q5.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q5.slt.part @@ -64,8 +64,8 @@ logical_plan 19)------------------TableScan: supplier projection=[s_suppkey, s_nationkey] 20)--------------TableScan: nation projection=[n_nationkey, n_name, n_regionkey] 21)----------Projection: region.r_regionkey -22)------------Filter: region.r_name = Utf8("ASIA") -23)--------------TableScan: region projection=[r_regionkey, r_name], partial_filters=[region.r_name = Utf8("ASIA")] +22)------------Filter: region.r_name = Utf8View("ASIA") +23)--------------TableScan: region projection=[r_regionkey, r_name], partial_filters=[region.r_name = Utf8View("ASIA")] physical_plan 01)SortPreservingMergeExec: [revenue@1 DESC] 02)--SortExec: expr=[revenue@1 DESC], preserve_partitioning=[true] diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q7.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q7.slt.part index e03de9596fbe..291d56e43f2d 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q7.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q7.slt.part @@ -63,7 +63,7 @@ logical_plan 03)----Aggregate: groupBy=[[shipping.supp_nation, shipping.cust_nation, shipping.l_year]], aggr=[[sum(shipping.volume)]] 04)------SubqueryAlias: shipping 05)--------Projection: n1.n_name AS supp_nation, n2.n_name AS cust_nation, date_part(Utf8("YEAR"), lineitem.l_shipdate) AS l_year, lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount) AS volume -06)----------Inner Join: customer.c_nationkey = n2.n_nationkey Filter: n1.n_name = Utf8("FRANCE") AND n2.n_name = Utf8("GERMANY") OR n1.n_name = Utf8("GERMANY") AND n2.n_name = Utf8("FRANCE") +06)----------Inner Join: customer.c_nationkey = n2.n_nationkey Filter: n1.n_name = Utf8View("FRANCE") AND n2.n_name = Utf8View("GERMANY") OR n1.n_name = Utf8View("GERMANY") AND n2.n_name = Utf8View("FRANCE") 07)------------Projection: lineitem.l_extendedprice, lineitem.l_discount, lineitem.l_shipdate, customer.c_nationkey, n1.n_name 08)--------------Inner Join: supplier.s_nationkey = n1.n_nationkey 09)----------------Projection: supplier.s_nationkey, lineitem.l_extendedprice, lineitem.l_discount, lineitem.l_shipdate, customer.c_nationkey @@ -78,11 +78,11 @@ logical_plan 18)------------------------TableScan: orders projection=[o_orderkey, o_custkey] 19)--------------------TableScan: customer projection=[c_custkey, c_nationkey] 20)----------------SubqueryAlias: n1 -21)------------------Filter: nation.n_name = Utf8("FRANCE") OR nation.n_name = Utf8("GERMANY") -22)--------------------TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8("FRANCE") OR nation.n_name = Utf8("GERMANY")] +21)------------------Filter: nation.n_name = Utf8View("FRANCE") OR nation.n_name = Utf8View("GERMANY") +22)--------------------TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8View("FRANCE") OR nation.n_name = Utf8View("GERMANY")] 23)------------SubqueryAlias: n2 -24)--------------Filter: nation.n_name = Utf8("GERMANY") OR nation.n_name = Utf8("FRANCE") -25)----------------TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8("GERMANY") OR nation.n_name = Utf8("FRANCE")] +24)--------------Filter: nation.n_name = Utf8View("GERMANY") OR nation.n_name = Utf8View("FRANCE") +25)----------------TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8View("GERMANY") OR nation.n_name = Utf8View("FRANCE")] physical_plan 01)SortPreservingMergeExec: [supp_nation@0 ASC NULLS LAST, cust_nation@1 ASC NULLS LAST, l_year@2 ASC NULLS LAST] 02)--SortExec: expr=[supp_nation@0 ASC NULLS LAST, cust_nation@1 ASC NULLS LAST, l_year@2 ASC NULLS LAST], preserve_partitioning=[true] diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q8.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q8.slt.part index 88ceffd62ad3..50171c528db6 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q8.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q8.slt.part @@ -58,7 +58,7 @@ order by logical_plan 01)Sort: all_nations.o_year ASC NULLS LAST 02)--Projection: all_nations.o_year, CAST(CAST(sum(CASE WHEN all_nations.nation = Utf8("BRAZIL") THEN all_nations.volume ELSE Int64(0) END) AS Decimal128(12, 2)) / CAST(sum(all_nations.volume) AS Decimal128(12, 2)) AS Decimal128(15, 2)) AS mkt_share -03)----Aggregate: groupBy=[[all_nations.o_year]], aggr=[[sum(CASE WHEN all_nations.nation = Utf8("BRAZIL") THEN all_nations.volume ELSE Decimal128(Some(0),38,4) END) AS sum(CASE WHEN all_nations.nation = Utf8("BRAZIL") THEN all_nations.volume ELSE Int64(0) END), sum(all_nations.volume)]] +03)----Aggregate: groupBy=[[all_nations.o_year]], aggr=[[sum(CASE WHEN all_nations.nation = Utf8View("BRAZIL") THEN all_nations.volume ELSE Decimal128(Some(0),38,4) END) AS sum(CASE WHEN all_nations.nation = Utf8("BRAZIL") THEN all_nations.volume ELSE Int64(0) END), sum(all_nations.volume)]] 04)------SubqueryAlias: all_nations 05)--------Projection: date_part(Utf8("YEAR"), orders.o_orderdate) AS o_year, lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount) AS volume, n2.n_name AS nation 06)----------Inner Join: n1.n_regionkey = region.r_regionkey @@ -75,8 +75,8 @@ logical_plan 17)--------------------------------Projection: lineitem.l_orderkey, lineitem.l_suppkey, lineitem.l_extendedprice, lineitem.l_discount 18)----------------------------------Inner Join: part.p_partkey = lineitem.l_partkey 19)------------------------------------Projection: part.p_partkey -20)--------------------------------------Filter: part.p_type = Utf8("ECONOMY ANODIZED STEEL") -21)----------------------------------------TableScan: part projection=[p_partkey, p_type], partial_filters=[part.p_type = Utf8("ECONOMY ANODIZED STEEL")] +20)--------------------------------------Filter: part.p_type = Utf8View("ECONOMY ANODIZED STEEL") +21)----------------------------------------TableScan: part projection=[p_partkey, p_type], partial_filters=[part.p_type = Utf8View("ECONOMY ANODIZED STEEL")] 22)------------------------------------TableScan: lineitem projection=[l_orderkey, l_partkey, l_suppkey, l_extendedprice, l_discount] 23)--------------------------------TableScan: supplier projection=[s_suppkey, s_nationkey] 24)----------------------------Filter: orders.o_orderdate >= Date32("1995-01-01") AND orders.o_orderdate <= Date32("1996-12-31") @@ -87,8 +87,8 @@ logical_plan 29)----------------SubqueryAlias: n2 30)------------------TableScan: nation projection=[n_nationkey, n_name] 31)------------Projection: region.r_regionkey -32)--------------Filter: region.r_name = Utf8("AMERICA") -33)----------------TableScan: region projection=[r_regionkey, r_name], partial_filters=[region.r_name = Utf8("AMERICA")] +32)--------------Filter: region.r_name = Utf8View("AMERICA") +33)----------------TableScan: region projection=[r_regionkey, r_name], partial_filters=[region.r_name = Utf8View("AMERICA")] physical_plan 01)SortPreservingMergeExec: [o_year@0 ASC NULLS LAST] 02)--SortExec: expr=[o_year@0 ASC NULLS LAST], preserve_partitioning=[true] diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q9.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q9.slt.part index 8ccf967187d7..3b31c1bc2e8e 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q9.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q9.slt.part @@ -67,8 +67,8 @@ logical_plan 13)------------------------Projection: lineitem.l_orderkey, lineitem.l_partkey, lineitem.l_suppkey, lineitem.l_quantity, lineitem.l_extendedprice, lineitem.l_discount 14)--------------------------Inner Join: part.p_partkey = lineitem.l_partkey 15)----------------------------Projection: part.p_partkey -16)------------------------------Filter: part.p_name LIKE Utf8("%green%") -17)--------------------------------TableScan: part projection=[p_partkey, p_name], partial_filters=[part.p_name LIKE Utf8("%green%")] +16)------------------------------Filter: part.p_name LIKE Utf8View("%green%") +17)--------------------------------TableScan: part projection=[p_partkey, p_name], partial_filters=[part.p_name LIKE Utf8View("%green%")] 18)----------------------------TableScan: lineitem projection=[l_orderkey, l_partkey, l_suppkey, l_quantity, l_extendedprice, l_discount] 19)------------------------TableScan: supplier projection=[s_suppkey, s_nationkey] 20)--------------------TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_supplycost] diff --git a/datafusion/sqllogictest/test_files/union.slt b/datafusion/sqllogictest/test_files/union.slt index 356f1598bc0f..d549f555f9d8 100644 --- a/datafusion/sqllogictest/test_files/union.slt +++ b/datafusion/sqllogictest/test_files/union.slt @@ -489,11 +489,11 @@ logical_plan 04)------Limit: skip=0, fetch=3 05)--------Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] 06)----------SubqueryAlias: a -07)------------Projection: +07)------------Projection: 08)--------------Aggregate: groupBy=[[aggregate_test_100.c1]], aggr=[[]] 09)----------------Projection: aggregate_test_100.c1 -10)------------------Filter: aggregate_test_100.c13 != Utf8("C2GT5KVyOPZpgKVl110TyZO0NcJ434") -11)--------------------TableScan: aggregate_test_100 projection=[c1, c13], partial_filters=[aggregate_test_100.c13 != Utf8("C2GT5KVyOPZpgKVl110TyZO0NcJ434")] +10)------------------Filter: aggregate_test_100.c13 != Utf8View("C2GT5KVyOPZpgKVl110TyZO0NcJ434") +11)--------------------TableScan: aggregate_test_100 projection=[c1, c13], partial_filters=[aggregate_test_100.c13 != Utf8View("C2GT5KVyOPZpgKVl110TyZO0NcJ434")] 12)----Projection: Int64(1) AS cnt 13)------Limit: skip=0, fetch=3 14)--------EmptyRelation @@ -829,10 +829,10 @@ ORDER BY c1 logical_plan 01)Sort: c1 ASC NULLS LAST 02)--Union -03)----Filter: aggregate_test_100.c1 = Utf8("a") -04)------TableScan: aggregate_test_100 projection=[c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13], partial_filters=[aggregate_test_100.c1 = Utf8("a")] -05)----Filter: aggregate_test_100.c1 = Utf8("a") -06)------TableScan: aggregate_test_100 projection=[c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13], partial_filters=[aggregate_test_100.c1 = Utf8("a")] +03)----Filter: aggregate_test_100.c1 = Utf8View("a") +04)------TableScan: aggregate_test_100 projection=[c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13], partial_filters=[aggregate_test_100.c1 = Utf8View("a")] +05)----Filter: aggregate_test_100.c1 = Utf8View("a") +06)------TableScan: aggregate_test_100 projection=[c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13], partial_filters=[aggregate_test_100.c1 = Utf8View("a")] physical_plan 01)CoalescePartitionsExec 02)--UnionExec diff --git a/datafusion/sqllogictest/test_files/union_by_name.slt b/datafusion/sqllogictest/test_files/union_by_name.slt index 9572e6efc3e6..233885618f83 100644 --- a/datafusion/sqllogictest/test_files/union_by_name.slt +++ b/datafusion/sqllogictest/test_files/union_by_name.slt @@ -348,7 +348,7 @@ Schema { fields: [ Field { name: "x", - data_type: Utf8, + data_type: Utf8View, nullable: true, dict_id: 0, dict_is_ordered: false, @@ -356,7 +356,7 @@ Schema { }, Field { name: "y", - data_type: Utf8, + data_type: Utf8View, nullable: true, dict_id: 0, dict_is_ordered: false, @@ -364,7 +364,7 @@ Schema { }, Field { name: "z", - data_type: Utf8, + data_type: Utf8View, nullable: true, dict_id: 0, dict_is_ordered: false, @@ -387,7 +387,7 @@ Schema { fields: [ Field { name: "x", - data_type: Utf8, + data_type: Utf8View, nullable: true, dict_id: 0, dict_is_ordered: false, @@ -395,7 +395,7 @@ Schema { }, Field { name: "y", - data_type: Utf8, + data_type: Utf8View, nullable: true, dict_id: 0, dict_is_ordered: false, @@ -403,7 +403,7 @@ Schema { }, Field { name: "z", - data_type: Utf8, + data_type: Utf8View, nullable: true, dict_id: 0, dict_is_ordered: false, diff --git a/datafusion/sqllogictest/test_files/union_function.slt b/datafusion/sqllogictest/test_files/union_function.slt index 9c70b1011f58..74616490ab70 100644 --- a/datafusion/sqllogictest/test_files/union_function.slt +++ b/datafusion/sqllogictest/test_files/union_function.slt @@ -15,6 +15,9 @@ # specific language governing permissions and limitations # under the License. +# Note: union_table is registered via Rust code in the sqllogictest test harness +# because there is no way to create a union type in SQL today + ########## ## UNION DataType Tests ########## @@ -23,7 +26,8 @@ query ?I select union_column, union_extract(union_column, 'int') from union_table; ---- {int=1} 1 -{int=2} 2 +{string=bar} NULL +{int=3} 3 query error DataFusion error: Execution error: field bool not found on union select union_extract(union_column, 'bool') from union_table; @@ -45,3 +49,19 @@ select union_extract(union_column, 1) from union_table; query error DataFusion error: Error during planning: The function 'union_extract' expected 2 arguments but received 3 select union_extract(union_column, 'a', 'b') from union_table; + +query ?T +select union_column, union_tag(union_column) from union_table; +---- +{int=1} int +{string=bar} string +{int=3} int + +query error DataFusion error: Error during planning: 'union_tag' does not support zero arguments +select union_tag() from union_table; + +query error DataFusion error: Error during planning: The function 'union_tag' expected 1 arguments but received 2 +select union_tag(union_column, 'int') from union_table; + +query error DataFusion error: Execution error: union_tag only support unions, got Utf8 +select union_tag('int') from union_table; diff --git a/datafusion/sqllogictest/test_files/unnest.slt b/datafusion/sqllogictest/test_files/unnest.slt index b9c13582952a..92e6f9995ae3 100644 --- a/datafusion/sqllogictest/test_files/unnest.slt +++ b/datafusion/sqllogictest/test_files/unnest.slt @@ -91,12 +91,12 @@ select * from unnest(null); ## Unnest empty array in select list -query I +query ? select unnest([]); ---- ## Unnest empty array in from clause -query I +query ? select * from unnest([]); ---- @@ -243,7 +243,7 @@ query error DataFusion error: This feature is not implemented: unnest\(\) does n select unnest(null) from unnest_table; ## Multiple unnest functions in selection -query II +query ?I select unnest([]), unnest(NULL::int[]); ---- @@ -263,10 +263,10 @@ NULL 10 NULL NULL NULL 17 NULL NULL 18 -query IIIT -select - unnest(column1), unnest(column2) + 2, - column3 * 10, unnest(array_remove(column1, '4')) +query IIII +select + unnest(column1), unnest(column2) + 2, + column3 * 10, unnest(array_remove(column1, 4)) from unnest_table; ---- 1 9 10 1 @@ -316,7 +316,7 @@ select * from unnest( 2 b NULL NULL NULL c NULL NULL -query II +query ?I select * from unnest([], NULL::int[]); ---- diff --git a/datafusion/sqllogictest/test_files/update.slt b/datafusion/sqllogictest/test_files/update.slt index 908d2b34aea4..9f2c16b21106 100644 --- a/datafusion/sqllogictest/test_files/update.slt +++ b/datafusion/sqllogictest/test_files/update.slt @@ -31,7 +31,7 @@ explain update t1 set a=1, b=2, c=3.0, d=NULL; ---- logical_plan 01)Dml: op=[Update] table=[t1] -02)--Projection: CAST(Int64(1) AS Int32) AS a, CAST(Int64(2) AS Utf8) AS b, Float64(3) AS c, CAST(NULL AS Int32) AS d +02)--Projection: CAST(Int64(1) AS Int32) AS a, CAST(Int64(2) AS Utf8View) AS b, Float64(3) AS c, CAST(NULL AS Int32) AS d 03)----TableScan: t1 physical_plan_error This feature is not implemented: Unsupported logical plan: Dml(Update) @@ -40,7 +40,7 @@ explain update t1 set a=c+1, b=a, c=c+1.0, d=b; ---- logical_plan 01)Dml: op=[Update] table=[t1] -02)--Projection: CAST(t1.c + CAST(Int64(1) AS Float64) AS Int32) AS a, CAST(t1.a AS Utf8) AS b, t1.c + Float64(1) AS c, CAST(t1.b AS Int32) AS d +02)--Projection: CAST(t1.c + CAST(Int64(1) AS Float64) AS Int32) AS a, CAST(t1.a AS Utf8View) AS b, t1.c + Float64(1) AS c, CAST(t1.b AS Int32) AS d 03)----TableScan: t1 physical_plan_error This feature is not implemented: Unsupported logical plan: Dml(Update) @@ -69,7 +69,7 @@ explain update t1 set b = t2.b, c = t2.a, d = 1 from t2 where t1.a = t2.a and t1 logical_plan 01)Dml: op=[Update] table=[t1] 02)--Projection: t1.a AS a, t2.b AS b, CAST(t2.a AS Float64) AS c, CAST(Int64(1) AS Int32) AS d -03)----Filter: t1.a = t2.a AND t1.b > Utf8("foo") AND t2.c > Float64(1) +03)----Filter: t1.a = t2.a AND t1.b > CAST(Utf8("foo") AS Utf8View) AND t2.c > Float64(1) 04)------Cross Join: 05)--------TableScan: t1 06)--------TableScan: t2 @@ -89,7 +89,7 @@ explain update t1 as T set b = t2.b, c = t.a, d = 1 from t2 where t.a = t2.a and logical_plan 01)Dml: op=[Update] table=[t1] 02)--Projection: t.a AS a, t2.b AS b, CAST(t.a AS Float64) AS c, CAST(Int64(1) AS Int32) AS d -03)----Filter: t.a = t2.a AND t.b > Utf8("foo") AND t2.c > Float64(1) +03)----Filter: t.a = t2.a AND t.b > CAST(Utf8("foo") AS Utf8View) AND t2.c > Float64(1) 04)------Cross Join: 05)--------SubqueryAlias: t 06)----------TableScan: t1 diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index 52cc80eae1c8..c86921012f9b 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -1767,11 +1767,11 @@ logical_plan 01)Projection: count(Int64(1)) AS count(*) AS global_count 02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] 03)----SubqueryAlias: a -04)------Projection: +04)------Projection: 05)--------Aggregate: groupBy=[[aggregate_test_100.c1]], aggr=[[]] 06)----------Projection: aggregate_test_100.c1 -07)------------Filter: aggregate_test_100.c13 != Utf8("C2GT5KVyOPZpgKVl110TyZO0NcJ434") -08)--------------TableScan: aggregate_test_100 projection=[c1, c13], partial_filters=[aggregate_test_100.c13 != Utf8("C2GT5KVyOPZpgKVl110TyZO0NcJ434")] +07)------------Filter: aggregate_test_100.c13 != Utf8View("C2GT5KVyOPZpgKVl110TyZO0NcJ434") +08)--------------TableScan: aggregate_test_100 projection=[c1, c13], partial_filters=[aggregate_test_100.c13 != Utf8View("C2GT5KVyOPZpgKVl110TyZO0NcJ434")] physical_plan 01)ProjectionExec: expr=[count(Int64(1))@0 as global_count] 02)--AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))] diff --git a/datafusion/substrait/Cargo.toml b/datafusion/substrait/Cargo.toml index d23e986914fc..e4ca7bc46c80 100644 --- a/datafusion/substrait/Cargo.toml +++ b/datafusion/substrait/Cargo.toml @@ -39,7 +39,7 @@ itertools = { workspace = true } object_store = { workspace = true } pbjson-types = { workspace = true } prost = { workspace = true } -substrait = { version = "0.55", features = ["serde"] } +substrait = { version = "0.56", features = ["serde"] } url = { workspace = true } tokio = { workspace = true, features = ["fs"] } diff --git a/datafusion/substrait/README.md b/datafusion/substrait/README.md index 92bb9abcc690..8e7f99b7df38 100644 --- a/datafusion/substrait/README.md +++ b/datafusion/substrait/README.md @@ -19,8 +19,9 @@ # Apache DataFusion Substrait -This crate contains a [Substrait] producer and consumer for Apache Arrow -[DataFusion] plans. See [API Docs] for details and examples. +This crate contains a [Substrait] producer and consumer for [Apache DataFusion] +plans. See [API Docs] for details and examples. [substrait]: https://substrait.io +[apache datafusion]: https://datafusion.apache.org [api docs]: https://docs.rs/datafusion-substrait/latest diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs deleted file mode 100644 index 1442267d3dbb..000000000000 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ /dev/null @@ -1,3452 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use arrow::array::types::{IntervalDayTime, IntervalMonthDayNano}; -use arrow::buffer::OffsetBuffer; -use async_recursion::async_recursion; -use datafusion::arrow::array::MapArray; -use datafusion::arrow::datatypes::{ - DataType, Field, FieldRef, Fields, IntervalUnit, Schema, TimeUnit, -}; -use datafusion::common::{ - not_impl_datafusion_err, not_impl_err, plan_datafusion_err, plan_err, - substrait_datafusion_err, substrait_err, DFSchema, DFSchemaRef, Spans, - TableReference, -}; -use datafusion::datasource::provider_as_source; -use datafusion::logical_expr::expr::{Exists, InSubquery, Sort, WindowFunctionParams}; - -use datafusion::logical_expr::{ - Aggregate, BinaryExpr, Case, Cast, EmptyRelation, Expr, ExprSchemable, Extension, - LogicalPlan, Operator, Projection, SortExpr, Subquery, TryCast, Values, -}; -use substrait::proto::aggregate_rel::Grouping; -use substrait::proto::expression as substrait_expression; -use substrait::proto::expression::subquery::set_predicate::PredicateOp; -use substrait::proto::expression_reference::ExprType; -use url::Url; - -use crate::extensions::Extensions; -use crate::variation_const::{ - DATE_32_TYPE_VARIATION_REF, DATE_64_TYPE_VARIATION_REF, - DECIMAL_128_TYPE_VARIATION_REF, DECIMAL_256_TYPE_VARIATION_REF, - DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_TYPE_VARIATION_REF, - LARGE_CONTAINER_TYPE_VARIATION_REF, UNSIGNED_INTEGER_TYPE_VARIATION_REF, - VIEW_CONTAINER_TYPE_VARIATION_REF, -}; -#[allow(deprecated)] -use crate::variation_const::{ - INTERVAL_DAY_TIME_TYPE_REF, INTERVAL_MONTH_DAY_NANO_TYPE_NAME, - INTERVAL_MONTH_DAY_NANO_TYPE_REF, INTERVAL_YEAR_MONTH_TYPE_REF, - TIMESTAMP_MICRO_TYPE_VARIATION_REF, TIMESTAMP_MILLI_TYPE_VARIATION_REF, - TIMESTAMP_NANO_TYPE_VARIATION_REF, TIMESTAMP_SECOND_TYPE_VARIATION_REF, -}; -use async_trait::async_trait; -use datafusion::arrow::array::{new_empty_array, AsArray}; -use datafusion::arrow::temporal_conversions::NANOSECONDS; -use datafusion::catalog::TableProvider; -use datafusion::common::scalar::ScalarStructBuilder; -use datafusion::execution::{FunctionRegistry, SessionState}; -use datafusion::logical_expr::builder::project; -use datafusion::logical_expr::expr::InList; -use datafusion::logical_expr::{ - col, expr, GroupingSet, Like, LogicalPlanBuilder, Partitioning, Repartition, - WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, -}; -use datafusion::prelude::{lit, JoinType}; -use datafusion::{ - arrow, error::Result, logical_expr::utils::split_conjunction, - logical_expr::utils::split_conjunction_owned, prelude::Column, scalar::ScalarValue, -}; -use std::collections::HashSet; -use std::sync::Arc; -use substrait::proto; -use substrait::proto::exchange_rel::ExchangeKind; -use substrait::proto::expression::cast::FailureBehavior::ReturnNull; -use substrait::proto::expression::literal::user_defined::Val; -use substrait::proto::expression::literal::{ - interval_day_to_second, IntervalCompound, IntervalDayToSecond, IntervalYearToMonth, -}; -use substrait::proto::expression::subquery::SubqueryType; -use substrait::proto::expression::{ - Enum, FieldReference, IfThen, Literal, MultiOrList, Nested, ScalarFunction, - SingularOrList, SwitchExpression, WindowFunction, -}; -use substrait::proto::read_rel::local_files::file_or_files::PathType::UriFile; -use substrait::proto::rel_common::{Emit, EmitKind}; -use substrait::proto::set_rel::SetOp; -use substrait::proto::{ - aggregate_function::AggregationInvocation, - expression::{ - field_reference::ReferenceType::DirectReference, literal::LiteralType, - reference_segment::ReferenceType::StructField, - window_function::bound as SubstraitBound, - window_function::bound::Kind as BoundKind, window_function::Bound, - window_function::BoundsType, MaskExpression, RexType, - }, - fetch_rel, - function_argument::ArgType, - join_rel, plan_rel, r#type, - read_rel::ReadType, - rel::RelType, - rel_common, - sort_field::{SortDirection, SortKind::*}, - AggregateFunction, AggregateRel, ConsistentPartitionWindowRel, CrossRel, - DynamicParameter, ExchangeRel, Expression, ExtendedExpression, ExtensionLeafRel, - ExtensionMultiRel, ExtensionSingleRel, FetchRel, FilterRel, FunctionArgument, - JoinRel, NamedStruct, Plan, ProjectRel, ReadRel, Rel, RelCommon, SetRel, SortField, - SortRel, Type, -}; - -#[async_trait] -/// This trait is used to consume Substrait plans, converting them into DataFusion Logical Plans. -/// It can be implemented by users to allow for custom handling of relations, expressions, etc. -/// -/// Combined with the [crate::logical_plan::producer::SubstraitProducer] this allows for fully -/// customizable Substrait serde. -/// -/// # Example Usage -/// -/// ``` -/// # use async_trait::async_trait; -/// # use datafusion::catalog::TableProvider; -/// # use datafusion::common::{not_impl_err, substrait_err, DFSchema, ScalarValue, TableReference}; -/// # use datafusion::error::Result; -/// # use datafusion::execution::{FunctionRegistry, SessionState}; -/// # use datafusion::logical_expr::{Expr, LogicalPlan, LogicalPlanBuilder}; -/// # use std::sync::Arc; -/// # use substrait::proto; -/// # use substrait::proto::{ExtensionLeafRel, FilterRel, ProjectRel}; -/// # use datafusion::arrow::datatypes::DataType; -/// # use datafusion::logical_expr::expr::ScalarFunction; -/// # use datafusion_substrait::extensions::Extensions; -/// # use datafusion_substrait::logical_plan::consumer::{ -/// # from_project_rel, from_substrait_rel, from_substrait_rex, SubstraitConsumer -/// # }; -/// -/// struct CustomSubstraitConsumer { -/// extensions: Arc, -/// state: Arc, -/// } -/// -/// #[async_trait] -/// impl SubstraitConsumer for CustomSubstraitConsumer { -/// async fn resolve_table_ref( -/// &self, -/// table_ref: &TableReference, -/// ) -> Result>> { -/// let table = table_ref.table().to_string(); -/// let schema = self.state.schema_for_ref(table_ref.clone())?; -/// let table_provider = schema.table(&table).await?; -/// Ok(table_provider) -/// } -/// -/// fn get_extensions(&self) -> &Extensions { -/// self.extensions.as_ref() -/// } -/// -/// fn get_function_registry(&self) -> &impl FunctionRegistry { -/// self.state.as_ref() -/// } -/// -/// // You can reuse existing consumer code to assist in handling advanced extensions -/// async fn consume_project(&self, rel: &ProjectRel) -> Result { -/// let df_plan = from_project_rel(self, rel).await?; -/// if let Some(advanced_extension) = rel.advanced_extension.as_ref() { -/// not_impl_err!( -/// "decode and handle an advanced extension: {:?}", -/// advanced_extension -/// ) -/// } else { -/// Ok(df_plan) -/// } -/// } -/// -/// // You can implement a fully custom consumer method if you need special handling -/// async fn consume_filter(&self, rel: &FilterRel) -> Result { -/// let input = self.consume_rel(rel.input.as_ref().unwrap()).await?; -/// let expression = -/// self.consume_expression(rel.condition.as_ref().unwrap(), input.schema()) -/// .await?; -/// // though this one is quite boring -/// LogicalPlanBuilder::from(input).filter(expression)?.build() -/// } -/// -/// // You can add handlers for extension relations -/// async fn consume_extension_leaf( -/// &self, -/// rel: &ExtensionLeafRel, -/// ) -> Result { -/// not_impl_err!( -/// "handle protobuf Any {} as you need", -/// rel.detail.as_ref().unwrap().type_url -/// ) -/// } -/// -/// // and handlers for user-define types -/// fn consume_user_defined_type(&self, typ: &proto::r#type::UserDefined) -> Result { -/// let type_string = self.extensions.types.get(&typ.type_reference).unwrap(); -/// match type_string.as_str() { -/// "u!foo" => not_impl_err!("handle foo conversion"), -/// "u!bar" => not_impl_err!("handle bar conversion"), -/// _ => substrait_err!("unexpected type") -/// } -/// } -/// -/// // and user-defined literals -/// fn consume_user_defined_literal(&self, literal: &proto::expression::literal::UserDefined) -> Result { -/// let type_string = self.extensions.types.get(&literal.type_reference).unwrap(); -/// match type_string.as_str() { -/// "u!foo" => not_impl_err!("handle foo conversion"), -/// "u!bar" => not_impl_err!("handle bar conversion"), -/// _ => substrait_err!("unexpected type") -/// } -/// } -/// } -/// ``` -/// -pub trait SubstraitConsumer: Send + Sync + Sized { - async fn resolve_table_ref( - &self, - table_ref: &TableReference, - ) -> Result>>; - - // TODO: Remove these two methods - // Ideally, the abstract consumer should not place any constraints on implementations. - // The functionality for which the Extensions and FunctionRegistry is needed should be abstracted - // out into methods on the trait. As an example, resolve_table_reference is such a method. - // See: https://github.com/apache/datafusion/issues/13863 - fn get_extensions(&self) -> &Extensions; - fn get_function_registry(&self) -> &impl FunctionRegistry; - - // Relation Methods - // There is one method per Substrait relation to allow for easy overriding of consumer behaviour. - // These methods have default implementations calling the common handler code, to allow for users - // to re-use common handling logic. - - /// All [Rel]s to be converted pass through this method. - /// You can provide your own implementation if you wish to customize the conversion behaviour. - async fn consume_rel(&self, rel: &Rel) -> Result { - from_substrait_rel(self, rel).await - } - - async fn consume_read(&self, rel: &ReadRel) -> Result { - from_read_rel(self, rel).await - } - - async fn consume_filter(&self, rel: &FilterRel) -> Result { - from_filter_rel(self, rel).await - } - - async fn consume_fetch(&self, rel: &FetchRel) -> Result { - from_fetch_rel(self, rel).await - } - - async fn consume_aggregate(&self, rel: &AggregateRel) -> Result { - from_aggregate_rel(self, rel).await - } - - async fn consume_sort(&self, rel: &SortRel) -> Result { - from_sort_rel(self, rel).await - } - - async fn consume_join(&self, rel: &JoinRel) -> Result { - from_join_rel(self, rel).await - } - - async fn consume_project(&self, rel: &ProjectRel) -> Result { - from_project_rel(self, rel).await - } - - async fn consume_set(&self, rel: &SetRel) -> Result { - from_set_rel(self, rel).await - } - - async fn consume_cross(&self, rel: &CrossRel) -> Result { - from_cross_rel(self, rel).await - } - - async fn consume_consistent_partition_window( - &self, - _rel: &ConsistentPartitionWindowRel, - ) -> Result { - not_impl_err!("Consistent Partition Window Rel not supported") - } - - async fn consume_exchange(&self, rel: &ExchangeRel) -> Result { - from_exchange_rel(self, rel).await - } - - // Expression Methods - // There is one method per Substrait expression to allow for easy overriding of consumer behaviour - // These methods have default implementations calling the common handler code, to allow for users - // to re-use common handling logic. - - /// All [Expression]s to be converted pass through this method. - /// You can provide your own implementation if you wish to customize the conversion behaviour. - async fn consume_expression( - &self, - expr: &Expression, - input_schema: &DFSchema, - ) -> Result { - from_substrait_rex(self, expr, input_schema).await - } - - async fn consume_literal(&self, expr: &Literal) -> Result { - from_literal(self, expr).await - } - - async fn consume_field_reference( - &self, - expr: &FieldReference, - input_schema: &DFSchema, - ) -> Result { - from_field_reference(self, expr, input_schema).await - } - - async fn consume_scalar_function( - &self, - expr: &ScalarFunction, - input_schema: &DFSchema, - ) -> Result { - from_scalar_function(self, expr, input_schema).await - } - - async fn consume_window_function( - &self, - expr: &WindowFunction, - input_schema: &DFSchema, - ) -> Result { - from_window_function(self, expr, input_schema).await - } - - async fn consume_if_then( - &self, - expr: &IfThen, - input_schema: &DFSchema, - ) -> Result { - from_if_then(self, expr, input_schema).await - } - - async fn consume_switch( - &self, - _expr: &SwitchExpression, - _input_schema: &DFSchema, - ) -> Result { - not_impl_err!("Switch expression not supported") - } - - async fn consume_singular_or_list( - &self, - expr: &SingularOrList, - input_schema: &DFSchema, - ) -> Result { - from_singular_or_list(self, expr, input_schema).await - } - - async fn consume_multi_or_list( - &self, - _expr: &MultiOrList, - _input_schema: &DFSchema, - ) -> Result { - not_impl_err!("Multi Or List expression not supported") - } - - async fn consume_cast( - &self, - expr: &substrait_expression::Cast, - input_schema: &DFSchema, - ) -> Result { - from_cast(self, expr, input_schema).await - } - - async fn consume_subquery( - &self, - expr: &substrait_expression::Subquery, - input_schema: &DFSchema, - ) -> Result { - from_subquery(self, expr, input_schema).await - } - - async fn consume_nested( - &self, - _expr: &Nested, - _input_schema: &DFSchema, - ) -> Result { - not_impl_err!("Nested expression not supported") - } - - async fn consume_enum(&self, _expr: &Enum, _input_schema: &DFSchema) -> Result { - not_impl_err!("Enum expression not supported") - } - - async fn consume_dynamic_parameter( - &self, - _expr: &DynamicParameter, - _input_schema: &DFSchema, - ) -> Result { - not_impl_err!("Dynamic Parameter expression not supported") - } - - // User-Defined Functionality - - // The details of extension relations, and how to handle them, are fully up to users to specify. - // The following methods allow users to customize the consumer behaviour - - async fn consume_extension_leaf( - &self, - rel: &ExtensionLeafRel, - ) -> Result { - if let Some(detail) = rel.detail.as_ref() { - return substrait_err!( - "Missing handler for ExtensionLeafRel: {}", - detail.type_url - ); - } - substrait_err!("Missing handler for ExtensionLeafRel") - } - - async fn consume_extension_single( - &self, - rel: &ExtensionSingleRel, - ) -> Result { - if let Some(detail) = rel.detail.as_ref() { - return substrait_err!( - "Missing handler for ExtensionSingleRel: {}", - detail.type_url - ); - } - substrait_err!("Missing handler for ExtensionSingleRel") - } - - async fn consume_extension_multi( - &self, - rel: &ExtensionMultiRel, - ) -> Result { - if let Some(detail) = rel.detail.as_ref() { - return substrait_err!( - "Missing handler for ExtensionMultiRel: {}", - detail.type_url - ); - } - substrait_err!("Missing handler for ExtensionMultiRel") - } - - // Users can bring their own types to Substrait which require custom handling - - fn consume_user_defined_type( - &self, - user_defined_type: &r#type::UserDefined, - ) -> Result { - substrait_err!( - "Missing handler for user-defined type: {}", - user_defined_type.type_reference - ) - } - - fn consume_user_defined_literal( - &self, - user_defined_literal: &proto::expression::literal::UserDefined, - ) -> Result { - substrait_err!( - "Missing handler for user-defined literals {}", - user_defined_literal.type_reference - ) - } -} - -/// Convert Substrait Rel to DataFusion DataFrame -#[async_recursion] -pub async fn from_substrait_rel( - consumer: &impl SubstraitConsumer, - relation: &Rel, -) -> Result { - let plan: Result = match &relation.rel_type { - Some(rel_type) => match rel_type { - RelType::Read(rel) => consumer.consume_read(rel).await, - RelType::Filter(rel) => consumer.consume_filter(rel).await, - RelType::Fetch(rel) => consumer.consume_fetch(rel).await, - RelType::Aggregate(rel) => consumer.consume_aggregate(rel).await, - RelType::Sort(rel) => consumer.consume_sort(rel).await, - RelType::Join(rel) => consumer.consume_join(rel).await, - RelType::Project(rel) => consumer.consume_project(rel).await, - RelType::Set(rel) => consumer.consume_set(rel).await, - RelType::ExtensionSingle(rel) => consumer.consume_extension_single(rel).await, - RelType::ExtensionMulti(rel) => consumer.consume_extension_multi(rel).await, - RelType::ExtensionLeaf(rel) => consumer.consume_extension_leaf(rel).await, - RelType::Cross(rel) => consumer.consume_cross(rel).await, - RelType::Window(rel) => { - consumer.consume_consistent_partition_window(rel).await - } - RelType::Exchange(rel) => consumer.consume_exchange(rel).await, - rt => not_impl_err!("{rt:?} rel not supported yet"), - }, - None => return substrait_err!("rel must set rel_type"), - }; - apply_emit_kind(retrieve_rel_common(relation), plan?) -} - -/// Default SubstraitConsumer for converting standard Substrait without user-defined extensions. -/// -/// Used as the consumer in [from_substrait_plan] -pub struct DefaultSubstraitConsumer<'a> { - extensions: &'a Extensions, - state: &'a SessionState, -} - -impl<'a> DefaultSubstraitConsumer<'a> { - pub fn new(extensions: &'a Extensions, state: &'a SessionState) -> Self { - DefaultSubstraitConsumer { extensions, state } - } -} - -#[async_trait] -impl SubstraitConsumer for DefaultSubstraitConsumer<'_> { - async fn resolve_table_ref( - &self, - table_ref: &TableReference, - ) -> Result>> { - let table = table_ref.table().to_string(); - let schema = self.state.schema_for_ref(table_ref.clone())?; - let table_provider = schema.table(&table).await?; - Ok(table_provider) - } - - fn get_extensions(&self) -> &Extensions { - self.extensions - } - - fn get_function_registry(&self) -> &impl FunctionRegistry { - self.state - } - - async fn consume_extension_leaf( - &self, - rel: &ExtensionLeafRel, - ) -> Result { - let Some(ext_detail) = &rel.detail else { - return substrait_err!("Unexpected empty detail in ExtensionLeafRel"); - }; - let plan = self - .state - .serializer_registry() - .deserialize_logical_plan(&ext_detail.type_url, &ext_detail.value)?; - Ok(LogicalPlan::Extension(Extension { node: plan })) - } - - async fn consume_extension_single( - &self, - rel: &ExtensionSingleRel, - ) -> Result { - let Some(ext_detail) = &rel.detail else { - return substrait_err!("Unexpected empty detail in ExtensionSingleRel"); - }; - let plan = self - .state - .serializer_registry() - .deserialize_logical_plan(&ext_detail.type_url, &ext_detail.value)?; - let Some(input_rel) = &rel.input else { - return substrait_err!( - "ExtensionSingleRel missing input rel, try using ExtensionLeafRel instead" - ); - }; - let input_plan = self.consume_rel(input_rel).await?; - let plan = plan.with_exprs_and_inputs(plan.expressions(), vec![input_plan])?; - Ok(LogicalPlan::Extension(Extension { node: plan })) - } - - async fn consume_extension_multi( - &self, - rel: &ExtensionMultiRel, - ) -> Result { - let Some(ext_detail) = &rel.detail else { - return substrait_err!("Unexpected empty detail in ExtensionMultiRel"); - }; - let plan = self - .state - .serializer_registry() - .deserialize_logical_plan(&ext_detail.type_url, &ext_detail.value)?; - let mut inputs = Vec::with_capacity(rel.inputs.len()); - for input in &rel.inputs { - let input_plan = self.consume_rel(input).await?; - inputs.push(input_plan); - } - let plan = plan.with_exprs_and_inputs(plan.expressions(), inputs)?; - Ok(LogicalPlan::Extension(Extension { node: plan })) - } -} - -// Substrait PrecisionTimestampTz indicates that the timestamp is relative to UTC, which -// is the same as the expectation for any non-empty timezone in DF, so any non-empty timezone -// results in correct points on the timeline, and we pick UTC as a reasonable default. -// However, DF uses the timezone also for some arithmetic and display purposes (see e.g. -// https://github.com/apache/arrow-rs/blob/ee5694078c86c8201549654246900a4232d531a9/arrow-cast/src/cast/mod.rs#L1749). -const DEFAULT_TIMEZONE: &str = "UTC"; - -pub fn name_to_op(name: &str) -> Option { - match name { - "equal" => Some(Operator::Eq), - "not_equal" => Some(Operator::NotEq), - "lt" => Some(Operator::Lt), - "lte" => Some(Operator::LtEq), - "gt" => Some(Operator::Gt), - "gte" => Some(Operator::GtEq), - "add" => Some(Operator::Plus), - "subtract" => Some(Operator::Minus), - "multiply" => Some(Operator::Multiply), - "divide" => Some(Operator::Divide), - "mod" => Some(Operator::Modulo), - "modulus" => Some(Operator::Modulo), - "and" => Some(Operator::And), - "or" => Some(Operator::Or), - "is_distinct_from" => Some(Operator::IsDistinctFrom), - "is_not_distinct_from" => Some(Operator::IsNotDistinctFrom), - "regex_match" => Some(Operator::RegexMatch), - "regex_imatch" => Some(Operator::RegexIMatch), - "regex_not_match" => Some(Operator::RegexNotMatch), - "regex_not_imatch" => Some(Operator::RegexNotIMatch), - "bitwise_and" => Some(Operator::BitwiseAnd), - "bitwise_or" => Some(Operator::BitwiseOr), - "str_concat" => Some(Operator::StringConcat), - "at_arrow" => Some(Operator::AtArrow), - "arrow_at" => Some(Operator::ArrowAt), - "bitwise_xor" => Some(Operator::BitwiseXor), - "bitwise_shift_right" => Some(Operator::BitwiseShiftRight), - "bitwise_shift_left" => Some(Operator::BitwiseShiftLeft), - _ => None, - } -} - -pub fn substrait_fun_name(name: &str) -> &str { - let name = match name.rsplit_once(':') { - // Since 0.32.0, Substrait requires the function names to be in a compound format - // https://substrait.io/extensions/#function-signature-compound-names - // for example, `add:i8_i8`. - // On the consumer side, we don't really care about the signature though, just the name. - Some((name, _)) => name, - None => name, - }; - name -} - -fn split_eq_and_noneq_join_predicate_with_nulls_equality( - filter: &Expr, -) -> (Vec<(Column, Column)>, bool, Option) { - let exprs = split_conjunction(filter); - - let mut accum_join_keys: Vec<(Column, Column)> = vec![]; - let mut accum_filters: Vec = vec![]; - let mut nulls_equal_nulls = false; - - for expr in exprs { - #[allow(clippy::collapsible_match)] - match expr { - Expr::BinaryExpr(binary_expr) => match binary_expr { - x @ (BinaryExpr { - left, - op: Operator::Eq, - right, - } - | BinaryExpr { - left, - op: Operator::IsNotDistinctFrom, - right, - }) => { - nulls_equal_nulls = match x.op { - Operator::Eq => false, - Operator::IsNotDistinctFrom => true, - _ => unreachable!(), - }; - - match (left.as_ref(), right.as_ref()) { - (Expr::Column(l), Expr::Column(r)) => { - accum_join_keys.push((l.clone(), r.clone())); - } - _ => accum_filters.push(expr.clone()), - } - } - _ => accum_filters.push(expr.clone()), - }, - _ => accum_filters.push(expr.clone()), - } - } - - let join_filter = accum_filters.into_iter().reduce(Expr::and); - (accum_join_keys, nulls_equal_nulls, join_filter) -} - -async fn union_rels( - consumer: &impl SubstraitConsumer, - rels: &[Rel], - is_all: bool, -) -> Result { - let mut union_builder = Ok(LogicalPlanBuilder::from( - consumer.consume_rel(&rels[0]).await?, - )); - for input in &rels[1..] { - let rel_plan = consumer.consume_rel(input).await?; - - union_builder = if is_all { - union_builder?.union(rel_plan) - } else { - union_builder?.union_distinct(rel_plan) - }; - } - union_builder?.build() -} - -async fn intersect_rels( - consumer: &impl SubstraitConsumer, - rels: &[Rel], - is_all: bool, -) -> Result { - let mut rel = consumer.consume_rel(&rels[0]).await?; - - for input in &rels[1..] { - rel = LogicalPlanBuilder::intersect( - rel, - consumer.consume_rel(input).await?, - is_all, - )? - } - - Ok(rel) -} - -async fn except_rels( - consumer: &impl SubstraitConsumer, - rels: &[Rel], - is_all: bool, -) -> Result { - let mut rel = consumer.consume_rel(&rels[0]).await?; - - for input in &rels[1..] { - rel = LogicalPlanBuilder::except(rel, consumer.consume_rel(input).await?, is_all)? - } - - Ok(rel) -} - -/// Convert Substrait Plan to DataFusion LogicalPlan -pub async fn from_substrait_plan( - state: &SessionState, - plan: &Plan, -) -> Result { - // Register function extension - let extensions = Extensions::try_from(&plan.extensions)?; - if !extensions.type_variations.is_empty() { - return not_impl_err!("Type variation extensions are not supported"); - } - - let consumer = DefaultSubstraitConsumer { - extensions: &extensions, - state, - }; - from_substrait_plan_with_consumer(&consumer, plan).await -} - -/// Convert Substrait Plan to DataFusion LogicalPlan using the given consumer -pub async fn from_substrait_plan_with_consumer( - consumer: &impl SubstraitConsumer, - plan: &Plan, -) -> Result { - match plan.relations.len() { - 1 => { - match plan.relations[0].rel_type.as_ref() { - Some(rt) => match rt { - plan_rel::RelType::Rel(rel) => Ok(consumer.consume_rel(rel).await?), - plan_rel::RelType::Root(root) => { - let plan = consumer.consume_rel(root.input.as_ref().unwrap()).await?; - if root.names.is_empty() { - // Backwards compatibility for plans missing names - return Ok(plan); - } - let renamed_schema = make_renamed_schema(plan.schema(), &root.names)?; - if renamed_schema.has_equivalent_names_and_types(plan.schema()).is_ok() { - // Nothing to do if the schema is already equivalent - return Ok(plan); - } - match plan { - // If the last node of the plan produces expressions, bake the renames into those expressions. - // This isn't necessary for correctness, but helps with roundtrip tests. - LogicalPlan::Projection(p) => Ok(LogicalPlan::Projection(Projection::try_new(rename_expressions(p.expr, p.input.schema(), renamed_schema.fields())?, p.input)?)), - LogicalPlan::Aggregate(a) => { - let (group_fields, expr_fields) = renamed_schema.fields().split_at(a.group_expr.len()); - let new_group_exprs = rename_expressions(a.group_expr, a.input.schema(), group_fields)?; - let new_aggr_exprs = rename_expressions(a.aggr_expr, a.input.schema(), expr_fields)?; - Ok(LogicalPlan::Aggregate(Aggregate::try_new(a.input, new_group_exprs, new_aggr_exprs)?)) - }, - // There are probably more plans where we could bake things in, can add them later as needed. - // Otherwise, add a new Project to handle the renaming. - _ => Ok(LogicalPlan::Projection(Projection::try_new(rename_expressions(plan.schema().columns().iter().map(|c| col(c.to_owned())), plan.schema(), renamed_schema.fields())?, Arc::new(plan))?)) - } - } - }, - None => plan_err!("Cannot parse plan relation: None") - } - }, - _ => not_impl_err!( - "Substrait plan with more than 1 relation trees not supported. Number of relation trees: {:?}", - plan.relations.len() - ) - } -} - -/// An ExprContainer is a container for a collection of expressions with a common input schema -/// -/// In addition, each expression is associated with a field, which defines the -/// expression's output. The data type and nullability of the field are calculated from the -/// expression and the input schema. However the names of the field (and its nested fields) are -/// derived from the Substrait message. -pub struct ExprContainer { - /// The input schema for the expressions - pub input_schema: DFSchemaRef, - /// The expressions - /// - /// Each item contains an expression and the field that defines the expected nullability and name of the expr's output - pub exprs: Vec<(Expr, Field)>, -} - -/// Convert Substrait ExtendedExpression to ExprContainer -/// -/// A Substrait ExtendedExpression message contains one or more expressions, -/// with names for the outputs, and an input schema. These pieces are all included -/// in the ExprContainer. -/// -/// This is a top-level message and can be used to send expressions (not plans) -/// between systems. This is often useful for scenarios like pushdown where filter -/// expressions need to be sent to remote systems. -pub async fn from_substrait_extended_expr( - state: &SessionState, - extended_expr: &ExtendedExpression, -) -> Result { - // Register function extension - let extensions = Extensions::try_from(&extended_expr.extensions)?; - if !extensions.type_variations.is_empty() { - return not_impl_err!("Type variation extensions are not supported"); - } - - let consumer = DefaultSubstraitConsumer { - extensions: &extensions, - state, - }; - - let input_schema = DFSchemaRef::new(match &extended_expr.base_schema { - Some(base_schema) => from_substrait_named_struct(&consumer, base_schema), - None => { - plan_err!("required property `base_schema` missing from Substrait ExtendedExpression message") - } - }?); - - // Parse expressions - let mut exprs = Vec::with_capacity(extended_expr.referred_expr.len()); - for (expr_idx, substrait_expr) in extended_expr.referred_expr.iter().enumerate() { - let scalar_expr = match &substrait_expr.expr_type { - Some(ExprType::Expression(scalar_expr)) => Ok(scalar_expr), - Some(ExprType::Measure(_)) => { - not_impl_err!("Measure expressions are not yet supported") - } - None => { - plan_err!("required property `expr_type` missing from Substrait ExpressionReference message") - } - }?; - let expr = consumer - .consume_expression(scalar_expr, &input_schema) - .await?; - let (output_type, expected_nullability) = - expr.data_type_and_nullable(&input_schema)?; - let output_field = Field::new("", output_type, expected_nullability); - let mut names_idx = 0; - let output_field = rename_field( - &output_field, - &substrait_expr.output_names, - expr_idx, - &mut names_idx, - /*rename_self=*/ true, - )?; - exprs.push((expr, output_field)); - } - - Ok(ExprContainer { - input_schema, - exprs, - }) -} - -pub fn apply_masking( - schema: DFSchema, - mask_expression: &::core::option::Option, -) -> Result { - match mask_expression { - Some(MaskExpression { select, .. }) => match &select.as_ref() { - Some(projection) => { - let column_indices: Vec = projection - .struct_items - .iter() - .map(|item| item.field as usize) - .collect(); - - let fields = column_indices - .iter() - .map(|i| schema.qualified_field(*i)) - .map(|(qualifier, field)| { - (qualifier.cloned(), Arc::new(field.clone())) - }) - .collect(); - - Ok(DFSchema::new_with_metadata( - fields, - schema.metadata().clone(), - )?) - } - None => Ok(schema), - }, - None => Ok(schema), - } -} - -/// Ensure the expressions have the right name(s) according to the new schema. -/// This includes the top-level (column) name, which will be renamed through aliasing if needed, -/// as well as nested names (if the expression produces any struct types), which will be renamed -/// through casting if needed. -fn rename_expressions( - exprs: impl IntoIterator, - input_schema: &DFSchema, - new_schema_fields: &[Arc], -) -> Result> { - exprs - .into_iter() - .zip(new_schema_fields) - .map(|(old_expr, new_field)| { - // Check if type (i.e. nested struct field names) match, use Cast to rename if needed - let new_expr = if &old_expr.get_type(input_schema)? != new_field.data_type() { - Expr::Cast(Cast::new( - Box::new(old_expr), - new_field.data_type().to_owned(), - )) - } else { - old_expr - }; - // Alias column if needed to fix the top-level name - match &new_expr { - // If expr is a column reference, alias_if_changed would cause an aliasing if the old expr has a qualifier - Expr::Column(c) if &c.name == new_field.name() => Ok(new_expr), - _ => new_expr.alias_if_changed(new_field.name().to_owned()), - } - }) - .collect() -} - -fn rename_field( - field: &Field, - dfs_names: &Vec, - unnamed_field_suffix: usize, // If Substrait doesn't provide a name, we'll use this "c{unnamed_field_suffix}" - name_idx: &mut usize, // Index into dfs_names - rename_self: bool, // Some fields (e.g. list items) don't have names in Substrait and this will be false to keep old name -) -> Result { - let name = if rename_self { - next_struct_field_name(unnamed_field_suffix, dfs_names, name_idx)? - } else { - field.name().to_string() - }; - match field.data_type() { - DataType::Struct(children) => { - let children = children - .iter() - .enumerate() - .map(|(child_idx, f)| { - rename_field( - f.as_ref(), - dfs_names, - child_idx, - name_idx, - /*rename_self=*/ true, - ) - }) - .collect::>()?; - Ok(field - .to_owned() - .with_name(name) - .with_data_type(DataType::Struct(children))) - } - DataType::List(inner) => { - let renamed_inner = rename_field( - inner.as_ref(), - dfs_names, - 0, - name_idx, - /*rename_self=*/ false, - )?; - Ok(field - .to_owned() - .with_data_type(DataType::List(FieldRef::new(renamed_inner))) - .with_name(name)) - } - DataType::LargeList(inner) => { - let renamed_inner = rename_field( - inner.as_ref(), - dfs_names, - 0, - name_idx, - /*rename_self= */ false, - )?; - Ok(field - .to_owned() - .with_data_type(DataType::LargeList(FieldRef::new(renamed_inner))) - .with_name(name)) - } - _ => Ok(field.to_owned().with_name(name)), - } -} - -/// Produce a version of the given schema with names matching the given list of names. -/// Substrait doesn't deal with column (incl. nested struct field) names within the schema, -/// but it does give us the list of expected names at the end of the plan, so we use this -/// to rename the schema to match the expected names. -fn make_renamed_schema( - schema: &DFSchemaRef, - dfs_names: &Vec, -) -> Result { - let mut name_idx = 0; - - let (qualifiers, fields): (_, Vec) = schema - .iter() - .enumerate() - .map(|(field_idx, (q, f))| { - let renamed_f = rename_field( - f.as_ref(), - dfs_names, - field_idx, - &mut name_idx, - /*rename_self=*/ true, - )?; - Ok((q.cloned(), renamed_f)) - }) - .collect::>>()? - .into_iter() - .unzip(); - - if name_idx != dfs_names.len() { - return substrait_err!( - "Names list must match exactly to nested schema, but found {} uses for {} names", - name_idx, - dfs_names.len()); - } - - DFSchema::from_field_specific_qualified_schema( - qualifiers, - &Arc::new(Schema::new(fields)), - ) -} - -#[async_recursion] -pub async fn from_project_rel( - consumer: &impl SubstraitConsumer, - p: &ProjectRel, -) -> Result { - if let Some(input) = p.input.as_ref() { - let input = consumer.consume_rel(input).await?; - let original_schema = Arc::clone(input.schema()); - - // Ensure that all expressions have a unique display name, so that - // validate_unique_names does not fail when constructing the project. - let mut name_tracker = NameTracker::new(); - - // By default, a Substrait Project emits all inputs fields followed by all expressions. - // We build the explicit expressions first, and then the input expressions to avoid - // adding aliases to the explicit expressions (as part of ensuring unique names). - // - // This is helpful for plan visualization and tests, because when DataFusion produces - // Substrait Projects it adds an output mapping that excludes all input columns - // leaving only explicit expressions. - - let mut explicit_exprs: Vec = vec![]; - // For WindowFunctions, we need to wrap them in a Window relation. If there are duplicates, - // we can do the window'ing only once, then the project will duplicate the result. - // Order here doesn't matter since LPB::window_plan sorts the expressions. - let mut window_exprs: HashSet = HashSet::new(); - for expr in &p.expressions { - let e = consumer - .consume_expression(expr, input.clone().schema()) - .await?; - // if the expression is WindowFunction, wrap in a Window relation - if let Expr::WindowFunction(_) = &e { - // Adding the same expression here and in the project below - // works because the project's builder uses columnize_expr(..) - // to transform it into a column reference - window_exprs.insert(e.clone()); - } - explicit_exprs.push(name_tracker.get_uniquely_named_expr(e)?); - } - - let input = if !window_exprs.is_empty() { - LogicalPlanBuilder::window_plan(input, window_exprs)? - } else { - input - }; - - let mut final_exprs: Vec = vec![]; - for index in 0..original_schema.fields().len() { - let e = Expr::Column(Column::from(original_schema.qualified_field(index))); - final_exprs.push(name_tracker.get_uniquely_named_expr(e)?); - } - final_exprs.append(&mut explicit_exprs); - project(input, final_exprs) - } else { - not_impl_err!("Projection without an input is not supported") - } -} - -#[async_recursion] -pub async fn from_filter_rel( - consumer: &impl SubstraitConsumer, - filter: &FilterRel, -) -> Result { - if let Some(input) = filter.input.as_ref() { - let input = LogicalPlanBuilder::from(consumer.consume_rel(input).await?); - if let Some(condition) = filter.condition.as_ref() { - let expr = consumer - .consume_expression(condition, input.schema()) - .await?; - input.filter(expr)?.build() - } else { - not_impl_err!("Filter without an condition is not valid") - } - } else { - not_impl_err!("Filter without an input is not valid") - } -} - -#[async_recursion] -pub async fn from_fetch_rel( - consumer: &impl SubstraitConsumer, - fetch: &FetchRel, -) -> Result { - if let Some(input) = fetch.input.as_ref() { - let input = LogicalPlanBuilder::from(consumer.consume_rel(input).await?); - let empty_schema = DFSchemaRef::new(DFSchema::empty()); - let offset = match &fetch.offset_mode { - Some(fetch_rel::OffsetMode::Offset(offset)) => Some(lit(*offset)), - Some(fetch_rel::OffsetMode::OffsetExpr(expr)) => { - Some(consumer.consume_expression(expr, &empty_schema).await?) - } - None => None, - }; - let count = match &fetch.count_mode { - Some(fetch_rel::CountMode::Count(count)) => { - // -1 means that ALL records should be returned, equivalent to None - (*count != -1).then(|| lit(*count)) - } - Some(fetch_rel::CountMode::CountExpr(expr)) => { - Some(consumer.consume_expression(expr, &empty_schema).await?) - } - None => None, - }; - input.limit_by_expr(offset, count)?.build() - } else { - not_impl_err!("Fetch without an input is not valid") - } -} - -pub async fn from_sort_rel( - consumer: &impl SubstraitConsumer, - sort: &SortRel, -) -> Result { - if let Some(input) = sort.input.as_ref() { - let input = LogicalPlanBuilder::from(consumer.consume_rel(input).await?); - let sorts = from_substrait_sorts(consumer, &sort.sorts, input.schema()).await?; - input.sort(sorts)?.build() - } else { - not_impl_err!("Sort without an input is not valid") - } -} - -pub async fn from_aggregate_rel( - consumer: &impl SubstraitConsumer, - agg: &AggregateRel, -) -> Result { - if let Some(input) = agg.input.as_ref() { - let input = LogicalPlanBuilder::from(consumer.consume_rel(input).await?); - let mut ref_group_exprs = vec![]; - - for e in &agg.grouping_expressions { - let x = consumer.consume_expression(e, input.schema()).await?; - ref_group_exprs.push(x); - } - - let mut group_exprs = vec![]; - let mut aggr_exprs = vec![]; - - match agg.groupings.len() { - 1 => { - group_exprs.extend_from_slice( - &from_substrait_grouping( - consumer, - &agg.groupings[0], - &ref_group_exprs, - input.schema(), - ) - .await?, - ); - } - _ => { - let mut grouping_sets = vec![]; - for grouping in &agg.groupings { - let grouping_set = from_substrait_grouping( - consumer, - grouping, - &ref_group_exprs, - input.schema(), - ) - .await?; - grouping_sets.push(grouping_set); - } - // Single-element grouping expression of type Expr::GroupingSet. - // Note that GroupingSet::Rollup would become GroupingSet::GroupingSets, when - // parsed by the producer and consumer, since Substrait does not have a type dedicated - // to ROLLUP. Only vector of Groupings (grouping sets) is available. - group_exprs - .push(Expr::GroupingSet(GroupingSet::GroupingSets(grouping_sets))); - } - }; - - for m in &agg.measures { - let filter = match &m.filter { - Some(fil) => Some(Box::new( - consumer.consume_expression(fil, input.schema()).await?, - )), - None => None, - }; - let agg_func = match &m.measure { - Some(f) => { - let distinct = match f.invocation { - _ if f.invocation == AggregationInvocation::Distinct as i32 => { - true - } - _ if f.invocation == AggregationInvocation::All as i32 => false, - _ => false, - }; - let order_by = if !f.sorts.is_empty() { - Some( - from_substrait_sorts(consumer, &f.sorts, input.schema()) - .await?, - ) - } else { - None - }; - - from_substrait_agg_func( - consumer, - f, - input.schema(), - filter, - order_by, - distinct, - ) - .await - } - None => { - not_impl_err!("Aggregate without aggregate function is not supported") - } - }; - aggr_exprs.push(agg_func?.as_ref().clone()); - } - input.aggregate(group_exprs, aggr_exprs)?.build() - } else { - not_impl_err!("Aggregate without an input is not valid") - } -} - -pub async fn from_join_rel( - consumer: &impl SubstraitConsumer, - join: &JoinRel, -) -> Result { - if join.post_join_filter.is_some() { - return not_impl_err!("JoinRel with post_join_filter is not yet supported"); - } - - let left: LogicalPlanBuilder = LogicalPlanBuilder::from( - consumer.consume_rel(join.left.as_ref().unwrap()).await?, - ); - let right = LogicalPlanBuilder::from( - consumer.consume_rel(join.right.as_ref().unwrap()).await?, - ); - let (left, right) = requalify_sides_if_needed(left, right)?; - - let join_type = from_substrait_jointype(join.r#type)?; - // The join condition expression needs full input schema and not the output schema from join since we lose columns from - // certain join types such as semi and anti joins - let in_join_schema = left.schema().join(right.schema())?; - - // If join expression exists, parse the `on` condition expression, build join and return - // Otherwise, build join with only the filter, without join keys - match &join.expression.as_ref() { - Some(expr) => { - let on = consumer.consume_expression(expr, &in_join_schema).await?; - // The join expression can contain both equal and non-equal ops. - // As of datafusion 31.0.0, the equal and non equal join conditions are in separate fields. - // So we extract each part as follows: - // - If an Eq or IsNotDistinctFrom op is encountered, add the left column, right column and is_null_equal_nulls to `join_ons` vector - // - Otherwise we add the expression to join_filter (use conjunction if filter already exists) - let (join_ons, nulls_equal_nulls, join_filter) = - split_eq_and_noneq_join_predicate_with_nulls_equality(&on); - let (left_cols, right_cols): (Vec<_>, Vec<_>) = - itertools::multiunzip(join_ons); - left.join_detailed( - right.build()?, - join_type, - (left_cols, right_cols), - join_filter, - nulls_equal_nulls, - )? - .build() - } - None => { - let on: Vec = vec![]; - left.join_detailed(right.build()?, join_type, (on.clone(), on), None, false)? - .build() - } - } -} - -pub async fn from_cross_rel( - consumer: &impl SubstraitConsumer, - cross: &CrossRel, -) -> Result { - let left = LogicalPlanBuilder::from( - consumer.consume_rel(cross.left.as_ref().unwrap()).await?, - ); - let right = LogicalPlanBuilder::from( - consumer.consume_rel(cross.right.as_ref().unwrap()).await?, - ); - let (left, right) = requalify_sides_if_needed(left, right)?; - left.cross_join(right.build()?)?.build() -} - -#[allow(deprecated)] -pub async fn from_read_rel( - consumer: &impl SubstraitConsumer, - read: &ReadRel, -) -> Result { - async fn read_with_schema( - consumer: &impl SubstraitConsumer, - table_ref: TableReference, - schema: DFSchema, - projection: &Option, - filter: &Option>, - ) -> Result { - let schema = schema.replace_qualifier(table_ref.clone()); - - let filters = if let Some(f) = filter { - let filter_expr = consumer.consume_expression(f, &schema).await?; - split_conjunction_owned(filter_expr) - } else { - vec![] - }; - - let plan = { - let provider = match consumer.resolve_table_ref(&table_ref).await? { - Some(ref provider) => Arc::clone(provider), - _ => return plan_err!("No table named '{table_ref}'"), - }; - - LogicalPlanBuilder::scan_with_filters( - table_ref, - provider_as_source(Arc::clone(&provider)), - None, - filters, - )? - .build()? - }; - - ensure_schema_compatibility(plan.schema(), schema.clone())?; - - let schema = apply_masking(schema, projection)?; - - apply_projection(plan, schema) - } - - let named_struct = read.base_schema.as_ref().ok_or_else(|| { - substrait_datafusion_err!("No base schema provided for Read Relation") - })?; - - let substrait_schema = from_substrait_named_struct(consumer, named_struct)?; - - match &read.read_type { - Some(ReadType::NamedTable(nt)) => { - let table_reference = match nt.names.len() { - 0 => { - return plan_err!("No table name found in NamedTable"); - } - 1 => TableReference::Bare { - table: nt.names[0].clone().into(), - }, - 2 => TableReference::Partial { - schema: nt.names[0].clone().into(), - table: nt.names[1].clone().into(), - }, - _ => TableReference::Full { - catalog: nt.names[0].clone().into(), - schema: nt.names[1].clone().into(), - table: nt.names[2].clone().into(), - }, - }; - - read_with_schema( - consumer, - table_reference, - substrait_schema, - &read.projection, - &read.filter, - ) - .await - } - Some(ReadType::VirtualTable(vt)) => { - if vt.values.is_empty() { - return Ok(LogicalPlan::EmptyRelation(EmptyRelation { - produce_one_row: false, - schema: DFSchemaRef::new(substrait_schema), - })); - } - - let values = vt - .values - .iter() - .map(|row| { - let mut name_idx = 0; - let lits = row - .fields - .iter() - .map(|lit| { - name_idx += 1; // top-level names are provided through schema - Ok(Expr::Literal(from_substrait_literal( - consumer, - lit, - &named_struct.names, - &mut name_idx, - )?)) - }) - .collect::>()?; - if name_idx != named_struct.names.len() { - return substrait_err!( - "Names list must match exactly to nested schema, but found {} uses for {} names", - name_idx, - named_struct.names.len() - ); - } - Ok(lits) - }) - .collect::>()?; - - Ok(LogicalPlan::Values(Values { - schema: DFSchemaRef::new(substrait_schema), - values, - })) - } - Some(ReadType::LocalFiles(lf)) => { - fn extract_filename(name: &str) -> Option { - let corrected_url = - if name.starts_with("file://") && !name.starts_with("file:///") { - name.replacen("file://", "file:///", 1) - } else { - name.to_string() - }; - - Url::parse(&corrected_url).ok().and_then(|url| { - let path = url.path(); - std::path::Path::new(path) - .file_name() - .map(|filename| filename.to_string_lossy().to_string()) - }) - } - - // we could use the file name to check the original table provider - // TODO: currently does not support multiple local files - let filename: Option = - lf.items.first().and_then(|x| match x.path_type.as_ref() { - Some(UriFile(name)) => extract_filename(name), - _ => None, - }); - - if lf.items.len() > 1 || filename.is_none() { - return not_impl_err!("Only single file reads are supported"); - } - let name = filename.unwrap(); - // directly use unwrap here since we could determine it is a valid one - let table_reference = TableReference::Bare { table: name.into() }; - - read_with_schema( - consumer, - table_reference, - substrait_schema, - &read.projection, - &read.filter, - ) - .await - } - _ => { - not_impl_err!("Unsupported ReadType: {:?}", read.read_type) - } - } -} - -pub async fn from_set_rel( - consumer: &impl SubstraitConsumer, - set: &SetRel, -) -> Result { - if set.inputs.len() < 2 { - substrait_err!("Set operation requires at least two inputs") - } else { - match set.op() { - SetOp::UnionAll => union_rels(consumer, &set.inputs, true).await, - SetOp::UnionDistinct => union_rels(consumer, &set.inputs, false).await, - SetOp::IntersectionPrimary => LogicalPlanBuilder::intersect( - consumer.consume_rel(&set.inputs[0]).await?, - union_rels(consumer, &set.inputs[1..], true).await?, - false, - ), - SetOp::IntersectionMultiset => { - intersect_rels(consumer, &set.inputs, false).await - } - SetOp::IntersectionMultisetAll => { - intersect_rels(consumer, &set.inputs, true).await - } - SetOp::MinusPrimary => except_rels(consumer, &set.inputs, false).await, - SetOp::MinusPrimaryAll => except_rels(consumer, &set.inputs, true).await, - set_op => not_impl_err!("Unsupported set operator: {set_op:?}"), - } - } -} - -pub async fn from_exchange_rel( - consumer: &impl SubstraitConsumer, - exchange: &ExchangeRel, -) -> Result { - let Some(input) = exchange.input.as_ref() else { - return substrait_err!("Unexpected empty input in ExchangeRel"); - }; - let input = Arc::new(consumer.consume_rel(input).await?); - - let Some(exchange_kind) = &exchange.exchange_kind else { - return substrait_err!("Unexpected empty input in ExchangeRel"); - }; - - // ref: https://substrait.io/relations/physical_relations/#exchange-types - let partitioning_scheme = match exchange_kind { - ExchangeKind::ScatterByFields(scatter_fields) => { - let mut partition_columns = vec![]; - let input_schema = input.schema(); - for field_ref in &scatter_fields.fields { - let column = from_substrait_field_reference(field_ref, input_schema)?; - partition_columns.push(column); - } - Partitioning::Hash(partition_columns, exchange.partition_count as usize) - } - ExchangeKind::RoundRobin(_) => { - Partitioning::RoundRobinBatch(exchange.partition_count as usize) - } - ExchangeKind::SingleTarget(_) - | ExchangeKind::MultiTarget(_) - | ExchangeKind::Broadcast(_) => { - return not_impl_err!("Unsupported exchange kind: {exchange_kind:?}"); - } - }; - Ok(LogicalPlan::Repartition(Repartition { - input, - partitioning_scheme, - })) -} - -fn retrieve_rel_common(rel: &Rel) -> Option<&RelCommon> { - match rel.rel_type.as_ref() { - None => None, - Some(rt) => match rt { - RelType::Read(r) => r.common.as_ref(), - RelType::Filter(f) => f.common.as_ref(), - RelType::Fetch(f) => f.common.as_ref(), - RelType::Aggregate(a) => a.common.as_ref(), - RelType::Sort(s) => s.common.as_ref(), - RelType::Join(j) => j.common.as_ref(), - RelType::Project(p) => p.common.as_ref(), - RelType::Set(s) => s.common.as_ref(), - RelType::ExtensionSingle(e) => e.common.as_ref(), - RelType::ExtensionMulti(e) => e.common.as_ref(), - RelType::ExtensionLeaf(e) => e.common.as_ref(), - RelType::Cross(c) => c.common.as_ref(), - RelType::Reference(_) => None, - RelType::Write(w) => w.common.as_ref(), - RelType::Ddl(d) => d.common.as_ref(), - RelType::HashJoin(j) => j.common.as_ref(), - RelType::MergeJoin(j) => j.common.as_ref(), - RelType::NestedLoopJoin(j) => j.common.as_ref(), - RelType::Window(w) => w.common.as_ref(), - RelType::Exchange(e) => e.common.as_ref(), - RelType::Expand(e) => e.common.as_ref(), - RelType::Update(_) => None, - }, - } -} - -fn retrieve_emit_kind(rel_common: Option<&RelCommon>) -> EmitKind { - // the default EmitKind is Direct if it is not set explicitly - let default = EmitKind::Direct(rel_common::Direct {}); - rel_common - .and_then(|rc| rc.emit_kind.as_ref()) - .map_or(default, |ek| ek.clone()) -} - -fn contains_volatile_expr(proj: &Projection) -> bool { - proj.expr.iter().any(|e| e.is_volatile()) -} - -fn apply_emit_kind( - rel_common: Option<&RelCommon>, - plan: LogicalPlan, -) -> Result { - match retrieve_emit_kind(rel_common) { - EmitKind::Direct(_) => Ok(plan), - EmitKind::Emit(Emit { output_mapping }) => { - // It is valid to reference the same field multiple times in the Emit - // In this case, we need to provide unique names to avoid collisions - let mut name_tracker = NameTracker::new(); - match plan { - // To avoid adding a projection on top of a projection, we apply special case - // handling to flatten Substrait Emits. This is only applicable if none of the - // expressions in the projection are volatile. This is to avoid issues like - // converting a single call of the random() function into multiple calls due to - // duplicate fields in the output_mapping. - LogicalPlan::Projection(proj) if !contains_volatile_expr(&proj) => { - let mut exprs: Vec = vec![]; - for field in output_mapping { - let expr = proj.expr - .get(field as usize) - .ok_or_else(|| substrait_datafusion_err!( - "Emit output field {} cannot be resolved in input schema {}", - field, proj.input.schema() - ))?; - exprs.push(name_tracker.get_uniquely_named_expr(expr.clone())?); - } - - let input = Arc::unwrap_or_clone(proj.input); - project(input, exprs) - } - // Otherwise we just handle the output_mapping as a projection - _ => { - let input_schema = plan.schema(); - - let mut exprs: Vec = vec![]; - for index in output_mapping.into_iter() { - let column = Expr::Column(Column::from( - input_schema.qualified_field(index as usize), - )); - let expr = name_tracker.get_uniquely_named_expr(column)?; - exprs.push(expr); - } - - project(plan, exprs) - } - } - } - } -} - -struct NameTracker { - seen_names: HashSet, -} - -enum NameTrackerStatus { - NeverSeen, - SeenBefore, -} - -impl NameTracker { - fn new() -> Self { - NameTracker { - seen_names: HashSet::default(), - } - } - fn get_unique_name(&mut self, name: String) -> (String, NameTrackerStatus) { - match self.seen_names.insert(name.clone()) { - true => (name, NameTrackerStatus::NeverSeen), - false => { - let mut counter = 0; - loop { - let candidate_name = format!("{}__temp__{}", name, counter); - if self.seen_names.insert(candidate_name.clone()) { - return (candidate_name, NameTrackerStatus::SeenBefore); - } - counter += 1; - } - } - } - } - - fn get_uniquely_named_expr(&mut self, expr: Expr) -> Result { - match self.get_unique_name(expr.name_for_alias()?) { - (_, NameTrackerStatus::NeverSeen) => Ok(expr), - (name, NameTrackerStatus::SeenBefore) => Ok(expr.alias(name)), - } - } -} - -/// Ensures that the given Substrait schema is compatible with the schema as given by DataFusion -/// -/// This means: -/// 1. All fields present in the Substrait schema are present in the DataFusion schema. The -/// DataFusion schema may have MORE fields, but not the other way around. -/// 2. All fields are compatible. See [`ensure_field_compatibility`] for details -fn ensure_schema_compatibility( - table_schema: &DFSchema, - substrait_schema: DFSchema, -) -> Result<()> { - substrait_schema - .strip_qualifiers() - .fields() - .iter() - .try_for_each(|substrait_field| { - let df_field = - table_schema.field_with_unqualified_name(substrait_field.name())?; - ensure_field_compatibility(df_field, substrait_field) - }) -} - -/// This function returns a DataFrame with fields adjusted if necessary in the event that the -/// Substrait schema is a subset of the DataFusion schema. -fn apply_projection( - plan: LogicalPlan, - substrait_schema: DFSchema, -) -> Result { - let df_schema = plan.schema(); - - if df_schema.logically_equivalent_names_and_types(&substrait_schema) { - return Ok(plan); - } - - let df_schema = df_schema.to_owned(); - - match plan { - LogicalPlan::TableScan(mut scan) => { - let column_indices: Vec = substrait_schema - .strip_qualifiers() - .fields() - .iter() - .map(|substrait_field| { - Ok(df_schema - .index_of_column_by_name(None, substrait_field.name().as_str()) - .unwrap()) - }) - .collect::>()?; - - let fields = column_indices - .iter() - .map(|i| df_schema.qualified_field(*i)) - .map(|(qualifier, field)| (qualifier.cloned(), Arc::new(field.clone()))) - .collect(); - - scan.projected_schema = DFSchemaRef::new(DFSchema::new_with_metadata( - fields, - df_schema.metadata().clone(), - )?); - scan.projection = Some(column_indices); - - Ok(LogicalPlan::TableScan(scan)) - } - _ => plan_err!("DataFrame passed to apply_projection must be a TableScan"), - } -} - -/// Ensures that the given Substrait field is compatible with the given DataFusion field -/// -/// A field is compatible between Substrait and DataFusion if: -/// 1. They have logically equivalent types. -/// 2. They have the same nullability OR the Substrait field is nullable and the DataFusion fields -/// is not nullable. -/// -/// If a Substrait field is not nullable, the Substrait plan may be built around assuming it is not -/// nullable. As such if DataFusion has that field as nullable the plan should be rejected. -fn ensure_field_compatibility( - datafusion_field: &Field, - substrait_field: &Field, -) -> Result<()> { - if !DFSchema::datatype_is_logically_equal( - datafusion_field.data_type(), - substrait_field.data_type(), - ) { - return substrait_err!( - "Field '{}' in Substrait schema has a different type ({}) than the corresponding field in the table schema ({}).", - substrait_field.name(), - substrait_field.data_type(), - datafusion_field.data_type() - ); - } - - if !compatible_nullabilities( - datafusion_field.is_nullable(), - substrait_field.is_nullable(), - ) { - // TODO: from_substrait_struct_type needs to be updated to set the nullability correctly. It defaults to true for now. - return substrait_err!( - "Field '{}' is nullable in the DataFusion schema but not nullable in the Substrait schema.", - substrait_field.name() - ); - } - Ok(()) -} - -/// Returns true if the DataFusion and Substrait nullabilities are compatible, false otherwise -fn compatible_nullabilities( - datafusion_nullability: bool, - substrait_nullability: bool, -) -> bool { - // DataFusion and Substrait have the same nullability - (datafusion_nullability == substrait_nullability) - // DataFusion is not nullable and Substrait is nullable - || (!datafusion_nullability && substrait_nullability) -} - -/// (Re)qualify the sides of a join if needed, i.e. if the columns from one side would otherwise -/// conflict with the columns from the other. -/// Substrait doesn't currently allow specifying aliases, neither for columns nor for tables. For -/// Substrait the names don't matter since it only refers to columns by indices, however DataFusion -/// requires columns to be uniquely identifiable, in some places (see e.g. DFSchema::check_names). -fn requalify_sides_if_needed( - left: LogicalPlanBuilder, - right: LogicalPlanBuilder, -) -> Result<(LogicalPlanBuilder, LogicalPlanBuilder)> { - let left_cols = left.schema().columns(); - let right_cols = right.schema().columns(); - if left_cols.iter().any(|l| { - right_cols.iter().any(|r| { - l == r || (l.name == r.name && (l.relation.is_none() || r.relation.is_none())) - }) - }) { - // These names have no connection to the original plan, but they'll make the columns - // (mostly) unique. - Ok(( - left.alias(TableReference::bare("left"))?, - right.alias(TableReference::bare("right"))?, - )) - } else { - Ok((left, right)) - } -} - -fn from_substrait_jointype(join_type: i32) -> Result { - if let Ok(substrait_join_type) = join_rel::JoinType::try_from(join_type) { - match substrait_join_type { - join_rel::JoinType::Inner => Ok(JoinType::Inner), - join_rel::JoinType::Left => Ok(JoinType::Left), - join_rel::JoinType::Right => Ok(JoinType::Right), - join_rel::JoinType::Outer => Ok(JoinType::Full), - join_rel::JoinType::LeftAnti => Ok(JoinType::LeftAnti), - join_rel::JoinType::LeftSemi => Ok(JoinType::LeftSemi), - join_rel::JoinType::LeftMark => Ok(JoinType::LeftMark), - _ => plan_err!("unsupported join type {substrait_join_type:?}"), - } - } else { - plan_err!("invalid join type variant {join_type:?}") - } -} - -/// Convert Substrait Sorts to DataFusion Exprs -pub async fn from_substrait_sorts( - consumer: &impl SubstraitConsumer, - substrait_sorts: &Vec, - input_schema: &DFSchema, -) -> Result> { - let mut sorts: Vec = vec![]; - for s in substrait_sorts { - let expr = consumer - .consume_expression(s.expr.as_ref().unwrap(), input_schema) - .await?; - let asc_nullfirst = match &s.sort_kind { - Some(k) => match k { - Direction(d) => { - let Ok(direction) = SortDirection::try_from(*d) else { - return not_impl_err!( - "Unsupported Substrait SortDirection value {d}" - ); - }; - - match direction { - SortDirection::AscNullsFirst => Ok((true, true)), - SortDirection::AscNullsLast => Ok((true, false)), - SortDirection::DescNullsFirst => Ok((false, true)), - SortDirection::DescNullsLast => Ok((false, false)), - SortDirection::Clustered => not_impl_err!( - "Sort with direction clustered is not yet supported" - ), - SortDirection::Unspecified => { - not_impl_err!("Unspecified sort direction is invalid") - } - } - } - ComparisonFunctionReference(_) => not_impl_err!( - "Sort using comparison function reference is not supported" - ), - }, - None => not_impl_err!("Sort without sort kind is invalid"), - }; - let (asc, nulls_first) = asc_nullfirst.unwrap(); - sorts.push(Sort { - expr, - asc, - nulls_first, - }); - } - Ok(sorts) -} - -/// Convert Substrait Expressions to DataFusion Exprs -pub async fn from_substrait_rex_vec( - consumer: &impl SubstraitConsumer, - exprs: &Vec, - input_schema: &DFSchema, -) -> Result> { - let mut expressions: Vec = vec![]; - for expr in exprs { - let expression = consumer.consume_expression(expr, input_schema).await?; - expressions.push(expression); - } - Ok(expressions) -} - -/// Convert Substrait FunctionArguments to DataFusion Exprs -pub async fn from_substrait_func_args( - consumer: &impl SubstraitConsumer, - arguments: &Vec, - input_schema: &DFSchema, -) -> Result> { - let mut args: Vec = vec![]; - for arg in arguments { - let arg_expr = match &arg.arg_type { - Some(ArgType::Value(e)) => consumer.consume_expression(e, input_schema).await, - _ => not_impl_err!("Function argument non-Value type not supported"), - }; - args.push(arg_expr?); - } - Ok(args) -} - -/// Convert Substrait AggregateFunction to DataFusion Expr -pub async fn from_substrait_agg_func( - consumer: &impl SubstraitConsumer, - f: &AggregateFunction, - input_schema: &DFSchema, - filter: Option>, - order_by: Option>, - distinct: bool, -) -> Result> { - let Some(fn_signature) = consumer - .get_extensions() - .functions - .get(&f.function_reference) - else { - return plan_err!( - "Aggregate function not registered: function anchor = {:?}", - f.function_reference - ); - }; - - let fn_name = substrait_fun_name(fn_signature); - let udaf = consumer.get_function_registry().udaf(fn_name); - let udaf = udaf.map_err(|_| { - not_impl_datafusion_err!( - "Aggregate function {} is not supported: function anchor = {:?}", - fn_signature, - f.function_reference - ) - })?; - - let args = from_substrait_func_args(consumer, &f.arguments, input_schema).await?; - - // Datafusion does not support aggregate functions with no arguments, so - // we inject a dummy argument that does not affect the query, but allows - // us to bypass this limitation. - let args = if udaf.name() == "count" && args.is_empty() { - vec![Expr::Literal(ScalarValue::Int64(Some(1)))] - } else { - args - }; - - Ok(Arc::new(Expr::AggregateFunction( - expr::AggregateFunction::new_udf(udaf, args, distinct, filter, order_by, None), - ))) -} - -/// Convert Substrait Rex to DataFusion Expr -pub async fn from_substrait_rex( - consumer: &impl SubstraitConsumer, - expression: &Expression, - input_schema: &DFSchema, -) -> Result { - match &expression.rex_type { - Some(t) => match t { - RexType::Literal(expr) => consumer.consume_literal(expr).await, - RexType::Selection(expr) => { - consumer.consume_field_reference(expr, input_schema).await - } - RexType::ScalarFunction(expr) => { - consumer.consume_scalar_function(expr, input_schema).await - } - RexType::WindowFunction(expr) => { - consumer.consume_window_function(expr, input_schema).await - } - RexType::IfThen(expr) => consumer.consume_if_then(expr, input_schema).await, - RexType::SwitchExpression(expr) => { - consumer.consume_switch(expr, input_schema).await - } - RexType::SingularOrList(expr) => { - consumer.consume_singular_or_list(expr, input_schema).await - } - - RexType::MultiOrList(expr) => { - consumer.consume_multi_or_list(expr, input_schema).await - } - - RexType::Cast(expr) => { - consumer.consume_cast(expr.as_ref(), input_schema).await - } - - RexType::Subquery(expr) => { - consumer.consume_subquery(expr.as_ref(), input_schema).await - } - RexType::Nested(expr) => consumer.consume_nested(expr, input_schema).await, - RexType::Enum(expr) => consumer.consume_enum(expr, input_schema).await, - RexType::DynamicParameter(expr) => { - consumer.consume_dynamic_parameter(expr, input_schema).await - } - }, - None => substrait_err!("Expression must set rex_type: {:?}", expression), - } -} - -pub async fn from_singular_or_list( - consumer: &impl SubstraitConsumer, - expr: &SingularOrList, - input_schema: &DFSchema, -) -> Result { - let substrait_expr = expr.value.as_ref().unwrap(); - let substrait_list = expr.options.as_ref(); - Ok(Expr::InList(InList { - expr: Box::new( - consumer - .consume_expression(substrait_expr, input_schema) - .await?, - ), - list: from_substrait_rex_vec(consumer, substrait_list, input_schema).await?, - negated: false, - })) -} - -pub async fn from_field_reference( - _consumer: &impl SubstraitConsumer, - field_ref: &FieldReference, - input_schema: &DFSchema, -) -> Result { - from_substrait_field_reference(field_ref, input_schema) -} - -pub async fn from_if_then( - consumer: &impl SubstraitConsumer, - if_then: &IfThen, - input_schema: &DFSchema, -) -> Result { - // Parse `ifs` - // If the first element does not have a `then` part, then we can assume it's a base expression - let mut when_then_expr: Vec<(Box, Box)> = vec![]; - let mut expr = None; - for (i, if_expr) in if_then.ifs.iter().enumerate() { - if i == 0 { - // Check if the first element is type base expression - if if_expr.then.is_none() { - expr = Some(Box::new( - consumer - .consume_expression(if_expr.r#if.as_ref().unwrap(), input_schema) - .await?, - )); - continue; - } - } - when_then_expr.push(( - Box::new( - consumer - .consume_expression(if_expr.r#if.as_ref().unwrap(), input_schema) - .await?, - ), - Box::new( - consumer - .consume_expression(if_expr.then.as_ref().unwrap(), input_schema) - .await?, - ), - )); - } - // Parse `else` - let else_expr = match &if_then.r#else { - Some(e) => Some(Box::new( - consumer.consume_expression(e, input_schema).await?, - )), - None => None, - }; - Ok(Expr::Case(Case { - expr, - when_then_expr, - else_expr, - })) -} - -pub async fn from_scalar_function( - consumer: &impl SubstraitConsumer, - f: &ScalarFunction, - input_schema: &DFSchema, -) -> Result { - let Some(fn_signature) = consumer - .get_extensions() - .functions - .get(&f.function_reference) - else { - return plan_err!( - "Scalar function not found: function reference = {:?}", - f.function_reference - ); - }; - let fn_name = substrait_fun_name(fn_signature); - let args = from_substrait_func_args(consumer, &f.arguments, input_schema).await?; - - // try to first match the requested function into registered udfs, then built-in ops - // and finally built-in expressions - if let Ok(func) = consumer.get_function_registry().udf(fn_name) { - Ok(Expr::ScalarFunction(expr::ScalarFunction::new_udf( - func.to_owned(), - args, - ))) - } else if let Some(op) = name_to_op(fn_name) { - if f.arguments.len() < 2 { - return not_impl_err!( - "Expect at least two arguments for binary operator {op:?}, the provided number of operators is {:?}", - f.arguments.len() - ); - } - // Some expressions are binary in DataFusion but take in a variadic number of args in Substrait. - // In those cases we iterate through all the arguments, applying the binary expression against them all - let combined_expr = args - .into_iter() - .fold(None, |combined_expr: Option, arg: Expr| { - Some(match combined_expr { - Some(expr) => Expr::BinaryExpr(BinaryExpr { - left: Box::new(expr), - op, - right: Box::new(arg), - }), - None => arg, - }) - }) - .unwrap(); - - Ok(combined_expr) - } else if let Some(builder) = BuiltinExprBuilder::try_from_name(fn_name) { - builder.build(consumer, f, input_schema).await - } else { - not_impl_err!("Unsupported function name: {fn_name:?}") - } -} - -pub async fn from_literal( - consumer: &impl SubstraitConsumer, - expr: &Literal, -) -> Result { - let scalar_value = from_substrait_literal_without_names(consumer, expr)?; - Ok(Expr::Literal(scalar_value)) -} - -pub async fn from_cast( - consumer: &impl SubstraitConsumer, - cast: &substrait_expression::Cast, - input_schema: &DFSchema, -) -> Result { - match cast.r#type.as_ref() { - Some(output_type) => { - let input_expr = Box::new( - consumer - .consume_expression( - cast.input.as_ref().unwrap().as_ref(), - input_schema, - ) - .await?, - ); - let data_type = from_substrait_type_without_names(consumer, output_type)?; - if cast.failure_behavior() == ReturnNull { - Ok(Expr::TryCast(TryCast::new(input_expr, data_type))) - } else { - Ok(Expr::Cast(Cast::new(input_expr, data_type))) - } - } - None => substrait_err!("Cast expression without output type is not allowed"), - } -} - -pub async fn from_window_function( - consumer: &impl SubstraitConsumer, - window: &WindowFunction, - input_schema: &DFSchema, -) -> Result { - let Some(fn_signature) = consumer - .get_extensions() - .functions - .get(&window.function_reference) - else { - return plan_err!( - "Window function not found: function reference = {:?}", - window.function_reference - ); - }; - let fn_name = substrait_fun_name(fn_signature); - - // check udwf first, then udaf, then built-in window and aggregate functions - let fun = if let Ok(udwf) = consumer.get_function_registry().udwf(fn_name) { - Ok(WindowFunctionDefinition::WindowUDF(udwf)) - } else if let Ok(udaf) = consumer.get_function_registry().udaf(fn_name) { - Ok(WindowFunctionDefinition::AggregateUDF(udaf)) - } else { - not_impl_err!( - "Window function {} is not supported: function anchor = {:?}", - fn_name, - window.function_reference - ) - }?; - - let mut order_by = - from_substrait_sorts(consumer, &window.sorts, input_schema).await?; - - let bound_units = match BoundsType::try_from(window.bounds_type).map_err(|e| { - plan_datafusion_err!("Invalid bound type {}: {e}", window.bounds_type) - })? { - BoundsType::Rows => WindowFrameUnits::Rows, - BoundsType::Range => WindowFrameUnits::Range, - BoundsType::Unspecified => { - // If the plan does not specify the bounds type, then we use a simple logic to determine the units - // If there is no `ORDER BY`, then by default, the frame counts each row from the lower up to upper boundary - // If there is `ORDER BY`, then by default, each frame is a range starting from unbounded preceding to current row - if order_by.is_empty() { - WindowFrameUnits::Rows - } else { - WindowFrameUnits::Range - } - } - }; - let window_frame = datafusion::logical_expr::WindowFrame::new_bounds( - bound_units, - from_substrait_bound(&window.lower_bound, true)?, - from_substrait_bound(&window.upper_bound, false)?, - ); - - window_frame.regularize_order_bys(&mut order_by)?; - - // Datafusion does not support aggregate functions with no arguments, so - // we inject a dummy argument that does not affect the query, but allows - // us to bypass this limitation. - let args = if fun.name() == "count" && window.arguments.is_empty() { - vec![Expr::Literal(ScalarValue::Int64(Some(1)))] - } else { - from_substrait_func_args(consumer, &window.arguments, input_schema).await? - }; - - Ok(Expr::WindowFunction(expr::WindowFunction { - fun, - params: WindowFunctionParams { - args, - partition_by: from_substrait_rex_vec( - consumer, - &window.partitions, - input_schema, - ) - .await?, - order_by, - window_frame, - null_treatment: None, - }, - })) -} - -pub async fn from_subquery( - consumer: &impl SubstraitConsumer, - subquery: &substrait_expression::Subquery, - input_schema: &DFSchema, -) -> Result { - match &subquery.subquery_type { - Some(subquery_type) => match subquery_type { - SubqueryType::InPredicate(in_predicate) => { - if in_predicate.needles.len() != 1 { - substrait_err!("InPredicate Subquery type must have exactly one Needle expression") - } else { - let needle_expr = &in_predicate.needles[0]; - let haystack_expr = &in_predicate.haystack; - if let Some(haystack_expr) = haystack_expr { - let haystack_expr = consumer.consume_rel(haystack_expr).await?; - let outer_refs = haystack_expr.all_out_ref_exprs(); - Ok(Expr::InSubquery(InSubquery { - expr: Box::new( - consumer - .consume_expression(needle_expr, input_schema) - .await?, - ), - subquery: Subquery { - subquery: Arc::new(haystack_expr), - outer_ref_columns: outer_refs, - spans: Spans::new(), - }, - negated: false, - })) - } else { - substrait_err!( - "InPredicate Subquery type must have a Haystack expression" - ) - } - } - } - SubqueryType::Scalar(query) => { - let plan = consumer - .consume_rel(&(query.input.clone()).unwrap_or_default()) - .await?; - let outer_ref_columns = plan.all_out_ref_exprs(); - Ok(Expr::ScalarSubquery(Subquery { - subquery: Arc::new(plan), - outer_ref_columns, - spans: Spans::new(), - })) - } - SubqueryType::SetPredicate(predicate) => { - match predicate.predicate_op() { - // exist - PredicateOp::Exists => { - let relation = &predicate.tuples; - let plan = consumer - .consume_rel(&relation.clone().unwrap_or_default()) - .await?; - let outer_ref_columns = plan.all_out_ref_exprs(); - Ok(Expr::Exists(Exists::new( - Subquery { - subquery: Arc::new(plan), - outer_ref_columns, - spans: Spans::new(), - }, - false, - ))) - } - other_type => substrait_err!( - "unimplemented type {:?} for set predicate", - other_type - ), - } - } - other_type => { - substrait_err!("Subquery type {:?} not implemented", other_type) - } - }, - None => { - substrait_err!("Subquery expression without SubqueryType is not allowed") - } - } -} - -pub(crate) fn from_substrait_type_without_names( - consumer: &impl SubstraitConsumer, - dt: &Type, -) -> Result { - from_substrait_type(consumer, dt, &[], &mut 0) -} - -fn from_substrait_type( - consumer: &impl SubstraitConsumer, - dt: &Type, - dfs_names: &[String], - name_idx: &mut usize, -) -> Result { - match &dt.kind { - Some(s_kind) => match s_kind { - r#type::Kind::Bool(_) => Ok(DataType::Boolean), - r#type::Kind::I8(integer) => match integer.type_variation_reference { - DEFAULT_TYPE_VARIATION_REF => Ok(DataType::Int8), - UNSIGNED_INTEGER_TYPE_VARIATION_REF => Ok(DataType::UInt8), - v => not_impl_err!( - "Unsupported Substrait type variation {v} of type {s_kind:?}" - ), - }, - r#type::Kind::I16(integer) => match integer.type_variation_reference { - DEFAULT_TYPE_VARIATION_REF => Ok(DataType::Int16), - UNSIGNED_INTEGER_TYPE_VARIATION_REF => Ok(DataType::UInt16), - v => not_impl_err!( - "Unsupported Substrait type variation {v} of type {s_kind:?}" - ), - }, - r#type::Kind::I32(integer) => match integer.type_variation_reference { - DEFAULT_TYPE_VARIATION_REF => Ok(DataType::Int32), - UNSIGNED_INTEGER_TYPE_VARIATION_REF => Ok(DataType::UInt32), - v => not_impl_err!( - "Unsupported Substrait type variation {v} of type {s_kind:?}" - ), - }, - r#type::Kind::I64(integer) => match integer.type_variation_reference { - DEFAULT_TYPE_VARIATION_REF => Ok(DataType::Int64), - UNSIGNED_INTEGER_TYPE_VARIATION_REF => Ok(DataType::UInt64), - v => not_impl_err!( - "Unsupported Substrait type variation {v} of type {s_kind:?}" - ), - }, - r#type::Kind::Fp32(_) => Ok(DataType::Float32), - r#type::Kind::Fp64(_) => Ok(DataType::Float64), - r#type::Kind::Timestamp(ts) => { - // Kept for backwards compatibility, new plans should use PrecisionTimestamp(Tz) instead - #[allow(deprecated)] - match ts.type_variation_reference { - TIMESTAMP_SECOND_TYPE_VARIATION_REF => { - Ok(DataType::Timestamp(TimeUnit::Second, None)) - } - TIMESTAMP_MILLI_TYPE_VARIATION_REF => { - Ok(DataType::Timestamp(TimeUnit::Millisecond, None)) - } - TIMESTAMP_MICRO_TYPE_VARIATION_REF => { - Ok(DataType::Timestamp(TimeUnit::Microsecond, None)) - } - TIMESTAMP_NANO_TYPE_VARIATION_REF => { - Ok(DataType::Timestamp(TimeUnit::Nanosecond, None)) - } - v => not_impl_err!( - "Unsupported Substrait type variation {v} of type {s_kind:?}" - ), - } - } - r#type::Kind::PrecisionTimestamp(pts) => { - let unit = match pts.precision { - 0 => Ok(TimeUnit::Second), - 3 => Ok(TimeUnit::Millisecond), - 6 => Ok(TimeUnit::Microsecond), - 9 => Ok(TimeUnit::Nanosecond), - p => not_impl_err!( - "Unsupported Substrait precision {p} for PrecisionTimestamp" - ), - }?; - Ok(DataType::Timestamp(unit, None)) - } - r#type::Kind::PrecisionTimestampTz(pts) => { - let unit = match pts.precision { - 0 => Ok(TimeUnit::Second), - 3 => Ok(TimeUnit::Millisecond), - 6 => Ok(TimeUnit::Microsecond), - 9 => Ok(TimeUnit::Nanosecond), - p => not_impl_err!( - "Unsupported Substrait precision {p} for PrecisionTimestampTz" - ), - }?; - Ok(DataType::Timestamp(unit, Some(DEFAULT_TIMEZONE.into()))) - } - r#type::Kind::Date(date) => match date.type_variation_reference { - DATE_32_TYPE_VARIATION_REF => Ok(DataType::Date32), - DATE_64_TYPE_VARIATION_REF => Ok(DataType::Date64), - v => not_impl_err!( - "Unsupported Substrait type variation {v} of type {s_kind:?}" - ), - }, - r#type::Kind::Binary(binary) => match binary.type_variation_reference { - DEFAULT_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::Binary), - LARGE_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::LargeBinary), - VIEW_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::BinaryView), - v => not_impl_err!( - "Unsupported Substrait type variation {v} of type {s_kind:?}" - ), - }, - r#type::Kind::FixedBinary(fixed) => { - Ok(DataType::FixedSizeBinary(fixed.length)) - } - r#type::Kind::String(string) => match string.type_variation_reference { - DEFAULT_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::Utf8), - LARGE_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::LargeUtf8), - VIEW_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::Utf8View), - v => not_impl_err!( - "Unsupported Substrait type variation {v} of type {s_kind:?}" - ), - }, - r#type::Kind::List(list) => { - let inner_type = list.r#type.as_ref().ok_or_else(|| { - substrait_datafusion_err!("List type must have inner type") - })?; - let field = Arc::new(Field::new_list_field( - from_substrait_type(consumer, inner_type, dfs_names, name_idx)?, - // We ignore Substrait's nullability here to match to_substrait_literal - // which always creates nullable lists - true, - )); - match list.type_variation_reference { - DEFAULT_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::List(field)), - LARGE_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::LargeList(field)), - v => not_impl_err!( - "Unsupported Substrait type variation {v} of type {s_kind:?}" - )?, - } - } - r#type::Kind::Map(map) => { - let key_type = map.key.as_ref().ok_or_else(|| { - substrait_datafusion_err!("Map type must have key type") - })?; - let value_type = map.value.as_ref().ok_or_else(|| { - substrait_datafusion_err!("Map type must have value type") - })?; - let key_field = Arc::new(Field::new( - "key", - from_substrait_type(consumer, key_type, dfs_names, name_idx)?, - false, - )); - let value_field = Arc::new(Field::new( - "value", - from_substrait_type(consumer, value_type, dfs_names, name_idx)?, - true, - )); - Ok(DataType::Map( - Arc::new(Field::new_struct( - "entries", - [key_field, value_field], - false, // The inner map field is always non-nullable (Arrow #1697), - )), - false, // whether keys are sorted - )) - } - r#type::Kind::Decimal(d) => match d.type_variation_reference { - DECIMAL_128_TYPE_VARIATION_REF => { - Ok(DataType::Decimal128(d.precision as u8, d.scale as i8)) - } - DECIMAL_256_TYPE_VARIATION_REF => { - Ok(DataType::Decimal256(d.precision as u8, d.scale as i8)) - } - v => not_impl_err!( - "Unsupported Substrait type variation {v} of type {s_kind:?}" - ), - }, - r#type::Kind::IntervalYear(_) => { - Ok(DataType::Interval(IntervalUnit::YearMonth)) - } - r#type::Kind::IntervalDay(_) => Ok(DataType::Interval(IntervalUnit::DayTime)), - r#type::Kind::IntervalCompound(_) => { - Ok(DataType::Interval(IntervalUnit::MonthDayNano)) - } - r#type::Kind::UserDefined(u) => { - if let Ok(data_type) = consumer.consume_user_defined_type(u) { - return Ok(data_type); - } - - // TODO: remove the code below once the producer has been updated - if let Some(name) = consumer.get_extensions().types.get(&u.type_reference) - { - #[allow(deprecated)] - match name.as_ref() { - // Kept for backwards compatibility, producers should use IntervalCompound instead - INTERVAL_MONTH_DAY_NANO_TYPE_NAME => Ok(DataType::Interval(IntervalUnit::MonthDayNano)), - _ => not_impl_err!( - "Unsupported Substrait user defined type with ref {} and variation {}", - u.type_reference, - u.type_variation_reference - ), - } - } else { - #[allow(deprecated)] - match u.type_reference { - // Kept for backwards compatibility, producers should use IntervalYear instead - INTERVAL_YEAR_MONTH_TYPE_REF => { - Ok(DataType::Interval(IntervalUnit::YearMonth)) - } - // Kept for backwards compatibility, producers should use IntervalDay instead - INTERVAL_DAY_TIME_TYPE_REF => { - Ok(DataType::Interval(IntervalUnit::DayTime)) - } - // Kept for backwards compatibility, producers should use IntervalCompound instead - INTERVAL_MONTH_DAY_NANO_TYPE_REF => { - Ok(DataType::Interval(IntervalUnit::MonthDayNano)) - } - _ => not_impl_err!( - "Unsupported Substrait user defined type with ref {} and variation {}", - u.type_reference, - u.type_variation_reference - ), - } - } - } - r#type::Kind::Struct(s) => Ok(DataType::Struct(from_substrait_struct_type( - consumer, s, dfs_names, name_idx, - )?)), - r#type::Kind::Varchar(_) => Ok(DataType::Utf8), - r#type::Kind::FixedChar(_) => Ok(DataType::Utf8), - _ => not_impl_err!("Unsupported Substrait type: {s_kind:?}"), - }, - _ => not_impl_err!("`None` Substrait kind is not supported"), - } -} - -fn from_substrait_struct_type( - consumer: &impl SubstraitConsumer, - s: &r#type::Struct, - dfs_names: &[String], - name_idx: &mut usize, -) -> Result { - let mut fields = vec![]; - for (i, f) in s.types.iter().enumerate() { - let field = Field::new( - next_struct_field_name(i, dfs_names, name_idx)?, - from_substrait_type(consumer, f, dfs_names, name_idx)?, - true, // We assume everything to be nullable since that's easier than ensuring it matches - ); - fields.push(field); - } - Ok(fields.into()) -} - -fn next_struct_field_name( - column_idx: usize, - dfs_names: &[String], - name_idx: &mut usize, -) -> Result { - if dfs_names.is_empty() { - // If names are not given, create dummy names - // c0, c1, ... align with e.g. SqlToRel::create_named_struct - Ok(format!("c{column_idx}")) - } else { - let name = dfs_names.get(*name_idx).cloned().ok_or_else(|| { - substrait_datafusion_err!("Named schema must contain names for all fields") - })?; - *name_idx += 1; - Ok(name) - } -} - -/// Convert Substrait NamedStruct to DataFusion DFSchemaRef -pub fn from_substrait_named_struct( - consumer: &impl SubstraitConsumer, - base_schema: &NamedStruct, -) -> Result { - let mut name_idx = 0; - let fields = from_substrait_struct_type( - consumer, - base_schema.r#struct.as_ref().ok_or_else(|| { - substrait_datafusion_err!("Named struct must contain a struct") - })?, - &base_schema.names, - &mut name_idx, - ); - if name_idx != base_schema.names.len() { - return substrait_err!( - "Names list must match exactly to nested schema, but found {} uses for {} names", - name_idx, - base_schema.names.len() - ); - } - DFSchema::try_from(Schema::new(fields?)) -} - -fn from_substrait_bound( - bound: &Option, - is_lower: bool, -) -> Result { - match bound { - Some(b) => match &b.kind { - Some(k) => match k { - BoundKind::CurrentRow(SubstraitBound::CurrentRow {}) => { - Ok(WindowFrameBound::CurrentRow) - } - BoundKind::Preceding(SubstraitBound::Preceding { offset }) => { - if *offset <= 0 { - return plan_err!("Preceding bound must be positive"); - } - Ok(WindowFrameBound::Preceding(ScalarValue::UInt64(Some( - *offset as u64, - )))) - } - BoundKind::Following(SubstraitBound::Following { offset }) => { - if *offset <= 0 { - return plan_err!("Following bound must be positive"); - } - Ok(WindowFrameBound::Following(ScalarValue::UInt64(Some( - *offset as u64, - )))) - } - BoundKind::Unbounded(SubstraitBound::Unbounded {}) => { - if is_lower { - Ok(WindowFrameBound::Preceding(ScalarValue::Null)) - } else { - Ok(WindowFrameBound::Following(ScalarValue::Null)) - } - } - }, - None => substrait_err!("WindowFunction missing Substrait Bound kind"), - }, - None => { - if is_lower { - Ok(WindowFrameBound::Preceding(ScalarValue::Null)) - } else { - Ok(WindowFrameBound::Following(ScalarValue::Null)) - } - } - } -} - -pub(crate) fn from_substrait_literal_without_names( - consumer: &impl SubstraitConsumer, - lit: &Literal, -) -> Result { - from_substrait_literal(consumer, lit, &vec![], &mut 0) -} - -fn from_substrait_literal( - consumer: &impl SubstraitConsumer, - lit: &Literal, - dfs_names: &Vec, - name_idx: &mut usize, -) -> Result { - let scalar_value = match &lit.literal_type { - Some(LiteralType::Boolean(b)) => ScalarValue::Boolean(Some(*b)), - Some(LiteralType::I8(n)) => match lit.type_variation_reference { - DEFAULT_TYPE_VARIATION_REF => ScalarValue::Int8(Some(*n as i8)), - UNSIGNED_INTEGER_TYPE_VARIATION_REF => ScalarValue::UInt8(Some(*n as u8)), - others => { - return substrait_err!("Unknown type variation reference {others}"); - } - }, - Some(LiteralType::I16(n)) => match lit.type_variation_reference { - DEFAULT_TYPE_VARIATION_REF => ScalarValue::Int16(Some(*n as i16)), - UNSIGNED_INTEGER_TYPE_VARIATION_REF => ScalarValue::UInt16(Some(*n as u16)), - others => { - return substrait_err!("Unknown type variation reference {others}"); - } - }, - Some(LiteralType::I32(n)) => match lit.type_variation_reference { - DEFAULT_TYPE_VARIATION_REF => ScalarValue::Int32(Some(*n)), - UNSIGNED_INTEGER_TYPE_VARIATION_REF => ScalarValue::UInt32(Some(*n as u32)), - others => { - return substrait_err!("Unknown type variation reference {others}"); - } - }, - Some(LiteralType::I64(n)) => match lit.type_variation_reference { - DEFAULT_TYPE_VARIATION_REF => ScalarValue::Int64(Some(*n)), - UNSIGNED_INTEGER_TYPE_VARIATION_REF => ScalarValue::UInt64(Some(*n as u64)), - others => { - return substrait_err!("Unknown type variation reference {others}"); - } - }, - Some(LiteralType::Fp32(f)) => ScalarValue::Float32(Some(*f)), - Some(LiteralType::Fp64(f)) => ScalarValue::Float64(Some(*f)), - Some(LiteralType::Timestamp(t)) => { - // Kept for backwards compatibility, new plans should use PrecisionTimestamp(Tz) instead - #[allow(deprecated)] - match lit.type_variation_reference { - TIMESTAMP_SECOND_TYPE_VARIATION_REF => { - ScalarValue::TimestampSecond(Some(*t), None) - } - TIMESTAMP_MILLI_TYPE_VARIATION_REF => { - ScalarValue::TimestampMillisecond(Some(*t), None) - } - TIMESTAMP_MICRO_TYPE_VARIATION_REF => { - ScalarValue::TimestampMicrosecond(Some(*t), None) - } - TIMESTAMP_NANO_TYPE_VARIATION_REF => { - ScalarValue::TimestampNanosecond(Some(*t), None) - } - others => { - return substrait_err!("Unknown type variation reference {others}"); - } - } - } - Some(LiteralType::PrecisionTimestamp(pt)) => match pt.precision { - 0 => ScalarValue::TimestampSecond(Some(pt.value), None), - 3 => ScalarValue::TimestampMillisecond(Some(pt.value), None), - 6 => ScalarValue::TimestampMicrosecond(Some(pt.value), None), - 9 => ScalarValue::TimestampNanosecond(Some(pt.value), None), - p => { - return not_impl_err!( - "Unsupported Substrait precision {p} for PrecisionTimestamp" - ); - } - }, - Some(LiteralType::PrecisionTimestampTz(pt)) => match pt.precision { - 0 => ScalarValue::TimestampSecond( - Some(pt.value), - Some(DEFAULT_TIMEZONE.into()), - ), - 3 => ScalarValue::TimestampMillisecond( - Some(pt.value), - Some(DEFAULT_TIMEZONE.into()), - ), - 6 => ScalarValue::TimestampMicrosecond( - Some(pt.value), - Some(DEFAULT_TIMEZONE.into()), - ), - 9 => ScalarValue::TimestampNanosecond( - Some(pt.value), - Some(DEFAULT_TIMEZONE.into()), - ), - p => { - return not_impl_err!( - "Unsupported Substrait precision {p} for PrecisionTimestamp" - ); - } - }, - Some(LiteralType::Date(d)) => ScalarValue::Date32(Some(*d)), - Some(LiteralType::String(s)) => match lit.type_variation_reference { - DEFAULT_CONTAINER_TYPE_VARIATION_REF => ScalarValue::Utf8(Some(s.clone())), - LARGE_CONTAINER_TYPE_VARIATION_REF => ScalarValue::LargeUtf8(Some(s.clone())), - VIEW_CONTAINER_TYPE_VARIATION_REF => ScalarValue::Utf8View(Some(s.clone())), - others => { - return substrait_err!("Unknown type variation reference {others}"); - } - }, - Some(LiteralType::Binary(b)) => match lit.type_variation_reference { - DEFAULT_CONTAINER_TYPE_VARIATION_REF => ScalarValue::Binary(Some(b.clone())), - LARGE_CONTAINER_TYPE_VARIATION_REF => { - ScalarValue::LargeBinary(Some(b.clone())) - } - VIEW_CONTAINER_TYPE_VARIATION_REF => ScalarValue::BinaryView(Some(b.clone())), - others => { - return substrait_err!("Unknown type variation reference {others}"); - } - }, - Some(LiteralType::FixedBinary(b)) => { - ScalarValue::FixedSizeBinary(b.len() as _, Some(b.clone())) - } - Some(LiteralType::Decimal(d)) => { - let value: [u8; 16] = d - .value - .clone() - .try_into() - .or(substrait_err!("Failed to parse decimal value"))?; - let p = d.precision.try_into().map_err(|e| { - substrait_datafusion_err!("Failed to parse decimal precision: {e}") - })?; - let s = d.scale.try_into().map_err(|e| { - substrait_datafusion_err!("Failed to parse decimal scale: {e}") - })?; - ScalarValue::Decimal128(Some(i128::from_le_bytes(value)), p, s) - } - Some(LiteralType::List(l)) => { - // Each element should start the name index from the same value, then we increase it - // once at the end - let mut element_name_idx = *name_idx; - let elements = l - .values - .iter() - .map(|el| { - element_name_idx = *name_idx; - from_substrait_literal(consumer, el, dfs_names, &mut element_name_idx) - }) - .collect::>>()?; - *name_idx = element_name_idx; - if elements.is_empty() { - return substrait_err!( - "Empty list must be encoded as EmptyList literal type, not List" - ); - } - let element_type = elements[0].data_type(); - match lit.type_variation_reference { - DEFAULT_CONTAINER_TYPE_VARIATION_REF => ScalarValue::List( - ScalarValue::new_list_nullable(elements.as_slice(), &element_type), - ), - LARGE_CONTAINER_TYPE_VARIATION_REF => ScalarValue::LargeList( - ScalarValue::new_large_list(elements.as_slice(), &element_type), - ), - others => { - return substrait_err!("Unknown type variation reference {others}"); - } - } - } - Some(LiteralType::EmptyList(l)) => { - let element_type = from_substrait_type( - consumer, - l.r#type.clone().unwrap().as_ref(), - dfs_names, - name_idx, - )?; - match lit.type_variation_reference { - DEFAULT_CONTAINER_TYPE_VARIATION_REF => { - ScalarValue::List(ScalarValue::new_list_nullable(&[], &element_type)) - } - LARGE_CONTAINER_TYPE_VARIATION_REF => ScalarValue::LargeList( - ScalarValue::new_large_list(&[], &element_type), - ), - others => { - return substrait_err!("Unknown type variation reference {others}"); - } - } - } - Some(LiteralType::Map(m)) => { - // Each entry should start the name index from the same value, then we increase it - // once at the end - let mut entry_name_idx = *name_idx; - let entries = m - .key_values - .iter() - .map(|kv| { - entry_name_idx = *name_idx; - let key_sv = from_substrait_literal( - consumer, - kv.key.as_ref().unwrap(), - dfs_names, - &mut entry_name_idx, - )?; - let value_sv = from_substrait_literal( - consumer, - kv.value.as_ref().unwrap(), - dfs_names, - &mut entry_name_idx, - )?; - ScalarStructBuilder::new() - .with_scalar(Field::new("key", key_sv.data_type(), false), key_sv) - .with_scalar( - Field::new("value", value_sv.data_type(), true), - value_sv, - ) - .build() - }) - .collect::>>()?; - *name_idx = entry_name_idx; - - if entries.is_empty() { - return substrait_err!( - "Empty map must be encoded as EmptyMap literal type, not Map" - ); - } - - ScalarValue::Map(Arc::new(MapArray::new( - Arc::new(Field::new("entries", entries[0].data_type(), false)), - OffsetBuffer::new(vec![0, entries.len() as i32].into()), - ScalarValue::iter_to_array(entries)?.as_struct().to_owned(), - None, - false, - ))) - } - Some(LiteralType::EmptyMap(m)) => { - let key = match &m.key { - Some(k) => Ok(k), - _ => plan_err!("Missing key type for empty map"), - }?; - let value = match &m.value { - Some(v) => Ok(v), - _ => plan_err!("Missing value type for empty map"), - }?; - let key_type = from_substrait_type(consumer, key, dfs_names, name_idx)?; - let value_type = from_substrait_type(consumer, value, dfs_names, name_idx)?; - - // new_empty_array on a MapType creates a too empty array - // We want it to contain an empty struct array to align with an empty MapBuilder one - let entries = Field::new_struct( - "entries", - vec![ - Field::new("key", key_type, false), - Field::new("value", value_type, true), - ], - false, - ); - let struct_array = - new_empty_array(entries.data_type()).as_struct().to_owned(); - ScalarValue::Map(Arc::new(MapArray::new( - Arc::new(entries), - OffsetBuffer::new(vec![0, 0].into()), - struct_array, - None, - false, - ))) - } - Some(LiteralType::Struct(s)) => { - let mut builder = ScalarStructBuilder::new(); - for (i, field) in s.fields.iter().enumerate() { - let name = next_struct_field_name(i, dfs_names, name_idx)?; - let sv = from_substrait_literal(consumer, field, dfs_names, name_idx)?; - // We assume everything to be nullable, since Arrow's strict about things matching - // and it's hard to match otherwise. - builder = builder.with_scalar(Field::new(name, sv.data_type(), true), sv); - } - builder.build()? - } - Some(LiteralType::Null(null_type)) => { - let data_type = - from_substrait_type(consumer, null_type, dfs_names, name_idx)?; - ScalarValue::try_from(&data_type)? - } - Some(LiteralType::IntervalDayToSecond(IntervalDayToSecond { - days, - seconds, - subseconds, - precision_mode, - })) => { - use interval_day_to_second::PrecisionMode; - // DF only supports millisecond precision, so for any more granular type we lose precision - let milliseconds = match precision_mode { - Some(PrecisionMode::Microseconds(ms)) => ms / 1000, - None => - if *subseconds != 0 { - return substrait_err!("Cannot set subseconds field of IntervalDayToSecond without setting precision"); - } else { - 0_i32 - } - Some(PrecisionMode::Precision(0)) => *subseconds as i32 * 1000, - Some(PrecisionMode::Precision(3)) => *subseconds as i32, - Some(PrecisionMode::Precision(6)) => (subseconds / 1000) as i32, - Some(PrecisionMode::Precision(9)) => (subseconds / 1000 / 1000) as i32, - _ => { - return not_impl_err!( - "Unsupported Substrait interval day to second precision mode: {precision_mode:?}") - } - }; - - ScalarValue::new_interval_dt(*days, (seconds * 1000) + milliseconds) - } - Some(LiteralType::IntervalYearToMonth(IntervalYearToMonth { years, months })) => { - ScalarValue::new_interval_ym(*years, *months) - } - Some(LiteralType::IntervalCompound(IntervalCompound { - interval_year_to_month, - interval_day_to_second, - })) => match (interval_year_to_month, interval_day_to_second) { - ( - Some(IntervalYearToMonth { years, months }), - Some(IntervalDayToSecond { - days, - seconds, - subseconds, - precision_mode: - Some(interval_day_to_second::PrecisionMode::Precision(p)), - }), - ) => { - if *p < 0 || *p > 9 { - return plan_err!( - "Unsupported Substrait interval day to second precision: {}", - p - ); - } - let nanos = *subseconds * i64::pow(10, (9 - p) as u32); - ScalarValue::new_interval_mdn( - *years * 12 + months, - *days, - *seconds as i64 * NANOSECONDS + nanos, - ) - } - _ => return plan_err!("Substrait compound interval missing components"), - }, - Some(LiteralType::FixedChar(c)) => ScalarValue::Utf8(Some(c.clone())), - Some(LiteralType::UserDefined(user_defined)) => { - if let Ok(value) = consumer.consume_user_defined_literal(user_defined) { - return Ok(value); - } - - // TODO: remove the code below once the producer has been updated - - // Helper function to prevent duplicating this code - can be inlined once the non-extension path is removed - let interval_month_day_nano = - |user_defined: &proto::expression::literal::UserDefined| -> Result { - let Some(Val::Value(raw_val)) = user_defined.val.as_ref() else { - return substrait_err!("Interval month day nano value is empty"); - }; - let value_slice: [u8; 16] = - (*raw_val.value).try_into().map_err(|_| { - substrait_datafusion_err!( - "Failed to parse interval month day nano value" - ) - })?; - let months = - i32::from_le_bytes(value_slice[0..4].try_into().unwrap()); - let days = i32::from_le_bytes(value_slice[4..8].try_into().unwrap()); - let nanoseconds = - i64::from_le_bytes(value_slice[8..16].try_into().unwrap()); - Ok(ScalarValue::IntervalMonthDayNano(Some( - IntervalMonthDayNano { - months, - days, - nanoseconds, - }, - ))) - }; - - if let Some(name) = consumer - .get_extensions() - .types - .get(&user_defined.type_reference) - { - match name.as_ref() { - // Kept for backwards compatibility - producers should use IntervalCompound instead - #[allow(deprecated)] - INTERVAL_MONTH_DAY_NANO_TYPE_NAME => { - interval_month_day_nano(user_defined)? - } - _ => { - return not_impl_err!( - "Unsupported Substrait user defined type with ref {} and name {}", - user_defined.type_reference, - name - ) - } - } - } else { - #[allow(deprecated)] - match user_defined.type_reference { - // Kept for backwards compatibility, producers should useIntervalYearToMonth instead - INTERVAL_YEAR_MONTH_TYPE_REF => { - let Some(Val::Value(raw_val)) = user_defined.val.as_ref() else { - return substrait_err!("Interval year month value is empty"); - }; - let value_slice: [u8; 4] = - (*raw_val.value).try_into().map_err(|_| { - substrait_datafusion_err!( - "Failed to parse interval year month value" - ) - })?; - ScalarValue::IntervalYearMonth(Some(i32::from_le_bytes( - value_slice, - ))) - } - // Kept for backwards compatibility, producers should useIntervalDayToSecond instead - INTERVAL_DAY_TIME_TYPE_REF => { - let Some(Val::Value(raw_val)) = user_defined.val.as_ref() else { - return substrait_err!("Interval day time value is empty"); - }; - let value_slice: [u8; 8] = - (*raw_val.value).try_into().map_err(|_| { - substrait_datafusion_err!( - "Failed to parse interval day time value" - ) - })?; - let days = - i32::from_le_bytes(value_slice[0..4].try_into().unwrap()); - let milliseconds = - i32::from_le_bytes(value_slice[4..8].try_into().unwrap()); - ScalarValue::IntervalDayTime(Some(IntervalDayTime { - days, - milliseconds, - })) - } - // Kept for backwards compatibility, producers should useIntervalCompound instead - INTERVAL_MONTH_DAY_NANO_TYPE_REF => { - interval_month_day_nano(user_defined)? - } - _ => { - return not_impl_err!( - "Unsupported Substrait user defined type literal with ref {}", - user_defined.type_reference - ) - } - } - } - } - _ => return not_impl_err!("Unsupported literal_type: {:?}", lit.literal_type), - }; - - Ok(scalar_value) -} - -#[allow(deprecated)] -async fn from_substrait_grouping( - consumer: &impl SubstraitConsumer, - grouping: &Grouping, - expressions: &[Expr], - input_schema: &DFSchemaRef, -) -> Result> { - let mut group_exprs = vec![]; - if !grouping.grouping_expressions.is_empty() { - for e in &grouping.grouping_expressions { - let expr = consumer.consume_expression(e, input_schema).await?; - group_exprs.push(expr); - } - return Ok(group_exprs); - } - for idx in &grouping.expression_references { - let e = &expressions[*idx as usize]; - group_exprs.push(e.clone()); - } - Ok(group_exprs) -} - -fn from_substrait_field_reference( - field_ref: &FieldReference, - input_schema: &DFSchema, -) -> Result { - match &field_ref.reference_type { - Some(DirectReference(direct)) => match &direct.reference_type.as_ref() { - Some(StructField(x)) => match &x.child.as_ref() { - Some(_) => not_impl_err!( - "Direct reference StructField with child is not supported" - ), - None => Ok(Expr::Column(Column::from( - input_schema.qualified_field(x.field as usize), - ))), - }, - _ => not_impl_err!( - "Direct reference with types other than StructField is not supported" - ), - }, - _ => not_impl_err!("unsupported field ref type"), - } -} - -/// Build [`Expr`] from its name and required inputs. -struct BuiltinExprBuilder { - expr_name: String, -} - -impl BuiltinExprBuilder { - pub fn try_from_name(name: &str) -> Option { - match name { - "not" | "like" | "ilike" | "is_null" | "is_not_null" | "is_true" - | "is_false" | "is_not_true" | "is_not_false" | "is_unknown" - | "is_not_unknown" | "negative" | "negate" => Some(Self { - expr_name: name.to_string(), - }), - _ => None, - } - } - - pub async fn build( - self, - consumer: &impl SubstraitConsumer, - f: &ScalarFunction, - input_schema: &DFSchema, - ) -> Result { - match self.expr_name.as_str() { - "like" => Self::build_like_expr(consumer, false, f, input_schema).await, - "ilike" => Self::build_like_expr(consumer, true, f, input_schema).await, - "not" | "negative" | "negate" | "is_null" | "is_not_null" | "is_true" - | "is_false" | "is_not_true" | "is_not_false" | "is_unknown" - | "is_not_unknown" => { - Self::build_unary_expr(consumer, &self.expr_name, f, input_schema).await - } - _ => { - not_impl_err!("Unsupported builtin expression: {}", self.expr_name) - } - } - } - - async fn build_unary_expr( - consumer: &impl SubstraitConsumer, - fn_name: &str, - f: &ScalarFunction, - input_schema: &DFSchema, - ) -> Result { - if f.arguments.len() != 1 { - return substrait_err!("Expect one argument for {fn_name} expr"); - } - let Some(ArgType::Value(expr_substrait)) = &f.arguments[0].arg_type else { - return substrait_err!("Invalid arguments type for {fn_name} expr"); - }; - let arg = consumer - .consume_expression(expr_substrait, input_schema) - .await?; - let arg = Box::new(arg); - - let expr = match fn_name { - "not" => Expr::Not(arg), - "negative" | "negate" => Expr::Negative(arg), - "is_null" => Expr::IsNull(arg), - "is_not_null" => Expr::IsNotNull(arg), - "is_true" => Expr::IsTrue(arg), - "is_false" => Expr::IsFalse(arg), - "is_not_true" => Expr::IsNotTrue(arg), - "is_not_false" => Expr::IsNotFalse(arg), - "is_unknown" => Expr::IsUnknown(arg), - "is_not_unknown" => Expr::IsNotUnknown(arg), - _ => return not_impl_err!("Unsupported builtin expression: {}", fn_name), - }; - - Ok(expr) - } - - async fn build_like_expr( - consumer: &impl SubstraitConsumer, - case_insensitive: bool, - f: &ScalarFunction, - input_schema: &DFSchema, - ) -> Result { - let fn_name = if case_insensitive { "ILIKE" } else { "LIKE" }; - if f.arguments.len() != 2 && f.arguments.len() != 3 { - return substrait_err!("Expect two or three arguments for `{fn_name}` expr"); - } - - let Some(ArgType::Value(expr_substrait)) = &f.arguments[0].arg_type else { - return substrait_err!("Invalid arguments type for `{fn_name}` expr"); - }; - let expr = consumer - .consume_expression(expr_substrait, input_schema) - .await?; - let Some(ArgType::Value(pattern_substrait)) = &f.arguments[1].arg_type else { - return substrait_err!("Invalid arguments type for `{fn_name}` expr"); - }; - let pattern = consumer - .consume_expression(pattern_substrait, input_schema) - .await?; - - // Default case: escape character is Literal(Utf8(None)) - let escape_char = if f.arguments.len() == 3 { - let Some(ArgType::Value(escape_char_substrait)) = &f.arguments[2].arg_type - else { - return substrait_err!("Invalid arguments type for `{fn_name}` expr"); - }; - - let escape_char_expr = consumer - .consume_expression(escape_char_substrait, input_schema) - .await?; - - match escape_char_expr { - Expr::Literal(ScalarValue::Utf8(escape_char_string)) => { - // Convert Option to Option - escape_char_string.and_then(|s| s.chars().next()) - } - _ => { - return substrait_err!( - "Expect Utf8 literal for escape char, but found {escape_char_expr:?}" - ) - } - } - } else { - None - }; - - Ok(Expr::Like(Like { - negated: false, - expr: Box::new(expr), - pattern: Box::new(pattern), - escape_char, - case_insensitive, - })) - } -} - -#[cfg(test)] -mod test { - use crate::extensions::Extensions; - use crate::logical_plan::consumer::{ - from_substrait_literal_without_names, from_substrait_rex, - DefaultSubstraitConsumer, - }; - use arrow::array::types::IntervalMonthDayNano; - use datafusion::arrow; - use datafusion::common::DFSchema; - use datafusion::error::Result; - use datafusion::execution::SessionState; - use datafusion::prelude::{Expr, SessionContext}; - use datafusion::scalar::ScalarValue; - use std::sync::LazyLock; - use substrait::proto::expression::literal::{ - interval_day_to_second, IntervalCompound, IntervalDayToSecond, - IntervalYearToMonth, LiteralType, - }; - use substrait::proto::expression::window_function::BoundsType; - use substrait::proto::expression::Literal; - - static TEST_SESSION_STATE: LazyLock = - LazyLock::new(|| SessionContext::default().state()); - static TEST_EXTENSIONS: LazyLock = LazyLock::new(Extensions::default); - fn test_consumer() -> DefaultSubstraitConsumer<'static> { - let extensions = &TEST_EXTENSIONS; - let state = &TEST_SESSION_STATE; - DefaultSubstraitConsumer::new(extensions, state) - } - - #[test] - fn interval_compound_different_precision() -> Result<()> { - // DF producer (and thus roundtrip) always uses precision = 9, - // this test exists to test with some other value. - let substrait = Literal { - nullable: false, - type_variation_reference: 0, - literal_type: Some(LiteralType::IntervalCompound(IntervalCompound { - interval_year_to_month: Some(IntervalYearToMonth { - years: 1, - months: 2, - }), - interval_day_to_second: Some(IntervalDayToSecond { - days: 3, - seconds: 4, - subseconds: 5, - precision_mode: Some( - interval_day_to_second::PrecisionMode::Precision(6), - ), - }), - })), - }; - - let consumer = test_consumer(); - assert_eq!( - from_substrait_literal_without_names(&consumer, &substrait)?, - ScalarValue::IntervalMonthDayNano(Some(IntervalMonthDayNano { - months: 14, - days: 3, - nanoseconds: 4_000_005_000 - })) - ); - - Ok(()) - } - - #[tokio::test] - async fn window_function_with_range_unit_and_no_order_by() -> Result<()> { - let substrait = substrait::proto::Expression { - rex_type: Some(substrait::proto::expression::RexType::WindowFunction( - substrait::proto::expression::WindowFunction { - function_reference: 0, - bounds_type: BoundsType::Range as i32, - sorts: vec![], - ..Default::default() - }, - )), - }; - - let mut consumer = test_consumer(); - - // Just registering a single function (index 0) so that the plan - // does not throw a "function not found" error. - let mut extensions = Extensions::default(); - extensions.register_function("count".to_string()); - consumer.extensions = &extensions; - - match from_substrait_rex(&consumer, &substrait, &DFSchema::empty()).await? { - Expr::WindowFunction(window_function) => { - assert_eq!(window_function.params.order_by.len(), 1) - } - _ => panic!("expr was not a WindowFunction"), - }; - - Ok(()) - } - - #[tokio::test] - async fn window_function_with_count() -> Result<()> { - let substrait = substrait::proto::Expression { - rex_type: Some(substrait::proto::expression::RexType::WindowFunction( - substrait::proto::expression::WindowFunction { - function_reference: 0, - ..Default::default() - }, - )), - }; - - let mut consumer = test_consumer(); - - let mut extensions = Extensions::default(); - extensions.register_function("count".to_string()); - consumer.extensions = &extensions; - - match from_substrait_rex(&consumer, &substrait, &DFSchema::empty()).await? { - Expr::WindowFunction(window_function) => { - assert_eq!(window_function.params.args.len(), 1) - } - _ => panic!("expr was not a WindowFunction"), - }; - - Ok(()) - } -} diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/aggregate_function.rs b/datafusion/substrait/src/logical_plan/consumer/expr/aggregate_function.rs new file mode 100644 index 000000000000..114fe1e7aecd --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/expr/aggregate_function.rs @@ -0,0 +1,71 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::consumer::{ + from_substrait_func_args, substrait_fun_name, SubstraitConsumer, +}; +use datafusion::common::{not_impl_datafusion_err, plan_err, DFSchema, ScalarValue}; +use datafusion::execution::FunctionRegistry; +use datafusion::logical_expr::{expr, Expr, SortExpr}; +use std::sync::Arc; +use substrait::proto::AggregateFunction; + +/// Convert Substrait AggregateFunction to DataFusion Expr +pub async fn from_substrait_agg_func( + consumer: &impl SubstraitConsumer, + f: &AggregateFunction, + input_schema: &DFSchema, + filter: Option>, + order_by: Option>, + distinct: bool, +) -> datafusion::common::Result> { + let Some(fn_signature) = consumer + .get_extensions() + .functions + .get(&f.function_reference) + else { + return plan_err!( + "Aggregate function not registered: function anchor = {:?}", + f.function_reference + ); + }; + + let fn_name = substrait_fun_name(fn_signature); + let udaf = consumer.get_function_registry().udaf(fn_name); + let udaf = udaf.map_err(|_| { + not_impl_datafusion_err!( + "Aggregate function {} is not supported: function anchor = {:?}", + fn_signature, + f.function_reference + ) + })?; + + let args = from_substrait_func_args(consumer, &f.arguments, input_schema).await?; + + // Datafusion does not support aggregate functions with no arguments, so + // we inject a dummy argument that does not affect the query, but allows + // us to bypass this limitation. + let args = if udaf.name() == "count" && args.is_empty() { + vec![Expr::Literal(ScalarValue::Int64(Some(1)), None)] + } else { + args + }; + + Ok(Arc::new(Expr::AggregateFunction( + expr::AggregateFunction::new_udf(udaf, args, distinct, filter, order_by, None), + ))) +} diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/cast.rs b/datafusion/substrait/src/logical_plan/consumer/expr/cast.rs new file mode 100644 index 000000000000..5e8d3d93065f --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/expr/cast.rs @@ -0,0 +1,49 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::consumer::types::from_substrait_type_without_names; +use crate::logical_plan::consumer::SubstraitConsumer; +use datafusion::common::{substrait_err, DFSchema}; +use datafusion::logical_expr::{Cast, Expr, TryCast}; +use substrait::proto::expression as substrait_expression; +use substrait::proto::expression::cast::FailureBehavior::ReturnNull; + +pub async fn from_cast( + consumer: &impl SubstraitConsumer, + cast: &substrait_expression::Cast, + input_schema: &DFSchema, +) -> datafusion::common::Result { + match cast.r#type.as_ref() { + Some(output_type) => { + let input_expr = Box::new( + consumer + .consume_expression( + cast.input.as_ref().unwrap().as_ref(), + input_schema, + ) + .await?, + ); + let data_type = from_substrait_type_without_names(consumer, output_type)?; + if cast.failure_behavior() == ReturnNull { + Ok(Expr::TryCast(TryCast::new(input_expr, data_type))) + } else { + Ok(Expr::Cast(Cast::new(input_expr, data_type))) + } + } + None => substrait_err!("Cast expression without output type is not allowed"), + } +} diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/field_reference.rs b/datafusion/substrait/src/logical_plan/consumer/expr/field_reference.rs new file mode 100644 index 000000000000..90b5b6418149 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/expr/field_reference.rs @@ -0,0 +1,53 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::consumer::SubstraitConsumer; +use datafusion::common::{not_impl_err, Column, DFSchema}; +use datafusion::logical_expr::Expr; +use substrait::proto::expression::field_reference::ReferenceType::DirectReference; +use substrait::proto::expression::reference_segment::ReferenceType::StructField; +use substrait::proto::expression::FieldReference; + +pub async fn from_field_reference( + _consumer: &impl SubstraitConsumer, + field_ref: &FieldReference, + input_schema: &DFSchema, +) -> datafusion::common::Result { + from_substrait_field_reference(field_ref, input_schema) +} + +pub(crate) fn from_substrait_field_reference( + field_ref: &FieldReference, + input_schema: &DFSchema, +) -> datafusion::common::Result { + match &field_ref.reference_type { + Some(DirectReference(direct)) => match &direct.reference_type.as_ref() { + Some(StructField(x)) => match &x.child.as_ref() { + Some(_) => not_impl_err!( + "Direct reference StructField with child is not supported" + ), + None => Ok(Expr::Column(Column::from( + input_schema.qualified_field(x.field as usize), + ))), + }, + _ => not_impl_err!( + "Direct reference with types other than StructField is not supported" + ), + }, + _ => not_impl_err!("unsupported field ref type"), + } +} diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/function_arguments.rs b/datafusion/substrait/src/logical_plan/consumer/expr/function_arguments.rs new file mode 100644 index 000000000000..0b610b61b1de --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/expr/function_arguments.rs @@ -0,0 +1,39 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::consumer::SubstraitConsumer; +use datafusion::common::{not_impl_err, DFSchema}; +use datafusion::logical_expr::Expr; +use substrait::proto::function_argument::ArgType; +use substrait::proto::FunctionArgument; + +/// Convert Substrait FunctionArguments to DataFusion Exprs +pub async fn from_substrait_func_args( + consumer: &impl SubstraitConsumer, + arguments: &Vec, + input_schema: &DFSchema, +) -> datafusion::common::Result> { + let mut args: Vec = vec![]; + for arg in arguments { + let arg_expr = match &arg.arg_type { + Some(ArgType::Value(e)) => consumer.consume_expression(e, input_schema).await, + _ => not_impl_err!("Function argument non-Value type not supported"), + }; + args.push(arg_expr?); + } + Ok(args) +} diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/if_then.rs b/datafusion/substrait/src/logical_plan/consumer/expr/if_then.rs new file mode 100644 index 000000000000..c4cc6c2fcd24 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/expr/if_then.rs @@ -0,0 +1,69 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::consumer::SubstraitConsumer; +use datafusion::common::DFSchema; +use datafusion::logical_expr::{Case, Expr}; +use substrait::proto::expression::IfThen; + +pub async fn from_if_then( + consumer: &impl SubstraitConsumer, + if_then: &IfThen, + input_schema: &DFSchema, +) -> datafusion::common::Result { + // Parse `ifs` + // If the first element does not have a `then` part, then we can assume it's a base expression + let mut when_then_expr: Vec<(Box, Box)> = vec![]; + let mut expr = None; + for (i, if_expr) in if_then.ifs.iter().enumerate() { + if i == 0 { + // Check if the first element is type base expression + if if_expr.then.is_none() { + expr = Some(Box::new( + consumer + .consume_expression(if_expr.r#if.as_ref().unwrap(), input_schema) + .await?, + )); + continue; + } + } + when_then_expr.push(( + Box::new( + consumer + .consume_expression(if_expr.r#if.as_ref().unwrap(), input_schema) + .await?, + ), + Box::new( + consumer + .consume_expression(if_expr.then.as_ref().unwrap(), input_schema) + .await?, + ), + )); + } + // Parse `else` + let else_expr = match &if_then.r#else { + Some(e) => Some(Box::new( + consumer.consume_expression(e, input_schema).await?, + )), + None => None, + }; + Ok(Expr::Case(Case { + expr, + when_then_expr, + else_expr, + })) +} diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/literal.rs b/datafusion/substrait/src/logical_plan/consumer/expr/literal.rs new file mode 100644 index 000000000000..d054e5267554 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/expr/literal.rs @@ -0,0 +1,547 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::consumer::types::from_substrait_type; +use crate::logical_plan::consumer::utils::{next_struct_field_name, DEFAULT_TIMEZONE}; +use crate::logical_plan::consumer::SubstraitConsumer; +#[allow(deprecated)] +use crate::variation_const::{ + DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_TYPE_VARIATION_REF, + INTERVAL_DAY_TIME_TYPE_REF, INTERVAL_MONTH_DAY_NANO_TYPE_NAME, + INTERVAL_MONTH_DAY_NANO_TYPE_REF, INTERVAL_YEAR_MONTH_TYPE_REF, + LARGE_CONTAINER_TYPE_VARIATION_REF, TIMESTAMP_MICRO_TYPE_VARIATION_REF, + TIMESTAMP_MILLI_TYPE_VARIATION_REF, TIMESTAMP_NANO_TYPE_VARIATION_REF, + TIMESTAMP_SECOND_TYPE_VARIATION_REF, UNSIGNED_INTEGER_TYPE_VARIATION_REF, + VIEW_CONTAINER_TYPE_VARIATION_REF, +}; +use datafusion::arrow::array::{new_empty_array, AsArray, MapArray}; +use datafusion::arrow::buffer::OffsetBuffer; +use datafusion::arrow::datatypes::{Field, IntervalDayTime, IntervalMonthDayNano}; +use datafusion::arrow::temporal_conversions::NANOSECONDS; +use datafusion::common::scalar::ScalarStructBuilder; +use datafusion::common::{ + not_impl_err, plan_err, substrait_datafusion_err, substrait_err, ScalarValue, +}; +use datafusion::logical_expr::Expr; +use std::sync::Arc; +use substrait::proto; +use substrait::proto::expression::literal::user_defined::Val; +use substrait::proto::expression::literal::{ + interval_day_to_second, IntervalCompound, IntervalDayToSecond, IntervalYearToMonth, + LiteralType, +}; +use substrait::proto::expression::Literal; + +pub async fn from_literal( + consumer: &impl SubstraitConsumer, + expr: &Literal, +) -> datafusion::common::Result { + let scalar_value = from_substrait_literal_without_names(consumer, expr)?; + Ok(Expr::Literal(scalar_value, None)) +} + +pub(crate) fn from_substrait_literal_without_names( + consumer: &impl SubstraitConsumer, + lit: &Literal, +) -> datafusion::common::Result { + from_substrait_literal(consumer, lit, &vec![], &mut 0) +} + +pub(crate) fn from_substrait_literal( + consumer: &impl SubstraitConsumer, + lit: &Literal, + dfs_names: &Vec, + name_idx: &mut usize, +) -> datafusion::common::Result { + let scalar_value = match &lit.literal_type { + Some(LiteralType::Boolean(b)) => ScalarValue::Boolean(Some(*b)), + Some(LiteralType::I8(n)) => match lit.type_variation_reference { + DEFAULT_TYPE_VARIATION_REF => ScalarValue::Int8(Some(*n as i8)), + UNSIGNED_INTEGER_TYPE_VARIATION_REF => ScalarValue::UInt8(Some(*n as u8)), + others => { + return substrait_err!("Unknown type variation reference {others}"); + } + }, + Some(LiteralType::I16(n)) => match lit.type_variation_reference { + DEFAULT_TYPE_VARIATION_REF => ScalarValue::Int16(Some(*n as i16)), + UNSIGNED_INTEGER_TYPE_VARIATION_REF => ScalarValue::UInt16(Some(*n as u16)), + others => { + return substrait_err!("Unknown type variation reference {others}"); + } + }, + Some(LiteralType::I32(n)) => match lit.type_variation_reference { + DEFAULT_TYPE_VARIATION_REF => ScalarValue::Int32(Some(*n)), + UNSIGNED_INTEGER_TYPE_VARIATION_REF => ScalarValue::UInt32(Some(*n as u32)), + others => { + return substrait_err!("Unknown type variation reference {others}"); + } + }, + Some(LiteralType::I64(n)) => match lit.type_variation_reference { + DEFAULT_TYPE_VARIATION_REF => ScalarValue::Int64(Some(*n)), + UNSIGNED_INTEGER_TYPE_VARIATION_REF => ScalarValue::UInt64(Some(*n as u64)), + others => { + return substrait_err!("Unknown type variation reference {others}"); + } + }, + Some(LiteralType::Fp32(f)) => ScalarValue::Float32(Some(*f)), + Some(LiteralType::Fp64(f)) => ScalarValue::Float64(Some(*f)), + Some(LiteralType::Timestamp(t)) => { + // Kept for backwards compatibility, new plans should use PrecisionTimestamp(Tz) instead + #[allow(deprecated)] + match lit.type_variation_reference { + TIMESTAMP_SECOND_TYPE_VARIATION_REF => { + ScalarValue::TimestampSecond(Some(*t), None) + } + TIMESTAMP_MILLI_TYPE_VARIATION_REF => { + ScalarValue::TimestampMillisecond(Some(*t), None) + } + TIMESTAMP_MICRO_TYPE_VARIATION_REF => { + ScalarValue::TimestampMicrosecond(Some(*t), None) + } + TIMESTAMP_NANO_TYPE_VARIATION_REF => { + ScalarValue::TimestampNanosecond(Some(*t), None) + } + others => { + return substrait_err!("Unknown type variation reference {others}"); + } + } + } + Some(LiteralType::PrecisionTimestamp(pt)) => match pt.precision { + 0 => ScalarValue::TimestampSecond(Some(pt.value), None), + 3 => ScalarValue::TimestampMillisecond(Some(pt.value), None), + 6 => ScalarValue::TimestampMicrosecond(Some(pt.value), None), + 9 => ScalarValue::TimestampNanosecond(Some(pt.value), None), + p => { + return not_impl_err!( + "Unsupported Substrait precision {p} for PrecisionTimestamp" + ); + } + }, + Some(LiteralType::PrecisionTimestampTz(pt)) => match pt.precision { + 0 => ScalarValue::TimestampSecond( + Some(pt.value), + Some(DEFAULT_TIMEZONE.into()), + ), + 3 => ScalarValue::TimestampMillisecond( + Some(pt.value), + Some(DEFAULT_TIMEZONE.into()), + ), + 6 => ScalarValue::TimestampMicrosecond( + Some(pt.value), + Some(DEFAULT_TIMEZONE.into()), + ), + 9 => ScalarValue::TimestampNanosecond( + Some(pt.value), + Some(DEFAULT_TIMEZONE.into()), + ), + p => { + return not_impl_err!( + "Unsupported Substrait precision {p} for PrecisionTimestamp" + ); + } + }, + Some(LiteralType::Date(d)) => ScalarValue::Date32(Some(*d)), + Some(LiteralType::String(s)) => match lit.type_variation_reference { + DEFAULT_CONTAINER_TYPE_VARIATION_REF => ScalarValue::Utf8(Some(s.clone())), + LARGE_CONTAINER_TYPE_VARIATION_REF => ScalarValue::LargeUtf8(Some(s.clone())), + VIEW_CONTAINER_TYPE_VARIATION_REF => ScalarValue::Utf8View(Some(s.clone())), + others => { + return substrait_err!("Unknown type variation reference {others}"); + } + }, + Some(LiteralType::Binary(b)) => match lit.type_variation_reference { + DEFAULT_CONTAINER_TYPE_VARIATION_REF => ScalarValue::Binary(Some(b.clone())), + LARGE_CONTAINER_TYPE_VARIATION_REF => { + ScalarValue::LargeBinary(Some(b.clone())) + } + VIEW_CONTAINER_TYPE_VARIATION_REF => ScalarValue::BinaryView(Some(b.clone())), + others => { + return substrait_err!("Unknown type variation reference {others}"); + } + }, + Some(LiteralType::FixedBinary(b)) => { + ScalarValue::FixedSizeBinary(b.len() as _, Some(b.clone())) + } + Some(LiteralType::Decimal(d)) => { + let value: [u8; 16] = d + .value + .clone() + .try_into() + .or(substrait_err!("Failed to parse decimal value"))?; + let p = d.precision.try_into().map_err(|e| { + substrait_datafusion_err!("Failed to parse decimal precision: {e}") + })?; + let s = d.scale.try_into().map_err(|e| { + substrait_datafusion_err!("Failed to parse decimal scale: {e}") + })?; + ScalarValue::Decimal128(Some(i128::from_le_bytes(value)), p, s) + } + Some(LiteralType::List(l)) => { + // Each element should start the name index from the same value, then we increase it + // once at the end + let mut element_name_idx = *name_idx; + let elements = l + .values + .iter() + .map(|el| { + element_name_idx = *name_idx; + from_substrait_literal(consumer, el, dfs_names, &mut element_name_idx) + }) + .collect::>>()?; + *name_idx = element_name_idx; + if elements.is_empty() { + return substrait_err!( + "Empty list must be encoded as EmptyList literal type, not List" + ); + } + let element_type = elements[0].data_type(); + match lit.type_variation_reference { + DEFAULT_CONTAINER_TYPE_VARIATION_REF => ScalarValue::List( + ScalarValue::new_list_nullable(elements.as_slice(), &element_type), + ), + LARGE_CONTAINER_TYPE_VARIATION_REF => ScalarValue::LargeList( + ScalarValue::new_large_list(elements.as_slice(), &element_type), + ), + others => { + return substrait_err!("Unknown type variation reference {others}"); + } + } + } + Some(LiteralType::EmptyList(l)) => { + let element_type = from_substrait_type( + consumer, + l.r#type.clone().unwrap().as_ref(), + dfs_names, + name_idx, + )?; + match lit.type_variation_reference { + DEFAULT_CONTAINER_TYPE_VARIATION_REF => { + ScalarValue::List(ScalarValue::new_list_nullable(&[], &element_type)) + } + LARGE_CONTAINER_TYPE_VARIATION_REF => ScalarValue::LargeList( + ScalarValue::new_large_list(&[], &element_type), + ), + others => { + return substrait_err!("Unknown type variation reference {others}"); + } + } + } + Some(LiteralType::Map(m)) => { + // Each entry should start the name index from the same value, then we increase it + // once at the end + let mut entry_name_idx = *name_idx; + let entries = m + .key_values + .iter() + .map(|kv| { + entry_name_idx = *name_idx; + let key_sv = from_substrait_literal( + consumer, + kv.key.as_ref().unwrap(), + dfs_names, + &mut entry_name_idx, + )?; + let value_sv = from_substrait_literal( + consumer, + kv.value.as_ref().unwrap(), + dfs_names, + &mut entry_name_idx, + )?; + ScalarStructBuilder::new() + .with_scalar(Field::new("key", key_sv.data_type(), false), key_sv) + .with_scalar( + Field::new("value", value_sv.data_type(), true), + value_sv, + ) + .build() + }) + .collect::>>()?; + *name_idx = entry_name_idx; + + if entries.is_empty() { + return substrait_err!( + "Empty map must be encoded as EmptyMap literal type, not Map" + ); + } + + ScalarValue::Map(Arc::new(MapArray::new( + Arc::new(Field::new("entries", entries[0].data_type(), false)), + OffsetBuffer::new(vec![0, entries.len() as i32].into()), + ScalarValue::iter_to_array(entries)?.as_struct().to_owned(), + None, + false, + ))) + } + Some(LiteralType::EmptyMap(m)) => { + let key = match &m.key { + Some(k) => Ok(k), + _ => plan_err!("Missing key type for empty map"), + }?; + let value = match &m.value { + Some(v) => Ok(v), + _ => plan_err!("Missing value type for empty map"), + }?; + let key_type = from_substrait_type(consumer, key, dfs_names, name_idx)?; + let value_type = from_substrait_type(consumer, value, dfs_names, name_idx)?; + + // new_empty_array on a MapType creates a too empty array + // We want it to contain an empty struct array to align with an empty MapBuilder one + let entries = Field::new_struct( + "entries", + vec![ + Field::new("key", key_type, false), + Field::new("value", value_type, true), + ], + false, + ); + let struct_array = + new_empty_array(entries.data_type()).as_struct().to_owned(); + ScalarValue::Map(Arc::new(MapArray::new( + Arc::new(entries), + OffsetBuffer::new(vec![0, 0].into()), + struct_array, + None, + false, + ))) + } + Some(LiteralType::Struct(s)) => { + let mut builder = ScalarStructBuilder::new(); + for (i, field) in s.fields.iter().enumerate() { + let name = next_struct_field_name(i, dfs_names, name_idx)?; + let sv = from_substrait_literal(consumer, field, dfs_names, name_idx)?; + // We assume everything to be nullable, since Arrow's strict about things matching + // and it's hard to match otherwise. + builder = builder.with_scalar(Field::new(name, sv.data_type(), true), sv); + } + builder.build()? + } + Some(LiteralType::Null(null_type)) => { + let data_type = + from_substrait_type(consumer, null_type, dfs_names, name_idx)?; + ScalarValue::try_from(&data_type)? + } + Some(LiteralType::IntervalDayToSecond(IntervalDayToSecond { + days, + seconds, + subseconds, + precision_mode, + })) => { + use interval_day_to_second::PrecisionMode; + // DF only supports millisecond precision, so for any more granular type we lose precision + let milliseconds = match precision_mode { + Some(PrecisionMode::Microseconds(ms)) => ms / 1000, + None => + if *subseconds != 0 { + return substrait_err!("Cannot set subseconds field of IntervalDayToSecond without setting precision"); + } else { + 0_i32 + } + Some(PrecisionMode::Precision(0)) => *subseconds as i32 * 1000, + Some(PrecisionMode::Precision(3)) => *subseconds as i32, + Some(PrecisionMode::Precision(6)) => (subseconds / 1000) as i32, + Some(PrecisionMode::Precision(9)) => (subseconds / 1000 / 1000) as i32, + _ => { + return not_impl_err!( + "Unsupported Substrait interval day to second precision mode: {precision_mode:?}") + } + }; + + ScalarValue::new_interval_dt(*days, (seconds * 1000) + milliseconds) + } + Some(LiteralType::IntervalYearToMonth(IntervalYearToMonth { years, months })) => { + ScalarValue::new_interval_ym(*years, *months) + } + Some(LiteralType::IntervalCompound(IntervalCompound { + interval_year_to_month, + interval_day_to_second, + })) => match (interval_year_to_month, interval_day_to_second) { + ( + Some(IntervalYearToMonth { years, months }), + Some(IntervalDayToSecond { + days, + seconds, + subseconds, + precision_mode: + Some(interval_day_to_second::PrecisionMode::Precision(p)), + }), + ) => { + if *p < 0 || *p > 9 { + return plan_err!( + "Unsupported Substrait interval day to second precision: {}", + p + ); + } + let nanos = *subseconds * i64::pow(10, (9 - p) as u32); + ScalarValue::new_interval_mdn( + *years * 12 + months, + *days, + *seconds as i64 * NANOSECONDS + nanos, + ) + } + _ => return plan_err!("Substrait compound interval missing components"), + }, + Some(LiteralType::FixedChar(c)) => ScalarValue::Utf8(Some(c.clone())), + Some(LiteralType::UserDefined(user_defined)) => { + if let Ok(value) = consumer.consume_user_defined_literal(user_defined) { + return Ok(value); + } + + // TODO: remove the code below once the producer has been updated + + // Helper function to prevent duplicating this code - can be inlined once the non-extension path is removed + let interval_month_day_nano = + |user_defined: &proto::expression::literal::UserDefined| -> datafusion::common::Result { + let Some(Val::Value(raw_val)) = user_defined.val.as_ref() else { + return substrait_err!("Interval month day nano value is empty"); + }; + let value_slice: [u8; 16] = + (*raw_val.value).try_into().map_err(|_| { + substrait_datafusion_err!( + "Failed to parse interval month day nano value" + ) + })?; + let months = + i32::from_le_bytes(value_slice[0..4].try_into().unwrap()); + let days = i32::from_le_bytes(value_slice[4..8].try_into().unwrap()); + let nanoseconds = + i64::from_le_bytes(value_slice[8..16].try_into().unwrap()); + Ok(ScalarValue::IntervalMonthDayNano(Some( + IntervalMonthDayNano { + months, + days, + nanoseconds, + }, + ))) + }; + + if let Some(name) = consumer + .get_extensions() + .types + .get(&user_defined.type_reference) + { + match name.as_ref() { + // Kept for backwards compatibility - producers should use IntervalCompound instead + #[allow(deprecated)] + INTERVAL_MONTH_DAY_NANO_TYPE_NAME => { + interval_month_day_nano(user_defined)? + } + _ => { + return not_impl_err!( + "Unsupported Substrait user defined type with ref {} and name {}", + user_defined.type_reference, + name + ) + } + } + } else { + #[allow(deprecated)] + match user_defined.type_reference { + // Kept for backwards compatibility, producers should useIntervalYearToMonth instead + INTERVAL_YEAR_MONTH_TYPE_REF => { + let Some(Val::Value(raw_val)) = user_defined.val.as_ref() else { + return substrait_err!("Interval year month value is empty"); + }; + let value_slice: [u8; 4] = + (*raw_val.value).try_into().map_err(|_| { + substrait_datafusion_err!( + "Failed to parse interval year month value" + ) + })?; + ScalarValue::IntervalYearMonth(Some(i32::from_le_bytes( + value_slice, + ))) + } + // Kept for backwards compatibility, producers should useIntervalDayToSecond instead + INTERVAL_DAY_TIME_TYPE_REF => { + let Some(Val::Value(raw_val)) = user_defined.val.as_ref() else { + return substrait_err!("Interval day time value is empty"); + }; + let value_slice: [u8; 8] = + (*raw_val.value).try_into().map_err(|_| { + substrait_datafusion_err!( + "Failed to parse interval day time value" + ) + })?; + let days = + i32::from_le_bytes(value_slice[0..4].try_into().unwrap()); + let milliseconds = + i32::from_le_bytes(value_slice[4..8].try_into().unwrap()); + ScalarValue::IntervalDayTime(Some(IntervalDayTime { + days, + milliseconds, + })) + } + // Kept for backwards compatibility, producers should useIntervalCompound instead + INTERVAL_MONTH_DAY_NANO_TYPE_REF => { + interval_month_day_nano(user_defined)? + } + _ => { + return not_impl_err!( + "Unsupported Substrait user defined type literal with ref {}", + user_defined.type_reference + ) + } + } + } + } + _ => return not_impl_err!("Unsupported literal_type: {:?}", lit.literal_type), + }; + + Ok(scalar_value) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::logical_plan::consumer::utils::tests::test_consumer; + + #[test] + fn interval_compound_different_precision() -> datafusion::common::Result<()> { + // DF producer (and thus roundtrip) always uses precision = 9, + // this test exists to test with some other value. + let substrait = Literal { + nullable: false, + type_variation_reference: 0, + literal_type: Some(LiteralType::IntervalCompound(IntervalCompound { + interval_year_to_month: Some(IntervalYearToMonth { + years: 1, + months: 2, + }), + interval_day_to_second: Some(IntervalDayToSecond { + days: 3, + seconds: 4, + subseconds: 5, + precision_mode: Some( + interval_day_to_second::PrecisionMode::Precision(6), + ), + }), + })), + }; + + let consumer = test_consumer(); + assert_eq!( + from_substrait_literal_without_names(&consumer, &substrait)?, + ScalarValue::IntervalMonthDayNano(Some(IntervalMonthDayNano { + months: 14, + days: 3, + nanoseconds: 4_000_005_000 + })) + ); + + Ok(()) + } +} diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/mod.rs b/datafusion/substrait/src/logical_plan/consumer/expr/mod.rs new file mode 100644 index 000000000000..b3ec2e37811f --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/expr/mod.rs @@ -0,0 +1,264 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod aggregate_function; +mod cast; +mod field_reference; +mod function_arguments; +mod if_then; +mod literal; +mod scalar_function; +mod singular_or_list; +mod subquery; +mod window_function; + +pub use aggregate_function::*; +pub use cast::*; +pub use field_reference::*; +pub use function_arguments::*; +pub use if_then::*; +pub use literal::*; +pub use scalar_function::*; +pub use singular_or_list::*; +pub use subquery::*; +pub use window_function::*; + +use crate::extensions::Extensions; +use crate::logical_plan::consumer::utils::rename_field; +use crate::logical_plan::consumer::{ + from_substrait_named_struct, DefaultSubstraitConsumer, SubstraitConsumer, +}; +use datafusion::arrow::datatypes::Field; +use datafusion::common::{not_impl_err, plan_err, substrait_err, DFSchema, DFSchemaRef}; +use datafusion::execution::SessionState; +use datafusion::logical_expr::{Expr, ExprSchemable}; +use substrait::proto::expression::RexType; +use substrait::proto::expression_reference::ExprType; +use substrait::proto::{Expression, ExtendedExpression}; + +/// Convert Substrait Rex to DataFusion Expr +pub async fn from_substrait_rex( + consumer: &impl SubstraitConsumer, + expression: &Expression, + input_schema: &DFSchema, +) -> datafusion::common::Result { + match &expression.rex_type { + Some(t) => match t { + RexType::Literal(expr) => consumer.consume_literal(expr).await, + RexType::Selection(expr) => { + consumer.consume_field_reference(expr, input_schema).await + } + RexType::ScalarFunction(expr) => { + consumer.consume_scalar_function(expr, input_schema).await + } + RexType::WindowFunction(expr) => { + consumer.consume_window_function(expr, input_schema).await + } + RexType::IfThen(expr) => consumer.consume_if_then(expr, input_schema).await, + RexType::SwitchExpression(expr) => { + consumer.consume_switch(expr, input_schema).await + } + RexType::SingularOrList(expr) => { + consumer.consume_singular_or_list(expr, input_schema).await + } + + RexType::MultiOrList(expr) => { + consumer.consume_multi_or_list(expr, input_schema).await + } + + RexType::Cast(expr) => { + consumer.consume_cast(expr.as_ref(), input_schema).await + } + + RexType::Subquery(expr) => { + consumer.consume_subquery(expr.as_ref(), input_schema).await + } + RexType::Nested(expr) => consumer.consume_nested(expr, input_schema).await, + RexType::Enum(expr) => consumer.consume_enum(expr, input_schema).await, + RexType::DynamicParameter(expr) => { + consumer.consume_dynamic_parameter(expr, input_schema).await + } + }, + None => substrait_err!("Expression must set rex_type: {:?}", expression), + } +} + +/// Convert Substrait ExtendedExpression to ExprContainer +/// +/// A Substrait ExtendedExpression message contains one or more expressions, +/// with names for the outputs, and an input schema. These pieces are all included +/// in the ExprContainer. +/// +/// This is a top-level message and can be used to send expressions (not plans) +/// between systems. This is often useful for scenarios like pushdown where filter +/// expressions need to be sent to remote systems. +pub async fn from_substrait_extended_expr( + state: &SessionState, + extended_expr: &ExtendedExpression, +) -> datafusion::common::Result { + // Register function extension + let extensions = Extensions::try_from(&extended_expr.extensions)?; + if !extensions.type_variations.is_empty() { + return not_impl_err!("Type variation extensions are not supported"); + } + + let consumer = DefaultSubstraitConsumer { + extensions: &extensions, + state, + }; + + let input_schema = DFSchemaRef::new(match &extended_expr.base_schema { + Some(base_schema) => from_substrait_named_struct(&consumer, base_schema), + None => { + plan_err!("required property `base_schema` missing from Substrait ExtendedExpression message") + } + }?); + + // Parse expressions + let mut exprs = Vec::with_capacity(extended_expr.referred_expr.len()); + for (expr_idx, substrait_expr) in extended_expr.referred_expr.iter().enumerate() { + let scalar_expr = match &substrait_expr.expr_type { + Some(ExprType::Expression(scalar_expr)) => Ok(scalar_expr), + Some(ExprType::Measure(_)) => { + not_impl_err!("Measure expressions are not yet supported") + } + None => { + plan_err!("required property `expr_type` missing from Substrait ExpressionReference message") + } + }?; + let expr = consumer + .consume_expression(scalar_expr, &input_schema) + .await?; + let (output_type, expected_nullability) = + expr.data_type_and_nullable(&input_schema)?; + let output_field = Field::new("", output_type, expected_nullability); + let mut names_idx = 0; + let output_field = rename_field( + &output_field, + &substrait_expr.output_names, + expr_idx, + &mut names_idx, + /*rename_self=*/ true, + )?; + exprs.push((expr, output_field)); + } + + Ok(ExprContainer { + input_schema, + exprs, + }) +} + +/// An ExprContainer is a container for a collection of expressions with a common input schema +/// +/// In addition, each expression is associated with a field, which defines the +/// expression's output. The data type and nullability of the field are calculated from the +/// expression and the input schema. However the names of the field (and its nested fields) are +/// derived from the Substrait message. +pub struct ExprContainer { + /// The input schema for the expressions + pub input_schema: DFSchemaRef, + /// The expressions + /// + /// Each item contains an expression and the field that defines the expected nullability and name of the expr's output + pub exprs: Vec<(Expr, Field)>, +} + +/// Convert Substrait Expressions to DataFusion Exprs +pub async fn from_substrait_rex_vec( + consumer: &impl SubstraitConsumer, + exprs: &Vec, + input_schema: &DFSchema, +) -> datafusion::common::Result> { + let mut expressions: Vec = vec![]; + for expr in exprs { + let expression = consumer.consume_expression(expr, input_schema).await?; + expressions.push(expression); + } + Ok(expressions) +} + +#[cfg(test)] +mod tests { + use crate::extensions::Extensions; + use crate::logical_plan::consumer::utils::tests::test_consumer; + use crate::logical_plan::consumer::*; + use datafusion::common::DFSchema; + use datafusion::logical_expr::Expr; + use substrait::proto::expression::window_function::BoundsType; + use substrait::proto::expression::RexType; + use substrait::proto::Expression; + + #[tokio::test] + async fn window_function_with_range_unit_and_no_order_by( + ) -> datafusion::common::Result<()> { + let substrait = Expression { + rex_type: Some(RexType::WindowFunction( + substrait::proto::expression::WindowFunction { + function_reference: 0, + bounds_type: BoundsType::Range as i32, + sorts: vec![], + ..Default::default() + }, + )), + }; + + let mut consumer = test_consumer(); + + // Just registering a single function (index 0) so that the plan + // does not throw a "function not found" error. + let mut extensions = Extensions::default(); + extensions.register_function("count".to_string()); + consumer.extensions = &extensions; + + match from_substrait_rex(&consumer, &substrait, &DFSchema::empty()).await? { + Expr::WindowFunction(window_function) => { + assert_eq!(window_function.params.order_by.len(), 1) + } + _ => panic!("expr was not a WindowFunction"), + }; + + Ok(()) + } + + #[tokio::test] + async fn window_function_with_count() -> datafusion::common::Result<()> { + let substrait = Expression { + rex_type: Some(RexType::WindowFunction( + substrait::proto::expression::WindowFunction { + function_reference: 0, + ..Default::default() + }, + )), + }; + + let mut consumer = test_consumer(); + + let mut extensions = Extensions::default(); + extensions.register_function("count".to_string()); + consumer.extensions = &extensions; + + match from_substrait_rex(&consumer, &substrait, &DFSchema::empty()).await? { + Expr::WindowFunction(window_function) => { + assert_eq!(window_function.params.args.len(), 1) + } + _ => panic!("expr was not a WindowFunction"), + }; + + Ok(()) + } +} diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/scalar_function.rs b/datafusion/substrait/src/logical_plan/consumer/expr/scalar_function.rs new file mode 100644 index 000000000000..7797c935211f --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/expr/scalar_function.rs @@ -0,0 +1,372 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::consumer::{from_substrait_func_args, SubstraitConsumer}; +use datafusion::common::Result; +use datafusion::common::{ + not_impl_err, plan_err, substrait_err, DFSchema, DataFusionError, ScalarValue, +}; +use datafusion::execution::FunctionRegistry; +use datafusion::logical_expr::{expr, BinaryExpr, Expr, Like, Operator}; +use std::vec::Drain; +use substrait::proto::expression::ScalarFunction; +use substrait::proto::function_argument::ArgType; + +pub async fn from_scalar_function( + consumer: &impl SubstraitConsumer, + f: &ScalarFunction, + input_schema: &DFSchema, +) -> Result { + let Some(fn_signature) = consumer + .get_extensions() + .functions + .get(&f.function_reference) + else { + return plan_err!( + "Scalar function not found: function reference = {:?}", + f.function_reference + ); + }; + let fn_name = substrait_fun_name(fn_signature); + let args = from_substrait_func_args(consumer, &f.arguments, input_schema).await?; + + // try to first match the requested function into registered udfs, then built-in ops + // and finally built-in expressions + if let Ok(func) = consumer.get_function_registry().udf(fn_name) { + Ok(Expr::ScalarFunction(expr::ScalarFunction::new_udf( + func.to_owned(), + args, + ))) + } else if let Some(op) = name_to_op(fn_name) { + if args.len() < 2 { + return not_impl_err!( + "Expect at least two arguments for binary operator {op:?}, the provided number of operators is {:?}", + f.arguments.len() + ); + } + // In those cases we build a balanced tree of BinaryExprs + arg_list_to_binary_op_tree(op, args) + } else if let Some(builder) = BuiltinExprBuilder::try_from_name(fn_name) { + builder.build(consumer, f, input_schema).await + } else { + not_impl_err!("Unsupported function name: {fn_name:?}") + } +} + +pub fn substrait_fun_name(name: &str) -> &str { + let name = match name.rsplit_once(':') { + // Since 0.32.0, Substrait requires the function names to be in a compound format + // https://substrait.io/extensions/#function-signature-compound-names + // for example, `add:i8_i8`. + // On the consumer side, we don't really care about the signature though, just the name. + Some((name, _)) => name, + None => name, + }; + name +} + +pub fn name_to_op(name: &str) -> Option { + match name { + "equal" => Some(Operator::Eq), + "not_equal" => Some(Operator::NotEq), + "lt" => Some(Operator::Lt), + "lte" => Some(Operator::LtEq), + "gt" => Some(Operator::Gt), + "gte" => Some(Operator::GtEq), + "add" => Some(Operator::Plus), + "subtract" => Some(Operator::Minus), + "multiply" => Some(Operator::Multiply), + "divide" => Some(Operator::Divide), + "mod" => Some(Operator::Modulo), + "modulus" => Some(Operator::Modulo), + "and" => Some(Operator::And), + "or" => Some(Operator::Or), + "is_distinct_from" => Some(Operator::IsDistinctFrom), + "is_not_distinct_from" => Some(Operator::IsNotDistinctFrom), + "regex_match" => Some(Operator::RegexMatch), + "regex_imatch" => Some(Operator::RegexIMatch), + "regex_not_match" => Some(Operator::RegexNotMatch), + "regex_not_imatch" => Some(Operator::RegexNotIMatch), + "bitwise_and" => Some(Operator::BitwiseAnd), + "bitwise_or" => Some(Operator::BitwiseOr), + "str_concat" => Some(Operator::StringConcat), + "at_arrow" => Some(Operator::AtArrow), + "arrow_at" => Some(Operator::ArrowAt), + "bitwise_xor" => Some(Operator::BitwiseXor), + "bitwise_shift_right" => Some(Operator::BitwiseShiftRight), + "bitwise_shift_left" => Some(Operator::BitwiseShiftLeft), + _ => None, + } +} + +/// Build a balanced tree of binary operations from a binary operator and a list of arguments. +/// +/// For example, `OR` `(a, b, c, d, e)` will be converted to: `OR(OR(a, OR(b, c)), OR(d, e))`. +/// +/// `args` must not be empty. +fn arg_list_to_binary_op_tree(op: Operator, mut args: Vec) -> Result { + let n_args = args.len(); + let mut drained_args = args.drain(..); + arg_list_to_binary_op_tree_inner(op, &mut drained_args, n_args) +} + +/// Helper function for [`arg_list_to_binary_op_tree`] implementation +/// +/// `take_len` represents the number of elements to take from `args` before returning. +/// We use `take_len` to avoid recursively building a `Take>>` type. +fn arg_list_to_binary_op_tree_inner( + op: Operator, + args: &mut Drain, + take_len: usize, +) -> Result { + if take_len == 1 { + return args.next().ok_or_else(|| { + DataFusionError::Substrait( + "Expected one more available element in iterator, found none".to_string(), + ) + }); + } else if take_len == 0 { + return substrait_err!("Cannot build binary operation tree with 0 arguments"); + } + // Cut argument list in 2 balanced parts + let left_take = take_len / 2; + let right_take = take_len - left_take; + let left = arg_list_to_binary_op_tree_inner(op, args, left_take)?; + let right = arg_list_to_binary_op_tree_inner(op, args, right_take)?; + Ok(Expr::BinaryExpr(BinaryExpr { + left: Box::new(left), + op, + right: Box::new(right), + })) +} + +/// Build [`Expr`] from its name and required inputs. +struct BuiltinExprBuilder { + expr_name: String, +} + +impl BuiltinExprBuilder { + pub fn try_from_name(name: &str) -> Option { + match name { + "not" | "like" | "ilike" | "is_null" | "is_not_null" | "is_true" + | "is_false" | "is_not_true" | "is_not_false" | "is_unknown" + | "is_not_unknown" | "negative" | "negate" => Some(Self { + expr_name: name.to_string(), + }), + _ => None, + } + } + + pub async fn build( + self, + consumer: &impl SubstraitConsumer, + f: &ScalarFunction, + input_schema: &DFSchema, + ) -> Result { + match self.expr_name.as_str() { + "like" => Self::build_like_expr(consumer, false, f, input_schema).await, + "ilike" => Self::build_like_expr(consumer, true, f, input_schema).await, + "not" | "negative" | "negate" | "is_null" | "is_not_null" | "is_true" + | "is_false" | "is_not_true" | "is_not_false" | "is_unknown" + | "is_not_unknown" => { + Self::build_unary_expr(consumer, &self.expr_name, f, input_schema).await + } + _ => { + not_impl_err!("Unsupported builtin expression: {}", self.expr_name) + } + } + } + + async fn build_unary_expr( + consumer: &impl SubstraitConsumer, + fn_name: &str, + f: &ScalarFunction, + input_schema: &DFSchema, + ) -> Result { + if f.arguments.len() != 1 { + return substrait_err!("Expect one argument for {fn_name} expr"); + } + let Some(ArgType::Value(expr_substrait)) = &f.arguments[0].arg_type else { + return substrait_err!("Invalid arguments type for {fn_name} expr"); + }; + let arg = consumer + .consume_expression(expr_substrait, input_schema) + .await?; + let arg = Box::new(arg); + + let expr = match fn_name { + "not" => Expr::Not(arg), + "negative" | "negate" => Expr::Negative(arg), + "is_null" => Expr::IsNull(arg), + "is_not_null" => Expr::IsNotNull(arg), + "is_true" => Expr::IsTrue(arg), + "is_false" => Expr::IsFalse(arg), + "is_not_true" => Expr::IsNotTrue(arg), + "is_not_false" => Expr::IsNotFalse(arg), + "is_unknown" => Expr::IsUnknown(arg), + "is_not_unknown" => Expr::IsNotUnknown(arg), + _ => return not_impl_err!("Unsupported builtin expression: {}", fn_name), + }; + + Ok(expr) + } + + async fn build_like_expr( + consumer: &impl SubstraitConsumer, + case_insensitive: bool, + f: &ScalarFunction, + input_schema: &DFSchema, + ) -> Result { + let fn_name = if case_insensitive { "ILIKE" } else { "LIKE" }; + if f.arguments.len() != 2 && f.arguments.len() != 3 { + return substrait_err!("Expect two or three arguments for `{fn_name}` expr"); + } + + let Some(ArgType::Value(expr_substrait)) = &f.arguments[0].arg_type else { + return substrait_err!("Invalid arguments type for `{fn_name}` expr"); + }; + let expr = consumer + .consume_expression(expr_substrait, input_schema) + .await?; + let Some(ArgType::Value(pattern_substrait)) = &f.arguments[1].arg_type else { + return substrait_err!("Invalid arguments type for `{fn_name}` expr"); + }; + let pattern = consumer + .consume_expression(pattern_substrait, input_schema) + .await?; + + // Default case: escape character is Literal(Utf8(None)) + let escape_char = if f.arguments.len() == 3 { + let Some(ArgType::Value(escape_char_substrait)) = &f.arguments[2].arg_type + else { + return substrait_err!("Invalid arguments type for `{fn_name}` expr"); + }; + + let escape_char_expr = consumer + .consume_expression(escape_char_substrait, input_schema) + .await?; + + match escape_char_expr { + Expr::Literal(ScalarValue::Utf8(escape_char_string), _) => { + // Convert Option to Option + escape_char_string.and_then(|s| s.chars().next()) + } + _ => { + return substrait_err!( + "Expect Utf8 literal for escape char, but found {escape_char_expr:?}" + ) + } + } + } else { + None + }; + + Ok(Expr::Like(Like { + negated: false, + expr: Box::new(expr), + pattern: Box::new(pattern), + escape_char, + case_insensitive, + })) + } +} + +#[cfg(test)] +mod tests { + use super::arg_list_to_binary_op_tree; + use crate::extensions::Extensions; + use crate::logical_plan::consumer::tests::TEST_SESSION_STATE; + use crate::logical_plan::consumer::{DefaultSubstraitConsumer, SubstraitConsumer}; + use datafusion::arrow::datatypes::{DataType, Field, Schema}; + use datafusion::common::{DFSchema, Result, ScalarValue}; + use datafusion::logical_expr::{Expr, Operator}; + use insta::assert_snapshot; + use substrait::proto::expression::literal::LiteralType; + use substrait::proto::expression::{Literal, RexType, ScalarFunction}; + use substrait::proto::function_argument::ArgType; + use substrait::proto::{Expression, FunctionArgument}; + + /// Test that large argument lists for binary operations do not crash the consumer + #[tokio::test] + async fn test_binary_op_large_argument_list() -> Result<()> { + // Build substrait extensions (we are using only one function) + let mut extensions = Extensions::default(); + extensions.functions.insert(0, String::from("or:bool_bool")); + // Build substrait consumer + let consumer = DefaultSubstraitConsumer::new(&extensions, &TEST_SESSION_STATE); + + // Build arguments for the function call, this is basically an OR(true, true, ..., true) + let arg = FunctionArgument { + arg_type: Some(ArgType::Value(Expression { + rex_type: Some(RexType::Literal(Literal { + nullable: false, + type_variation_reference: 0, + literal_type: Some(LiteralType::Boolean(true)), + })), + })), + }; + let arguments = vec![arg; 50000]; + let func = ScalarFunction { + function_reference: 0, + arguments, + ..Default::default() + }; + // Trivial input schema + let schema = Schema::new(vec![Field::new("a", DataType::Boolean, false)]); + let df_schema = DFSchema::try_from(schema).unwrap(); + + // Consume the expression and ensure we don't crash + let _ = consumer.consume_scalar_function(&func, &df_schema).await?; + Ok(()) + } + + fn int64_literals(integers: &[i64]) -> Vec { + integers + .iter() + .map(|value| Expr::Literal(ScalarValue::Int64(Some(*value)), None)) + .collect() + } + + #[test] + fn arg_list_to_binary_op_tree_1_arg() -> Result<()> { + let expr = arg_list_to_binary_op_tree(Operator::Or, int64_literals(&[1]))?; + assert_snapshot!(expr.to_string(), @"Int64(1)"); + Ok(()) + } + + #[test] + fn arg_list_to_binary_op_tree_2_args() -> Result<()> { + let expr = arg_list_to_binary_op_tree(Operator::Or, int64_literals(&[1, 2]))?; + assert_snapshot!(expr.to_string(), @"Int64(1) OR Int64(2)"); + Ok(()) + } + + #[test] + fn arg_list_to_binary_op_tree_3_args() -> Result<()> { + let expr = arg_list_to_binary_op_tree(Operator::Or, int64_literals(&[1, 2, 3]))?; + assert_snapshot!(expr.to_string(), @"Int64(1) OR Int64(2) OR Int64(3)"); + Ok(()) + } + + #[test] + fn arg_list_to_binary_op_tree_4_args() -> Result<()> { + let expr = + arg_list_to_binary_op_tree(Operator::Or, int64_literals(&[1, 2, 3, 4]))?; + assert_snapshot!(expr.to_string(), @"Int64(1) OR Int64(2) OR Int64(3) OR Int64(4)"); + Ok(()) + } +} diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/singular_or_list.rs b/datafusion/substrait/src/logical_plan/consumer/expr/singular_or_list.rs new file mode 100644 index 000000000000..6d44ebcce590 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/expr/singular_or_list.rs @@ -0,0 +1,40 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::consumer::{from_substrait_rex_vec, SubstraitConsumer}; +use datafusion::common::DFSchema; +use datafusion::logical_expr::expr::InList; +use datafusion::logical_expr::Expr; +use substrait::proto::expression::SingularOrList; + +pub async fn from_singular_or_list( + consumer: &impl SubstraitConsumer, + expr: &SingularOrList, + input_schema: &DFSchema, +) -> datafusion::common::Result { + let substrait_expr = expr.value.as_ref().unwrap(); + let substrait_list = expr.options.as_ref(); + Ok(Expr::InList(InList { + expr: Box::new( + consumer + .consume_expression(substrait_expr, input_schema) + .await?, + ), + list: from_substrait_rex_vec(consumer, substrait_list, input_schema).await?, + negated: false, + })) +} diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/subquery.rs b/datafusion/substrait/src/logical_plan/consumer/expr/subquery.rs new file mode 100644 index 000000000000..f7e4c2bb0fbd --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/expr/subquery.rs @@ -0,0 +1,106 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::consumer::SubstraitConsumer; +use datafusion::common::{substrait_err, DFSchema, Spans}; +use datafusion::logical_expr::expr::{Exists, InSubquery}; +use datafusion::logical_expr::{Expr, Subquery}; +use std::sync::Arc; +use substrait::proto::expression as substrait_expression; +use substrait::proto::expression::subquery::set_predicate::PredicateOp; +use substrait::proto::expression::subquery::SubqueryType; + +pub async fn from_subquery( + consumer: &impl SubstraitConsumer, + subquery: &substrait_expression::Subquery, + input_schema: &DFSchema, +) -> datafusion::common::Result { + match &subquery.subquery_type { + Some(subquery_type) => match subquery_type { + SubqueryType::InPredicate(in_predicate) => { + if in_predicate.needles.len() != 1 { + substrait_err!("InPredicate Subquery type must have exactly one Needle expression") + } else { + let needle_expr = &in_predicate.needles[0]; + let haystack_expr = &in_predicate.haystack; + if let Some(haystack_expr) = haystack_expr { + let haystack_expr = consumer.consume_rel(haystack_expr).await?; + let outer_refs = haystack_expr.all_out_ref_exprs(); + Ok(Expr::InSubquery(InSubquery { + expr: Box::new( + consumer + .consume_expression(needle_expr, input_schema) + .await?, + ), + subquery: Subquery { + subquery: Arc::new(haystack_expr), + outer_ref_columns: outer_refs, + spans: Spans::new(), + }, + negated: false, + })) + } else { + substrait_err!( + "InPredicate Subquery type must have a Haystack expression" + ) + } + } + } + SubqueryType::Scalar(query) => { + let plan = consumer + .consume_rel(&(query.input.clone()).unwrap_or_default()) + .await?; + let outer_ref_columns = plan.all_out_ref_exprs(); + Ok(Expr::ScalarSubquery(Subquery { + subquery: Arc::new(plan), + outer_ref_columns, + spans: Spans::new(), + })) + } + SubqueryType::SetPredicate(predicate) => { + match predicate.predicate_op() { + // exist + PredicateOp::Exists => { + let relation = &predicate.tuples; + let plan = consumer + .consume_rel(&relation.clone().unwrap_or_default()) + .await?; + let outer_ref_columns = plan.all_out_ref_exprs(); + Ok(Expr::Exists(Exists::new( + Subquery { + subquery: Arc::new(plan), + outer_ref_columns, + spans: Spans::new(), + }, + false, + ))) + } + other_type => substrait_err!( + "unimplemented type {:?} for set predicate", + other_type + ), + } + } + other_type => { + substrait_err!("Subquery type {:?} not implemented", other_type) + } + }, + None => { + substrait_err!("Subquery expression without SubqueryType is not allowed") + } + } +} diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/window_function.rs b/datafusion/substrait/src/logical_plan/consumer/expr/window_function.rs new file mode 100644 index 000000000000..80b643a547ee --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/expr/window_function.rs @@ -0,0 +1,163 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::consumer::{ + from_substrait_func_args, from_substrait_rex_vec, from_substrait_sorts, + substrait_fun_name, SubstraitConsumer, +}; +use datafusion::common::{ + not_impl_err, plan_datafusion_err, plan_err, substrait_err, DFSchema, ScalarValue, +}; +use datafusion::execution::FunctionRegistry; +use datafusion::logical_expr::expr::WindowFunctionParams; +use datafusion::logical_expr::{ + expr, Expr, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, +}; +use substrait::proto::expression::window_function::{Bound, BoundsType}; +use substrait::proto::expression::WindowFunction; +use substrait::proto::expression::{ + window_function::bound as SubstraitBound, window_function::bound::Kind as BoundKind, +}; + +pub async fn from_window_function( + consumer: &impl SubstraitConsumer, + window: &WindowFunction, + input_schema: &DFSchema, +) -> datafusion::common::Result { + let Some(fn_signature) = consumer + .get_extensions() + .functions + .get(&window.function_reference) + else { + return plan_err!( + "Window function not found: function reference = {:?}", + window.function_reference + ); + }; + let fn_name = substrait_fun_name(fn_signature); + + // check udwf first, then udaf, then built-in window and aggregate functions + let fun = if let Ok(udwf) = consumer.get_function_registry().udwf(fn_name) { + Ok(WindowFunctionDefinition::WindowUDF(udwf)) + } else if let Ok(udaf) = consumer.get_function_registry().udaf(fn_name) { + Ok(WindowFunctionDefinition::AggregateUDF(udaf)) + } else { + not_impl_err!( + "Window function {} is not supported: function anchor = {:?}", + fn_name, + window.function_reference + ) + }?; + + let mut order_by = + from_substrait_sorts(consumer, &window.sorts, input_schema).await?; + + let bound_units = match BoundsType::try_from(window.bounds_type).map_err(|e| { + plan_datafusion_err!("Invalid bound type {}: {e}", window.bounds_type) + })? { + BoundsType::Rows => WindowFrameUnits::Rows, + BoundsType::Range => WindowFrameUnits::Range, + BoundsType::Unspecified => { + // If the plan does not specify the bounds type, then we use a simple logic to determine the units + // If there is no `ORDER BY`, then by default, the frame counts each row from the lower up to upper boundary + // If there is `ORDER BY`, then by default, each frame is a range starting from unbounded preceding to current row + if order_by.is_empty() { + WindowFrameUnits::Rows + } else { + WindowFrameUnits::Range + } + } + }; + let window_frame = datafusion::logical_expr::WindowFrame::new_bounds( + bound_units, + from_substrait_bound(&window.lower_bound, true)?, + from_substrait_bound(&window.upper_bound, false)?, + ); + + window_frame.regularize_order_bys(&mut order_by)?; + + // Datafusion does not support aggregate functions with no arguments, so + // we inject a dummy argument that does not affect the query, but allows + // us to bypass this limitation. + let args = if fun.name() == "count" && window.arguments.is_empty() { + vec![Expr::Literal(ScalarValue::Int64(Some(1)), None)] + } else { + from_substrait_func_args(consumer, &window.arguments, input_schema).await? + }; + + Ok(Expr::from(expr::WindowFunction { + fun, + params: WindowFunctionParams { + args, + partition_by: from_substrait_rex_vec( + consumer, + &window.partitions, + input_schema, + ) + .await?, + order_by, + window_frame, + null_treatment: None, + }, + })) +} + +fn from_substrait_bound( + bound: &Option, + is_lower: bool, +) -> datafusion::common::Result { + match bound { + Some(b) => match &b.kind { + Some(k) => match k { + BoundKind::CurrentRow(SubstraitBound::CurrentRow {}) => { + Ok(WindowFrameBound::CurrentRow) + } + BoundKind::Preceding(SubstraitBound::Preceding { offset }) => { + if *offset <= 0 { + return plan_err!("Preceding bound must be positive"); + } + Ok(WindowFrameBound::Preceding(ScalarValue::UInt64(Some( + *offset as u64, + )))) + } + BoundKind::Following(SubstraitBound::Following { offset }) => { + if *offset <= 0 { + return plan_err!("Following bound must be positive"); + } + Ok(WindowFrameBound::Following(ScalarValue::UInt64(Some( + *offset as u64, + )))) + } + BoundKind::Unbounded(SubstraitBound::Unbounded {}) => { + if is_lower { + Ok(WindowFrameBound::Preceding(ScalarValue::Null)) + } else { + Ok(WindowFrameBound::Following(ScalarValue::Null)) + } + } + }, + None => substrait_err!("WindowFunction missing Substrait Bound kind"), + }, + None => { + if is_lower { + Ok(WindowFrameBound::Preceding(ScalarValue::Null)) + } else { + Ok(WindowFrameBound::Following(ScalarValue::Null)) + } + } + } +} diff --git a/datafusion/substrait/src/logical_plan/consumer/mod.rs b/datafusion/substrait/src/logical_plan/consumer/mod.rs new file mode 100644 index 000000000000..0e01d6ded6e4 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/mod.rs @@ -0,0 +1,30 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod expr; +mod plan; +mod rel; +mod substrait_consumer; +mod types; +mod utils; + +pub use expr::*; +pub use plan::*; +pub use rel::*; +pub use substrait_consumer::*; +pub use types::*; +pub use utils::*; diff --git a/datafusion/substrait/src/logical_plan/consumer/plan.rs b/datafusion/substrait/src/logical_plan/consumer/plan.rs new file mode 100644 index 000000000000..f994f792a17e --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/plan.rs @@ -0,0 +1,90 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::utils::{make_renamed_schema, rename_expressions}; +use super::{DefaultSubstraitConsumer, SubstraitConsumer}; +use crate::extensions::Extensions; +use datafusion::common::{not_impl_err, plan_err}; +use datafusion::execution::SessionState; +use datafusion::logical_expr::{col, Aggregate, LogicalPlan, Projection}; +use std::sync::Arc; +use substrait::proto::{plan_rel, Plan}; + +/// Convert Substrait Plan to DataFusion LogicalPlan +pub async fn from_substrait_plan( + state: &SessionState, + plan: &Plan, +) -> datafusion::common::Result { + // Register function extension + let extensions = Extensions::try_from(&plan.extensions)?; + if !extensions.type_variations.is_empty() { + return not_impl_err!("Type variation extensions are not supported"); + } + + let consumer = DefaultSubstraitConsumer { + extensions: &extensions, + state, + }; + from_substrait_plan_with_consumer(&consumer, plan).await +} + +/// Convert Substrait Plan to DataFusion LogicalPlan using the given consumer +pub async fn from_substrait_plan_with_consumer( + consumer: &impl SubstraitConsumer, + plan: &Plan, +) -> datafusion::common::Result { + match plan.relations.len() { + 1 => { + match plan.relations[0].rel_type.as_ref() { + Some(rt) => match rt { + plan_rel::RelType::Rel(rel) => Ok(consumer.consume_rel(rel).await?), + plan_rel::RelType::Root(root) => { + let plan = consumer.consume_rel(root.input.as_ref().unwrap()).await?; + if root.names.is_empty() { + // Backwards compatibility for plans missing names + return Ok(plan); + } + let renamed_schema = make_renamed_schema(plan.schema(), &root.names)?; + if renamed_schema.has_equivalent_names_and_types(plan.schema()).is_ok() { + // Nothing to do if the schema is already equivalent + return Ok(plan); + } + match plan { + // If the last node of the plan produces expressions, bake the renames into those expressions. + // This isn't necessary for correctness, but helps with roundtrip tests. + LogicalPlan::Projection(p) => Ok(LogicalPlan::Projection(Projection::try_new(rename_expressions(p.expr, p.input.schema(), renamed_schema.fields())?, p.input)?)), + LogicalPlan::Aggregate(a) => { + let (group_fields, expr_fields) = renamed_schema.fields().split_at(a.group_expr.len()); + let new_group_exprs = rename_expressions(a.group_expr, a.input.schema(), group_fields)?; + let new_aggr_exprs = rename_expressions(a.aggr_expr, a.input.schema(), expr_fields)?; + Ok(LogicalPlan::Aggregate(Aggregate::try_new(a.input, new_group_exprs, new_aggr_exprs)?)) + }, + // There are probably more plans where we could bake things in, can add them later as needed. + // Otherwise, add a new Project to handle the renaming. + _ => Ok(LogicalPlan::Projection(Projection::try_new(rename_expressions(plan.schema().columns().iter().map(|c| col(c.to_owned())), plan.schema(), renamed_schema.fields())?, Arc::new(plan))?)) + } + } + }, + None => plan_err!("Cannot parse plan relation: None") + } + }, + _ => not_impl_err!( + "Substrait plan with more than 1 relation trees not supported. Number of relation trees: {:?}", + plan.relations.len() + ) + } +} diff --git a/datafusion/substrait/src/logical_plan/consumer/rel/aggregate_rel.rs b/datafusion/substrait/src/logical_plan/consumer/rel/aggregate_rel.rs new file mode 100644 index 000000000000..9421bb17c162 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/rel/aggregate_rel.rs @@ -0,0 +1,150 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::consumer::{from_substrait_agg_func, from_substrait_sorts}; +use crate::logical_plan::consumer::{NameTracker, SubstraitConsumer}; +use datafusion::common::{not_impl_err, DFSchemaRef}; +use datafusion::logical_expr::{Expr, GroupingSet, LogicalPlan, LogicalPlanBuilder}; +use substrait::proto::aggregate_function::AggregationInvocation; +use substrait::proto::aggregate_rel::Grouping; +use substrait::proto::AggregateRel; + +pub async fn from_aggregate_rel( + consumer: &impl SubstraitConsumer, + agg: &AggregateRel, +) -> datafusion::common::Result { + if let Some(input) = agg.input.as_ref() { + let input = LogicalPlanBuilder::from(consumer.consume_rel(input).await?); + let mut ref_group_exprs = vec![]; + + for e in &agg.grouping_expressions { + let x = consumer.consume_expression(e, input.schema()).await?; + ref_group_exprs.push(x); + } + + let mut group_exprs = vec![]; + let mut aggr_exprs = vec![]; + + match agg.groupings.len() { + 1 => { + group_exprs.extend_from_slice( + &from_substrait_grouping( + consumer, + &agg.groupings[0], + &ref_group_exprs, + input.schema(), + ) + .await?, + ); + } + _ => { + let mut grouping_sets = vec![]; + for grouping in &agg.groupings { + let grouping_set = from_substrait_grouping( + consumer, + grouping, + &ref_group_exprs, + input.schema(), + ) + .await?; + grouping_sets.push(grouping_set); + } + // Single-element grouping expression of type Expr::GroupingSet. + // Note that GroupingSet::Rollup would become GroupingSet::GroupingSets, when + // parsed by the producer and consumer, since Substrait does not have a type dedicated + // to ROLLUP. Only vector of Groupings (grouping sets) is available. + group_exprs + .push(Expr::GroupingSet(GroupingSet::GroupingSets(grouping_sets))); + } + }; + + for m in &agg.measures { + let filter = match &m.filter { + Some(fil) => Some(Box::new( + consumer.consume_expression(fil, input.schema()).await?, + )), + None => None, + }; + let agg_func = match &m.measure { + Some(f) => { + let distinct = match f.invocation { + _ if f.invocation == AggregationInvocation::Distinct as i32 => { + true + } + _ if f.invocation == AggregationInvocation::All as i32 => false, + _ => false, + }; + let order_by = if !f.sorts.is_empty() { + Some( + from_substrait_sorts(consumer, &f.sorts, input.schema()) + .await?, + ) + } else { + None + }; + + from_substrait_agg_func( + consumer, + f, + input.schema(), + filter, + order_by, + distinct, + ) + .await + } + None => { + not_impl_err!("Aggregate without aggregate function is not supported") + } + }; + aggr_exprs.push(agg_func?.as_ref().clone()); + } + + // Ensure that all expressions have a unique name + let mut name_tracker = NameTracker::new(); + let group_exprs = group_exprs + .iter() + .map(|e| name_tracker.get_uniquely_named_expr(e.clone())) + .collect::, _>>()?; + + input.aggregate(group_exprs, aggr_exprs)?.build() + } else { + not_impl_err!("Aggregate without an input is not valid") + } +} + +#[allow(deprecated)] +async fn from_substrait_grouping( + consumer: &impl SubstraitConsumer, + grouping: &Grouping, + expressions: &[Expr], + input_schema: &DFSchemaRef, +) -> datafusion::common::Result> { + let mut group_exprs = vec![]; + if !grouping.grouping_expressions.is_empty() { + for e in &grouping.grouping_expressions { + let expr = consumer.consume_expression(e, input_schema).await?; + group_exprs.push(expr); + } + return Ok(group_exprs); + } + for idx in &grouping.expression_references { + let e = &expressions[*idx as usize]; + group_exprs.push(e.clone()); + } + Ok(group_exprs) +} diff --git a/datafusion/substrait/src/logical_plan/consumer/rel/cross_rel.rs b/datafusion/substrait/src/logical_plan/consumer/rel/cross_rel.rs new file mode 100644 index 000000000000..a91366e47742 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/rel/cross_rel.rs @@ -0,0 +1,35 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::consumer::utils::requalify_sides_if_needed; +use crate::logical_plan::consumer::SubstraitConsumer; +use datafusion::logical_expr::{LogicalPlan, LogicalPlanBuilder}; +use substrait::proto::CrossRel; + +pub async fn from_cross_rel( + consumer: &impl SubstraitConsumer, + cross: &CrossRel, +) -> datafusion::common::Result { + let left = LogicalPlanBuilder::from( + consumer.consume_rel(cross.left.as_ref().unwrap()).await?, + ); + let right = LogicalPlanBuilder::from( + consumer.consume_rel(cross.right.as_ref().unwrap()).await?, + ); + let (left, right) = requalify_sides_if_needed(left, right)?; + left.cross_join(right.build()?)?.build() +} diff --git a/datafusion/substrait/src/logical_plan/consumer/rel/exchange_rel.rs b/datafusion/substrait/src/logical_plan/consumer/rel/exchange_rel.rs new file mode 100644 index 000000000000..d326fff44bbb --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/rel/exchange_rel.rs @@ -0,0 +1,63 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::consumer::from_substrait_field_reference; +use crate::logical_plan::consumer::SubstraitConsumer; +use datafusion::common::{not_impl_err, substrait_err}; +use datafusion::logical_expr::{LogicalPlan, Partitioning, Repartition}; +use std::sync::Arc; +use substrait::proto::exchange_rel::ExchangeKind; +use substrait::proto::ExchangeRel; + +pub async fn from_exchange_rel( + consumer: &impl SubstraitConsumer, + exchange: &ExchangeRel, +) -> datafusion::common::Result { + let Some(input) = exchange.input.as_ref() else { + return substrait_err!("Unexpected empty input in ExchangeRel"); + }; + let input = Arc::new(consumer.consume_rel(input).await?); + + let Some(exchange_kind) = &exchange.exchange_kind else { + return substrait_err!("Unexpected empty input in ExchangeRel"); + }; + + // ref: https://substrait.io/relations/physical_relations/#exchange-types + let partitioning_scheme = match exchange_kind { + ExchangeKind::ScatterByFields(scatter_fields) => { + let mut partition_columns = vec![]; + let input_schema = input.schema(); + for field_ref in &scatter_fields.fields { + let column = from_substrait_field_reference(field_ref, input_schema)?; + partition_columns.push(column); + } + Partitioning::Hash(partition_columns, exchange.partition_count as usize) + } + ExchangeKind::RoundRobin(_) => { + Partitioning::RoundRobinBatch(exchange.partition_count as usize) + } + ExchangeKind::SingleTarget(_) + | ExchangeKind::MultiTarget(_) + | ExchangeKind::Broadcast(_) => { + return not_impl_err!("Unsupported exchange kind: {exchange_kind:?}"); + } + }; + Ok(LogicalPlan::Repartition(Repartition { + input, + partitioning_scheme, + })) +} diff --git a/datafusion/substrait/src/logical_plan/consumer/rel/fetch_rel.rs b/datafusion/substrait/src/logical_plan/consumer/rel/fetch_rel.rs new file mode 100644 index 000000000000..74161d8600ea --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/rel/fetch_rel.rs @@ -0,0 +1,53 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::consumer::SubstraitConsumer; +use async_recursion::async_recursion; +use datafusion::common::{not_impl_err, DFSchema, DFSchemaRef}; +use datafusion::logical_expr::{lit, LogicalPlan, LogicalPlanBuilder}; +use substrait::proto::{fetch_rel, FetchRel}; + +#[async_recursion] +pub async fn from_fetch_rel( + consumer: &impl SubstraitConsumer, + fetch: &FetchRel, +) -> datafusion::common::Result { + if let Some(input) = fetch.input.as_ref() { + let input = LogicalPlanBuilder::from(consumer.consume_rel(input).await?); + let empty_schema = DFSchemaRef::new(DFSchema::empty()); + let offset = match &fetch.offset_mode { + Some(fetch_rel::OffsetMode::Offset(offset)) => Some(lit(*offset)), + Some(fetch_rel::OffsetMode::OffsetExpr(expr)) => { + Some(consumer.consume_expression(expr, &empty_schema).await?) + } + None => None, + }; + let count = match &fetch.count_mode { + Some(fetch_rel::CountMode::Count(count)) => { + // -1 means that ALL records should be returned, equivalent to None + (*count != -1).then(|| lit(*count)) + } + Some(fetch_rel::CountMode::CountExpr(expr)) => { + Some(consumer.consume_expression(expr, &empty_schema).await?) + } + None => None, + }; + input.limit_by_expr(offset, count)?.build() + } else { + not_impl_err!("Fetch without an input is not valid") + } +} diff --git a/datafusion/substrait/src/logical_plan/consumer/rel/filter_rel.rs b/datafusion/substrait/src/logical_plan/consumer/rel/filter_rel.rs new file mode 100644 index 000000000000..645b98278208 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/rel/filter_rel.rs @@ -0,0 +1,42 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::consumer::SubstraitConsumer; +use async_recursion::async_recursion; +use datafusion::common::not_impl_err; +use datafusion::logical_expr::{LogicalPlan, LogicalPlanBuilder}; +use substrait::proto::FilterRel; + +#[async_recursion] +pub async fn from_filter_rel( + consumer: &impl SubstraitConsumer, + filter: &FilterRel, +) -> datafusion::common::Result { + if let Some(input) = filter.input.as_ref() { + let input = LogicalPlanBuilder::from(consumer.consume_rel(input).await?); + if let Some(condition) = filter.condition.as_ref() { + let expr = consumer + .consume_expression(condition, input.schema()) + .await?; + input.filter(expr)?.build() + } else { + not_impl_err!("Filter without an condition is not valid") + } + } else { + not_impl_err!("Filter without an input is not valid") + } +} diff --git a/datafusion/substrait/src/logical_plan/consumer/rel/join_rel.rs b/datafusion/substrait/src/logical_plan/consumer/rel/join_rel.rs new file mode 100644 index 000000000000..881157dcfa66 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/rel/join_rel.rs @@ -0,0 +1,140 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::consumer::utils::requalify_sides_if_needed; +use crate::logical_plan::consumer::SubstraitConsumer; +use datafusion::common::{not_impl_err, plan_err, Column, JoinType}; +use datafusion::logical_expr::utils::split_conjunction; +use datafusion::logical_expr::{ + BinaryExpr, Expr, LogicalPlan, LogicalPlanBuilder, Operator, +}; +use substrait::proto::{join_rel, JoinRel}; + +pub async fn from_join_rel( + consumer: &impl SubstraitConsumer, + join: &JoinRel, +) -> datafusion::common::Result { + if join.post_join_filter.is_some() { + return not_impl_err!("JoinRel with post_join_filter is not yet supported"); + } + + let left: LogicalPlanBuilder = LogicalPlanBuilder::from( + consumer.consume_rel(join.left.as_ref().unwrap()).await?, + ); + let right = LogicalPlanBuilder::from( + consumer.consume_rel(join.right.as_ref().unwrap()).await?, + ); + let (left, right) = requalify_sides_if_needed(left, right)?; + + let join_type = from_substrait_jointype(join.r#type)?; + // The join condition expression needs full input schema and not the output schema from join since we lose columns from + // certain join types such as semi and anti joins + let in_join_schema = left.schema().join(right.schema())?; + + // If join expression exists, parse the `on` condition expression, build join and return + // Otherwise, build join with only the filter, without join keys + match &join.expression.as_ref() { + Some(expr) => { + let on = consumer.consume_expression(expr, &in_join_schema).await?; + // The join expression can contain both equal and non-equal ops. + // As of datafusion 31.0.0, the equal and non equal join conditions are in separate fields. + // So we extract each part as follows: + // - If an Eq or IsNotDistinctFrom op is encountered, add the left column, right column and is_null_equal_nulls to `join_ons` vector + // - Otherwise we add the expression to join_filter (use conjunction if filter already exists) + let (join_ons, nulls_equal_nulls, join_filter) = + split_eq_and_noneq_join_predicate_with_nulls_equality(&on); + let (left_cols, right_cols): (Vec<_>, Vec<_>) = + itertools::multiunzip(join_ons); + left.join_detailed( + right.build()?, + join_type, + (left_cols, right_cols), + join_filter, + nulls_equal_nulls, + )? + .build() + } + None => { + let on: Vec = vec![]; + left.join_detailed(right.build()?, join_type, (on.clone(), on), None, false)? + .build() + } + } +} + +fn split_eq_and_noneq_join_predicate_with_nulls_equality( + filter: &Expr, +) -> (Vec<(Column, Column)>, bool, Option) { + let exprs = split_conjunction(filter); + + let mut accum_join_keys: Vec<(Column, Column)> = vec![]; + let mut accum_filters: Vec = vec![]; + let mut nulls_equal_nulls = false; + + for expr in exprs { + #[allow(clippy::collapsible_match)] + match expr { + Expr::BinaryExpr(binary_expr) => match binary_expr { + x @ (BinaryExpr { + left, + op: Operator::Eq, + right, + } + | BinaryExpr { + left, + op: Operator::IsNotDistinctFrom, + right, + }) => { + nulls_equal_nulls = match x.op { + Operator::Eq => false, + Operator::IsNotDistinctFrom => true, + _ => unreachable!(), + }; + + match (left.as_ref(), right.as_ref()) { + (Expr::Column(l), Expr::Column(r)) => { + accum_join_keys.push((l.clone(), r.clone())); + } + _ => accum_filters.push(expr.clone()), + } + } + _ => accum_filters.push(expr.clone()), + }, + _ => accum_filters.push(expr.clone()), + } + } + + let join_filter = accum_filters.into_iter().reduce(Expr::and); + (accum_join_keys, nulls_equal_nulls, join_filter) +} + +fn from_substrait_jointype(join_type: i32) -> datafusion::common::Result { + if let Ok(substrait_join_type) = join_rel::JoinType::try_from(join_type) { + match substrait_join_type { + join_rel::JoinType::Inner => Ok(JoinType::Inner), + join_rel::JoinType::Left => Ok(JoinType::Left), + join_rel::JoinType::Right => Ok(JoinType::Right), + join_rel::JoinType::Outer => Ok(JoinType::Full), + join_rel::JoinType::LeftAnti => Ok(JoinType::LeftAnti), + join_rel::JoinType::LeftSemi => Ok(JoinType::LeftSemi), + join_rel::JoinType::LeftMark => Ok(JoinType::LeftMark), + _ => plan_err!("unsupported join type {substrait_join_type:?}"), + } + } else { + plan_err!("invalid join type variant {join_type:?}") + } +} diff --git a/datafusion/substrait/src/logical_plan/consumer/rel/mod.rs b/datafusion/substrait/src/logical_plan/consumer/rel/mod.rs new file mode 100644 index 000000000000..a83ddd8997b2 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/rel/mod.rs @@ -0,0 +1,173 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod aggregate_rel; +mod cross_rel; +mod exchange_rel; +mod fetch_rel; +mod filter_rel; +mod join_rel; +mod project_rel; +mod read_rel; +mod set_rel; +mod sort_rel; + +pub use aggregate_rel::*; +pub use cross_rel::*; +pub use exchange_rel::*; +pub use fetch_rel::*; +pub use filter_rel::*; +pub use join_rel::*; +pub use project_rel::*; +pub use read_rel::*; +pub use set_rel::*; +pub use sort_rel::*; + +use crate::logical_plan::consumer::utils::NameTracker; +use crate::logical_plan::consumer::SubstraitConsumer; +use async_recursion::async_recursion; +use datafusion::common::{not_impl_err, substrait_datafusion_err, substrait_err, Column}; +use datafusion::logical_expr::builder::project; +use datafusion::logical_expr::{Expr, LogicalPlan, Projection}; +use std::sync::Arc; +use substrait::proto::rel::RelType; +use substrait::proto::rel_common::{Emit, EmitKind}; +use substrait::proto::{rel_common, Rel, RelCommon}; + +/// Convert Substrait Rel to DataFusion DataFrame +#[async_recursion] +pub async fn from_substrait_rel( + consumer: &impl SubstraitConsumer, + relation: &Rel, +) -> datafusion::common::Result { + let plan: datafusion::common::Result = match &relation.rel_type { + Some(rel_type) => match rel_type { + RelType::Read(rel) => consumer.consume_read(rel).await, + RelType::Filter(rel) => consumer.consume_filter(rel).await, + RelType::Fetch(rel) => consumer.consume_fetch(rel).await, + RelType::Aggregate(rel) => consumer.consume_aggregate(rel).await, + RelType::Sort(rel) => consumer.consume_sort(rel).await, + RelType::Join(rel) => consumer.consume_join(rel).await, + RelType::Project(rel) => consumer.consume_project(rel).await, + RelType::Set(rel) => consumer.consume_set(rel).await, + RelType::ExtensionSingle(rel) => consumer.consume_extension_single(rel).await, + RelType::ExtensionMulti(rel) => consumer.consume_extension_multi(rel).await, + RelType::ExtensionLeaf(rel) => consumer.consume_extension_leaf(rel).await, + RelType::Cross(rel) => consumer.consume_cross(rel).await, + RelType::Window(rel) => { + consumer.consume_consistent_partition_window(rel).await + } + RelType::Exchange(rel) => consumer.consume_exchange(rel).await, + rt => not_impl_err!("{rt:?} rel not supported yet"), + }, + None => return substrait_err!("rel must set rel_type"), + }; + apply_emit_kind(retrieve_rel_common(relation), plan?) +} + +fn apply_emit_kind( + rel_common: Option<&RelCommon>, + plan: LogicalPlan, +) -> datafusion::common::Result { + match retrieve_emit_kind(rel_common) { + EmitKind::Direct(_) => Ok(plan), + EmitKind::Emit(Emit { output_mapping }) => { + // It is valid to reference the same field multiple times in the Emit + // In this case, we need to provide unique names to avoid collisions + let mut name_tracker = NameTracker::new(); + match plan { + // To avoid adding a projection on top of a projection, we apply special case + // handling to flatten Substrait Emits. This is only applicable if none of the + // expressions in the projection are volatile. This is to avoid issues like + // converting a single call of the random() function into multiple calls due to + // duplicate fields in the output_mapping. + LogicalPlan::Projection(proj) if !contains_volatile_expr(&proj) => { + let mut exprs: Vec = vec![]; + for field in output_mapping { + let expr = proj.expr + .get(field as usize) + .ok_or_else(|| substrait_datafusion_err!( + "Emit output field {} cannot be resolved in input schema {}", + field, proj.input.schema() + ))?; + exprs.push(name_tracker.get_uniquely_named_expr(expr.clone())?); + } + + let input = Arc::unwrap_or_clone(proj.input); + project(input, exprs) + } + // Otherwise we just handle the output_mapping as a projection + _ => { + let input_schema = plan.schema(); + + let mut exprs: Vec = vec![]; + for index in output_mapping.into_iter() { + let column = Expr::Column(Column::from( + input_schema.qualified_field(index as usize), + )); + let expr = name_tracker.get_uniquely_named_expr(column)?; + exprs.push(expr); + } + + project(plan, exprs) + } + } + } + } +} + +fn retrieve_rel_common(rel: &Rel) -> Option<&RelCommon> { + match rel.rel_type.as_ref() { + None => None, + Some(rt) => match rt { + RelType::Read(r) => r.common.as_ref(), + RelType::Filter(f) => f.common.as_ref(), + RelType::Fetch(f) => f.common.as_ref(), + RelType::Aggregate(a) => a.common.as_ref(), + RelType::Sort(s) => s.common.as_ref(), + RelType::Join(j) => j.common.as_ref(), + RelType::Project(p) => p.common.as_ref(), + RelType::Set(s) => s.common.as_ref(), + RelType::ExtensionSingle(e) => e.common.as_ref(), + RelType::ExtensionMulti(e) => e.common.as_ref(), + RelType::ExtensionLeaf(e) => e.common.as_ref(), + RelType::Cross(c) => c.common.as_ref(), + RelType::Reference(_) => None, + RelType::Write(w) => w.common.as_ref(), + RelType::Ddl(d) => d.common.as_ref(), + RelType::HashJoin(j) => j.common.as_ref(), + RelType::MergeJoin(j) => j.common.as_ref(), + RelType::NestedLoopJoin(j) => j.common.as_ref(), + RelType::Window(w) => w.common.as_ref(), + RelType::Exchange(e) => e.common.as_ref(), + RelType::Expand(e) => e.common.as_ref(), + RelType::Update(_) => None, + }, + } +} + +fn retrieve_emit_kind(rel_common: Option<&RelCommon>) -> EmitKind { + // the default EmitKind is Direct if it is not set explicitly + let default = EmitKind::Direct(rel_common::Direct {}); + rel_common + .and_then(|rc| rc.emit_kind.as_ref()) + .map_or(default, |ek| ek.clone()) +} + +fn contains_volatile_expr(proj: &Projection) -> bool { + proj.expr.iter().any(|e| e.is_volatile()) +} diff --git a/datafusion/substrait/src/logical_plan/consumer/rel/project_rel.rs b/datafusion/substrait/src/logical_plan/consumer/rel/project_rel.rs new file mode 100644 index 000000000000..8ece6392974e --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/rel/project_rel.rs @@ -0,0 +1,84 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::consumer::utils::NameTracker; +use crate::logical_plan::consumer::SubstraitConsumer; +use async_recursion::async_recursion; +use datafusion::common::{not_impl_err, Column}; +use datafusion::logical_expr::builder::project; +use datafusion::logical_expr::{Expr, LogicalPlan, LogicalPlanBuilder}; +use std::collections::HashSet; +use std::sync::Arc; +use substrait::proto::ProjectRel; + +#[async_recursion] +pub async fn from_project_rel( + consumer: &impl SubstraitConsumer, + p: &ProjectRel, +) -> datafusion::common::Result { + if let Some(input) = p.input.as_ref() { + let input = consumer.consume_rel(input).await?; + let original_schema = Arc::clone(input.schema()); + + // Ensure that all expressions have a unique display name, so that + // validate_unique_names does not fail when constructing the project. + let mut name_tracker = NameTracker::new(); + + // By default, a Substrait Project emits all inputs fields followed by all expressions. + // We build the explicit expressions first, and then the input expressions to avoid + // adding aliases to the explicit expressions (as part of ensuring unique names). + // + // This is helpful for plan visualization and tests, because when DataFusion produces + // Substrait Projects it adds an output mapping that excludes all input columns + // leaving only explicit expressions. + + let mut explicit_exprs: Vec = vec![]; + // For WindowFunctions, we need to wrap them in a Window relation. If there are duplicates, + // we can do the window'ing only once, then the project will duplicate the result. + // Order here doesn't matter since LPB::window_plan sorts the expressions. + let mut window_exprs: HashSet = HashSet::new(); + for expr in &p.expressions { + let e = consumer + .consume_expression(expr, input.clone().schema()) + .await?; + // if the expression is WindowFunction, wrap in a Window relation + if let Expr::WindowFunction(_) = &e { + // Adding the same expression here and in the project below + // works because the project's builder uses columnize_expr(..) + // to transform it into a column reference + window_exprs.insert(e.clone()); + } + explicit_exprs.push(name_tracker.get_uniquely_named_expr(e)?); + } + + let input = if !window_exprs.is_empty() { + LogicalPlanBuilder::window_plan(input, window_exprs)? + } else { + input + }; + + let mut final_exprs: Vec = vec![]; + for index in 0..original_schema.fields().len() { + let e = Expr::Column(Column::from(original_schema.qualified_field(index))); + final_exprs.push(name_tracker.get_uniquely_named_expr(e)?); + } + final_exprs.append(&mut explicit_exprs); + project(input, final_exprs) + } else { + not_impl_err!("Projection without an input is not supported") + } +} diff --git a/datafusion/substrait/src/logical_plan/consumer/rel/read_rel.rs b/datafusion/substrait/src/logical_plan/consumer/rel/read_rel.rs new file mode 100644 index 000000000000..f1cbd16d2d8f --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/rel/read_rel.rs @@ -0,0 +1,280 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::consumer::from_substrait_literal; +use crate::logical_plan::consumer::from_substrait_named_struct; +use crate::logical_plan::consumer::utils::ensure_schema_compatibility; +use crate::logical_plan::consumer::SubstraitConsumer; +use datafusion::common::{ + not_impl_err, plan_err, substrait_datafusion_err, substrait_err, DFSchema, + DFSchemaRef, TableReference, +}; +use datafusion::datasource::provider_as_source; +use datafusion::logical_expr::utils::split_conjunction_owned; +use datafusion::logical_expr::{ + EmptyRelation, Expr, LogicalPlan, LogicalPlanBuilder, Values, +}; +use std::sync::Arc; +use substrait::proto::expression::MaskExpression; +use substrait::proto::read_rel::local_files::file_or_files::PathType::UriFile; +use substrait::proto::read_rel::ReadType; +use substrait::proto::{Expression, ReadRel}; +use url::Url; + +#[allow(deprecated)] +pub async fn from_read_rel( + consumer: &impl SubstraitConsumer, + read: &ReadRel, +) -> datafusion::common::Result { + async fn read_with_schema( + consumer: &impl SubstraitConsumer, + table_ref: TableReference, + schema: DFSchema, + projection: &Option, + filter: &Option>, + ) -> datafusion::common::Result { + let schema = schema.replace_qualifier(table_ref.clone()); + + let filters = if let Some(f) = filter { + let filter_expr = consumer.consume_expression(f, &schema).await?; + split_conjunction_owned(filter_expr) + } else { + vec![] + }; + + let plan = { + let provider = match consumer.resolve_table_ref(&table_ref).await? { + Some(ref provider) => Arc::clone(provider), + _ => return plan_err!("No table named '{table_ref}'"), + }; + + LogicalPlanBuilder::scan_with_filters( + table_ref, + provider_as_source(Arc::clone(&provider)), + None, + filters, + )? + .build()? + }; + + ensure_schema_compatibility(plan.schema(), schema.clone())?; + + let schema = apply_masking(schema, projection)?; + + apply_projection(plan, schema) + } + + let named_struct = read.base_schema.as_ref().ok_or_else(|| { + substrait_datafusion_err!("No base schema provided for Read Relation") + })?; + + let substrait_schema = from_substrait_named_struct(consumer, named_struct)?; + + match &read.read_type { + Some(ReadType::NamedTable(nt)) => { + let table_reference = match nt.names.len() { + 0 => { + return plan_err!("No table name found in NamedTable"); + } + 1 => TableReference::Bare { + table: nt.names[0].clone().into(), + }, + 2 => TableReference::Partial { + schema: nt.names[0].clone().into(), + table: nt.names[1].clone().into(), + }, + _ => TableReference::Full { + catalog: nt.names[0].clone().into(), + schema: nt.names[1].clone().into(), + table: nt.names[2].clone().into(), + }, + }; + + read_with_schema( + consumer, + table_reference, + substrait_schema, + &read.projection, + &read.filter, + ) + .await + } + Some(ReadType::VirtualTable(vt)) => { + if vt.values.is_empty() { + return Ok(LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: DFSchemaRef::new(substrait_schema), + })); + } + + let values = vt + .values + .iter() + .map(|row| { + let mut name_idx = 0; + let lits = row + .fields + .iter() + .map(|lit| { + name_idx += 1; // top-level names are provided through schema + Ok(Expr::Literal(from_substrait_literal( + consumer, + lit, + &named_struct.names, + &mut name_idx, + )?, None)) + }) + .collect::>()?; + if name_idx != named_struct.names.len() { + return substrait_err!( + "Names list must match exactly to nested schema, but found {} uses for {} names", + name_idx, + named_struct.names.len() + ); + } + Ok(lits) + }) + .collect::>()?; + + Ok(LogicalPlan::Values(Values { + schema: DFSchemaRef::new(substrait_schema), + values, + })) + } + Some(ReadType::LocalFiles(lf)) => { + fn extract_filename(name: &str) -> Option { + let corrected_url = + if name.starts_with("file://") && !name.starts_with("file:///") { + name.replacen("file://", "file:///", 1) + } else { + name.to_string() + }; + + Url::parse(&corrected_url).ok().and_then(|url| { + let path = url.path(); + std::path::Path::new(path) + .file_name() + .map(|filename| filename.to_string_lossy().to_string()) + }) + } + + // we could use the file name to check the original table provider + // TODO: currently does not support multiple local files + let filename: Option = + lf.items.first().and_then(|x| match x.path_type.as_ref() { + Some(UriFile(name)) => extract_filename(name), + _ => None, + }); + + if lf.items.len() > 1 || filename.is_none() { + return not_impl_err!("Only single file reads are supported"); + } + let name = filename.unwrap(); + // directly use unwrap here since we could determine it is a valid one + let table_reference = TableReference::Bare { table: name.into() }; + + read_with_schema( + consumer, + table_reference, + substrait_schema, + &read.projection, + &read.filter, + ) + .await + } + _ => { + not_impl_err!("Unsupported ReadType: {:?}", read.read_type) + } + } +} + +pub fn apply_masking( + schema: DFSchema, + mask_expression: &::core::option::Option, +) -> datafusion::common::Result { + match mask_expression { + Some(MaskExpression { select, .. }) => match &select.as_ref() { + Some(projection) => { + let column_indices: Vec = projection + .struct_items + .iter() + .map(|item| item.field as usize) + .collect(); + + let fields = column_indices + .iter() + .map(|i| schema.qualified_field(*i)) + .map(|(qualifier, field)| { + (qualifier.cloned(), Arc::new(field.clone())) + }) + .collect(); + + Ok(DFSchema::new_with_metadata( + fields, + schema.metadata().clone(), + )?) + } + None => Ok(schema), + }, + None => Ok(schema), + } +} + +/// This function returns a DataFrame with fields adjusted if necessary in the event that the +/// Substrait schema is a subset of the DataFusion schema. +fn apply_projection( + plan: LogicalPlan, + substrait_schema: DFSchema, +) -> datafusion::common::Result { + let df_schema = plan.schema(); + + if df_schema.logically_equivalent_names_and_types(&substrait_schema) { + return Ok(plan); + } + + let df_schema = df_schema.to_owned(); + + match plan { + LogicalPlan::TableScan(mut scan) => { + let column_indices: Vec = substrait_schema + .strip_qualifiers() + .fields() + .iter() + .map(|substrait_field| { + Ok(df_schema + .index_of_column_by_name(None, substrait_field.name().as_str()) + .unwrap()) + }) + .collect::>()?; + + let fields = column_indices + .iter() + .map(|i| df_schema.qualified_field(*i)) + .map(|(qualifier, field)| (qualifier.cloned(), Arc::new(field.clone()))) + .collect(); + + scan.projected_schema = DFSchemaRef::new(DFSchema::new_with_metadata( + fields, + df_schema.metadata().clone(), + )?); + scan.projection = Some(column_indices); + + Ok(LogicalPlan::TableScan(scan)) + } + _ => plan_err!("DataFrame passed to apply_projection must be a TableScan"), + } +} diff --git a/datafusion/substrait/src/logical_plan/consumer/rel/set_rel.rs b/datafusion/substrait/src/logical_plan/consumer/rel/set_rel.rs new file mode 100644 index 000000000000..6688a80f5274 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/rel/set_rel.rs @@ -0,0 +1,102 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::consumer::SubstraitConsumer; +use datafusion::common::{not_impl_err, substrait_err}; +use datafusion::logical_expr::{LogicalPlan, LogicalPlanBuilder}; +use substrait::proto::set_rel::SetOp; +use substrait::proto::{Rel, SetRel}; + +pub async fn from_set_rel( + consumer: &impl SubstraitConsumer, + set: &SetRel, +) -> datafusion::common::Result { + if set.inputs.len() < 2 { + substrait_err!("Set operation requires at least two inputs") + } else { + match set.op() { + SetOp::UnionAll => union_rels(consumer, &set.inputs, true).await, + SetOp::UnionDistinct => union_rels(consumer, &set.inputs, false).await, + SetOp::IntersectionPrimary => LogicalPlanBuilder::intersect( + consumer.consume_rel(&set.inputs[0]).await?, + union_rels(consumer, &set.inputs[1..], true).await?, + false, + ), + SetOp::IntersectionMultiset => { + intersect_rels(consumer, &set.inputs, false).await + } + SetOp::IntersectionMultisetAll => { + intersect_rels(consumer, &set.inputs, true).await + } + SetOp::MinusPrimary => except_rels(consumer, &set.inputs, false).await, + SetOp::MinusPrimaryAll => except_rels(consumer, &set.inputs, true).await, + set_op => not_impl_err!("Unsupported set operator: {set_op:?}"), + } + } +} + +async fn union_rels( + consumer: &impl SubstraitConsumer, + rels: &[Rel], + is_all: bool, +) -> datafusion::common::Result { + let mut union_builder = Ok(LogicalPlanBuilder::from( + consumer.consume_rel(&rels[0]).await?, + )); + for input in &rels[1..] { + let rel_plan = consumer.consume_rel(input).await?; + + union_builder = if is_all { + union_builder?.union(rel_plan) + } else { + union_builder?.union_distinct(rel_plan) + }; + } + union_builder?.build() +} + +async fn intersect_rels( + consumer: &impl SubstraitConsumer, + rels: &[Rel], + is_all: bool, +) -> datafusion::common::Result { + let mut rel = consumer.consume_rel(&rels[0]).await?; + + for input in &rels[1..] { + rel = LogicalPlanBuilder::intersect( + rel, + consumer.consume_rel(input).await?, + is_all, + )? + } + + Ok(rel) +} + +async fn except_rels( + consumer: &impl SubstraitConsumer, + rels: &[Rel], + is_all: bool, +) -> datafusion::common::Result { + let mut rel = consumer.consume_rel(&rels[0]).await?; + + for input in &rels[1..] { + rel = LogicalPlanBuilder::except(rel, consumer.consume_rel(input).await?, is_all)? + } + + Ok(rel) +} diff --git a/datafusion/substrait/src/logical_plan/consumer/rel/sort_rel.rs b/datafusion/substrait/src/logical_plan/consumer/rel/sort_rel.rs new file mode 100644 index 000000000000..56ca0ba03857 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/rel/sort_rel.rs @@ -0,0 +1,34 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::consumer::{from_substrait_sorts, SubstraitConsumer}; +use datafusion::common::not_impl_err; +use datafusion::logical_expr::{LogicalPlan, LogicalPlanBuilder}; +use substrait::proto::SortRel; + +pub async fn from_sort_rel( + consumer: &impl SubstraitConsumer, + sort: &SortRel, +) -> datafusion::common::Result { + if let Some(input) = sort.input.as_ref() { + let input = LogicalPlanBuilder::from(consumer.consume_rel(input).await?); + let sorts = from_substrait_sorts(consumer, &sort.sorts, input.schema()).await?; + input.sort(sorts)?.build() + } else { + not_impl_err!("Sort without an input is not valid") + } +} diff --git a/datafusion/substrait/src/logical_plan/consumer/substrait_consumer.rs b/datafusion/substrait/src/logical_plan/consumer/substrait_consumer.rs new file mode 100644 index 000000000000..5392dd77b576 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/substrait_consumer.rs @@ -0,0 +1,523 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::{ + from_aggregate_rel, from_cast, from_cross_rel, from_exchange_rel, from_fetch_rel, + from_field_reference, from_filter_rel, from_if_then, from_join_rel, from_literal, + from_project_rel, from_read_rel, from_scalar_function, from_set_rel, + from_singular_or_list, from_sort_rel, from_subquery, from_substrait_rel, + from_substrait_rex, from_window_function, +}; +use crate::extensions::Extensions; +use async_trait::async_trait; +use datafusion::arrow::datatypes::DataType; +use datafusion::catalog::TableProvider; +use datafusion::common::{ + not_impl_err, substrait_err, DFSchema, ScalarValue, TableReference, +}; +use datafusion::execution::{FunctionRegistry, SessionState}; +use datafusion::logical_expr::{Expr, Extension, LogicalPlan}; +use std::sync::Arc; +use substrait::proto; +use substrait::proto::expression as substrait_expression; +use substrait::proto::expression::{ + Enum, FieldReference, IfThen, Literal, MultiOrList, Nested, ScalarFunction, + SingularOrList, SwitchExpression, WindowFunction, +}; +use substrait::proto::{ + r#type, AggregateRel, ConsistentPartitionWindowRel, CrossRel, DynamicParameter, + ExchangeRel, Expression, ExtensionLeafRel, ExtensionMultiRel, ExtensionSingleRel, + FetchRel, FilterRel, JoinRel, ProjectRel, ReadRel, Rel, SetRel, SortRel, +}; + +#[async_trait] +/// This trait is used to consume Substrait plans, converting them into DataFusion Logical Plans. +/// It can be implemented by users to allow for custom handling of relations, expressions, etc. +/// +/// Combined with the [crate::logical_plan::producer::SubstraitProducer] this allows for fully +/// customizable Substrait serde. +/// +/// # Example Usage +/// +/// ``` +/// # use async_trait::async_trait; +/// # use datafusion::catalog::TableProvider; +/// # use datafusion::common::{not_impl_err, substrait_err, DFSchema, ScalarValue, TableReference}; +/// # use datafusion::error::Result; +/// # use datafusion::execution::{FunctionRegistry, SessionState}; +/// # use datafusion::logical_expr::{Expr, LogicalPlan, LogicalPlanBuilder}; +/// # use std::sync::Arc; +/// # use substrait::proto; +/// # use substrait::proto::{ExtensionLeafRel, FilterRel, ProjectRel}; +/// # use datafusion::arrow::datatypes::DataType; +/// # use datafusion::logical_expr::expr::ScalarFunction; +/// # use datafusion_substrait::extensions::Extensions; +/// # use datafusion_substrait::logical_plan::consumer::{ +/// # from_project_rel, from_substrait_rel, from_substrait_rex, SubstraitConsumer +/// # }; +/// +/// struct CustomSubstraitConsumer { +/// extensions: Arc, +/// state: Arc, +/// } +/// +/// #[async_trait] +/// impl SubstraitConsumer for CustomSubstraitConsumer { +/// async fn resolve_table_ref( +/// &self, +/// table_ref: &TableReference, +/// ) -> Result>> { +/// let table = table_ref.table().to_string(); +/// let schema = self.state.schema_for_ref(table_ref.clone())?; +/// let table_provider = schema.table(&table).await?; +/// Ok(table_provider) +/// } +/// +/// fn get_extensions(&self) -> &Extensions { +/// self.extensions.as_ref() +/// } +/// +/// fn get_function_registry(&self) -> &impl FunctionRegistry { +/// self.state.as_ref() +/// } +/// +/// // You can reuse existing consumer code to assist in handling advanced extensions +/// async fn consume_project(&self, rel: &ProjectRel) -> Result { +/// let df_plan = from_project_rel(self, rel).await?; +/// if let Some(advanced_extension) = rel.advanced_extension.as_ref() { +/// not_impl_err!( +/// "decode and handle an advanced extension: {:?}", +/// advanced_extension +/// ) +/// } else { +/// Ok(df_plan) +/// } +/// } +/// +/// // You can implement a fully custom consumer method if you need special handling +/// async fn consume_filter(&self, rel: &FilterRel) -> Result { +/// let input = self.consume_rel(rel.input.as_ref().unwrap()).await?; +/// let expression = +/// self.consume_expression(rel.condition.as_ref().unwrap(), input.schema()) +/// .await?; +/// // though this one is quite boring +/// LogicalPlanBuilder::from(input).filter(expression)?.build() +/// } +/// +/// // You can add handlers for extension relations +/// async fn consume_extension_leaf( +/// &self, +/// rel: &ExtensionLeafRel, +/// ) -> Result { +/// not_impl_err!( +/// "handle protobuf Any {} as you need", +/// rel.detail.as_ref().unwrap().type_url +/// ) +/// } +/// +/// // and handlers for user-define types +/// fn consume_user_defined_type(&self, typ: &proto::r#type::UserDefined) -> Result { +/// let type_string = self.extensions.types.get(&typ.type_reference).unwrap(); +/// match type_string.as_str() { +/// "u!foo" => not_impl_err!("handle foo conversion"), +/// "u!bar" => not_impl_err!("handle bar conversion"), +/// _ => substrait_err!("unexpected type") +/// } +/// } +/// +/// // and user-defined literals +/// fn consume_user_defined_literal(&self, literal: &proto::expression::literal::UserDefined) -> Result { +/// let type_string = self.extensions.types.get(&literal.type_reference).unwrap(); +/// match type_string.as_str() { +/// "u!foo" => not_impl_err!("handle foo conversion"), +/// "u!bar" => not_impl_err!("handle bar conversion"), +/// _ => substrait_err!("unexpected type") +/// } +/// } +/// } +/// ``` +/// +pub trait SubstraitConsumer: Send + Sync + Sized { + async fn resolve_table_ref( + &self, + table_ref: &TableReference, + ) -> datafusion::common::Result>>; + + // TODO: Remove these two methods + // Ideally, the abstract consumer should not place any constraints on implementations. + // The functionality for which the Extensions and FunctionRegistry is needed should be abstracted + // out into methods on the trait. As an example, resolve_table_reference is such a method. + // See: https://github.com/apache/datafusion/issues/13863 + fn get_extensions(&self) -> &Extensions; + fn get_function_registry(&self) -> &impl FunctionRegistry; + + // Relation Methods + // There is one method per Substrait relation to allow for easy overriding of consumer behaviour. + // These methods have default implementations calling the common handler code, to allow for users + // to re-use common handling logic. + + /// All [Rel]s to be converted pass through this method. + /// You can provide your own implementation if you wish to customize the conversion behaviour. + async fn consume_rel(&self, rel: &Rel) -> datafusion::common::Result { + from_substrait_rel(self, rel).await + } + + async fn consume_read( + &self, + rel: &ReadRel, + ) -> datafusion::common::Result { + from_read_rel(self, rel).await + } + + async fn consume_filter( + &self, + rel: &FilterRel, + ) -> datafusion::common::Result { + from_filter_rel(self, rel).await + } + + async fn consume_fetch( + &self, + rel: &FetchRel, + ) -> datafusion::common::Result { + from_fetch_rel(self, rel).await + } + + async fn consume_aggregate( + &self, + rel: &AggregateRel, + ) -> datafusion::common::Result { + from_aggregate_rel(self, rel).await + } + + async fn consume_sort( + &self, + rel: &SortRel, + ) -> datafusion::common::Result { + from_sort_rel(self, rel).await + } + + async fn consume_join( + &self, + rel: &JoinRel, + ) -> datafusion::common::Result { + from_join_rel(self, rel).await + } + + async fn consume_project( + &self, + rel: &ProjectRel, + ) -> datafusion::common::Result { + from_project_rel(self, rel).await + } + + async fn consume_set(&self, rel: &SetRel) -> datafusion::common::Result { + from_set_rel(self, rel).await + } + + async fn consume_cross( + &self, + rel: &CrossRel, + ) -> datafusion::common::Result { + from_cross_rel(self, rel).await + } + + async fn consume_consistent_partition_window( + &self, + _rel: &ConsistentPartitionWindowRel, + ) -> datafusion::common::Result { + not_impl_err!("Consistent Partition Window Rel not supported") + } + + async fn consume_exchange( + &self, + rel: &ExchangeRel, + ) -> datafusion::common::Result { + from_exchange_rel(self, rel).await + } + + // Expression Methods + // There is one method per Substrait expression to allow for easy overriding of consumer behaviour + // These methods have default implementations calling the common handler code, to allow for users + // to re-use common handling logic. + + /// All [Expression]s to be converted pass through this method. + /// You can provide your own implementation if you wish to customize the conversion behaviour. + async fn consume_expression( + &self, + expr: &Expression, + input_schema: &DFSchema, + ) -> datafusion::common::Result { + from_substrait_rex(self, expr, input_schema).await + } + + async fn consume_literal(&self, expr: &Literal) -> datafusion::common::Result { + from_literal(self, expr).await + } + + async fn consume_field_reference( + &self, + expr: &FieldReference, + input_schema: &DFSchema, + ) -> datafusion::common::Result { + from_field_reference(self, expr, input_schema).await + } + + async fn consume_scalar_function( + &self, + expr: &ScalarFunction, + input_schema: &DFSchema, + ) -> datafusion::common::Result { + from_scalar_function(self, expr, input_schema).await + } + + async fn consume_window_function( + &self, + expr: &WindowFunction, + input_schema: &DFSchema, + ) -> datafusion::common::Result { + from_window_function(self, expr, input_schema).await + } + + async fn consume_if_then( + &self, + expr: &IfThen, + input_schema: &DFSchema, + ) -> datafusion::common::Result { + from_if_then(self, expr, input_schema).await + } + + async fn consume_switch( + &self, + _expr: &SwitchExpression, + _input_schema: &DFSchema, + ) -> datafusion::common::Result { + not_impl_err!("Switch expression not supported") + } + + async fn consume_singular_or_list( + &self, + expr: &SingularOrList, + input_schema: &DFSchema, + ) -> datafusion::common::Result { + from_singular_or_list(self, expr, input_schema).await + } + + async fn consume_multi_or_list( + &self, + _expr: &MultiOrList, + _input_schema: &DFSchema, + ) -> datafusion::common::Result { + not_impl_err!("Multi Or List expression not supported") + } + + async fn consume_cast( + &self, + expr: &substrait_expression::Cast, + input_schema: &DFSchema, + ) -> datafusion::common::Result { + from_cast(self, expr, input_schema).await + } + + async fn consume_subquery( + &self, + expr: &substrait_expression::Subquery, + input_schema: &DFSchema, + ) -> datafusion::common::Result { + from_subquery(self, expr, input_schema).await + } + + async fn consume_nested( + &self, + _expr: &Nested, + _input_schema: &DFSchema, + ) -> datafusion::common::Result { + not_impl_err!("Nested expression not supported") + } + + async fn consume_enum( + &self, + _expr: &Enum, + _input_schema: &DFSchema, + ) -> datafusion::common::Result { + not_impl_err!("Enum expression not supported") + } + + async fn consume_dynamic_parameter( + &self, + _expr: &DynamicParameter, + _input_schema: &DFSchema, + ) -> datafusion::common::Result { + not_impl_err!("Dynamic Parameter expression not supported") + } + + // User-Defined Functionality + + // The details of extension relations, and how to handle them, are fully up to users to specify. + // The following methods allow users to customize the consumer behaviour + + async fn consume_extension_leaf( + &self, + rel: &ExtensionLeafRel, + ) -> datafusion::common::Result { + if let Some(detail) = rel.detail.as_ref() { + return substrait_err!( + "Missing handler for ExtensionLeafRel: {}", + detail.type_url + ); + } + substrait_err!("Missing handler for ExtensionLeafRel") + } + + async fn consume_extension_single( + &self, + rel: &ExtensionSingleRel, + ) -> datafusion::common::Result { + if let Some(detail) = rel.detail.as_ref() { + return substrait_err!( + "Missing handler for ExtensionSingleRel: {}", + detail.type_url + ); + } + substrait_err!("Missing handler for ExtensionSingleRel") + } + + async fn consume_extension_multi( + &self, + rel: &ExtensionMultiRel, + ) -> datafusion::common::Result { + if let Some(detail) = rel.detail.as_ref() { + return substrait_err!( + "Missing handler for ExtensionMultiRel: {}", + detail.type_url + ); + } + substrait_err!("Missing handler for ExtensionMultiRel") + } + + // Users can bring their own types to Substrait which require custom handling + + fn consume_user_defined_type( + &self, + user_defined_type: &r#type::UserDefined, + ) -> datafusion::common::Result { + substrait_err!( + "Missing handler for user-defined type: {}", + user_defined_type.type_reference + ) + } + + fn consume_user_defined_literal( + &self, + user_defined_literal: &proto::expression::literal::UserDefined, + ) -> datafusion::common::Result { + substrait_err!( + "Missing handler for user-defined literals {}", + user_defined_literal.type_reference + ) + } +} + +/// Default SubstraitConsumer for converting standard Substrait without user-defined extensions. +/// +/// Used as the consumer in [crate::logical_plan::consumer::from_substrait_plan] +pub struct DefaultSubstraitConsumer<'a> { + pub(super) extensions: &'a Extensions, + pub(super) state: &'a SessionState, +} + +impl<'a> DefaultSubstraitConsumer<'a> { + pub fn new(extensions: &'a Extensions, state: &'a SessionState) -> Self { + DefaultSubstraitConsumer { extensions, state } + } +} + +#[async_trait] +impl SubstraitConsumer for DefaultSubstraitConsumer<'_> { + async fn resolve_table_ref( + &self, + table_ref: &TableReference, + ) -> datafusion::common::Result>> { + let table = table_ref.table().to_string(); + let schema = self.state.schema_for_ref(table_ref.clone())?; + let table_provider = schema.table(&table).await?; + Ok(table_provider) + } + + fn get_extensions(&self) -> &Extensions { + self.extensions + } + + fn get_function_registry(&self) -> &impl FunctionRegistry { + self.state + } + + async fn consume_extension_leaf( + &self, + rel: &ExtensionLeafRel, + ) -> datafusion::common::Result { + let Some(ext_detail) = &rel.detail else { + return substrait_err!("Unexpected empty detail in ExtensionLeafRel"); + }; + let plan = self + .state + .serializer_registry() + .deserialize_logical_plan(&ext_detail.type_url, &ext_detail.value)?; + Ok(LogicalPlan::Extension(Extension { node: plan })) + } + + async fn consume_extension_single( + &self, + rel: &ExtensionSingleRel, + ) -> datafusion::common::Result { + let Some(ext_detail) = &rel.detail else { + return substrait_err!("Unexpected empty detail in ExtensionSingleRel"); + }; + let plan = self + .state + .serializer_registry() + .deserialize_logical_plan(&ext_detail.type_url, &ext_detail.value)?; + let Some(input_rel) = &rel.input else { + return substrait_err!( + "ExtensionSingleRel missing input rel, try using ExtensionLeafRel instead" + ); + }; + let input_plan = self.consume_rel(input_rel).await?; + let plan = plan.with_exprs_and_inputs(plan.expressions(), vec![input_plan])?; + Ok(LogicalPlan::Extension(Extension { node: plan })) + } + + async fn consume_extension_multi( + &self, + rel: &ExtensionMultiRel, + ) -> datafusion::common::Result { + let Some(ext_detail) = &rel.detail else { + return substrait_err!("Unexpected empty detail in ExtensionMultiRel"); + }; + let plan = self + .state + .serializer_registry() + .deserialize_logical_plan(&ext_detail.type_url, &ext_detail.value)?; + let mut inputs = Vec::with_capacity(rel.inputs.len()); + for input in &rel.inputs { + let input_plan = self.consume_rel(input).await?; + inputs.push(input_plan); + } + let plan = plan.with_exprs_and_inputs(plan.expressions(), inputs)?; + Ok(LogicalPlan::Extension(Extension { node: plan })) + } +} diff --git a/datafusion/substrait/src/logical_plan/consumer/types.rs b/datafusion/substrait/src/logical_plan/consumer/types.rs new file mode 100644 index 000000000000..7bc30e433d86 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/types.rs @@ -0,0 +1,312 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::utils::{next_struct_field_name, DEFAULT_TIMEZONE}; +use super::SubstraitConsumer; +#[allow(deprecated)] +use crate::variation_const::{ + DATE_32_TYPE_VARIATION_REF, DATE_64_TYPE_VARIATION_REF, + DECIMAL_128_TYPE_VARIATION_REF, DECIMAL_256_TYPE_VARIATION_REF, + DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_TYPE_VARIATION_REF, + INTERVAL_DAY_TIME_TYPE_REF, INTERVAL_MONTH_DAY_NANO_TYPE_NAME, + INTERVAL_MONTH_DAY_NANO_TYPE_REF, INTERVAL_YEAR_MONTH_TYPE_REF, + LARGE_CONTAINER_TYPE_VARIATION_REF, TIMESTAMP_MICRO_TYPE_VARIATION_REF, + TIMESTAMP_MILLI_TYPE_VARIATION_REF, TIMESTAMP_NANO_TYPE_VARIATION_REF, + TIMESTAMP_SECOND_TYPE_VARIATION_REF, UNSIGNED_INTEGER_TYPE_VARIATION_REF, + VIEW_CONTAINER_TYPE_VARIATION_REF, +}; +use datafusion::arrow::datatypes::{ + DataType, Field, Fields, IntervalUnit, Schema, TimeUnit, +}; +use datafusion::common::{ + not_impl_err, substrait_datafusion_err, substrait_err, DFSchema, +}; +use std::sync::Arc; +use substrait::proto::{r#type, NamedStruct, Type}; + +pub(crate) fn from_substrait_type_without_names( + consumer: &impl SubstraitConsumer, + dt: &Type, +) -> datafusion::common::Result { + from_substrait_type(consumer, dt, &[], &mut 0) +} + +pub fn from_substrait_type( + consumer: &impl SubstraitConsumer, + dt: &Type, + dfs_names: &[String], + name_idx: &mut usize, +) -> datafusion::common::Result { + match &dt.kind { + Some(s_kind) => match s_kind { + r#type::Kind::Bool(_) => Ok(DataType::Boolean), + r#type::Kind::I8(integer) => match integer.type_variation_reference { + DEFAULT_TYPE_VARIATION_REF => Ok(DataType::Int8), + UNSIGNED_INTEGER_TYPE_VARIATION_REF => Ok(DataType::UInt8), + v => not_impl_err!( + "Unsupported Substrait type variation {v} of type {s_kind:?}" + ), + }, + r#type::Kind::I16(integer) => match integer.type_variation_reference { + DEFAULT_TYPE_VARIATION_REF => Ok(DataType::Int16), + UNSIGNED_INTEGER_TYPE_VARIATION_REF => Ok(DataType::UInt16), + v => not_impl_err!( + "Unsupported Substrait type variation {v} of type {s_kind:?}" + ), + }, + r#type::Kind::I32(integer) => match integer.type_variation_reference { + DEFAULT_TYPE_VARIATION_REF => Ok(DataType::Int32), + UNSIGNED_INTEGER_TYPE_VARIATION_REF => Ok(DataType::UInt32), + v => not_impl_err!( + "Unsupported Substrait type variation {v} of type {s_kind:?}" + ), + }, + r#type::Kind::I64(integer) => match integer.type_variation_reference { + DEFAULT_TYPE_VARIATION_REF => Ok(DataType::Int64), + UNSIGNED_INTEGER_TYPE_VARIATION_REF => Ok(DataType::UInt64), + v => not_impl_err!( + "Unsupported Substrait type variation {v} of type {s_kind:?}" + ), + }, + r#type::Kind::Fp32(_) => Ok(DataType::Float32), + r#type::Kind::Fp64(_) => Ok(DataType::Float64), + r#type::Kind::Timestamp(ts) => { + // Kept for backwards compatibility, new plans should use PrecisionTimestamp(Tz) instead + #[allow(deprecated)] + match ts.type_variation_reference { + TIMESTAMP_SECOND_TYPE_VARIATION_REF => { + Ok(DataType::Timestamp(TimeUnit::Second, None)) + } + TIMESTAMP_MILLI_TYPE_VARIATION_REF => { + Ok(DataType::Timestamp(TimeUnit::Millisecond, None)) + } + TIMESTAMP_MICRO_TYPE_VARIATION_REF => { + Ok(DataType::Timestamp(TimeUnit::Microsecond, None)) + } + TIMESTAMP_NANO_TYPE_VARIATION_REF => { + Ok(DataType::Timestamp(TimeUnit::Nanosecond, None)) + } + v => not_impl_err!( + "Unsupported Substrait type variation {v} of type {s_kind:?}" + ), + } + } + r#type::Kind::PrecisionTimestamp(pts) => { + let unit = match pts.precision { + 0 => Ok(TimeUnit::Second), + 3 => Ok(TimeUnit::Millisecond), + 6 => Ok(TimeUnit::Microsecond), + 9 => Ok(TimeUnit::Nanosecond), + p => not_impl_err!( + "Unsupported Substrait precision {p} for PrecisionTimestamp" + ), + }?; + Ok(DataType::Timestamp(unit, None)) + } + r#type::Kind::PrecisionTimestampTz(pts) => { + let unit = match pts.precision { + 0 => Ok(TimeUnit::Second), + 3 => Ok(TimeUnit::Millisecond), + 6 => Ok(TimeUnit::Microsecond), + 9 => Ok(TimeUnit::Nanosecond), + p => not_impl_err!( + "Unsupported Substrait precision {p} for PrecisionTimestampTz" + ), + }?; + Ok(DataType::Timestamp(unit, Some(DEFAULT_TIMEZONE.into()))) + } + r#type::Kind::Date(date) => match date.type_variation_reference { + DATE_32_TYPE_VARIATION_REF => Ok(DataType::Date32), + DATE_64_TYPE_VARIATION_REF => Ok(DataType::Date64), + v => not_impl_err!( + "Unsupported Substrait type variation {v} of type {s_kind:?}" + ), + }, + r#type::Kind::Binary(binary) => match binary.type_variation_reference { + DEFAULT_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::Binary), + LARGE_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::LargeBinary), + VIEW_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::BinaryView), + v => not_impl_err!( + "Unsupported Substrait type variation {v} of type {s_kind:?}" + ), + }, + r#type::Kind::FixedBinary(fixed) => { + Ok(DataType::FixedSizeBinary(fixed.length)) + } + r#type::Kind::String(string) => match string.type_variation_reference { + DEFAULT_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::Utf8), + LARGE_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::LargeUtf8), + VIEW_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::Utf8View), + v => not_impl_err!( + "Unsupported Substrait type variation {v} of type {s_kind:?}" + ), + }, + r#type::Kind::List(list) => { + let inner_type = list.r#type.as_ref().ok_or_else(|| { + substrait_datafusion_err!("List type must have inner type") + })?; + let field = Arc::new(Field::new_list_field( + from_substrait_type(consumer, inner_type, dfs_names, name_idx)?, + // We ignore Substrait's nullability here to match to_substrait_literal + // which always creates nullable lists + true, + )); + match list.type_variation_reference { + DEFAULT_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::List(field)), + LARGE_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::LargeList(field)), + v => not_impl_err!( + "Unsupported Substrait type variation {v} of type {s_kind:?}" + )?, + } + } + r#type::Kind::Map(map) => { + let key_type = map.key.as_ref().ok_or_else(|| { + substrait_datafusion_err!("Map type must have key type") + })?; + let value_type = map.value.as_ref().ok_or_else(|| { + substrait_datafusion_err!("Map type must have value type") + })?; + let key_field = Arc::new(Field::new( + "key", + from_substrait_type(consumer, key_type, dfs_names, name_idx)?, + false, + )); + let value_field = Arc::new(Field::new( + "value", + from_substrait_type(consumer, value_type, dfs_names, name_idx)?, + true, + )); + Ok(DataType::Map( + Arc::new(Field::new_struct( + "entries", + [key_field, value_field], + false, // The inner map field is always non-nullable (Arrow #1697), + )), + false, // whether keys are sorted + )) + } + r#type::Kind::Decimal(d) => match d.type_variation_reference { + DECIMAL_128_TYPE_VARIATION_REF => { + Ok(DataType::Decimal128(d.precision as u8, d.scale as i8)) + } + DECIMAL_256_TYPE_VARIATION_REF => { + Ok(DataType::Decimal256(d.precision as u8, d.scale as i8)) + } + v => not_impl_err!( + "Unsupported Substrait type variation {v} of type {s_kind:?}" + ), + }, + r#type::Kind::IntervalYear(_) => { + Ok(DataType::Interval(IntervalUnit::YearMonth)) + } + r#type::Kind::IntervalDay(_) => Ok(DataType::Interval(IntervalUnit::DayTime)), + r#type::Kind::IntervalCompound(_) => { + Ok(DataType::Interval(IntervalUnit::MonthDayNano)) + } + r#type::Kind::UserDefined(u) => { + if let Ok(data_type) = consumer.consume_user_defined_type(u) { + return Ok(data_type); + } + + // TODO: remove the code below once the producer has been updated + if let Some(name) = consumer.get_extensions().types.get(&u.type_reference) + { + #[allow(deprecated)] + match name.as_ref() { + // Kept for backwards compatibility, producers should use IntervalCompound instead + INTERVAL_MONTH_DAY_NANO_TYPE_NAME => Ok(DataType::Interval(IntervalUnit::MonthDayNano)), + _ => not_impl_err!( + "Unsupported Substrait user defined type with ref {} and variation {}", + u.type_reference, + u.type_variation_reference + ), + } + } else { + #[allow(deprecated)] + match u.type_reference { + // Kept for backwards compatibility, producers should use IntervalYear instead + INTERVAL_YEAR_MONTH_TYPE_REF => { + Ok(DataType::Interval(IntervalUnit::YearMonth)) + } + // Kept for backwards compatibility, producers should use IntervalDay instead + INTERVAL_DAY_TIME_TYPE_REF => { + Ok(DataType::Interval(IntervalUnit::DayTime)) + } + // Kept for backwards compatibility, producers should use IntervalCompound instead + INTERVAL_MONTH_DAY_NANO_TYPE_REF => { + Ok(DataType::Interval(IntervalUnit::MonthDayNano)) + } + _ => not_impl_err!( + "Unsupported Substrait user defined type with ref {} and variation {}", + u.type_reference, + u.type_variation_reference + ), + } + } + } + r#type::Kind::Struct(s) => Ok(DataType::Struct(from_substrait_struct_type( + consumer, s, dfs_names, name_idx, + )?)), + r#type::Kind::Varchar(_) => Ok(DataType::Utf8), + r#type::Kind::FixedChar(_) => Ok(DataType::Utf8), + _ => not_impl_err!("Unsupported Substrait type: {s_kind:?}"), + }, + _ => not_impl_err!("`None` Substrait kind is not supported"), + } +} + +/// Convert Substrait NamedStruct to DataFusion DFSchemaRef +pub fn from_substrait_named_struct( + consumer: &impl SubstraitConsumer, + base_schema: &NamedStruct, +) -> datafusion::common::Result { + let mut name_idx = 0; + let fields = from_substrait_struct_type( + consumer, + base_schema.r#struct.as_ref().ok_or_else(|| { + substrait_datafusion_err!("Named struct must contain a struct") + })?, + &base_schema.names, + &mut name_idx, + ); + if name_idx != base_schema.names.len() { + return substrait_err!( + "Names list must match exactly to nested schema, but found {} uses for {} names", + name_idx, + base_schema.names.len() + ); + } + DFSchema::try_from(Schema::new(fields?)) +} + +fn from_substrait_struct_type( + consumer: &impl SubstraitConsumer, + s: &r#type::Struct, + dfs_names: &[String], + name_idx: &mut usize, +) -> datafusion::common::Result { + let mut fields = vec![]; + for (i, f) in s.types.iter().enumerate() { + let field = Field::new( + next_struct_field_name(i, dfs_names, name_idx)?, + from_substrait_type(consumer, f, dfs_names, name_idx)?, + true, // We assume everything to be nullable since that's easier than ensuring it matches + ); + fields.push(field); + } + Ok(fields.into()) +} diff --git a/datafusion/substrait/src/logical_plan/consumer/utils.rs b/datafusion/substrait/src/logical_plan/consumer/utils.rs new file mode 100644 index 000000000000..a267971ff8d3 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/utils.rs @@ -0,0 +1,563 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::consumer::SubstraitConsumer; +use datafusion::arrow::datatypes::{DataType, Field, FieldRef, Fields, Schema}; +use datafusion::common::{ + not_impl_err, substrait_datafusion_err, substrait_err, DFSchema, DFSchemaRef, + TableReference, +}; +use datafusion::logical_expr::expr::Sort; +use datafusion::logical_expr::{Cast, Expr, ExprSchemable, LogicalPlanBuilder}; +use std::collections::HashSet; +use std::sync::Arc; +use substrait::proto::sort_field::SortDirection; +use substrait::proto::sort_field::SortKind::{ComparisonFunctionReference, Direction}; +use substrait::proto::SortField; + +// Substrait PrecisionTimestampTz indicates that the timestamp is relative to UTC, which +// is the same as the expectation for any non-empty timezone in DF, so any non-empty timezone +// results in correct points on the timeline, and we pick UTC as a reasonable default. +// However, DF uses the timezone also for some arithmetic and display purposes (see e.g. +// https://github.com/apache/arrow-rs/blob/ee5694078c86c8201549654246900a4232d531a9/arrow-cast/src/cast/mod.rs#L1749). +pub(super) const DEFAULT_TIMEZONE: &str = "UTC"; + +/// (Re)qualify the sides of a join if needed, i.e. if the columns from one side would otherwise +/// conflict with the columns from the other. +/// Substrait doesn't currently allow specifying aliases, neither for columns nor for tables. For +/// Substrait the names don't matter since it only refers to columns by indices, however DataFusion +/// requires columns to be uniquely identifiable, in some places (see e.g. DFSchema::check_names). +pub(super) fn requalify_sides_if_needed( + left: LogicalPlanBuilder, + right: LogicalPlanBuilder, +) -> datafusion::common::Result<(LogicalPlanBuilder, LogicalPlanBuilder)> { + let left_cols = left.schema().columns(); + let right_cols = right.schema().columns(); + if left_cols.iter().any(|l| { + right_cols.iter().any(|r| { + l == r || (l.name == r.name && (l.relation.is_none() || r.relation.is_none())) + }) + }) { + // These names have no connection to the original plan, but they'll make the columns + // (mostly) unique. + Ok(( + left.alias(TableReference::bare("left"))?, + right.alias(TableReference::bare("right"))?, + )) + } else { + Ok((left, right)) + } +} + +pub(super) fn next_struct_field_name( + column_idx: usize, + dfs_names: &[String], + name_idx: &mut usize, +) -> datafusion::common::Result { + if dfs_names.is_empty() { + // If names are not given, create dummy names + // c0, c1, ... align with e.g. SqlToRel::create_named_struct + Ok(format!("c{column_idx}")) + } else { + let name = dfs_names.get(*name_idx).cloned().ok_or_else(|| { + substrait_datafusion_err!("Named schema must contain names for all fields") + })?; + *name_idx += 1; + Ok(name) + } +} + +pub(super) fn rename_field( + field: &Field, + dfs_names: &Vec, + unnamed_field_suffix: usize, // If Substrait doesn't provide a name, we'll use this "c{unnamed_field_suffix}" + name_idx: &mut usize, // Index into dfs_names + rename_self: bool, // Some fields (e.g. list items) don't have names in Substrait and this will be false to keep old name +) -> datafusion::common::Result { + let name = if rename_self { + next_struct_field_name(unnamed_field_suffix, dfs_names, name_idx)? + } else { + field.name().to_string() + }; + match field.data_type() { + DataType::Struct(children) => { + let children = children + .iter() + .enumerate() + .map(|(child_idx, f)| { + rename_field( + f.as_ref(), + dfs_names, + child_idx, + name_idx, + /*rename_self=*/ true, + ) + }) + .collect::>()?; + Ok(field + .to_owned() + .with_name(name) + .with_data_type(DataType::Struct(children))) + } + DataType::List(inner) => { + let renamed_inner = rename_field( + inner.as_ref(), + dfs_names, + 0, + name_idx, + /*rename_self=*/ false, + )?; + Ok(field + .to_owned() + .with_data_type(DataType::List(FieldRef::new(renamed_inner))) + .with_name(name)) + } + DataType::LargeList(inner) => { + let renamed_inner = rename_field( + inner.as_ref(), + dfs_names, + 0, + name_idx, + /*rename_self= */ false, + )?; + Ok(field + .to_owned() + .with_data_type(DataType::LargeList(FieldRef::new(renamed_inner))) + .with_name(name)) + } + DataType::Map(inner, sorted) => match inner.data_type() { + DataType::Struct(key_and_value) if key_and_value.len() == 2 => { + let renamed_keys = rename_field( + key_and_value[0].as_ref(), + dfs_names, + 0, + name_idx, + /*rename_self=*/ false, + )?; + let renamed_values = rename_field( + key_and_value[1].as_ref(), + dfs_names, + 0, + name_idx, + /*rename_self=*/ false, + )?; + Ok(field + .to_owned() + .with_data_type(DataType::Map( + Arc::new(Field::new( + inner.name(), + DataType::Struct(Fields::from(vec![ + renamed_keys, + renamed_values, + ])), + inner.is_nullable(), + )), + *sorted, + )) + .with_name(name)) + } + _ => substrait_err!("Map fields must contain a Struct with exactly 2 fields"), + }, + _ => Ok(field.to_owned().with_name(name)), + } +} + +/// Produce a version of the given schema with names matching the given list of names. +/// Substrait doesn't deal with column (incl. nested struct field) names within the schema, +/// but it does give us the list of expected names at the end of the plan, so we use this +/// to rename the schema to match the expected names. +pub(super) fn make_renamed_schema( + schema: &DFSchemaRef, + dfs_names: &Vec, +) -> datafusion::common::Result { + let mut name_idx = 0; + + let (qualifiers, fields): (_, Vec) = schema + .iter() + .enumerate() + .map(|(field_idx, (q, f))| { + let renamed_f = rename_field( + f.as_ref(), + dfs_names, + field_idx, + &mut name_idx, + /*rename_self=*/ true, + )?; + Ok((q.cloned(), renamed_f)) + }) + .collect::>>()? + .into_iter() + .unzip(); + + if name_idx != dfs_names.len() { + return substrait_err!( + "Names list must match exactly to nested schema, but found {} uses for {} names", + name_idx, + dfs_names.len()); + } + + DFSchema::from_field_specific_qualified_schema( + qualifiers, + &Arc::new(Schema::new(fields)), + ) +} + +/// Ensure the expressions have the right name(s) according to the new schema. +/// This includes the top-level (column) name, which will be renamed through aliasing if needed, +/// as well as nested names (if the expression produces any struct types), which will be renamed +/// through casting if needed. +pub(super) fn rename_expressions( + exprs: impl IntoIterator, + input_schema: &DFSchema, + new_schema_fields: &[Arc], +) -> datafusion::common::Result> { + exprs + .into_iter() + .zip(new_schema_fields) + .map(|(old_expr, new_field)| { + // Check if type (i.e. nested struct field names) match, use Cast to rename if needed + let new_expr = if &old_expr.get_type(input_schema)? != new_field.data_type() { + Expr::Cast(Cast::new( + Box::new(old_expr), + new_field.data_type().to_owned(), + )) + } else { + old_expr + }; + // Alias column if needed to fix the top-level name + match &new_expr { + // If expr is a column reference, alias_if_changed would cause an aliasing if the old expr has a qualifier + Expr::Column(c) if &c.name == new_field.name() => Ok(new_expr), + _ => new_expr.alias_if_changed(new_field.name().to_owned()), + } + }) + .collect() +} + +/// Ensures that the given Substrait schema is compatible with the schema as given by DataFusion +/// +/// This means: +/// 1. All fields present in the Substrait schema are present in the DataFusion schema. The +/// DataFusion schema may have MORE fields, but not the other way around. +/// 2. All fields are compatible. See [`ensure_field_compatibility`] for details +pub(super) fn ensure_schema_compatibility( + table_schema: &DFSchema, + substrait_schema: DFSchema, +) -> datafusion::common::Result<()> { + substrait_schema + .strip_qualifiers() + .fields() + .iter() + .try_for_each(|substrait_field| { + let df_field = + table_schema.field_with_unqualified_name(substrait_field.name())?; + ensure_field_compatibility(df_field, substrait_field) + }) +} + +/// Ensures that the given Substrait field is compatible with the given DataFusion field +/// +/// A field is compatible between Substrait and DataFusion if: +/// 1. They have logically equivalent types. +/// 2. They have the same nullability OR the Substrait field is nullable and the DataFusion fields +/// is not nullable. +/// +/// If a Substrait field is not nullable, the Substrait plan may be built around assuming it is not +/// nullable. As such if DataFusion has that field as nullable the plan should be rejected. +fn ensure_field_compatibility( + datafusion_field: &Field, + substrait_field: &Field, +) -> datafusion::common::Result<()> { + if !DFSchema::datatype_is_logically_equal( + datafusion_field.data_type(), + substrait_field.data_type(), + ) { + return substrait_err!( + "Field '{}' in Substrait schema has a different type ({}) than the corresponding field in the table schema ({}).", + substrait_field.name(), + substrait_field.data_type(), + datafusion_field.data_type() + ); + } + + if !compatible_nullabilities( + datafusion_field.is_nullable(), + substrait_field.is_nullable(), + ) { + // TODO: from_substrait_struct_type needs to be updated to set the nullability correctly. It defaults to true for now. + return substrait_err!( + "Field '{}' is nullable in the DataFusion schema but not nullable in the Substrait schema.", + substrait_field.name() + ); + } + Ok(()) +} + +/// Returns true if the DataFusion and Substrait nullabilities are compatible, false otherwise +fn compatible_nullabilities( + datafusion_nullability: bool, + substrait_nullability: bool, +) -> bool { + // DataFusion and Substrait have the same nullability + (datafusion_nullability == substrait_nullability) + // DataFusion is not nullable and Substrait is nullable + || (!datafusion_nullability && substrait_nullability) +} + +pub(super) struct NameTracker { + seen_names: HashSet, +} + +pub(super) enum NameTrackerStatus { + NeverSeen, + SeenBefore, +} + +impl NameTracker { + pub(super) fn new() -> Self { + NameTracker { + seen_names: HashSet::default(), + } + } + pub(super) fn get_unique_name( + &mut self, + name: String, + ) -> (String, NameTrackerStatus) { + match self.seen_names.insert(name.clone()) { + true => (name, NameTrackerStatus::NeverSeen), + false => { + let mut counter = 0; + loop { + let candidate_name = format!("{name}__temp__{counter}"); + if self.seen_names.insert(candidate_name.clone()) { + return (candidate_name, NameTrackerStatus::SeenBefore); + } + counter += 1; + } + } + } + } + + pub(super) fn get_uniquely_named_expr( + &mut self, + expr: Expr, + ) -> datafusion::common::Result { + match self.get_unique_name(expr.name_for_alias()?) { + (_, NameTrackerStatus::NeverSeen) => Ok(expr), + (name, NameTrackerStatus::SeenBefore) => Ok(expr.alias(name)), + } + } +} + +/// Convert Substrait Sorts to DataFusion Exprs +pub async fn from_substrait_sorts( + consumer: &impl SubstraitConsumer, + substrait_sorts: &Vec, + input_schema: &DFSchema, +) -> datafusion::common::Result> { + let mut sorts: Vec = vec![]; + for s in substrait_sorts { + let expr = consumer + .consume_expression(s.expr.as_ref().unwrap(), input_schema) + .await?; + let asc_nullfirst = match &s.sort_kind { + Some(k) => match k { + Direction(d) => { + let Ok(direction) = SortDirection::try_from(*d) else { + return not_impl_err!( + "Unsupported Substrait SortDirection value {d}" + ); + }; + + match direction { + SortDirection::AscNullsFirst => Ok((true, true)), + SortDirection::AscNullsLast => Ok((true, false)), + SortDirection::DescNullsFirst => Ok((false, true)), + SortDirection::DescNullsLast => Ok((false, false)), + SortDirection::Clustered => not_impl_err!( + "Sort with direction clustered is not yet supported" + ), + SortDirection::Unspecified => { + not_impl_err!("Unspecified sort direction is invalid") + } + } + } + ComparisonFunctionReference(_) => not_impl_err!( + "Sort using comparison function reference is not supported" + ), + }, + None => not_impl_err!("Sort without sort kind is invalid"), + }; + let (asc, nulls_first) = asc_nullfirst.unwrap(); + sorts.push(Sort { + expr, + asc, + nulls_first, + }); + } + Ok(sorts) +} + +#[cfg(test)] +pub(crate) mod tests { + use super::make_renamed_schema; + use crate::extensions::Extensions; + use crate::logical_plan::consumer::DefaultSubstraitConsumer; + use datafusion::arrow::datatypes::{DataType, Field}; + use datafusion::common::DFSchema; + use datafusion::error::Result; + use datafusion::execution::SessionState; + use datafusion::prelude::SessionContext; + use datafusion::sql::TableReference; + use std::collections::HashMap; + use std::sync::{Arc, LazyLock}; + + pub(crate) static TEST_SESSION_STATE: LazyLock = + LazyLock::new(|| SessionContext::default().state()); + pub(crate) static TEST_EXTENSIONS: LazyLock = + LazyLock::new(Extensions::default); + pub(crate) fn test_consumer() -> DefaultSubstraitConsumer<'static> { + let extensions = &TEST_EXTENSIONS; + let state = &TEST_SESSION_STATE; + DefaultSubstraitConsumer::new(extensions, state) + } + + #[tokio::test] + async fn rename_schema() -> Result<()> { + let table_ref = TableReference::bare("test"); + let fields = vec![ + ( + Some(table_ref.clone()), + Arc::new(Field::new("0", DataType::Int32, false)), + ), + ( + Some(table_ref.clone()), + Arc::new(Field::new_struct( + "1", + vec![ + Field::new("2", DataType::Int32, false), + Field::new_struct( + "3", + vec![Field::new("4", DataType::Int32, false)], + false, + ), + ], + false, + )), + ), + ( + Some(table_ref.clone()), + Arc::new(Field::new_list( + "5", + Arc::new(Field::new_struct( + "item", + vec![Field::new("6", DataType::Int32, false)], + false, + )), + false, + )), + ), + ( + Some(table_ref.clone()), + Arc::new(Field::new_map( + "7", + "entries", + Arc::new(Field::new_struct( + "keys", + vec![Field::new("8", DataType::Int32, false)], + false, + )), + Arc::new(Field::new_struct( + "values", + vec![Field::new("9", DataType::Int32, false)], + false, + )), + false, + false, + )), + ), + ]; + + let schema = Arc::new(DFSchema::new_with_metadata(fields, HashMap::default())?); + let dfs_names = vec![ + "a".to_string(), + "b".to_string(), + "c".to_string(), + "d".to_string(), + "e".to_string(), + "f".to_string(), + "g".to_string(), + "h".to_string(), + "i".to_string(), + "j".to_string(), + ]; + let renamed_schema = make_renamed_schema(&schema, &dfs_names)?; + + assert_eq!(renamed_schema.fields().len(), 4); + assert_eq!( + *renamed_schema.field(0), + Field::new("a", DataType::Int32, false) + ); + assert_eq!( + *renamed_schema.field(1), + Field::new_struct( + "b", + vec![ + Field::new("c", DataType::Int32, false), + Field::new_struct( + "d", + vec![Field::new("e", DataType::Int32, false)], + false, + ) + ], + false, + ) + ); + assert_eq!( + *renamed_schema.field(2), + Field::new_list( + "f", + Arc::new(Field::new_struct( + "item", + vec![Field::new("g", DataType::Int32, false)], + false, + )), + false, + ) + ); + assert_eq!( + *renamed_schema.field(3), + Field::new_map( + "h", + "entries", + Arc::new(Field::new_struct( + "keys", + vec![Field::new("i", DataType::Int32, false)], + false, + )), + Arc::new(Field::new_struct( + "values", + vec![Field::new("j", DataType::Int32, false)], + false, + )), + false, + false, + ) + ); + Ok(()) + } +} diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs deleted file mode 100644 index 07bf0cb96aa3..000000000000 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ /dev/null @@ -1,2915 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::sync::Arc; -use substrait::proto::expression_reference::ExprType; - -use datafusion::arrow::datatypes::{Field, IntervalUnit}; -use datafusion::logical_expr::{ - Aggregate, Distinct, EmptyRelation, Extension, Filter, Join, Like, Limit, - Partitioning, Projection, Repartition, Sort, SortExpr, SubqueryAlias, TableScan, - TryCast, Union, Values, Window, WindowFrameUnits, -}; -use datafusion::{ - arrow::datatypes::{DataType, TimeUnit}, - error::{DataFusionError, Result}, - logical_expr::{WindowFrame, WindowFrameBound}, - prelude::JoinType, - scalar::ScalarValue, -}; - -use crate::extensions::Extensions; -use crate::variation_const::{ - DATE_32_TYPE_VARIATION_REF, DATE_64_TYPE_VARIATION_REF, - DECIMAL_128_TYPE_VARIATION_REF, DECIMAL_256_TYPE_VARIATION_REF, - DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_TYPE_VARIATION_REF, - LARGE_CONTAINER_TYPE_VARIATION_REF, UNSIGNED_INTEGER_TYPE_VARIATION_REF, - VIEW_CONTAINER_TYPE_VARIATION_REF, -}; -use datafusion::arrow::array::{Array, GenericListArray, OffsetSizeTrait}; -use datafusion::arrow::temporal_conversions::NANOSECONDS; -use datafusion::common::{ - exec_err, internal_err, not_impl_err, plan_err, substrait_datafusion_err, - substrait_err, Column, DFSchema, DFSchemaRef, ToDFSchema, -}; -use datafusion::execution::registry::SerializerRegistry; -use datafusion::execution::SessionState; -use datafusion::logical_expr::expr::{ - AggregateFunctionParams, Alias, BinaryExpr, Case, Cast, GroupingSet, InList, - InSubquery, WindowFunction, WindowFunctionParams, -}; -use datafusion::logical_expr::utils::conjunction; -use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, Operator}; -use datafusion::prelude::Expr; -use pbjson_types::Any as ProtoAny; -use substrait::proto::exchange_rel::{ExchangeKind, RoundRobin, ScatterFields}; -use substrait::proto::expression::cast::FailureBehavior; -use substrait::proto::expression::field_reference::{RootReference, RootType}; -use substrait::proto::expression::literal::interval_day_to_second::PrecisionMode; -use substrait::proto::expression::literal::map::KeyValue; -use substrait::proto::expression::literal::{ - IntervalCompound, IntervalDayToSecond, IntervalYearToMonth, List, Map, - PrecisionTimestamp, Struct, -}; -use substrait::proto::expression::subquery::InPredicate; -use substrait::proto::expression::window_function::BoundsType; -use substrait::proto::expression::ScalarFunction; -use substrait::proto::read_rel::VirtualTable; -use substrait::proto::rel_common::EmitKind; -use substrait::proto::rel_common::EmitKind::Emit; -use substrait::proto::{ - fetch_rel, rel_common, ExchangeRel, ExpressionReference, ExtendedExpression, - RelCommon, -}; -use substrait::{ - proto::{ - aggregate_function::AggregationInvocation, - aggregate_rel::{Grouping, Measure}, - expression::{ - field_reference::ReferenceType, - if_then::IfClause, - literal::{Decimal, LiteralType}, - mask_expression::{StructItem, StructSelect}, - reference_segment, - window_function::bound as SubstraitBound, - window_function::bound::Kind as BoundKind, - window_function::Bound, - FieldReference, IfThen, Literal, MaskExpression, ReferenceSegment, RexType, - SingularOrList, WindowFunction as SubstraitWindowFunction, - }, - function_argument::ArgType, - join_rel, plan_rel, r#type, - read_rel::{NamedTable, ReadType}, - rel::RelType, - set_rel, - sort_field::{SortDirection, SortKind}, - AggregateFunction, AggregateRel, AggregationPhase, Expression, ExtensionLeafRel, - ExtensionMultiRel, ExtensionSingleRel, FetchRel, FilterRel, FunctionArgument, - JoinRel, NamedStruct, Plan, PlanRel, ProjectRel, ReadRel, Rel, RelRoot, SetRel, - SortField, SortRel, - }, - version, -}; - -/// This trait is used to produce Substrait plans, converting them from DataFusion Logical Plans. -/// It can be implemented by users to allow for custom handling of relations, expressions, etc. -/// -/// Combined with the [crate::logical_plan::consumer::SubstraitConsumer] this allows for fully -/// customizable Substrait serde. -/// -/// # Example Usage -/// -/// ``` -/// # use std::sync::Arc; -/// # use substrait::proto::{Expression, Rel}; -/// # use substrait::proto::rel::RelType; -/// # use datafusion::common::DFSchemaRef; -/// # use datafusion::error::Result; -/// # use datafusion::execution::SessionState; -/// # use datafusion::logical_expr::{Between, Extension, Projection}; -/// # use datafusion_substrait::extensions::Extensions; -/// # use datafusion_substrait::logical_plan::producer::{from_projection, SubstraitProducer}; -/// -/// struct CustomSubstraitProducer { -/// extensions: Extensions, -/// state: Arc, -/// } -/// -/// impl SubstraitProducer for CustomSubstraitProducer { -/// -/// fn register_function(&mut self, signature: String) -> u32 { -/// self.extensions.register_function(signature) -/// } -/// -/// fn get_extensions(self) -> Extensions { -/// self.extensions -/// } -/// -/// // You can set additional metadata on the Rels you produce -/// fn handle_projection(&mut self, plan: &Projection) -> Result> { -/// let mut rel = from_projection(self, plan)?; -/// match rel.rel_type { -/// Some(RelType::Project(mut project)) => { -/// let mut project = project.clone(); -/// // set common metadata or advanced extension -/// project.common = None; -/// project.advanced_extension = None; -/// Ok(Box::new(Rel { -/// rel_type: Some(RelType::Project(project)), -/// })) -/// } -/// rel_type => Ok(Box::new(Rel { rel_type })), -/// } -/// } -/// -/// // You can tweak how you convert expressions for your target system -/// fn handle_between(&mut self, between: &Between, schema: &DFSchemaRef) -> Result { -/// // add your own encoding for Between -/// todo!() -/// } -/// -/// // You can fully control how you convert UserDefinedLogicalNodes into Substrait -/// fn handle_extension(&mut self, _plan: &Extension) -> Result> { -/// // implement your own serializer into Substrait -/// todo!() -/// } -/// } -/// ``` -pub trait SubstraitProducer: Send + Sync + Sized { - /// Within a Substrait plan, functions are referenced using function anchors that are stored at - /// the top level of the [Plan] within - /// [ExtensionFunction](substrait::proto::extensions::simple_extension_declaration::ExtensionFunction) - /// messages. - /// - /// When given a function signature, this method should return the existing anchor for it if - /// there is one. Otherwise, it should generate a new anchor. - fn register_function(&mut self, signature: String) -> u32; - - /// Consume the producer to generate the [Extensions] for the Substrait plan based on the - /// functions that have been registered - fn get_extensions(self) -> Extensions; - - // Logical Plan Methods - // There is one method per LogicalPlan to allow for easy overriding of producer behaviour. - // These methods have default implementations calling the common handler code, to allow for users - // to re-use common handling logic. - - fn handle_plan(&mut self, plan: &LogicalPlan) -> Result> { - to_substrait_rel(self, plan) - } - - fn handle_projection(&mut self, plan: &Projection) -> Result> { - from_projection(self, plan) - } - - fn handle_filter(&mut self, plan: &Filter) -> Result> { - from_filter(self, plan) - } - - fn handle_window(&mut self, plan: &Window) -> Result> { - from_window(self, plan) - } - - fn handle_aggregate(&mut self, plan: &Aggregate) -> Result> { - from_aggregate(self, plan) - } - - fn handle_sort(&mut self, plan: &Sort) -> Result> { - from_sort(self, plan) - } - - fn handle_join(&mut self, plan: &Join) -> Result> { - from_join(self, plan) - } - - fn handle_repartition(&mut self, plan: &Repartition) -> Result> { - from_repartition(self, plan) - } - - fn handle_union(&mut self, plan: &Union) -> Result> { - from_union(self, plan) - } - - fn handle_table_scan(&mut self, plan: &TableScan) -> Result> { - from_table_scan(self, plan) - } - - fn handle_empty_relation(&mut self, plan: &EmptyRelation) -> Result> { - from_empty_relation(plan) - } - - fn handle_subquery_alias(&mut self, plan: &SubqueryAlias) -> Result> { - from_subquery_alias(self, plan) - } - - fn handle_limit(&mut self, plan: &Limit) -> Result> { - from_limit(self, plan) - } - - fn handle_values(&mut self, plan: &Values) -> Result> { - from_values(self, plan) - } - - fn handle_distinct(&mut self, plan: &Distinct) -> Result> { - from_distinct(self, plan) - } - - fn handle_extension(&mut self, _plan: &Extension) -> Result> { - substrait_err!("Specify handling for LogicalPlan::Extension by implementing the SubstraitProducer trait") - } - - // Expression Methods - // There is one method per DataFusion Expr to allow for easy overriding of producer behaviour - // These methods have default implementations calling the common handler code, to allow for users - // to re-use common handling logic. - - fn handle_expr(&mut self, expr: &Expr, schema: &DFSchemaRef) -> Result { - to_substrait_rex(self, expr, schema) - } - - fn handle_alias( - &mut self, - alias: &Alias, - schema: &DFSchemaRef, - ) -> Result { - from_alias(self, alias, schema) - } - - fn handle_column( - &mut self, - column: &Column, - schema: &DFSchemaRef, - ) -> Result { - from_column(column, schema) - } - - fn handle_literal(&mut self, value: &ScalarValue) -> Result { - from_literal(self, value) - } - - fn handle_binary_expr( - &mut self, - expr: &BinaryExpr, - schema: &DFSchemaRef, - ) -> Result { - from_binary_expr(self, expr, schema) - } - - fn handle_like(&mut self, like: &Like, schema: &DFSchemaRef) -> Result { - from_like(self, like, schema) - } - - /// For handling Not, IsNotNull, IsNull, IsTrue, IsFalse, IsUnknown, IsNotTrue, IsNotFalse, IsNotUnknown, Negative - fn handle_unary_expr( - &mut self, - expr: &Expr, - schema: &DFSchemaRef, - ) -> Result { - from_unary_expr(self, expr, schema) - } - - fn handle_between( - &mut self, - between: &Between, - schema: &DFSchemaRef, - ) -> Result { - from_between(self, between, schema) - } - - fn handle_case(&mut self, case: &Case, schema: &DFSchemaRef) -> Result { - from_case(self, case, schema) - } - - fn handle_cast(&mut self, cast: &Cast, schema: &DFSchemaRef) -> Result { - from_cast(self, cast, schema) - } - - fn handle_try_cast( - &mut self, - cast: &TryCast, - schema: &DFSchemaRef, - ) -> Result { - from_try_cast(self, cast, schema) - } - - fn handle_scalar_function( - &mut self, - scalar_fn: &expr::ScalarFunction, - schema: &DFSchemaRef, - ) -> Result { - from_scalar_function(self, scalar_fn, schema) - } - - fn handle_aggregate_function( - &mut self, - agg_fn: &expr::AggregateFunction, - schema: &DFSchemaRef, - ) -> Result { - from_aggregate_function(self, agg_fn, schema) - } - - fn handle_window_function( - &mut self, - window_fn: &WindowFunction, - schema: &DFSchemaRef, - ) -> Result { - from_window_function(self, window_fn, schema) - } - - fn handle_in_list( - &mut self, - in_list: &InList, - schema: &DFSchemaRef, - ) -> Result { - from_in_list(self, in_list, schema) - } - - fn handle_in_subquery( - &mut self, - in_subquery: &InSubquery, - schema: &DFSchemaRef, - ) -> Result { - from_in_subquery(self, in_subquery, schema) - } -} - -pub struct DefaultSubstraitProducer<'a> { - extensions: Extensions, - serializer_registry: &'a dyn SerializerRegistry, -} - -impl<'a> DefaultSubstraitProducer<'a> { - pub fn new(state: &'a SessionState) -> Self { - DefaultSubstraitProducer { - extensions: Extensions::default(), - serializer_registry: state.serializer_registry().as_ref(), - } - } -} - -impl SubstraitProducer for DefaultSubstraitProducer<'_> { - fn register_function(&mut self, fn_name: String) -> u32 { - self.extensions.register_function(fn_name) - } - - fn get_extensions(self) -> Extensions { - self.extensions - } - - fn handle_extension(&mut self, plan: &Extension) -> Result> { - let extension_bytes = self - .serializer_registry - .serialize_logical_plan(plan.node.as_ref())?; - let detail = ProtoAny { - type_url: plan.node.name().to_string(), - value: extension_bytes.into(), - }; - let mut inputs_rel = plan - .node - .inputs() - .into_iter() - .map(|plan| self.handle_plan(plan)) - .collect::>>()?; - let rel_type = match inputs_rel.len() { - 0 => RelType::ExtensionLeaf(ExtensionLeafRel { - common: None, - detail: Some(detail), - }), - 1 => RelType::ExtensionSingle(Box::new(ExtensionSingleRel { - common: None, - detail: Some(detail), - input: Some(inputs_rel.pop().unwrap()), - })), - _ => RelType::ExtensionMulti(ExtensionMultiRel { - common: None, - detail: Some(detail), - inputs: inputs_rel.into_iter().map(|r| *r).collect(), - }), - }; - Ok(Box::new(Rel { - rel_type: Some(rel_type), - })) - } -} - -/// Convert DataFusion LogicalPlan to Substrait Plan -pub fn to_substrait_plan(plan: &LogicalPlan, state: &SessionState) -> Result> { - // Parse relation nodes - // Generate PlanRel(s) - // Note: Only 1 relation tree is currently supported - - let mut producer: DefaultSubstraitProducer = DefaultSubstraitProducer::new(state); - let plan_rels = vec![PlanRel { - rel_type: Some(plan_rel::RelType::Root(RelRoot { - input: Some(*producer.handle_plan(plan)?), - names: to_substrait_named_struct(plan.schema())?.names, - })), - }]; - - // Return parsed plan - let extensions = producer.get_extensions(); - Ok(Box::new(Plan { - version: Some(version::version_with_producer("datafusion")), - extension_uris: vec![], - extensions: extensions.into(), - relations: plan_rels, - advanced_extensions: None, - expected_type_urls: vec![], - parameter_bindings: vec![], - })) -} - -/// Serializes a collection of expressions to a Substrait ExtendedExpression message -/// -/// The ExtendedExpression message is a top-level message that can be used to send -/// expressions (not plans) between systems. -/// -/// Each expression is also given names for the output type. These are provided as a -/// field and not a String (since the names may be nested, e.g. a struct). The data -/// type and nullability of this field is redundant (those can be determined by the -/// Expr) and will be ignored. -/// -/// Substrait also requires the input schema of the expressions to be included in the -/// message. The field names of the input schema will be serialized. -pub fn to_substrait_extended_expr( - exprs: &[(&Expr, &Field)], - schema: &DFSchemaRef, - state: &SessionState, -) -> Result> { - let mut producer = DefaultSubstraitProducer::new(state); - let substrait_exprs = exprs - .iter() - .map(|(expr, field)| { - let substrait_expr = producer.handle_expr(expr, schema)?; - let mut output_names = Vec::new(); - flatten_names(field, false, &mut output_names)?; - Ok(ExpressionReference { - output_names, - expr_type: Some(ExprType::Expression(substrait_expr)), - }) - }) - .collect::>>()?; - let substrait_schema = to_substrait_named_struct(schema)?; - - let extensions = producer.get_extensions(); - Ok(Box::new(ExtendedExpression { - advanced_extensions: None, - expected_type_urls: vec![], - extension_uris: vec![], - extensions: extensions.into(), - version: Some(version::version_with_producer("datafusion")), - referred_expr: substrait_exprs, - base_schema: Some(substrait_schema), - })) -} - -pub fn to_substrait_rel( - producer: &mut impl SubstraitProducer, - plan: &LogicalPlan, -) -> Result> { - match plan { - LogicalPlan::Projection(plan) => producer.handle_projection(plan), - LogicalPlan::Filter(plan) => producer.handle_filter(plan), - LogicalPlan::Window(plan) => producer.handle_window(plan), - LogicalPlan::Aggregate(plan) => producer.handle_aggregate(plan), - LogicalPlan::Sort(plan) => producer.handle_sort(plan), - LogicalPlan::Join(plan) => producer.handle_join(plan), - LogicalPlan::Repartition(plan) => producer.handle_repartition(plan), - LogicalPlan::Union(plan) => producer.handle_union(plan), - LogicalPlan::TableScan(plan) => producer.handle_table_scan(plan), - LogicalPlan::EmptyRelation(plan) => producer.handle_empty_relation(plan), - LogicalPlan::Subquery(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, - LogicalPlan::SubqueryAlias(plan) => producer.handle_subquery_alias(plan), - LogicalPlan::Limit(plan) => producer.handle_limit(plan), - LogicalPlan::Statement(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, - LogicalPlan::Values(plan) => producer.handle_values(plan), - LogicalPlan::Explain(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, - LogicalPlan::Analyze(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, - LogicalPlan::Extension(plan) => producer.handle_extension(plan), - LogicalPlan::Distinct(plan) => producer.handle_distinct(plan), - LogicalPlan::Dml(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, - LogicalPlan::Ddl(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, - LogicalPlan::Copy(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, - LogicalPlan::DescribeTable(plan) => { - not_impl_err!("Unsupported plan type: {plan:?}")? - } - LogicalPlan::Unnest(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, - LogicalPlan::RecursiveQuery(plan) => { - not_impl_err!("Unsupported plan type: {plan:?}")? - } - } -} - -pub fn from_table_scan( - producer: &mut impl SubstraitProducer, - scan: &TableScan, -) -> Result> { - let projection = scan.projection.as_ref().map(|p| { - p.iter() - .map(|i| StructItem { - field: *i as i32, - child: None, - }) - .collect() - }); - - let projection = projection.map(|struct_items| MaskExpression { - select: Some(StructSelect { struct_items }), - maintain_singular_struct: false, - }); - - let table_schema = scan.source.schema().to_dfschema_ref()?; - let base_schema = to_substrait_named_struct(&table_schema)?; - - let filter_option = if scan.filters.is_empty() { - None - } else { - let table_schema_qualified = Arc::new( - DFSchema::try_from_qualified_schema( - scan.table_name.clone(), - &(scan.source.schema()), - ) - .unwrap(), - ); - - let combined_expr = conjunction(scan.filters.clone()).unwrap(); - let filter_expr = - producer.handle_expr(&combined_expr, &table_schema_qualified)?; - Some(Box::new(filter_expr)) - }; - - Ok(Box::new(Rel { - rel_type: Some(RelType::Read(Box::new(ReadRel { - common: None, - base_schema: Some(base_schema), - filter: filter_option, - best_effort_filter: None, - projection, - advanced_extension: None, - read_type: Some(ReadType::NamedTable(NamedTable { - names: scan.table_name.to_vec(), - advanced_extension: None, - })), - }))), - })) -} - -pub fn from_empty_relation(e: &EmptyRelation) -> Result> { - if e.produce_one_row { - return not_impl_err!("Producing a row from empty relation is unsupported"); - } - #[allow(deprecated)] - Ok(Box::new(Rel { - rel_type: Some(RelType::Read(Box::new(ReadRel { - common: None, - base_schema: Some(to_substrait_named_struct(&e.schema)?), - filter: None, - best_effort_filter: None, - projection: None, - advanced_extension: None, - read_type: Some(ReadType::VirtualTable(VirtualTable { - values: vec![], - expressions: vec![], - })), - }))), - })) -} - -pub fn from_values( - producer: &mut impl SubstraitProducer, - v: &Values, -) -> Result> { - let values = v - .values - .iter() - .map(|row| { - let fields = row - .iter() - .map(|v| match v { - Expr::Literal(sv) => to_substrait_literal(producer, sv), - Expr::Alias(alias) => match alias.expr.as_ref() { - // The schema gives us the names, so we can skip aliases - Expr::Literal(sv) => to_substrait_literal(producer, sv), - _ => Err(substrait_datafusion_err!( - "Only literal types can be aliased in Virtual Tables, got: {}", alias.expr.variant_name() - )), - }, - _ => Err(substrait_datafusion_err!( - "Only literal types and aliases are supported in Virtual Tables, got: {}", v.variant_name() - )), - }) - .collect::>()?; - Ok(Struct { fields }) - }) - .collect::>()?; - #[allow(deprecated)] - Ok(Box::new(Rel { - rel_type: Some(RelType::Read(Box::new(ReadRel { - common: None, - base_schema: Some(to_substrait_named_struct(&v.schema)?), - filter: None, - best_effort_filter: None, - projection: None, - advanced_extension: None, - read_type: Some(ReadType::VirtualTable(VirtualTable { - values, - expressions: vec![], - })), - }))), - })) -} - -pub fn from_projection( - producer: &mut impl SubstraitProducer, - p: &Projection, -) -> Result> { - let expressions = p - .expr - .iter() - .map(|e| producer.handle_expr(e, p.input.schema())) - .collect::>>()?; - - let emit_kind = create_project_remapping( - expressions.len(), - p.input.as_ref().schema().fields().len(), - ); - let common = RelCommon { - emit_kind: Some(emit_kind), - hint: None, - advanced_extension: None, - }; - - Ok(Box::new(Rel { - rel_type: Some(RelType::Project(Box::new(ProjectRel { - common: Some(common), - input: Some(producer.handle_plan(p.input.as_ref())?), - expressions, - advanced_extension: None, - }))), - })) -} - -pub fn from_filter( - producer: &mut impl SubstraitProducer, - filter: &Filter, -) -> Result> { - let input = producer.handle_plan(filter.input.as_ref())?; - let filter_expr = producer.handle_expr(&filter.predicate, filter.input.schema())?; - Ok(Box::new(Rel { - rel_type: Some(RelType::Filter(Box::new(FilterRel { - common: None, - input: Some(input), - condition: Some(Box::new(filter_expr)), - advanced_extension: None, - }))), - })) -} - -pub fn from_limit( - producer: &mut impl SubstraitProducer, - limit: &Limit, -) -> Result> { - let input = producer.handle_plan(limit.input.as_ref())?; - let empty_schema = Arc::new(DFSchema::empty()); - let offset_mode = limit - .skip - .as_ref() - .map(|expr| producer.handle_expr(expr.as_ref(), &empty_schema)) - .transpose()? - .map(Box::new) - .map(fetch_rel::OffsetMode::OffsetExpr); - let count_mode = limit - .fetch - .as_ref() - .map(|expr| producer.handle_expr(expr.as_ref(), &empty_schema)) - .transpose()? - .map(Box::new) - .map(fetch_rel::CountMode::CountExpr); - Ok(Box::new(Rel { - rel_type: Some(RelType::Fetch(Box::new(FetchRel { - common: None, - input: Some(input), - offset_mode, - count_mode, - advanced_extension: None, - }))), - })) -} - -pub fn from_sort(producer: &mut impl SubstraitProducer, sort: &Sort) -> Result> { - let Sort { expr, input, fetch } = sort; - let sort_fields = expr - .iter() - .map(|e| substrait_sort_field(producer, e, input.schema())) - .collect::>>()?; - - let input = producer.handle_plan(input.as_ref())?; - - let sort_rel = Box::new(Rel { - rel_type: Some(RelType::Sort(Box::new(SortRel { - common: None, - input: Some(input), - sorts: sort_fields, - advanced_extension: None, - }))), - }); - - match fetch { - Some(amount) => { - let count_mode = - Some(fetch_rel::CountMode::CountExpr(Box::new(Expression { - rex_type: Some(RexType::Literal(Literal { - nullable: false, - type_variation_reference: DEFAULT_TYPE_VARIATION_REF, - literal_type: Some(LiteralType::I64(*amount as i64)), - })), - }))); - Ok(Box::new(Rel { - rel_type: Some(RelType::Fetch(Box::new(FetchRel { - common: None, - input: Some(sort_rel), - offset_mode: None, - count_mode, - advanced_extension: None, - }))), - })) - } - None => Ok(sort_rel), - } -} - -pub fn from_aggregate( - producer: &mut impl SubstraitProducer, - agg: &Aggregate, -) -> Result> { - let input = producer.handle_plan(agg.input.as_ref())?; - let (grouping_expressions, groupings) = - to_substrait_groupings(producer, &agg.group_expr, agg.input.schema())?; - let measures = agg - .aggr_expr - .iter() - .map(|e| to_substrait_agg_measure(producer, e, agg.input.schema())) - .collect::>>()?; - - Ok(Box::new(Rel { - rel_type: Some(RelType::Aggregate(Box::new(AggregateRel { - common: None, - input: Some(input), - grouping_expressions, - groupings, - measures, - advanced_extension: None, - }))), - })) -} - -pub fn from_distinct( - producer: &mut impl SubstraitProducer, - distinct: &Distinct, -) -> Result> { - match distinct { - Distinct::All(plan) => { - // Use Substrait's AggregateRel with empty measures to represent `select distinct` - let input = producer.handle_plan(plan.as_ref())?; - // Get grouping keys from the input relation's number of output fields - let grouping = (0..plan.schema().fields().len()) - .map(substrait_field_ref) - .collect::>>()?; - - #[allow(deprecated)] - Ok(Box::new(Rel { - rel_type: Some(RelType::Aggregate(Box::new(AggregateRel { - common: None, - input: Some(input), - grouping_expressions: vec![], - groupings: vec![Grouping { - grouping_expressions: grouping, - expression_references: vec![], - }], - measures: vec![], - advanced_extension: None, - }))), - })) - } - Distinct::On(_) => not_impl_err!("Cannot convert Distinct::On"), - } -} - -pub fn from_join(producer: &mut impl SubstraitProducer, join: &Join) -> Result> { - let left = producer.handle_plan(join.left.as_ref())?; - let right = producer.handle_plan(join.right.as_ref())?; - let join_type = to_substrait_jointype(join.join_type); - // we only support basic joins so return an error for anything not yet supported - match join.join_constraint { - JoinConstraint::On => {} - JoinConstraint::Using => return not_impl_err!("join constraint: `using`"), - } - let in_join_schema = Arc::new(join.left.schema().join(join.right.schema())?); - - // convert filter if present - let join_filter = match &join.filter { - Some(filter) => Some(producer.handle_expr(filter, &in_join_schema)?), - None => None, - }; - - // map the left and right columns to binary expressions in the form `l = r` - // build a single expression for the ON condition, such as `l.a = r.a AND l.b = r.b` - let eq_op = if join.null_equals_null { - Operator::IsNotDistinctFrom - } else { - Operator::Eq - }; - let join_on = to_substrait_join_expr(producer, &join.on, eq_op, &in_join_schema)?; - - // create conjunction between `join_on` and `join_filter` to embed all join conditions, - // whether equal or non-equal in a single expression - let join_expr = match &join_on { - Some(on_expr) => match &join_filter { - Some(filter) => Some(Box::new(make_binary_op_scalar_func( - producer, - on_expr, - filter, - Operator::And, - ))), - None => join_on.map(Box::new), // the join expression will only contain `join_on` if filter doesn't exist - }, - None => match &join_filter { - Some(_) => join_filter.map(Box::new), // the join expression will only contain `join_filter` if the `on` condition doesn't exist - None => None, - }, - }; - - Ok(Box::new(Rel { - rel_type: Some(RelType::Join(Box::new(JoinRel { - common: None, - left: Some(left), - right: Some(right), - r#type: join_type as i32, - expression: join_expr, - post_join_filter: None, - advanced_extension: None, - }))), - })) -} - -pub fn from_subquery_alias( - producer: &mut impl SubstraitProducer, - alias: &SubqueryAlias, -) -> Result> { - // Do nothing if encounters SubqueryAlias - // since there is no corresponding relation type in Substrait - producer.handle_plan(alias.input.as_ref()) -} - -pub fn from_union( - producer: &mut impl SubstraitProducer, - union: &Union, -) -> Result> { - let input_rels = union - .inputs - .iter() - .map(|input| producer.handle_plan(input.as_ref())) - .collect::>>()? - .into_iter() - .map(|ptr| *ptr) - .collect(); - Ok(Box::new(Rel { - rel_type: Some(RelType::Set(SetRel { - common: None, - inputs: input_rels, - op: set_rel::SetOp::UnionAll as i32, // UNION DISTINCT gets translated to AGGREGATION + UNION ALL - advanced_extension: None, - })), - })) -} - -pub fn from_window( - producer: &mut impl SubstraitProducer, - window: &Window, -) -> Result> { - let input = producer.handle_plan(window.input.as_ref())?; - - // create a field reference for each input field - let mut expressions = (0..window.input.schema().fields().len()) - .map(substrait_field_ref) - .collect::>>()?; - - // process and add each window function expression - for expr in &window.window_expr { - expressions.push(producer.handle_expr(expr, window.input.schema())?); - } - - let emit_kind = - create_project_remapping(expressions.len(), window.input.schema().fields().len()); - let common = RelCommon { - emit_kind: Some(emit_kind), - hint: None, - advanced_extension: None, - }; - let project_rel = Box::new(ProjectRel { - common: Some(common), - input: Some(input), - expressions, - advanced_extension: None, - }); - - Ok(Box::new(Rel { - rel_type: Some(RelType::Project(project_rel)), - })) -} - -pub fn from_repartition( - producer: &mut impl SubstraitProducer, - repartition: &Repartition, -) -> Result> { - let input = producer.handle_plan(repartition.input.as_ref())?; - let partition_count = match repartition.partitioning_scheme { - Partitioning::RoundRobinBatch(num) => num, - Partitioning::Hash(_, num) => num, - Partitioning::DistributeBy(_) => { - return not_impl_err!( - "Physical plan does not support DistributeBy partitioning" - ) - } - }; - // ref: https://substrait.io/relations/physical_relations/#exchange-types - let exchange_kind = match &repartition.partitioning_scheme { - Partitioning::RoundRobinBatch(_) => { - ExchangeKind::RoundRobin(RoundRobin::default()) - } - Partitioning::Hash(exprs, _) => { - let fields = exprs - .iter() - .map(|e| try_to_substrait_field_reference(e, repartition.input.schema())) - .collect::>>()?; - ExchangeKind::ScatterByFields(ScatterFields { fields }) - } - Partitioning::DistributeBy(_) => { - return not_impl_err!( - "Physical plan does not support DistributeBy partitioning" - ) - } - }; - let exchange_rel = ExchangeRel { - common: None, - input: Some(input), - exchange_kind: Some(exchange_kind), - advanced_extension: None, - partition_count: partition_count as i32, - targets: vec![], - }; - Ok(Box::new(Rel { - rel_type: Some(RelType::Exchange(Box::new(exchange_rel))), - })) -} - -/// By default, a Substrait Project outputs all input fields followed by all expressions. -/// A DataFusion Projection only outputs expressions. In order to keep the Substrait -/// plan consistent with DataFusion, we must apply an output mapping that skips the input -/// fields so that the Substrait Project will only output the expression fields. -fn create_project_remapping(expr_count: usize, input_field_count: usize) -> EmitKind { - let expression_field_start = input_field_count; - let expression_field_end = expression_field_start + expr_count; - let output_mapping = (expression_field_start..expression_field_end) - .map(|i| i as i32) - .collect(); - Emit(rel_common::Emit { output_mapping }) -} - -// Substrait wants a list of all field names, including nested fields from structs, -// also from within e.g. lists and maps. However, it does not want the list and map field names -// themselves - only proper structs fields are considered to have useful names. -fn flatten_names(field: &Field, skip_self: bool, names: &mut Vec) -> Result<()> { - if !skip_self { - names.push(field.name().to_string()); - } - match field.data_type() { - DataType::Struct(fields) => { - for field in fields { - flatten_names(field, false, names)?; - } - Ok(()) - } - DataType::List(l) => flatten_names(l, true, names), - DataType::LargeList(l) => flatten_names(l, true, names), - DataType::Map(m, _) => match m.data_type() { - DataType::Struct(key_and_value) if key_and_value.len() == 2 => { - flatten_names(&key_and_value[0], true, names)?; - flatten_names(&key_and_value[1], true, names) - } - _ => plan_err!("Map fields must contain a Struct with exactly 2 fields"), - }, - _ => Ok(()), - }?; - Ok(()) -} - -fn to_substrait_named_struct(schema: &DFSchemaRef) -> Result { - let mut names = Vec::with_capacity(schema.fields().len()); - for field in schema.fields() { - flatten_names(field, false, &mut names)?; - } - - let field_types = r#type::Struct { - types: schema - .fields() - .iter() - .map(|f| to_substrait_type(f.data_type(), f.is_nullable())) - .collect::>()?, - type_variation_reference: DEFAULT_TYPE_VARIATION_REF, - nullability: r#type::Nullability::Required as i32, - }; - - Ok(NamedStruct { - names, - r#struct: Some(field_types), - }) -} - -fn to_substrait_join_expr( - producer: &mut impl SubstraitProducer, - join_conditions: &Vec<(Expr, Expr)>, - eq_op: Operator, - join_schema: &DFSchemaRef, -) -> Result> { - // Only support AND conjunction for each binary expression in join conditions - let mut exprs: Vec = vec![]; - for (left, right) in join_conditions { - let l = producer.handle_expr(left, join_schema)?; - let r = producer.handle_expr(right, join_schema)?; - // AND with existing expression - exprs.push(make_binary_op_scalar_func(producer, &l, &r, eq_op)); - } - - let join_expr: Option = - exprs.into_iter().reduce(|acc: Expression, e: Expression| { - make_binary_op_scalar_func(producer, &acc, &e, Operator::And) - }); - Ok(join_expr) -} - -fn to_substrait_jointype(join_type: JoinType) -> join_rel::JoinType { - match join_type { - JoinType::Inner => join_rel::JoinType::Inner, - JoinType::Left => join_rel::JoinType::Left, - JoinType::Right => join_rel::JoinType::Right, - JoinType::Full => join_rel::JoinType::Outer, - JoinType::LeftAnti => join_rel::JoinType::LeftAnti, - JoinType::LeftSemi => join_rel::JoinType::LeftSemi, - JoinType::LeftMark => join_rel::JoinType::LeftMark, - JoinType::RightAnti | JoinType::RightSemi => { - unimplemented!() - } - } -} - -pub fn operator_to_name(op: Operator) -> &'static str { - match op { - Operator::Eq => "equal", - Operator::NotEq => "not_equal", - Operator::Lt => "lt", - Operator::LtEq => "lte", - Operator::Gt => "gt", - Operator::GtEq => "gte", - Operator::Plus => "add", - Operator::Minus => "subtract", - Operator::Multiply => "multiply", - Operator::Divide => "divide", - Operator::Modulo => "modulus", - Operator::And => "and", - Operator::Or => "or", - Operator::IsDistinctFrom => "is_distinct_from", - Operator::IsNotDistinctFrom => "is_not_distinct_from", - Operator::RegexMatch => "regex_match", - Operator::RegexIMatch => "regex_imatch", - Operator::RegexNotMatch => "regex_not_match", - Operator::RegexNotIMatch => "regex_not_imatch", - Operator::LikeMatch => "like_match", - Operator::ILikeMatch => "like_imatch", - Operator::NotLikeMatch => "like_not_match", - Operator::NotILikeMatch => "like_not_imatch", - Operator::BitwiseAnd => "bitwise_and", - Operator::BitwiseOr => "bitwise_or", - Operator::StringConcat => "str_concat", - Operator::AtArrow => "at_arrow", - Operator::ArrowAt => "arrow_at", - Operator::Arrow => "arrow", - Operator::LongArrow => "long_arrow", - Operator::HashArrow => "hash_arrow", - Operator::HashLongArrow => "hash_long_arrow", - Operator::AtAt => "at_at", - Operator::IntegerDivide => "integer_divide", - Operator::HashMinus => "hash_minus", - Operator::AtQuestion => "at_question", - Operator::Question => "question", - Operator::QuestionAnd => "question_and", - Operator::QuestionPipe => "question_pipe", - Operator::BitwiseXor => "bitwise_xor", - Operator::BitwiseShiftRight => "bitwise_shift_right", - Operator::BitwiseShiftLeft => "bitwise_shift_left", - } -} - -pub fn parse_flat_grouping_exprs( - producer: &mut impl SubstraitProducer, - exprs: &[Expr], - schema: &DFSchemaRef, - ref_group_exprs: &mut Vec, -) -> Result { - let mut expression_references = vec![]; - let mut grouping_expressions = vec![]; - - for e in exprs { - let rex = producer.handle_expr(e, schema)?; - grouping_expressions.push(rex.clone()); - ref_group_exprs.push(rex); - expression_references.push((ref_group_exprs.len() - 1) as u32); - } - #[allow(deprecated)] - Ok(Grouping { - grouping_expressions, - expression_references, - }) -} - -pub fn to_substrait_groupings( - producer: &mut impl SubstraitProducer, - exprs: &[Expr], - schema: &DFSchemaRef, -) -> Result<(Vec, Vec)> { - let mut ref_group_exprs = vec![]; - let groupings = match exprs.len() { - 1 => match &exprs[0] { - Expr::GroupingSet(gs) => match gs { - GroupingSet::Cube(_) => Err(DataFusionError::NotImplemented( - "GroupingSet CUBE is not yet supported".to_string(), - )), - GroupingSet::GroupingSets(sets) => Ok(sets - .iter() - .map(|set| { - parse_flat_grouping_exprs( - producer, - set, - schema, - &mut ref_group_exprs, - ) - }) - .collect::>>()?), - GroupingSet::Rollup(set) => { - let mut sets: Vec> = vec![vec![]]; - for i in 0..set.len() { - sets.push(set[..=i].to_vec()); - } - Ok(sets - .iter() - .rev() - .map(|set| { - parse_flat_grouping_exprs( - producer, - set, - schema, - &mut ref_group_exprs, - ) - }) - .collect::>>()?) - } - }, - _ => Ok(vec![parse_flat_grouping_exprs( - producer, - exprs, - schema, - &mut ref_group_exprs, - )?]), - }, - _ => Ok(vec![parse_flat_grouping_exprs( - producer, - exprs, - schema, - &mut ref_group_exprs, - )?]), - }?; - Ok((ref_group_exprs, groupings)) -} - -pub fn from_aggregate_function( - producer: &mut impl SubstraitProducer, - agg_fn: &expr::AggregateFunction, - schema: &DFSchemaRef, -) -> Result { - let expr::AggregateFunction { - func, - params: - AggregateFunctionParams { - args, - distinct, - filter, - order_by, - null_treatment: _null_treatment, - }, - } = agg_fn; - let sorts = if let Some(order_by) = order_by { - order_by - .iter() - .map(|expr| to_substrait_sort_field(producer, expr, schema)) - .collect::>>()? - } else { - vec![] - }; - let mut arguments: Vec = vec![]; - for arg in args { - arguments.push(FunctionArgument { - arg_type: Some(ArgType::Value(producer.handle_expr(arg, schema)?)), - }); - } - let function_anchor = producer.register_function(func.name().to_string()); - #[allow(deprecated)] - Ok(Measure { - measure: Some(AggregateFunction { - function_reference: function_anchor, - arguments, - sorts, - output_type: None, - invocation: match distinct { - true => AggregationInvocation::Distinct as i32, - false => AggregationInvocation::All as i32, - }, - phase: AggregationPhase::Unspecified as i32, - args: vec![], - options: vec![], - }), - filter: match filter { - Some(f) => Some(producer.handle_expr(f, schema)?), - None => None, - }, - }) -} - -pub fn to_substrait_agg_measure( - producer: &mut impl SubstraitProducer, - expr: &Expr, - schema: &DFSchemaRef, -) -> Result { - match expr { - Expr::AggregateFunction(agg_fn) => from_aggregate_function(producer, agg_fn, schema), - Expr::Alias(Alias { expr, .. }) => { - to_substrait_agg_measure(producer, expr, schema) - } - _ => internal_err!( - "Expression must be compatible with aggregation. Unsupported expression: {:?}. ExpressionType: {:?}", - expr, - expr.variant_name() - ), - } -} - -/// Converts sort expression to corresponding substrait `SortField` -fn to_substrait_sort_field( - producer: &mut impl SubstraitProducer, - sort: &expr::Sort, - schema: &DFSchemaRef, -) -> Result { - let sort_kind = match (sort.asc, sort.nulls_first) { - (true, true) => SortDirection::AscNullsFirst, - (true, false) => SortDirection::AscNullsLast, - (false, true) => SortDirection::DescNullsFirst, - (false, false) => SortDirection::DescNullsLast, - }; - Ok(SortField { - expr: Some(producer.handle_expr(&sort.expr, schema)?), - sort_kind: Some(SortKind::Direction(sort_kind.into())), - }) -} - -/// Return Substrait scalar function with two arguments -pub fn make_binary_op_scalar_func( - producer: &mut impl SubstraitProducer, - lhs: &Expression, - rhs: &Expression, - op: Operator, -) -> Expression { - let function_anchor = producer.register_function(operator_to_name(op).to_string()); - #[allow(deprecated)] - Expression { - rex_type: Some(RexType::ScalarFunction(ScalarFunction { - function_reference: function_anchor, - arguments: vec![ - FunctionArgument { - arg_type: Some(ArgType::Value(lhs.clone())), - }, - FunctionArgument { - arg_type: Some(ArgType::Value(rhs.clone())), - }, - ], - output_type: None, - args: vec![], - options: vec![], - })), - } -} - -/// Convert DataFusion Expr to Substrait Rex -/// -/// # Arguments -/// * `producer` - SubstraitProducer implementation which the handles the actual conversion -/// * `expr` - DataFusion expression to convert into a Substrait expression -/// * `schema` - DataFusion input schema for looking up columns -pub fn to_substrait_rex( - producer: &mut impl SubstraitProducer, - expr: &Expr, - schema: &DFSchemaRef, -) -> Result { - match expr { - Expr::Alias(expr) => producer.handle_alias(expr, schema), - Expr::Column(expr) => producer.handle_column(expr, schema), - Expr::ScalarVariable(_, _) => { - not_impl_err!("Cannot convert {expr:?} to Substrait") - } - Expr::Literal(expr) => producer.handle_literal(expr), - Expr::BinaryExpr(expr) => producer.handle_binary_expr(expr, schema), - Expr::Like(expr) => producer.handle_like(expr, schema), - Expr::SimilarTo(_) => not_impl_err!("Cannot convert {expr:?} to Substrait"), - Expr::Not(_) => producer.handle_unary_expr(expr, schema), - Expr::IsNotNull(_) => producer.handle_unary_expr(expr, schema), - Expr::IsNull(_) => producer.handle_unary_expr(expr, schema), - Expr::IsTrue(_) => producer.handle_unary_expr(expr, schema), - Expr::IsFalse(_) => producer.handle_unary_expr(expr, schema), - Expr::IsUnknown(_) => producer.handle_unary_expr(expr, schema), - Expr::IsNotTrue(_) => producer.handle_unary_expr(expr, schema), - Expr::IsNotFalse(_) => producer.handle_unary_expr(expr, schema), - Expr::IsNotUnknown(_) => producer.handle_unary_expr(expr, schema), - Expr::Negative(_) => producer.handle_unary_expr(expr, schema), - Expr::Between(expr) => producer.handle_between(expr, schema), - Expr::Case(expr) => producer.handle_case(expr, schema), - Expr::Cast(expr) => producer.handle_cast(expr, schema), - Expr::TryCast(expr) => producer.handle_try_cast(expr, schema), - Expr::ScalarFunction(expr) => producer.handle_scalar_function(expr, schema), - Expr::AggregateFunction(_) => { - internal_err!( - "AggregateFunction should only be encountered as part of a LogicalPlan::Aggregate" - ) - } - Expr::WindowFunction(expr) => producer.handle_window_function(expr, schema), - Expr::InList(expr) => producer.handle_in_list(expr, schema), - Expr::Exists(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), - Expr::InSubquery(expr) => producer.handle_in_subquery(expr, schema), - Expr::ScalarSubquery(expr) => { - not_impl_err!("Cannot convert {expr:?} to Substrait") - } - #[expect(deprecated)] - Expr::Wildcard { .. } => not_impl_err!("Cannot convert {expr:?} to Substrait"), - Expr::GroupingSet(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), - Expr::Placeholder(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), - Expr::OuterReferenceColumn(_, _) => { - not_impl_err!("Cannot convert {expr:?} to Substrait") - } - Expr::Unnest(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), - } -} - -pub fn from_in_list( - producer: &mut impl SubstraitProducer, - in_list: &InList, - schema: &DFSchemaRef, -) -> Result { - let InList { - expr, - list, - negated, - } = in_list; - let substrait_list = list - .iter() - .map(|x| producer.handle_expr(x, schema)) - .collect::>>()?; - let substrait_expr = producer.handle_expr(expr, schema)?; - - let substrait_or_list = Expression { - rex_type: Some(RexType::SingularOrList(Box::new(SingularOrList { - value: Some(Box::new(substrait_expr)), - options: substrait_list, - }))), - }; - - if *negated { - let function_anchor = producer.register_function("not".to_string()); - - #[allow(deprecated)] - Ok(Expression { - rex_type: Some(RexType::ScalarFunction(ScalarFunction { - function_reference: function_anchor, - arguments: vec![FunctionArgument { - arg_type: Some(ArgType::Value(substrait_or_list)), - }], - output_type: None, - args: vec![], - options: vec![], - })), - }) - } else { - Ok(substrait_or_list) - } -} - -pub fn from_scalar_function( - producer: &mut impl SubstraitProducer, - fun: &expr::ScalarFunction, - schema: &DFSchemaRef, -) -> Result { - let mut arguments: Vec = vec![]; - for arg in &fun.args { - arguments.push(FunctionArgument { - arg_type: Some(ArgType::Value(producer.handle_expr(arg, schema)?)), - }); - } - - let function_anchor = producer.register_function(fun.name().to_string()); - #[allow(deprecated)] - Ok(Expression { - rex_type: Some(RexType::ScalarFunction(ScalarFunction { - function_reference: function_anchor, - arguments, - output_type: None, - options: vec![], - args: vec![], - })), - }) -} - -pub fn from_between( - producer: &mut impl SubstraitProducer, - between: &Between, - schema: &DFSchemaRef, -) -> Result { - let Between { - expr, - negated, - low, - high, - } = between; - if *negated { - // `expr NOT BETWEEN low AND high` can be translated into (expr < low OR high < expr) - let substrait_expr = producer.handle_expr(expr.as_ref(), schema)?; - let substrait_low = producer.handle_expr(low.as_ref(), schema)?; - let substrait_high = producer.handle_expr(high.as_ref(), schema)?; - - let l_expr = make_binary_op_scalar_func( - producer, - &substrait_expr, - &substrait_low, - Operator::Lt, - ); - let r_expr = make_binary_op_scalar_func( - producer, - &substrait_high, - &substrait_expr, - Operator::Lt, - ); - - Ok(make_binary_op_scalar_func( - producer, - &l_expr, - &r_expr, - Operator::Or, - )) - } else { - // `expr BETWEEN low AND high` can be translated into (low <= expr AND expr <= high) - let substrait_expr = producer.handle_expr(expr.as_ref(), schema)?; - let substrait_low = producer.handle_expr(low.as_ref(), schema)?; - let substrait_high = producer.handle_expr(high.as_ref(), schema)?; - - let l_expr = make_binary_op_scalar_func( - producer, - &substrait_low, - &substrait_expr, - Operator::LtEq, - ); - let r_expr = make_binary_op_scalar_func( - producer, - &substrait_expr, - &substrait_high, - Operator::LtEq, - ); - - Ok(make_binary_op_scalar_func( - producer, - &l_expr, - &r_expr, - Operator::And, - )) - } -} -pub fn from_column(col: &Column, schema: &DFSchemaRef) -> Result { - let index = schema.index_of_column(col)?; - substrait_field_ref(index) -} - -pub fn from_binary_expr( - producer: &mut impl SubstraitProducer, - expr: &BinaryExpr, - schema: &DFSchemaRef, -) -> Result { - let BinaryExpr { left, op, right } = expr; - let l = producer.handle_expr(left, schema)?; - let r = producer.handle_expr(right, schema)?; - Ok(make_binary_op_scalar_func(producer, &l, &r, *op)) -} -pub fn from_case( - producer: &mut impl SubstraitProducer, - case: &Case, - schema: &DFSchemaRef, -) -> Result { - let Case { - expr, - when_then_expr, - else_expr, - } = case; - let mut ifs: Vec = vec![]; - // Parse base - if let Some(e) = expr { - // Base expression exists - ifs.push(IfClause { - r#if: Some(producer.handle_expr(e, schema)?), - then: None, - }); - } - // Parse `when`s - for (r#if, then) in when_then_expr { - ifs.push(IfClause { - r#if: Some(producer.handle_expr(r#if, schema)?), - then: Some(producer.handle_expr(then, schema)?), - }); - } - - // Parse outer `else` - let r#else: Option> = match else_expr { - Some(e) => Some(Box::new(producer.handle_expr(e, schema)?)), - None => None, - }; - - Ok(Expression { - rex_type: Some(RexType::IfThen(Box::new(IfThen { ifs, r#else }))), - }) -} - -pub fn from_cast( - producer: &mut impl SubstraitProducer, - cast: &Cast, - schema: &DFSchemaRef, -) -> Result { - let Cast { expr, data_type } = cast; - Ok(Expression { - rex_type: Some(RexType::Cast(Box::new( - substrait::proto::expression::Cast { - r#type: Some(to_substrait_type(data_type, true)?), - input: Some(Box::new(producer.handle_expr(expr, schema)?)), - failure_behavior: FailureBehavior::ThrowException.into(), - }, - ))), - }) -} - -pub fn from_try_cast( - producer: &mut impl SubstraitProducer, - cast: &TryCast, - schema: &DFSchemaRef, -) -> Result { - let TryCast { expr, data_type } = cast; - Ok(Expression { - rex_type: Some(RexType::Cast(Box::new( - substrait::proto::expression::Cast { - r#type: Some(to_substrait_type(data_type, true)?), - input: Some(Box::new(producer.handle_expr(expr, schema)?)), - failure_behavior: FailureBehavior::ReturnNull.into(), - }, - ))), - }) -} - -pub fn from_literal( - producer: &mut impl SubstraitProducer, - value: &ScalarValue, -) -> Result { - to_substrait_literal_expr(producer, value) -} - -pub fn from_alias( - producer: &mut impl SubstraitProducer, - alias: &Alias, - schema: &DFSchemaRef, -) -> Result { - producer.handle_expr(alias.expr.as_ref(), schema) -} - -pub fn from_window_function( - producer: &mut impl SubstraitProducer, - window_fn: &WindowFunction, - schema: &DFSchemaRef, -) -> Result { - let WindowFunction { - fun, - params: - WindowFunctionParams { - args, - partition_by, - order_by, - window_frame, - null_treatment: _, - }, - } = window_fn; - // function reference - let function_anchor = producer.register_function(fun.to_string()); - // arguments - let mut arguments: Vec = vec![]; - for arg in args { - arguments.push(FunctionArgument { - arg_type: Some(ArgType::Value(producer.handle_expr(arg, schema)?)), - }); - } - // partition by expressions - let partition_by = partition_by - .iter() - .map(|e| producer.handle_expr(e, schema)) - .collect::>>()?; - // order by expressions - let order_by = order_by - .iter() - .map(|e| substrait_sort_field(producer, e, schema)) - .collect::>>()?; - // window frame - let bounds = to_substrait_bounds(window_frame)?; - let bound_type = to_substrait_bound_type(window_frame)?; - Ok(make_substrait_window_function( - function_anchor, - arguments, - partition_by, - order_by, - bounds, - bound_type, - )) -} - -pub fn from_like( - producer: &mut impl SubstraitProducer, - like: &Like, - schema: &DFSchemaRef, -) -> Result { - let Like { - negated, - expr, - pattern, - escape_char, - case_insensitive, - } = like; - make_substrait_like_expr( - producer, - *case_insensitive, - *negated, - expr, - pattern, - *escape_char, - schema, - ) -} - -pub fn from_in_subquery( - producer: &mut impl SubstraitProducer, - subquery: &InSubquery, - schema: &DFSchemaRef, -) -> Result { - let InSubquery { - expr, - subquery, - negated, - } = subquery; - let substrait_expr = producer.handle_expr(expr, schema)?; - - let subquery_plan = producer.handle_plan(subquery.subquery.as_ref())?; - - let substrait_subquery = Expression { - rex_type: Some(RexType::Subquery(Box::new( - substrait::proto::expression::Subquery { - subquery_type: Some( - substrait::proto::expression::subquery::SubqueryType::InPredicate( - Box::new(InPredicate { - needles: (vec![substrait_expr]), - haystack: Some(subquery_plan), - }), - ), - ), - }, - ))), - }; - if *negated { - let function_anchor = producer.register_function("not".to_string()); - - #[allow(deprecated)] - Ok(Expression { - rex_type: Some(RexType::ScalarFunction(ScalarFunction { - function_reference: function_anchor, - arguments: vec![FunctionArgument { - arg_type: Some(ArgType::Value(substrait_subquery)), - }], - output_type: None, - args: vec![], - options: vec![], - })), - }) - } else { - Ok(substrait_subquery) - } -} - -pub fn from_unary_expr( - producer: &mut impl SubstraitProducer, - expr: &Expr, - schema: &DFSchemaRef, -) -> Result { - let (fn_name, arg) = match expr { - Expr::Not(arg) => ("not", arg), - Expr::IsNull(arg) => ("is_null", arg), - Expr::IsNotNull(arg) => ("is_not_null", arg), - Expr::IsTrue(arg) => ("is_true", arg), - Expr::IsFalse(arg) => ("is_false", arg), - Expr::IsUnknown(arg) => ("is_unknown", arg), - Expr::IsNotTrue(arg) => ("is_not_true", arg), - Expr::IsNotFalse(arg) => ("is_not_false", arg), - Expr::IsNotUnknown(arg) => ("is_not_unknown", arg), - Expr::Negative(arg) => ("negate", arg), - expr => not_impl_err!("Unsupported expression: {expr:?}")?, - }; - to_substrait_unary_scalar_fn(producer, fn_name, arg, schema) -} - -fn to_substrait_type(dt: &DataType, nullable: bool) -> Result { - let nullability = if nullable { - r#type::Nullability::Nullable as i32 - } else { - r#type::Nullability::Required as i32 - }; - match dt { - DataType::Null => internal_err!("Null cast is not valid"), - DataType::Boolean => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::Bool(r#type::Boolean { - type_variation_reference: DEFAULT_TYPE_VARIATION_REF, - nullability, - })), - }), - DataType::Int8 => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::I8(r#type::I8 { - type_variation_reference: DEFAULT_TYPE_VARIATION_REF, - nullability, - })), - }), - DataType::UInt8 => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::I8(r#type::I8 { - type_variation_reference: UNSIGNED_INTEGER_TYPE_VARIATION_REF, - nullability, - })), - }), - DataType::Int16 => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::I16(r#type::I16 { - type_variation_reference: DEFAULT_TYPE_VARIATION_REF, - nullability, - })), - }), - DataType::UInt16 => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::I16(r#type::I16 { - type_variation_reference: UNSIGNED_INTEGER_TYPE_VARIATION_REF, - nullability, - })), - }), - DataType::Int32 => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::I32(r#type::I32 { - type_variation_reference: DEFAULT_TYPE_VARIATION_REF, - nullability, - })), - }), - DataType::UInt32 => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::I32(r#type::I32 { - type_variation_reference: UNSIGNED_INTEGER_TYPE_VARIATION_REF, - nullability, - })), - }), - DataType::Int64 => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::I64(r#type::I64 { - type_variation_reference: DEFAULT_TYPE_VARIATION_REF, - nullability, - })), - }), - DataType::UInt64 => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::I64(r#type::I64 { - type_variation_reference: UNSIGNED_INTEGER_TYPE_VARIATION_REF, - nullability, - })), - }), - // Float16 is not supported in Substrait - DataType::Float32 => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::Fp32(r#type::Fp32 { - type_variation_reference: DEFAULT_TYPE_VARIATION_REF, - nullability, - })), - }), - DataType::Float64 => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::Fp64(r#type::Fp64 { - type_variation_reference: DEFAULT_TYPE_VARIATION_REF, - nullability, - })), - }), - DataType::Timestamp(unit, tz) => { - let precision = match unit { - TimeUnit::Second => 0, - TimeUnit::Millisecond => 3, - TimeUnit::Microsecond => 6, - TimeUnit::Nanosecond => 9, - }; - let kind = match tz { - None => r#type::Kind::PrecisionTimestamp(r#type::PrecisionTimestamp { - type_variation_reference: DEFAULT_TYPE_VARIATION_REF, - nullability, - precision, - }), - Some(_) => { - // If timezone is present, no matter what the actual tz value is, it indicates the - // value of the timestamp is tied to UTC epoch. That's all that Substrait cares about. - // As the timezone is lost, this conversion may be lossy for downstream use of the value. - r#type::Kind::PrecisionTimestampTz(r#type::PrecisionTimestampTz { - type_variation_reference: DEFAULT_TYPE_VARIATION_REF, - nullability, - precision, - }) - } - }; - Ok(substrait::proto::Type { kind: Some(kind) }) - } - DataType::Date32 => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::Date(r#type::Date { - type_variation_reference: DATE_32_TYPE_VARIATION_REF, - nullability, - })), - }), - DataType::Date64 => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::Date(r#type::Date { - type_variation_reference: DATE_64_TYPE_VARIATION_REF, - nullability, - })), - }), - DataType::Interval(interval_unit) => { - match interval_unit { - IntervalUnit::YearMonth => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::IntervalYear(r#type::IntervalYear { - type_variation_reference: DEFAULT_TYPE_VARIATION_REF, - nullability, - })), - }), - IntervalUnit::DayTime => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::IntervalDay(r#type::IntervalDay { - type_variation_reference: DEFAULT_TYPE_VARIATION_REF, - nullability, - precision: Some(3), // DayTime precision is always milliseconds - })), - }), - IntervalUnit::MonthDayNano => { - Ok(substrait::proto::Type { - kind: Some(r#type::Kind::IntervalCompound( - r#type::IntervalCompound { - type_variation_reference: DEFAULT_TYPE_VARIATION_REF, - nullability, - precision: 9, // nanos - }, - )), - }) - } - } - } - DataType::Binary => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::Binary(r#type::Binary { - type_variation_reference: DEFAULT_CONTAINER_TYPE_VARIATION_REF, - nullability, - })), - }), - DataType::FixedSizeBinary(length) => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::FixedBinary(r#type::FixedBinary { - length: *length, - type_variation_reference: DEFAULT_TYPE_VARIATION_REF, - nullability, - })), - }), - DataType::LargeBinary => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::Binary(r#type::Binary { - type_variation_reference: LARGE_CONTAINER_TYPE_VARIATION_REF, - nullability, - })), - }), - DataType::BinaryView => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::Binary(r#type::Binary { - type_variation_reference: VIEW_CONTAINER_TYPE_VARIATION_REF, - nullability, - })), - }), - DataType::Utf8 => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::String(r#type::String { - type_variation_reference: DEFAULT_CONTAINER_TYPE_VARIATION_REF, - nullability, - })), - }), - DataType::LargeUtf8 => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::String(r#type::String { - type_variation_reference: LARGE_CONTAINER_TYPE_VARIATION_REF, - nullability, - })), - }), - DataType::Utf8View => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::String(r#type::String { - type_variation_reference: VIEW_CONTAINER_TYPE_VARIATION_REF, - nullability, - })), - }), - DataType::List(inner) => { - let inner_type = to_substrait_type(inner.data_type(), inner.is_nullable())?; - Ok(substrait::proto::Type { - kind: Some(r#type::Kind::List(Box::new(r#type::List { - r#type: Some(Box::new(inner_type)), - type_variation_reference: DEFAULT_CONTAINER_TYPE_VARIATION_REF, - nullability, - }))), - }) - } - DataType::LargeList(inner) => { - let inner_type = to_substrait_type(inner.data_type(), inner.is_nullable())?; - Ok(substrait::proto::Type { - kind: Some(r#type::Kind::List(Box::new(r#type::List { - r#type: Some(Box::new(inner_type)), - type_variation_reference: LARGE_CONTAINER_TYPE_VARIATION_REF, - nullability, - }))), - }) - } - DataType::Map(inner, _) => match inner.data_type() { - DataType::Struct(key_and_value) if key_and_value.len() == 2 => { - let key_type = to_substrait_type( - key_and_value[0].data_type(), - key_and_value[0].is_nullable(), - )?; - let value_type = to_substrait_type( - key_and_value[1].data_type(), - key_and_value[1].is_nullable(), - )?; - Ok(substrait::proto::Type { - kind: Some(r#type::Kind::Map(Box::new(r#type::Map { - key: Some(Box::new(key_type)), - value: Some(Box::new(value_type)), - type_variation_reference: DEFAULT_CONTAINER_TYPE_VARIATION_REF, - nullability, - }))), - }) - } - _ => plan_err!("Map fields must contain a Struct with exactly 2 fields"), - }, - DataType::Struct(fields) => { - let field_types = fields - .iter() - .map(|field| to_substrait_type(field.data_type(), field.is_nullable())) - .collect::>>()?; - Ok(substrait::proto::Type { - kind: Some(r#type::Kind::Struct(r#type::Struct { - types: field_types, - type_variation_reference: DEFAULT_TYPE_VARIATION_REF, - nullability, - })), - }) - } - DataType::Decimal128(p, s) => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::Decimal(r#type::Decimal { - type_variation_reference: DECIMAL_128_TYPE_VARIATION_REF, - nullability, - scale: *s as i32, - precision: *p as i32, - })), - }), - DataType::Decimal256(p, s) => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::Decimal(r#type::Decimal { - type_variation_reference: DECIMAL_256_TYPE_VARIATION_REF, - nullability, - scale: *s as i32, - precision: *p as i32, - })), - }), - _ => not_impl_err!("Unsupported cast type: {dt:?}"), - } -} - -fn make_substrait_window_function( - function_reference: u32, - arguments: Vec, - partitions: Vec, - sorts: Vec, - bounds: (Bound, Bound), - bounds_type: BoundsType, -) -> Expression { - #[allow(deprecated)] - Expression { - rex_type: Some(RexType::WindowFunction(SubstraitWindowFunction { - function_reference, - arguments, - partitions, - sorts, - options: vec![], - output_type: None, - phase: 0, // default to AGGREGATION_PHASE_UNSPECIFIED - invocation: 0, // TODO: fix - lower_bound: Some(bounds.0), - upper_bound: Some(bounds.1), - args: vec![], - bounds_type: bounds_type as i32, - })), - } -} - -fn make_substrait_like_expr( - producer: &mut impl SubstraitProducer, - ignore_case: bool, - negated: bool, - expr: &Expr, - pattern: &Expr, - escape_char: Option, - schema: &DFSchemaRef, -) -> Result { - let function_anchor = if ignore_case { - producer.register_function("ilike".to_string()) - } else { - producer.register_function("like".to_string()) - }; - let expr = producer.handle_expr(expr, schema)?; - let pattern = producer.handle_expr(pattern, schema)?; - let escape_char = to_substrait_literal_expr( - producer, - &ScalarValue::Utf8(escape_char.map(|c| c.to_string())), - )?; - let arguments = vec![ - FunctionArgument { - arg_type: Some(ArgType::Value(expr)), - }, - FunctionArgument { - arg_type: Some(ArgType::Value(pattern)), - }, - FunctionArgument { - arg_type: Some(ArgType::Value(escape_char)), - }, - ]; - - #[allow(deprecated)] - let substrait_like = Expression { - rex_type: Some(RexType::ScalarFunction(ScalarFunction { - function_reference: function_anchor, - arguments, - output_type: None, - args: vec![], - options: vec![], - })), - }; - - if negated { - let function_anchor = producer.register_function("not".to_string()); - - #[allow(deprecated)] - Ok(Expression { - rex_type: Some(RexType::ScalarFunction(ScalarFunction { - function_reference: function_anchor, - arguments: vec![FunctionArgument { - arg_type: Some(ArgType::Value(substrait_like)), - }], - output_type: None, - args: vec![], - options: vec![], - })), - }) - } else { - Ok(substrait_like) - } -} - -fn to_substrait_bound_offset(value: &ScalarValue) -> Option { - match value { - ScalarValue::UInt8(Some(v)) => Some(*v as i64), - ScalarValue::UInt16(Some(v)) => Some(*v as i64), - ScalarValue::UInt32(Some(v)) => Some(*v as i64), - ScalarValue::UInt64(Some(v)) => Some(*v as i64), - ScalarValue::Int8(Some(v)) => Some(*v as i64), - ScalarValue::Int16(Some(v)) => Some(*v as i64), - ScalarValue::Int32(Some(v)) => Some(*v as i64), - ScalarValue::Int64(Some(v)) => Some(*v), - _ => None, - } -} - -fn to_substrait_bound(bound: &WindowFrameBound) -> Bound { - match bound { - WindowFrameBound::CurrentRow => Bound { - kind: Some(BoundKind::CurrentRow(SubstraitBound::CurrentRow {})), - }, - WindowFrameBound::Preceding(s) => match to_substrait_bound_offset(s) { - Some(offset) => Bound { - kind: Some(BoundKind::Preceding(SubstraitBound::Preceding { offset })), - }, - None => Bound { - kind: Some(BoundKind::Unbounded(SubstraitBound::Unbounded {})), - }, - }, - WindowFrameBound::Following(s) => match to_substrait_bound_offset(s) { - Some(offset) => Bound { - kind: Some(BoundKind::Following(SubstraitBound::Following { offset })), - }, - None => Bound { - kind: Some(BoundKind::Unbounded(SubstraitBound::Unbounded {})), - }, - }, - } -} - -fn to_substrait_bound_type(window_frame: &WindowFrame) -> Result { - match window_frame.units { - WindowFrameUnits::Rows => Ok(BoundsType::Rows), // ROWS - WindowFrameUnits::Range => Ok(BoundsType::Range), // RANGE - // TODO: Support GROUPS - unit => not_impl_err!("Unsupported window frame unit: {unit:?}"), - } -} - -fn to_substrait_bounds(window_frame: &WindowFrame) -> Result<(Bound, Bound)> { - Ok(( - to_substrait_bound(&window_frame.start_bound), - to_substrait_bound(&window_frame.end_bound), - )) -} - -fn to_substrait_literal( - producer: &mut impl SubstraitProducer, - value: &ScalarValue, -) -> Result { - if value.is_null() { - return Ok(Literal { - nullable: true, - type_variation_reference: DEFAULT_TYPE_VARIATION_REF, - literal_type: Some(LiteralType::Null(to_substrait_type( - &value.data_type(), - true, - )?)), - }); - } - let (literal_type, type_variation_reference) = match value { - ScalarValue::Boolean(Some(b)) => { - (LiteralType::Boolean(*b), DEFAULT_TYPE_VARIATION_REF) - } - ScalarValue::Int8(Some(n)) => { - (LiteralType::I8(*n as i32), DEFAULT_TYPE_VARIATION_REF) - } - ScalarValue::UInt8(Some(n)) => ( - LiteralType::I8(*n as i32), - UNSIGNED_INTEGER_TYPE_VARIATION_REF, - ), - ScalarValue::Int16(Some(n)) => { - (LiteralType::I16(*n as i32), DEFAULT_TYPE_VARIATION_REF) - } - ScalarValue::UInt16(Some(n)) => ( - LiteralType::I16(*n as i32), - UNSIGNED_INTEGER_TYPE_VARIATION_REF, - ), - ScalarValue::Int32(Some(n)) => (LiteralType::I32(*n), DEFAULT_TYPE_VARIATION_REF), - ScalarValue::UInt32(Some(n)) => ( - LiteralType::I32(*n as i32), - UNSIGNED_INTEGER_TYPE_VARIATION_REF, - ), - ScalarValue::Int64(Some(n)) => (LiteralType::I64(*n), DEFAULT_TYPE_VARIATION_REF), - ScalarValue::UInt64(Some(n)) => ( - LiteralType::I64(*n as i64), - UNSIGNED_INTEGER_TYPE_VARIATION_REF, - ), - ScalarValue::Float32(Some(f)) => { - (LiteralType::Fp32(*f), DEFAULT_TYPE_VARIATION_REF) - } - ScalarValue::Float64(Some(f)) => { - (LiteralType::Fp64(*f), DEFAULT_TYPE_VARIATION_REF) - } - ScalarValue::TimestampSecond(Some(t), None) => ( - LiteralType::PrecisionTimestamp(PrecisionTimestamp { - precision: 0, - value: *t, - }), - DEFAULT_TYPE_VARIATION_REF, - ), - ScalarValue::TimestampMillisecond(Some(t), None) => ( - LiteralType::PrecisionTimestamp(PrecisionTimestamp { - precision: 3, - value: *t, - }), - DEFAULT_TYPE_VARIATION_REF, - ), - ScalarValue::TimestampMicrosecond(Some(t), None) => ( - LiteralType::PrecisionTimestamp(PrecisionTimestamp { - precision: 6, - value: *t, - }), - DEFAULT_TYPE_VARIATION_REF, - ), - ScalarValue::TimestampNanosecond(Some(t), None) => ( - LiteralType::PrecisionTimestamp(PrecisionTimestamp { - precision: 9, - value: *t, - }), - DEFAULT_TYPE_VARIATION_REF, - ), - // If timezone is present, no matter what the actual tz value is, it indicates the - // value of the timestamp is tied to UTC epoch. That's all that Substrait cares about. - // As the timezone is lost, this conversion may be lossy for downstream use of the value. - ScalarValue::TimestampSecond(Some(t), Some(_)) => ( - LiteralType::PrecisionTimestampTz(PrecisionTimestamp { - precision: 0, - value: *t, - }), - DEFAULT_TYPE_VARIATION_REF, - ), - ScalarValue::TimestampMillisecond(Some(t), Some(_)) => ( - LiteralType::PrecisionTimestampTz(PrecisionTimestamp { - precision: 3, - value: *t, - }), - DEFAULT_TYPE_VARIATION_REF, - ), - ScalarValue::TimestampMicrosecond(Some(t), Some(_)) => ( - LiteralType::PrecisionTimestampTz(PrecisionTimestamp { - precision: 6, - value: *t, - }), - DEFAULT_TYPE_VARIATION_REF, - ), - ScalarValue::TimestampNanosecond(Some(t), Some(_)) => ( - LiteralType::PrecisionTimestampTz(PrecisionTimestamp { - precision: 9, - value: *t, - }), - DEFAULT_TYPE_VARIATION_REF, - ), - ScalarValue::Date32(Some(d)) => { - (LiteralType::Date(*d), DATE_32_TYPE_VARIATION_REF) - } - // Date64 literal is not supported in Substrait - ScalarValue::IntervalYearMonth(Some(i)) => ( - LiteralType::IntervalYearToMonth(IntervalYearToMonth { - // DF only tracks total months, but there should always be 12 months in a year - years: *i / 12, - months: *i % 12, - }), - DEFAULT_TYPE_VARIATION_REF, - ), - ScalarValue::IntervalMonthDayNano(Some(i)) => ( - LiteralType::IntervalCompound(IntervalCompound { - interval_year_to_month: Some(IntervalYearToMonth { - years: i.months / 12, - months: i.months % 12, - }), - interval_day_to_second: Some(IntervalDayToSecond { - days: i.days, - seconds: (i.nanoseconds / NANOSECONDS) as i32, - subseconds: i.nanoseconds % NANOSECONDS, - precision_mode: Some(PrecisionMode::Precision(9)), // nanoseconds - }), - }), - DEFAULT_TYPE_VARIATION_REF, - ), - ScalarValue::IntervalDayTime(Some(i)) => ( - LiteralType::IntervalDayToSecond(IntervalDayToSecond { - days: i.days, - seconds: i.milliseconds / 1000, - subseconds: (i.milliseconds % 1000) as i64, - precision_mode: Some(PrecisionMode::Precision(3)), // 3 for milliseconds - }), - DEFAULT_TYPE_VARIATION_REF, - ), - ScalarValue::Binary(Some(b)) => ( - LiteralType::Binary(b.clone()), - DEFAULT_CONTAINER_TYPE_VARIATION_REF, - ), - ScalarValue::LargeBinary(Some(b)) => ( - LiteralType::Binary(b.clone()), - LARGE_CONTAINER_TYPE_VARIATION_REF, - ), - ScalarValue::BinaryView(Some(b)) => ( - LiteralType::Binary(b.clone()), - VIEW_CONTAINER_TYPE_VARIATION_REF, - ), - ScalarValue::FixedSizeBinary(_, Some(b)) => ( - LiteralType::FixedBinary(b.clone()), - DEFAULT_TYPE_VARIATION_REF, - ), - ScalarValue::Utf8(Some(s)) => ( - LiteralType::String(s.clone()), - DEFAULT_CONTAINER_TYPE_VARIATION_REF, - ), - ScalarValue::LargeUtf8(Some(s)) => ( - LiteralType::String(s.clone()), - LARGE_CONTAINER_TYPE_VARIATION_REF, - ), - ScalarValue::Utf8View(Some(s)) => ( - LiteralType::String(s.clone()), - VIEW_CONTAINER_TYPE_VARIATION_REF, - ), - ScalarValue::Decimal128(v, p, s) if v.is_some() => ( - LiteralType::Decimal(Decimal { - value: v.unwrap().to_le_bytes().to_vec(), - precision: *p as i32, - scale: *s as i32, - }), - DECIMAL_128_TYPE_VARIATION_REF, - ), - ScalarValue::List(l) => ( - convert_array_to_literal_list(producer, l)?, - DEFAULT_CONTAINER_TYPE_VARIATION_REF, - ), - ScalarValue::LargeList(l) => ( - convert_array_to_literal_list(producer, l)?, - LARGE_CONTAINER_TYPE_VARIATION_REF, - ), - ScalarValue::Map(m) => { - let map = if m.is_empty() || m.value(0).is_empty() { - let mt = to_substrait_type(m.data_type(), m.is_nullable())?; - let mt = match mt { - substrait::proto::Type { - kind: Some(r#type::Kind::Map(mt)), - } => Ok(mt.as_ref().to_owned()), - _ => exec_err!("Unexpected type for a map: {mt:?}"), - }?; - LiteralType::EmptyMap(mt) - } else { - let keys = (0..m.keys().len()) - .map(|i| { - to_substrait_literal( - producer, - &ScalarValue::try_from_array(&m.keys(), i)?, - ) - }) - .collect::>>()?; - let values = (0..m.values().len()) - .map(|i| { - to_substrait_literal( - producer, - &ScalarValue::try_from_array(&m.values(), i)?, - ) - }) - .collect::>>()?; - - let key_values = keys - .into_iter() - .zip(values.into_iter()) - .map(|(k, v)| { - Ok(KeyValue { - key: Some(k), - value: Some(v), - }) - }) - .collect::>>()?; - LiteralType::Map(Map { key_values }) - }; - (map, DEFAULT_CONTAINER_TYPE_VARIATION_REF) - } - ScalarValue::Struct(s) => ( - LiteralType::Struct(Struct { - fields: s - .columns() - .iter() - .map(|col| { - to_substrait_literal( - producer, - &ScalarValue::try_from_array(col, 0)?, - ) - }) - .collect::>>()?, - }), - DEFAULT_TYPE_VARIATION_REF, - ), - _ => ( - not_impl_err!("Unsupported literal: {value:?}")?, - DEFAULT_TYPE_VARIATION_REF, - ), - }; - - Ok(Literal { - nullable: false, - type_variation_reference, - literal_type: Some(literal_type), - }) -} - -fn convert_array_to_literal_list( - producer: &mut impl SubstraitProducer, - array: &GenericListArray, -) -> Result { - assert_eq!(array.len(), 1); - let nested_array = array.value(0); - - let values = (0..nested_array.len()) - .map(|i| { - to_substrait_literal( - producer, - &ScalarValue::try_from_array(&nested_array, i)?, - ) - }) - .collect::>>()?; - - if values.is_empty() { - let lt = match to_substrait_type(array.data_type(), array.is_nullable())? { - substrait::proto::Type { - kind: Some(r#type::Kind::List(lt)), - } => lt.as_ref().to_owned(), - _ => unreachable!(), - }; - Ok(LiteralType::EmptyList(lt)) - } else { - Ok(LiteralType::List(List { values })) - } -} - -fn to_substrait_literal_expr( - producer: &mut impl SubstraitProducer, - value: &ScalarValue, -) -> Result { - let literal = to_substrait_literal(producer, value)?; - Ok(Expression { - rex_type: Some(RexType::Literal(literal)), - }) -} - -/// Util to generate substrait [RexType::ScalarFunction] with one argument -fn to_substrait_unary_scalar_fn( - producer: &mut impl SubstraitProducer, - fn_name: &str, - arg: &Expr, - schema: &DFSchemaRef, -) -> Result { - let function_anchor = producer.register_function(fn_name.to_string()); - let substrait_expr = producer.handle_expr(arg, schema)?; - - Ok(Expression { - rex_type: Some(RexType::ScalarFunction(ScalarFunction { - function_reference: function_anchor, - arguments: vec![FunctionArgument { - arg_type: Some(ArgType::Value(substrait_expr)), - }], - output_type: None, - options: vec![], - ..Default::default() - })), - }) -} - -/// Try to convert an [Expr] to a [FieldReference]. -/// Returns `Err` if the [Expr] is not a [Expr::Column]. -fn try_to_substrait_field_reference( - expr: &Expr, - schema: &DFSchemaRef, -) -> Result { - match expr { - Expr::Column(col) => { - let index = schema.index_of_column(col)?; - Ok(FieldReference { - reference_type: Some(ReferenceType::DirectReference(ReferenceSegment { - reference_type: Some(reference_segment::ReferenceType::StructField( - Box::new(reference_segment::StructField { - field: index as i32, - child: None, - }), - )), - })), - root_type: Some(RootType::RootReference(RootReference {})), - }) - } - _ => substrait_err!("Expect a `Column` expr, but found {expr:?}"), - } -} - -fn substrait_sort_field( - producer: &mut impl SubstraitProducer, - sort: &SortExpr, - schema: &DFSchemaRef, -) -> Result { - let SortExpr { - expr, - asc, - nulls_first, - } = sort; - let e = producer.handle_expr(expr, schema)?; - let d = match (asc, nulls_first) { - (true, true) => SortDirection::AscNullsFirst, - (true, false) => SortDirection::AscNullsLast, - (false, true) => SortDirection::DescNullsFirst, - (false, false) => SortDirection::DescNullsLast, - }; - Ok(SortField { - expr: Some(e), - sort_kind: Some(SortKind::Direction(d as i32)), - }) -} - -fn substrait_field_ref(index: usize) -> Result { - Ok(Expression { - rex_type: Some(RexType::Selection(Box::new(FieldReference { - reference_type: Some(ReferenceType::DirectReference(ReferenceSegment { - reference_type: Some(reference_segment::ReferenceType::StructField( - Box::new(reference_segment::StructField { - field: index as i32, - child: None, - }), - )), - })), - root_type: Some(RootType::RootReference(RootReference {})), - }))), - }) -} - -#[cfg(test)] -mod test { - use super::*; - use crate::logical_plan::consumer::{ - from_substrait_extended_expr, from_substrait_literal_without_names, - from_substrait_named_struct, from_substrait_type_without_names, - DefaultSubstraitConsumer, - }; - use arrow::array::types::{IntervalDayTime, IntervalMonthDayNano}; - use datafusion::arrow; - use datafusion::arrow::array::{ - GenericListArray, Int64Builder, MapBuilder, StringBuilder, - }; - use datafusion::arrow::datatypes::{Field, Fields, Schema}; - use datafusion::common::scalar::ScalarStructBuilder; - use datafusion::common::DFSchema; - use datafusion::execution::{SessionState, SessionStateBuilder}; - use datafusion::prelude::SessionContext; - use std::sync::LazyLock; - - static TEST_SESSION_STATE: LazyLock = - LazyLock::new(|| SessionContext::default().state()); - static TEST_EXTENSIONS: LazyLock = LazyLock::new(Extensions::default); - fn test_consumer() -> DefaultSubstraitConsumer<'static> { - let extensions = &TEST_EXTENSIONS; - let state = &TEST_SESSION_STATE; - DefaultSubstraitConsumer::new(extensions, state) - } - - #[test] - fn round_trip_literals() -> Result<()> { - round_trip_literal(ScalarValue::Boolean(None))?; - round_trip_literal(ScalarValue::Boolean(Some(true)))?; - round_trip_literal(ScalarValue::Boolean(Some(false)))?; - - round_trip_literal(ScalarValue::Int8(None))?; - round_trip_literal(ScalarValue::Int8(Some(i8::MIN)))?; - round_trip_literal(ScalarValue::Int8(Some(i8::MAX)))?; - round_trip_literal(ScalarValue::UInt8(None))?; - round_trip_literal(ScalarValue::UInt8(Some(u8::MIN)))?; - round_trip_literal(ScalarValue::UInt8(Some(u8::MAX)))?; - - round_trip_literal(ScalarValue::Int16(None))?; - round_trip_literal(ScalarValue::Int16(Some(i16::MIN)))?; - round_trip_literal(ScalarValue::Int16(Some(i16::MAX)))?; - round_trip_literal(ScalarValue::UInt16(None))?; - round_trip_literal(ScalarValue::UInt16(Some(u16::MIN)))?; - round_trip_literal(ScalarValue::UInt16(Some(u16::MAX)))?; - - round_trip_literal(ScalarValue::Int32(None))?; - round_trip_literal(ScalarValue::Int32(Some(i32::MIN)))?; - round_trip_literal(ScalarValue::Int32(Some(i32::MAX)))?; - round_trip_literal(ScalarValue::UInt32(None))?; - round_trip_literal(ScalarValue::UInt32(Some(u32::MIN)))?; - round_trip_literal(ScalarValue::UInt32(Some(u32::MAX)))?; - - round_trip_literal(ScalarValue::Int64(None))?; - round_trip_literal(ScalarValue::Int64(Some(i64::MIN)))?; - round_trip_literal(ScalarValue::Int64(Some(i64::MAX)))?; - round_trip_literal(ScalarValue::UInt64(None))?; - round_trip_literal(ScalarValue::UInt64(Some(u64::MIN)))?; - round_trip_literal(ScalarValue::UInt64(Some(u64::MAX)))?; - - for (ts, tz) in [ - (Some(12345), None), - (None, None), - (Some(12345), Some("UTC".into())), - (None, Some("UTC".into())), - ] { - round_trip_literal(ScalarValue::TimestampSecond(ts, tz.clone()))?; - round_trip_literal(ScalarValue::TimestampMillisecond(ts, tz.clone()))?; - round_trip_literal(ScalarValue::TimestampMicrosecond(ts, tz.clone()))?; - round_trip_literal(ScalarValue::TimestampNanosecond(ts, tz))?; - } - - round_trip_literal(ScalarValue::List(ScalarValue::new_list_nullable( - &[ScalarValue::Float32(Some(1.0))], - &DataType::Float32, - )))?; - round_trip_literal(ScalarValue::List(ScalarValue::new_list_nullable( - &[], - &DataType::Float32, - )))?; - round_trip_literal(ScalarValue::List(Arc::new(GenericListArray::new_null( - Field::new_list_field(DataType::Float32, true).into(), - 1, - ))))?; - round_trip_literal(ScalarValue::LargeList(ScalarValue::new_large_list( - &[ScalarValue::Float32(Some(1.0))], - &DataType::Float32, - )))?; - round_trip_literal(ScalarValue::LargeList(ScalarValue::new_large_list( - &[], - &DataType::Float32, - )))?; - round_trip_literal(ScalarValue::LargeList(Arc::new( - GenericListArray::new_null( - Field::new_list_field(DataType::Float32, true).into(), - 1, - ), - )))?; - - // Null map - let mut map_builder = - MapBuilder::new(None, StringBuilder::new(), Int64Builder::new()); - map_builder.append(false)?; - round_trip_literal(ScalarValue::Map(Arc::new(map_builder.finish())))?; - - // Empty map - let mut map_builder = - MapBuilder::new(None, StringBuilder::new(), Int64Builder::new()); - map_builder.append(true)?; - round_trip_literal(ScalarValue::Map(Arc::new(map_builder.finish())))?; - - // Valid map - let mut map_builder = - MapBuilder::new(None, StringBuilder::new(), Int64Builder::new()); - map_builder.keys().append_value("key1"); - map_builder.keys().append_value("key2"); - map_builder.values().append_value(1); - map_builder.values().append_value(2); - map_builder.append(true)?; - round_trip_literal(ScalarValue::Map(Arc::new(map_builder.finish())))?; - - let c0 = Field::new("c0", DataType::Boolean, true); - let c1 = Field::new("c1", DataType::Int32, true); - let c2 = Field::new("c2", DataType::Utf8, true); - round_trip_literal( - ScalarStructBuilder::new() - .with_scalar(c0.to_owned(), ScalarValue::Boolean(Some(true))) - .with_scalar(c1.to_owned(), ScalarValue::Int32(Some(1))) - .with_scalar(c2.to_owned(), ScalarValue::Utf8(None)) - .build()?, - )?; - round_trip_literal(ScalarStructBuilder::new_null(vec![c0, c1, c2]))?; - - round_trip_literal(ScalarValue::IntervalYearMonth(Some(17)))?; - round_trip_literal(ScalarValue::IntervalMonthDayNano(Some( - IntervalMonthDayNano::new(17, 25, 1234567890), - )))?; - round_trip_literal(ScalarValue::IntervalDayTime(Some(IntervalDayTime::new( - 57, 123456, - ))))?; - - Ok(()) - } - - fn round_trip_literal(scalar: ScalarValue) -> Result<()> { - println!("Checking round trip of {scalar:?}"); - let state = SessionContext::default().state(); - let mut producer = DefaultSubstraitProducer::new(&state); - let substrait_literal = to_substrait_literal(&mut producer, &scalar)?; - let roundtrip_scalar = - from_substrait_literal_without_names(&test_consumer(), &substrait_literal)?; - assert_eq!(scalar, roundtrip_scalar); - Ok(()) - } - - #[test] - fn round_trip_types() -> Result<()> { - round_trip_type(DataType::Boolean)?; - round_trip_type(DataType::Int8)?; - round_trip_type(DataType::UInt8)?; - round_trip_type(DataType::Int16)?; - round_trip_type(DataType::UInt16)?; - round_trip_type(DataType::Int32)?; - round_trip_type(DataType::UInt32)?; - round_trip_type(DataType::Int64)?; - round_trip_type(DataType::UInt64)?; - round_trip_type(DataType::Float32)?; - round_trip_type(DataType::Float64)?; - - for tz in [None, Some("UTC".into())] { - round_trip_type(DataType::Timestamp(TimeUnit::Second, tz.clone()))?; - round_trip_type(DataType::Timestamp(TimeUnit::Millisecond, tz.clone()))?; - round_trip_type(DataType::Timestamp(TimeUnit::Microsecond, tz.clone()))?; - round_trip_type(DataType::Timestamp(TimeUnit::Nanosecond, tz))?; - } - - round_trip_type(DataType::Date32)?; - round_trip_type(DataType::Date64)?; - round_trip_type(DataType::Binary)?; - round_trip_type(DataType::FixedSizeBinary(10))?; - round_trip_type(DataType::LargeBinary)?; - round_trip_type(DataType::BinaryView)?; - round_trip_type(DataType::Utf8)?; - round_trip_type(DataType::LargeUtf8)?; - round_trip_type(DataType::Utf8View)?; - round_trip_type(DataType::Decimal128(10, 2))?; - round_trip_type(DataType::Decimal256(30, 2))?; - - round_trip_type(DataType::List( - Field::new_list_field(DataType::Int32, true).into(), - ))?; - round_trip_type(DataType::LargeList( - Field::new_list_field(DataType::Int32, true).into(), - ))?; - - round_trip_type(DataType::Map( - Field::new_struct( - "entries", - [ - Field::new("key", DataType::Utf8, false).into(), - Field::new("value", DataType::Int32, true).into(), - ], - false, - ) - .into(), - false, - ))?; - - round_trip_type(DataType::Struct( - vec![ - Field::new("c0", DataType::Int32, true), - Field::new("c1", DataType::Utf8, true), - ] - .into(), - ))?; - - round_trip_type(DataType::Interval(IntervalUnit::YearMonth))?; - round_trip_type(DataType::Interval(IntervalUnit::MonthDayNano))?; - round_trip_type(DataType::Interval(IntervalUnit::DayTime))?; - - Ok(()) - } - - fn round_trip_type(dt: DataType) -> Result<()> { - println!("Checking round trip of {dt:?}"); - - // As DataFusion doesn't consider nullability as a property of the type, but field, - // it doesn't matter if we set nullability to true or false here. - let substrait = to_substrait_type(&dt, true)?; - let consumer = test_consumer(); - let roundtrip_dt = from_substrait_type_without_names(&consumer, &substrait)?; - assert_eq!(dt, roundtrip_dt); - Ok(()) - } - - #[test] - fn to_field_reference() -> Result<()> { - let expression = substrait_field_ref(2)?; - - match &expression.rex_type { - Some(RexType::Selection(field_ref)) => { - assert_eq!( - field_ref - .root_type - .clone() - .expect("root type should be set"), - RootType::RootReference(RootReference {}) - ); - } - - _ => panic!("Should not be anything other than field reference"), - } - Ok(()) - } - - #[test] - fn named_struct_names() -> Result<()> { - let schema = DFSchemaRef::new(DFSchema::try_from(Schema::new(vec![ - Field::new("int", DataType::Int32, true), - Field::new( - "struct", - DataType::Struct(Fields::from(vec![Field::new( - "inner", - DataType::List(Arc::new(Field::new_list_field(DataType::Utf8, true))), - true, - )])), - true, - ), - Field::new("trailer", DataType::Float64, true), - ]))?); - - let named_struct = to_substrait_named_struct(&schema)?; - - // Struct field names should be flattened DFS style - // List field names should be omitted - assert_eq!( - named_struct.names, - vec!["int", "struct", "inner", "trailer"] - ); - - let roundtrip_schema = - from_substrait_named_struct(&test_consumer(), &named_struct)?; - assert_eq!(schema.as_ref(), &roundtrip_schema); - Ok(()) - } - - #[tokio::test] - async fn extended_expressions() -> Result<()> { - let state = SessionStateBuilder::default().build(); - - // One expression, empty input schema - let expr = Expr::Literal(ScalarValue::Int32(Some(42))); - let field = Field::new("out", DataType::Int32, false); - let empty_schema = DFSchemaRef::new(DFSchema::empty()); - let substrait = - to_substrait_extended_expr(&[(&expr, &field)], &empty_schema, &state)?; - let roundtrip_expr = from_substrait_extended_expr(&state, &substrait).await?; - - assert_eq!(roundtrip_expr.input_schema, empty_schema); - assert_eq!(roundtrip_expr.exprs.len(), 1); - - let (rt_expr, rt_field) = roundtrip_expr.exprs.first().unwrap(); - assert_eq!(rt_field, &field); - assert_eq!(rt_expr, &expr); - - // Multiple expressions, with column references - let expr1 = Expr::Column("c0".into()); - let expr2 = Expr::Column("c1".into()); - let out1 = Field::new("out1", DataType::Int32, true); - let out2 = Field::new("out2", DataType::Utf8, true); - let input_schema = DFSchemaRef::new(DFSchema::try_from(Schema::new(vec![ - Field::new("c0", DataType::Int32, true), - Field::new("c1", DataType::Utf8, true), - ]))?); - - let substrait = to_substrait_extended_expr( - &[(&expr1, &out1), (&expr2, &out2)], - &input_schema, - &state, - )?; - let roundtrip_expr = from_substrait_extended_expr(&state, &substrait).await?; - - assert_eq!(roundtrip_expr.input_schema, input_schema); - assert_eq!(roundtrip_expr.exprs.len(), 2); - - let mut exprs = roundtrip_expr.exprs.into_iter(); - - let (rt_expr, rt_field) = exprs.next().unwrap(); - assert_eq!(rt_field, out1); - assert_eq!(rt_expr, expr1); - - let (rt_expr, rt_field) = exprs.next().unwrap(); - assert_eq!(rt_field, out2); - assert_eq!(rt_expr, expr2); - - Ok(()) - } - - #[tokio::test] - async fn invalid_extended_expression() { - let state = SessionStateBuilder::default().build(); - - // Not ok if input schema is missing field referenced by expr - let expr = Expr::Column("missing".into()); - let field = Field::new("out", DataType::Int32, false); - let empty_schema = DFSchemaRef::new(DFSchema::empty()); - - let err = to_substrait_extended_expr(&[(&expr, &field)], &empty_schema, &state); - - assert!(matches!(err, Err(DataFusionError::SchemaError(_, _)))); - } -} diff --git a/datafusion/substrait/src/logical_plan/producer/expr/aggregate_function.rs b/datafusion/substrait/src/logical_plan/producer/expr/aggregate_function.rs new file mode 100644 index 000000000000..0619b497532d --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/expr/aggregate_function.rs @@ -0,0 +1,99 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::producer::SubstraitProducer; +use datafusion::common::DFSchemaRef; +use datafusion::logical_expr::expr; +use datafusion::logical_expr::expr::AggregateFunctionParams; +use substrait::proto::aggregate_function::AggregationInvocation; +use substrait::proto::aggregate_rel::Measure; +use substrait::proto::function_argument::ArgType; +use substrait::proto::sort_field::{SortDirection, SortKind}; +use substrait::proto::{ + AggregateFunction, AggregationPhase, FunctionArgument, SortField, +}; + +pub fn from_aggregate_function( + producer: &mut impl SubstraitProducer, + agg_fn: &expr::AggregateFunction, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + let expr::AggregateFunction { + func, + params: + AggregateFunctionParams { + args, + distinct, + filter, + order_by, + null_treatment: _null_treatment, + }, + } = agg_fn; + let sorts = if let Some(order_by) = order_by { + order_by + .iter() + .map(|expr| to_substrait_sort_field(producer, expr, schema)) + .collect::>>()? + } else { + vec![] + }; + let mut arguments: Vec = vec![]; + for arg in args { + arguments.push(FunctionArgument { + arg_type: Some(ArgType::Value(producer.handle_expr(arg, schema)?)), + }); + } + let function_anchor = producer.register_function(func.name().to_string()); + #[allow(deprecated)] + Ok(Measure { + measure: Some(AggregateFunction { + function_reference: function_anchor, + arguments, + sorts, + output_type: None, + invocation: match distinct { + true => AggregationInvocation::Distinct as i32, + false => AggregationInvocation::All as i32, + }, + phase: AggregationPhase::Unspecified as i32, + args: vec![], + options: vec![], + }), + filter: match filter { + Some(f) => Some(producer.handle_expr(f, schema)?), + None => None, + }, + }) +} + +/// Converts sort expression to corresponding substrait `SortField` +fn to_substrait_sort_field( + producer: &mut impl SubstraitProducer, + sort: &expr::Sort, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + let sort_kind = match (sort.asc, sort.nulls_first) { + (true, true) => SortDirection::AscNullsFirst, + (true, false) => SortDirection::AscNullsLast, + (false, true) => SortDirection::DescNullsFirst, + (false, false) => SortDirection::DescNullsLast, + }; + Ok(SortField { + expr: Some(producer.handle_expr(&sort.expr, schema)?), + sort_kind: Some(SortKind::Direction(sort_kind.into())), + }) +} diff --git a/datafusion/substrait/src/logical_plan/producer/expr/cast.rs b/datafusion/substrait/src/logical_plan/producer/expr/cast.rs new file mode 100644 index 000000000000..9741dcdd1095 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/expr/cast.rs @@ -0,0 +1,154 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::producer::{to_substrait_type, SubstraitProducer}; +use crate::variation_const::DEFAULT_TYPE_VARIATION_REF; +use datafusion::common::{DFSchemaRef, ScalarValue}; +use datafusion::logical_expr::{Cast, Expr, TryCast}; +use substrait::proto::expression::cast::FailureBehavior; +use substrait::proto::expression::literal::LiteralType; +use substrait::proto::expression::{Literal, RexType}; +use substrait::proto::Expression; + +pub fn from_cast( + producer: &mut impl SubstraitProducer, + cast: &Cast, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + let Cast { expr, data_type } = cast; + // since substrait Null must be typed, so if we see a cast(null, dt), we make it a typed null + if let Expr::Literal(lit, _) = expr.as_ref() { + // only the untyped(a null scalar value) null literal need this special handling + // since all other kind of nulls are already typed and can be handled by substrait + // e.g. null:: or null:: + if matches!(lit, ScalarValue::Null) { + let lit = Literal { + nullable: true, + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + literal_type: Some(LiteralType::Null(to_substrait_type( + data_type, true, + )?)), + }; + return Ok(Expression { + rex_type: Some(RexType::Literal(lit)), + }); + } + } + Ok(Expression { + rex_type: Some(RexType::Cast(Box::new( + substrait::proto::expression::Cast { + r#type: Some(to_substrait_type(data_type, true)?), + input: Some(Box::new(producer.handle_expr(expr, schema)?)), + failure_behavior: FailureBehavior::ThrowException.into(), + }, + ))), + }) +} + +pub fn from_try_cast( + producer: &mut impl SubstraitProducer, + cast: &TryCast, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + let TryCast { expr, data_type } = cast; + Ok(Expression { + rex_type: Some(RexType::Cast(Box::new( + substrait::proto::expression::Cast { + r#type: Some(to_substrait_type(data_type, true)?), + input: Some(Box::new(producer.handle_expr(expr, schema)?)), + failure_behavior: FailureBehavior::ReturnNull.into(), + }, + ))), + }) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::logical_plan::producer::to_substrait_extended_expr; + use datafusion::arrow::datatypes::{DataType, Field}; + use datafusion::common::DFSchema; + use datafusion::execution::SessionStateBuilder; + use datafusion::logical_expr::ExprSchemable; + use substrait::proto::expression_reference::ExprType; + + #[tokio::test] + async fn fold_cast_null() { + let state = SessionStateBuilder::default().build(); + let empty_schema = DFSchemaRef::new(DFSchema::empty()); + let field = Field::new("out", DataType::Int32, false); + + let expr = Expr::Literal(ScalarValue::Null, None) + .cast_to(&DataType::Int32, &empty_schema) + .unwrap(); + + let typed_null = + to_substrait_extended_expr(&[(&expr, &field)], &empty_schema, &state) + .unwrap(); + + if let ExprType::Expression(expr) = + typed_null.referred_expr[0].expr_type.as_ref().unwrap() + { + let lit = Literal { + nullable: true, + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + literal_type: Some(LiteralType::Null( + to_substrait_type(&DataType::Int32, true).unwrap(), + )), + }; + let expected = Expression { + rex_type: Some(RexType::Literal(lit)), + }; + assert_eq!(*expr, expected); + } else { + panic!("Expected expression type"); + } + + // a typed null should not be folded + let expr = Expr::Literal(ScalarValue::Int64(None), None) + .cast_to(&DataType::Int32, &empty_schema) + .unwrap(); + + let typed_null = + to_substrait_extended_expr(&[(&expr, &field)], &empty_schema, &state) + .unwrap(); + + if let ExprType::Expression(expr) = + typed_null.referred_expr[0].expr_type.as_ref().unwrap() + { + let cast_expr = substrait::proto::expression::Cast { + r#type: Some(to_substrait_type(&DataType::Int32, true).unwrap()), + input: Some(Box::new(Expression { + rex_type: Some(RexType::Literal(Literal { + nullable: true, + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + literal_type: Some(LiteralType::Null( + to_substrait_type(&DataType::Int64, true).unwrap(), + )), + })), + })), + failure_behavior: FailureBehavior::ThrowException as i32, + }; + let expected = Expression { + rex_type: Some(RexType::Cast(Box::new(cast_expr))), + }; + assert_eq!(*expr, expected); + } else { + panic!("Expected expression type"); + } + } +} diff --git a/datafusion/substrait/src/logical_plan/producer/expr/field_reference.rs b/datafusion/substrait/src/logical_plan/producer/expr/field_reference.rs new file mode 100644 index 000000000000..d1d80ca545ff --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/expr/field_reference.rs @@ -0,0 +1,103 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::common::{substrait_err, Column, DFSchemaRef}; +use datafusion::logical_expr::Expr; +use substrait::proto::expression::field_reference::{ + ReferenceType, RootReference, RootType, +}; +use substrait::proto::expression::{ + reference_segment, FieldReference, ReferenceSegment, RexType, +}; +use substrait::proto::Expression; + +pub fn from_column( + col: &Column, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + let index = schema.index_of_column(col)?; + substrait_field_ref(index) +} + +pub(crate) fn substrait_field_ref( + index: usize, +) -> datafusion::common::Result { + Ok(Expression { + rex_type: Some(RexType::Selection(Box::new(FieldReference { + reference_type: Some(ReferenceType::DirectReference(ReferenceSegment { + reference_type: Some(reference_segment::ReferenceType::StructField( + Box::new(reference_segment::StructField { + field: index as i32, + child: None, + }), + )), + })), + root_type: Some(RootType::RootReference(RootReference {})), + }))), + }) +} + +/// Try to convert an [Expr] to a [FieldReference]. +/// Returns `Err` if the [Expr] is not a [Expr::Column]. +pub(crate) fn try_to_substrait_field_reference( + expr: &Expr, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + match expr { + Expr::Column(col) => { + let index = schema.index_of_column(col)?; + Ok(FieldReference { + reference_type: Some(ReferenceType::DirectReference(ReferenceSegment { + reference_type: Some(reference_segment::ReferenceType::StructField( + Box::new(reference_segment::StructField { + field: index as i32, + child: None, + }), + )), + })), + root_type: Some(RootType::RootReference(RootReference {})), + }) + } + _ => substrait_err!("Expect a `Column` expr, but found {expr:?}"), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use datafusion::common::Result; + + #[test] + fn to_field_reference() -> Result<()> { + let expression = substrait_field_ref(2)?; + + match &expression.rex_type { + Some(RexType::Selection(field_ref)) => { + assert_eq!( + field_ref + .root_type + .clone() + .expect("root type should be set"), + RootType::RootReference(RootReference {}) + ); + } + + _ => panic!("Should not be anything other than field reference"), + } + Ok(()) + } +} diff --git a/datafusion/substrait/src/logical_plan/producer/expr/if_then.rs b/datafusion/substrait/src/logical_plan/producer/expr/if_then.rs new file mode 100644 index 000000000000..a34959ead76d --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/expr/if_then.rs @@ -0,0 +1,61 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::producer::SubstraitProducer; +use datafusion::common::DFSchemaRef; +use datafusion::logical_expr::Case; +use substrait::proto::expression::if_then::IfClause; +use substrait::proto::expression::{IfThen, RexType}; +use substrait::proto::Expression; + +pub fn from_case( + producer: &mut impl SubstraitProducer, + case: &Case, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + let Case { + expr, + when_then_expr, + else_expr, + } = case; + let mut ifs: Vec = vec![]; + // Parse base + if let Some(e) = expr { + // Base expression exists + ifs.push(IfClause { + r#if: Some(producer.handle_expr(e, schema)?), + then: None, + }); + } + // Parse `when`s + for (r#if, then) in when_then_expr { + ifs.push(IfClause { + r#if: Some(producer.handle_expr(r#if, schema)?), + then: Some(producer.handle_expr(then, schema)?), + }); + } + + // Parse outer `else` + let r#else: Option> = match else_expr { + Some(e) => Some(Box::new(producer.handle_expr(e, schema)?)), + None => None, + }; + + Ok(Expression { + rex_type: Some(RexType::IfThen(Box::new(IfThen { ifs, r#else }))), + }) +} diff --git a/datafusion/substrait/src/logical_plan/producer/expr/literal.rs b/datafusion/substrait/src/logical_plan/producer/expr/literal.rs new file mode 100644 index 000000000000..31f4866bdc85 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/expr/literal.rs @@ -0,0 +1,483 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::producer::{to_substrait_type, SubstraitProducer}; +use crate::variation_const::{ + DATE_32_TYPE_VARIATION_REF, DECIMAL_128_TYPE_VARIATION_REF, + DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_TYPE_VARIATION_REF, + LARGE_CONTAINER_TYPE_VARIATION_REF, UNSIGNED_INTEGER_TYPE_VARIATION_REF, + VIEW_CONTAINER_TYPE_VARIATION_REF, +}; +use datafusion::arrow::array::{Array, GenericListArray, OffsetSizeTrait}; +use datafusion::arrow::temporal_conversions::NANOSECONDS; +use datafusion::common::{exec_err, not_impl_err, ScalarValue}; +use substrait::proto::expression::literal::interval_day_to_second::PrecisionMode; +use substrait::proto::expression::literal::map::KeyValue; +use substrait::proto::expression::literal::{ + Decimal, IntervalCompound, IntervalDayToSecond, IntervalYearToMonth, List, + LiteralType, Map, PrecisionTimestamp, Struct, +}; +use substrait::proto::expression::{Literal, RexType}; +use substrait::proto::{r#type, Expression}; + +pub fn from_literal( + producer: &mut impl SubstraitProducer, + value: &ScalarValue, +) -> datafusion::common::Result { + to_substrait_literal_expr(producer, value) +} + +pub(crate) fn to_substrait_literal_expr( + producer: &mut impl SubstraitProducer, + value: &ScalarValue, +) -> datafusion::common::Result { + let literal = to_substrait_literal(producer, value)?; + Ok(Expression { + rex_type: Some(RexType::Literal(literal)), + }) +} + +pub(crate) fn to_substrait_literal( + producer: &mut impl SubstraitProducer, + value: &ScalarValue, +) -> datafusion::common::Result { + if value.is_null() { + return Ok(Literal { + nullable: true, + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + literal_type: Some(LiteralType::Null(to_substrait_type( + &value.data_type(), + true, + )?)), + }); + } + let (literal_type, type_variation_reference) = match value { + ScalarValue::Boolean(Some(b)) => { + (LiteralType::Boolean(*b), DEFAULT_TYPE_VARIATION_REF) + } + ScalarValue::Int8(Some(n)) => { + (LiteralType::I8(*n as i32), DEFAULT_TYPE_VARIATION_REF) + } + ScalarValue::UInt8(Some(n)) => ( + LiteralType::I8(*n as i32), + UNSIGNED_INTEGER_TYPE_VARIATION_REF, + ), + ScalarValue::Int16(Some(n)) => { + (LiteralType::I16(*n as i32), DEFAULT_TYPE_VARIATION_REF) + } + ScalarValue::UInt16(Some(n)) => ( + LiteralType::I16(*n as i32), + UNSIGNED_INTEGER_TYPE_VARIATION_REF, + ), + ScalarValue::Int32(Some(n)) => (LiteralType::I32(*n), DEFAULT_TYPE_VARIATION_REF), + ScalarValue::UInt32(Some(n)) => ( + LiteralType::I32(*n as i32), + UNSIGNED_INTEGER_TYPE_VARIATION_REF, + ), + ScalarValue::Int64(Some(n)) => (LiteralType::I64(*n), DEFAULT_TYPE_VARIATION_REF), + ScalarValue::UInt64(Some(n)) => ( + LiteralType::I64(*n as i64), + UNSIGNED_INTEGER_TYPE_VARIATION_REF, + ), + ScalarValue::Float32(Some(f)) => { + (LiteralType::Fp32(*f), DEFAULT_TYPE_VARIATION_REF) + } + ScalarValue::Float64(Some(f)) => { + (LiteralType::Fp64(*f), DEFAULT_TYPE_VARIATION_REF) + } + ScalarValue::TimestampSecond(Some(t), None) => ( + LiteralType::PrecisionTimestamp(PrecisionTimestamp { + precision: 0, + value: *t, + }), + DEFAULT_TYPE_VARIATION_REF, + ), + ScalarValue::TimestampMillisecond(Some(t), None) => ( + LiteralType::PrecisionTimestamp(PrecisionTimestamp { + precision: 3, + value: *t, + }), + DEFAULT_TYPE_VARIATION_REF, + ), + ScalarValue::TimestampMicrosecond(Some(t), None) => ( + LiteralType::PrecisionTimestamp(PrecisionTimestamp { + precision: 6, + value: *t, + }), + DEFAULT_TYPE_VARIATION_REF, + ), + ScalarValue::TimestampNanosecond(Some(t), None) => ( + LiteralType::PrecisionTimestamp(PrecisionTimestamp { + precision: 9, + value: *t, + }), + DEFAULT_TYPE_VARIATION_REF, + ), + // If timezone is present, no matter what the actual tz value is, it indicates the + // value of the timestamp is tied to UTC epoch. That's all that Substrait cares about. + // As the timezone is lost, this conversion may be lossy for downstream use of the value. + ScalarValue::TimestampSecond(Some(t), Some(_)) => ( + LiteralType::PrecisionTimestampTz(PrecisionTimestamp { + precision: 0, + value: *t, + }), + DEFAULT_TYPE_VARIATION_REF, + ), + ScalarValue::TimestampMillisecond(Some(t), Some(_)) => ( + LiteralType::PrecisionTimestampTz(PrecisionTimestamp { + precision: 3, + value: *t, + }), + DEFAULT_TYPE_VARIATION_REF, + ), + ScalarValue::TimestampMicrosecond(Some(t), Some(_)) => ( + LiteralType::PrecisionTimestampTz(PrecisionTimestamp { + precision: 6, + value: *t, + }), + DEFAULT_TYPE_VARIATION_REF, + ), + ScalarValue::TimestampNanosecond(Some(t), Some(_)) => ( + LiteralType::PrecisionTimestampTz(PrecisionTimestamp { + precision: 9, + value: *t, + }), + DEFAULT_TYPE_VARIATION_REF, + ), + ScalarValue::Date32(Some(d)) => { + (LiteralType::Date(*d), DATE_32_TYPE_VARIATION_REF) + } + // Date64 literal is not supported in Substrait + ScalarValue::IntervalYearMonth(Some(i)) => ( + LiteralType::IntervalYearToMonth(IntervalYearToMonth { + // DF only tracks total months, but there should always be 12 months in a year + years: *i / 12, + months: *i % 12, + }), + DEFAULT_TYPE_VARIATION_REF, + ), + ScalarValue::IntervalMonthDayNano(Some(i)) => ( + LiteralType::IntervalCompound(IntervalCompound { + interval_year_to_month: Some(IntervalYearToMonth { + years: i.months / 12, + months: i.months % 12, + }), + interval_day_to_second: Some(IntervalDayToSecond { + days: i.days, + seconds: (i.nanoseconds / NANOSECONDS) as i32, + subseconds: i.nanoseconds % NANOSECONDS, + precision_mode: Some(PrecisionMode::Precision(9)), // nanoseconds + }), + }), + DEFAULT_TYPE_VARIATION_REF, + ), + ScalarValue::IntervalDayTime(Some(i)) => ( + LiteralType::IntervalDayToSecond(IntervalDayToSecond { + days: i.days, + seconds: i.milliseconds / 1000, + subseconds: (i.milliseconds % 1000) as i64, + precision_mode: Some(PrecisionMode::Precision(3)), // 3 for milliseconds + }), + DEFAULT_TYPE_VARIATION_REF, + ), + ScalarValue::Binary(Some(b)) => ( + LiteralType::Binary(b.clone()), + DEFAULT_CONTAINER_TYPE_VARIATION_REF, + ), + ScalarValue::LargeBinary(Some(b)) => ( + LiteralType::Binary(b.clone()), + LARGE_CONTAINER_TYPE_VARIATION_REF, + ), + ScalarValue::BinaryView(Some(b)) => ( + LiteralType::Binary(b.clone()), + VIEW_CONTAINER_TYPE_VARIATION_REF, + ), + ScalarValue::FixedSizeBinary(_, Some(b)) => ( + LiteralType::FixedBinary(b.clone()), + DEFAULT_TYPE_VARIATION_REF, + ), + ScalarValue::Utf8(Some(s)) => ( + LiteralType::String(s.clone()), + DEFAULT_CONTAINER_TYPE_VARIATION_REF, + ), + ScalarValue::LargeUtf8(Some(s)) => ( + LiteralType::String(s.clone()), + LARGE_CONTAINER_TYPE_VARIATION_REF, + ), + ScalarValue::Utf8View(Some(s)) => ( + LiteralType::String(s.clone()), + VIEW_CONTAINER_TYPE_VARIATION_REF, + ), + ScalarValue::Decimal128(v, p, s) if v.is_some() => ( + LiteralType::Decimal(Decimal { + value: v.unwrap().to_le_bytes().to_vec(), + precision: *p as i32, + scale: *s as i32, + }), + DECIMAL_128_TYPE_VARIATION_REF, + ), + ScalarValue::List(l) => ( + convert_array_to_literal_list(producer, l)?, + DEFAULT_CONTAINER_TYPE_VARIATION_REF, + ), + ScalarValue::LargeList(l) => ( + convert_array_to_literal_list(producer, l)?, + LARGE_CONTAINER_TYPE_VARIATION_REF, + ), + ScalarValue::Map(m) => { + let map = if m.is_empty() || m.value(0).is_empty() { + let mt = to_substrait_type(m.data_type(), m.is_nullable())?; + let mt = match mt { + substrait::proto::Type { + kind: Some(r#type::Kind::Map(mt)), + } => Ok(mt.as_ref().to_owned()), + _ => exec_err!("Unexpected type for a map: {mt:?}"), + }?; + LiteralType::EmptyMap(mt) + } else { + let keys = (0..m.keys().len()) + .map(|i| { + to_substrait_literal( + producer, + &ScalarValue::try_from_array(&m.keys(), i)?, + ) + }) + .collect::>>()?; + let values = (0..m.values().len()) + .map(|i| { + to_substrait_literal( + producer, + &ScalarValue::try_from_array(&m.values(), i)?, + ) + }) + .collect::>>()?; + + let key_values = keys + .into_iter() + .zip(values.into_iter()) + .map(|(k, v)| { + Ok(KeyValue { + key: Some(k), + value: Some(v), + }) + }) + .collect::>>()?; + LiteralType::Map(Map { key_values }) + }; + (map, DEFAULT_CONTAINER_TYPE_VARIATION_REF) + } + ScalarValue::Struct(s) => ( + LiteralType::Struct(Struct { + fields: s + .columns() + .iter() + .map(|col| { + to_substrait_literal( + producer, + &ScalarValue::try_from_array(col, 0)?, + ) + }) + .collect::>>()?, + }), + DEFAULT_TYPE_VARIATION_REF, + ), + _ => ( + not_impl_err!("Unsupported literal: {value:?}")?, + DEFAULT_TYPE_VARIATION_REF, + ), + }; + + Ok(Literal { + nullable: false, + type_variation_reference, + literal_type: Some(literal_type), + }) +} + +fn convert_array_to_literal_list( + producer: &mut impl SubstraitProducer, + array: &GenericListArray, +) -> datafusion::common::Result { + assert_eq!(array.len(), 1); + let nested_array = array.value(0); + + let values = (0..nested_array.len()) + .map(|i| { + to_substrait_literal( + producer, + &ScalarValue::try_from_array(&nested_array, i)?, + ) + }) + .collect::>>()?; + + if values.is_empty() { + let lt = match to_substrait_type(array.data_type(), array.is_nullable())? { + substrait::proto::Type { + kind: Some(r#type::Kind::List(lt)), + } => lt.as_ref().to_owned(), + _ => unreachable!(), + }; + Ok(LiteralType::EmptyList(lt)) + } else { + Ok(LiteralType::List(List { values })) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::logical_plan::consumer::from_substrait_literal_without_names; + use crate::logical_plan::consumer::tests::test_consumer; + use crate::logical_plan::producer::DefaultSubstraitProducer; + use datafusion::arrow::array::{Int64Builder, MapBuilder, StringBuilder}; + use datafusion::arrow::datatypes::{ + DataType, Field, IntervalDayTime, IntervalMonthDayNano, + }; + use datafusion::common::scalar::ScalarStructBuilder; + use datafusion::common::Result; + use datafusion::prelude::SessionContext; + use std::sync::Arc; + + #[test] + fn round_trip_literals() -> Result<()> { + round_trip_literal(ScalarValue::Boolean(None))?; + round_trip_literal(ScalarValue::Boolean(Some(true)))?; + round_trip_literal(ScalarValue::Boolean(Some(false)))?; + + round_trip_literal(ScalarValue::Int8(None))?; + round_trip_literal(ScalarValue::Int8(Some(i8::MIN)))?; + round_trip_literal(ScalarValue::Int8(Some(i8::MAX)))?; + round_trip_literal(ScalarValue::UInt8(None))?; + round_trip_literal(ScalarValue::UInt8(Some(u8::MIN)))?; + round_trip_literal(ScalarValue::UInt8(Some(u8::MAX)))?; + + round_trip_literal(ScalarValue::Int16(None))?; + round_trip_literal(ScalarValue::Int16(Some(i16::MIN)))?; + round_trip_literal(ScalarValue::Int16(Some(i16::MAX)))?; + round_trip_literal(ScalarValue::UInt16(None))?; + round_trip_literal(ScalarValue::UInt16(Some(u16::MIN)))?; + round_trip_literal(ScalarValue::UInt16(Some(u16::MAX)))?; + + round_trip_literal(ScalarValue::Int32(None))?; + round_trip_literal(ScalarValue::Int32(Some(i32::MIN)))?; + round_trip_literal(ScalarValue::Int32(Some(i32::MAX)))?; + round_trip_literal(ScalarValue::UInt32(None))?; + round_trip_literal(ScalarValue::UInt32(Some(u32::MIN)))?; + round_trip_literal(ScalarValue::UInt32(Some(u32::MAX)))?; + + round_trip_literal(ScalarValue::Int64(None))?; + round_trip_literal(ScalarValue::Int64(Some(i64::MIN)))?; + round_trip_literal(ScalarValue::Int64(Some(i64::MAX)))?; + round_trip_literal(ScalarValue::UInt64(None))?; + round_trip_literal(ScalarValue::UInt64(Some(u64::MIN)))?; + round_trip_literal(ScalarValue::UInt64(Some(u64::MAX)))?; + + for (ts, tz) in [ + (Some(12345), None), + (None, None), + (Some(12345), Some("UTC".into())), + (None, Some("UTC".into())), + ] { + round_trip_literal(ScalarValue::TimestampSecond(ts, tz.clone()))?; + round_trip_literal(ScalarValue::TimestampMillisecond(ts, tz.clone()))?; + round_trip_literal(ScalarValue::TimestampMicrosecond(ts, tz.clone()))?; + round_trip_literal(ScalarValue::TimestampNanosecond(ts, tz))?; + } + + round_trip_literal(ScalarValue::List(ScalarValue::new_list_nullable( + &[ScalarValue::Float32(Some(1.0))], + &DataType::Float32, + )))?; + round_trip_literal(ScalarValue::List(ScalarValue::new_list_nullable( + &[], + &DataType::Float32, + )))?; + round_trip_literal(ScalarValue::List(Arc::new(GenericListArray::new_null( + Field::new_list_field(DataType::Float32, true).into(), + 1, + ))))?; + round_trip_literal(ScalarValue::LargeList(ScalarValue::new_large_list( + &[ScalarValue::Float32(Some(1.0))], + &DataType::Float32, + )))?; + round_trip_literal(ScalarValue::LargeList(ScalarValue::new_large_list( + &[], + &DataType::Float32, + )))?; + round_trip_literal(ScalarValue::LargeList(Arc::new( + GenericListArray::new_null( + Field::new_list_field(DataType::Float32, true).into(), + 1, + ), + )))?; + + // Null map + let mut map_builder = + MapBuilder::new(None, StringBuilder::new(), Int64Builder::new()); + map_builder.append(false)?; + round_trip_literal(ScalarValue::Map(Arc::new(map_builder.finish())))?; + + // Empty map + let mut map_builder = + MapBuilder::new(None, StringBuilder::new(), Int64Builder::new()); + map_builder.append(true)?; + round_trip_literal(ScalarValue::Map(Arc::new(map_builder.finish())))?; + + // Valid map + let mut map_builder = + MapBuilder::new(None, StringBuilder::new(), Int64Builder::new()); + map_builder.keys().append_value("key1"); + map_builder.keys().append_value("key2"); + map_builder.values().append_value(1); + map_builder.values().append_value(2); + map_builder.append(true)?; + round_trip_literal(ScalarValue::Map(Arc::new(map_builder.finish())))?; + + let c0 = Field::new("c0", DataType::Boolean, true); + let c1 = Field::new("c1", DataType::Int32, true); + let c2 = Field::new("c2", DataType::Utf8, true); + round_trip_literal( + ScalarStructBuilder::new() + .with_scalar(c0.to_owned(), ScalarValue::Boolean(Some(true))) + .with_scalar(c1.to_owned(), ScalarValue::Int32(Some(1))) + .with_scalar(c2.to_owned(), ScalarValue::Utf8(None)) + .build()?, + )?; + round_trip_literal(ScalarStructBuilder::new_null(vec![c0, c1, c2]))?; + + round_trip_literal(ScalarValue::IntervalYearMonth(Some(17)))?; + round_trip_literal(ScalarValue::IntervalMonthDayNano(Some( + IntervalMonthDayNano::new(17, 25, 1234567890), + )))?; + round_trip_literal(ScalarValue::IntervalDayTime(Some(IntervalDayTime::new( + 57, 123456, + ))))?; + + Ok(()) + } + + fn round_trip_literal(scalar: ScalarValue) -> Result<()> { + println!("Checking round trip of {scalar:?}"); + let state = SessionContext::default().state(); + let mut producer = DefaultSubstraitProducer::new(&state); + let substrait_literal = to_substrait_literal(&mut producer, &scalar)?; + let roundtrip_scalar = + from_substrait_literal_without_names(&test_consumer(), &substrait_literal)?; + assert_eq!(scalar, roundtrip_scalar); + Ok(()) + } +} diff --git a/datafusion/substrait/src/logical_plan/producer/expr/mod.rs b/datafusion/substrait/src/logical_plan/producer/expr/mod.rs new file mode 100644 index 000000000000..42e1f962f1d1 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/expr/mod.rs @@ -0,0 +1,235 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod aggregate_function; +mod cast; +mod field_reference; +mod if_then; +mod literal; +mod scalar_function; +mod singular_or_list; +mod subquery; +mod window_function; + +pub use aggregate_function::*; +pub use cast::*; +pub use field_reference::*; +pub use if_then::*; +pub use literal::*; +pub use scalar_function::*; +pub use singular_or_list::*; +pub use subquery::*; +pub use window_function::*; + +use crate::logical_plan::producer::utils::flatten_names; +use crate::logical_plan::producer::{ + to_substrait_named_struct, DefaultSubstraitProducer, SubstraitProducer, +}; +use datafusion::arrow::datatypes::Field; +use datafusion::common::{internal_err, not_impl_err, DFSchemaRef}; +use datafusion::execution::SessionState; +use datafusion::logical_expr::expr::Alias; +use datafusion::logical_expr::Expr; +use substrait::proto::expression_reference::ExprType; +use substrait::proto::{Expression, ExpressionReference, ExtendedExpression}; +use substrait::version; + +/// Serializes a collection of expressions to a Substrait ExtendedExpression message +/// +/// The ExtendedExpression message is a top-level message that can be used to send +/// expressions (not plans) between systems. +/// +/// Each expression is also given names for the output type. These are provided as a +/// field and not a String (since the names may be nested, e.g. a struct). The data +/// type and nullability of this field is redundant (those can be determined by the +/// Expr) and will be ignored. +/// +/// Substrait also requires the input schema of the expressions to be included in the +/// message. The field names of the input schema will be serialized. +pub fn to_substrait_extended_expr( + exprs: &[(&Expr, &Field)], + schema: &DFSchemaRef, + state: &SessionState, +) -> datafusion::common::Result> { + let mut producer = DefaultSubstraitProducer::new(state); + let substrait_exprs = exprs + .iter() + .map(|(expr, field)| { + let substrait_expr = producer.handle_expr(expr, schema)?; + let mut output_names = Vec::new(); + flatten_names(field, false, &mut output_names)?; + Ok(ExpressionReference { + output_names, + expr_type: Some(ExprType::Expression(substrait_expr)), + }) + }) + .collect::>>()?; + let substrait_schema = to_substrait_named_struct(schema)?; + + let extensions = producer.get_extensions(); + Ok(Box::new(ExtendedExpression { + advanced_extensions: None, + expected_type_urls: vec![], + extension_uris: vec![], + extensions: extensions.into(), + version: Some(version::version_with_producer("datafusion")), + referred_expr: substrait_exprs, + base_schema: Some(substrait_schema), + })) +} + +/// Convert DataFusion Expr to Substrait Rex +/// +/// # Arguments +/// * `producer` - SubstraitProducer implementation which the handles the actual conversion +/// * `expr` - DataFusion expression to convert into a Substrait expression +/// * `schema` - DataFusion input schema for looking up columns +pub fn to_substrait_rex( + producer: &mut impl SubstraitProducer, + expr: &Expr, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + match expr { + Expr::Alias(expr) => producer.handle_alias(expr, schema), + Expr::Column(expr) => producer.handle_column(expr, schema), + Expr::ScalarVariable(_, _) => { + not_impl_err!("Cannot convert {expr:?} to Substrait") + } + Expr::Literal(expr, _) => producer.handle_literal(expr), + Expr::BinaryExpr(expr) => producer.handle_binary_expr(expr, schema), + Expr::Like(expr) => producer.handle_like(expr, schema), + Expr::SimilarTo(_) => not_impl_err!("Cannot convert {expr:?} to Substrait"), + Expr::Not(_) => producer.handle_unary_expr(expr, schema), + Expr::IsNotNull(_) => producer.handle_unary_expr(expr, schema), + Expr::IsNull(_) => producer.handle_unary_expr(expr, schema), + Expr::IsTrue(_) => producer.handle_unary_expr(expr, schema), + Expr::IsFalse(_) => producer.handle_unary_expr(expr, schema), + Expr::IsUnknown(_) => producer.handle_unary_expr(expr, schema), + Expr::IsNotTrue(_) => producer.handle_unary_expr(expr, schema), + Expr::IsNotFalse(_) => producer.handle_unary_expr(expr, schema), + Expr::IsNotUnknown(_) => producer.handle_unary_expr(expr, schema), + Expr::Negative(_) => producer.handle_unary_expr(expr, schema), + Expr::Between(expr) => producer.handle_between(expr, schema), + Expr::Case(expr) => producer.handle_case(expr, schema), + Expr::Cast(expr) => producer.handle_cast(expr, schema), + Expr::TryCast(expr) => producer.handle_try_cast(expr, schema), + Expr::ScalarFunction(expr) => producer.handle_scalar_function(expr, schema), + Expr::AggregateFunction(_) => { + internal_err!( + "AggregateFunction should only be encountered as part of a LogicalPlan::Aggregate" + ) + } + Expr::WindowFunction(expr) => producer.handle_window_function(expr, schema), + Expr::InList(expr) => producer.handle_in_list(expr, schema), + Expr::Exists(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), + Expr::InSubquery(expr) => producer.handle_in_subquery(expr, schema), + Expr::ScalarSubquery(expr) => { + not_impl_err!("Cannot convert {expr:?} to Substrait") + } + #[expect(deprecated)] + Expr::Wildcard { .. } => not_impl_err!("Cannot convert {expr:?} to Substrait"), + Expr::GroupingSet(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), + Expr::Placeholder(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), + Expr::OuterReferenceColumn(_, _) => { + not_impl_err!("Cannot convert {expr:?} to Substrait") + } + Expr::Unnest(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), + } +} + +pub fn from_alias( + producer: &mut impl SubstraitProducer, + alias: &Alias, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + producer.handle_expr(alias.expr.as_ref(), schema) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::logical_plan::consumer::from_substrait_extended_expr; + use datafusion::arrow::datatypes::{DataType, Schema}; + use datafusion::common::{DFSchema, DataFusionError, ScalarValue}; + use datafusion::execution::SessionStateBuilder; + + #[tokio::test] + async fn extended_expressions() -> datafusion::common::Result<()> { + let state = SessionStateBuilder::default().build(); + + // One expression, empty input schema + let expr = Expr::Literal(ScalarValue::Int32(Some(42)), None); + let field = Field::new("out", DataType::Int32, false); + let empty_schema = DFSchemaRef::new(DFSchema::empty()); + let substrait = + to_substrait_extended_expr(&[(&expr, &field)], &empty_schema, &state)?; + let roundtrip_expr = from_substrait_extended_expr(&state, &substrait).await?; + + assert_eq!(roundtrip_expr.input_schema, empty_schema); + assert_eq!(roundtrip_expr.exprs.len(), 1); + + let (rt_expr, rt_field) = roundtrip_expr.exprs.first().unwrap(); + assert_eq!(rt_field, &field); + assert_eq!(rt_expr, &expr); + + // Multiple expressions, with column references + let expr1 = Expr::Column("c0".into()); + let expr2 = Expr::Column("c1".into()); + let out1 = Field::new("out1", DataType::Int32, true); + let out2 = Field::new("out2", DataType::Utf8, true); + let input_schema = DFSchemaRef::new(DFSchema::try_from(Schema::new(vec![ + Field::new("c0", DataType::Int32, true), + Field::new("c1", DataType::Utf8, true), + ]))?); + + let substrait = to_substrait_extended_expr( + &[(&expr1, &out1), (&expr2, &out2)], + &input_schema, + &state, + )?; + let roundtrip_expr = from_substrait_extended_expr(&state, &substrait).await?; + + assert_eq!(roundtrip_expr.input_schema, input_schema); + assert_eq!(roundtrip_expr.exprs.len(), 2); + + let mut exprs = roundtrip_expr.exprs.into_iter(); + + let (rt_expr, rt_field) = exprs.next().unwrap(); + assert_eq!(rt_field, out1); + assert_eq!(rt_expr, expr1); + + let (rt_expr, rt_field) = exprs.next().unwrap(); + assert_eq!(rt_field, out2); + assert_eq!(rt_expr, expr2); + + Ok(()) + } + + #[tokio::test] + async fn invalid_extended_expression() { + let state = SessionStateBuilder::default().build(); + + // Not ok if input schema is missing field referenced by expr + let expr = Expr::Column("missing".into()); + let field = Field::new("out", DataType::Int32, false); + let empty_schema = DFSchemaRef::new(DFSchema::empty()); + + let err = to_substrait_extended_expr(&[(&expr, &field)], &empty_schema, &state); + + assert!(matches!(err, Err(DataFusionError::SchemaError(_, _)))); + } +} diff --git a/datafusion/substrait/src/logical_plan/producer/expr/scalar_function.rs b/datafusion/substrait/src/logical_plan/producer/expr/scalar_function.rs new file mode 100644 index 000000000000..1172c43319c6 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/expr/scalar_function.rs @@ -0,0 +1,327 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::producer::{to_substrait_literal_expr, SubstraitProducer}; +use datafusion::common::{not_impl_err, DFSchemaRef, ScalarValue}; +use datafusion::logical_expr::{expr, Between, BinaryExpr, Expr, Like, Operator}; +use substrait::proto::expression::{RexType, ScalarFunction}; +use substrait::proto::function_argument::ArgType; +use substrait::proto::{Expression, FunctionArgument}; + +pub fn from_scalar_function( + producer: &mut impl SubstraitProducer, + fun: &expr::ScalarFunction, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + let mut arguments: Vec = vec![]; + for arg in &fun.args { + arguments.push(FunctionArgument { + arg_type: Some(ArgType::Value(producer.handle_expr(arg, schema)?)), + }); + } + + let function_anchor = producer.register_function(fun.name().to_string()); + #[allow(deprecated)] + Ok(Expression { + rex_type: Some(RexType::ScalarFunction(ScalarFunction { + function_reference: function_anchor, + arguments, + output_type: None, + options: vec![], + args: vec![], + })), + }) +} + +pub fn from_unary_expr( + producer: &mut impl SubstraitProducer, + expr: &Expr, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + let (fn_name, arg) = match expr { + Expr::Not(arg) => ("not", arg), + Expr::IsNull(arg) => ("is_null", arg), + Expr::IsNotNull(arg) => ("is_not_null", arg), + Expr::IsTrue(arg) => ("is_true", arg), + Expr::IsFalse(arg) => ("is_false", arg), + Expr::IsUnknown(arg) => ("is_unknown", arg), + Expr::IsNotTrue(arg) => ("is_not_true", arg), + Expr::IsNotFalse(arg) => ("is_not_false", arg), + Expr::IsNotUnknown(arg) => ("is_not_unknown", arg), + Expr::Negative(arg) => ("negate", arg), + expr => not_impl_err!("Unsupported expression: {expr:?}")?, + }; + to_substrait_unary_scalar_fn(producer, fn_name, arg, schema) +} + +pub fn from_binary_expr( + producer: &mut impl SubstraitProducer, + expr: &BinaryExpr, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + let BinaryExpr { left, op, right } = expr; + let l = producer.handle_expr(left, schema)?; + let r = producer.handle_expr(right, schema)?; + Ok(make_binary_op_scalar_func(producer, &l, &r, *op)) +} + +pub fn from_like( + producer: &mut impl SubstraitProducer, + like: &Like, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + let Like { + negated, + expr, + pattern, + escape_char, + case_insensitive, + } = like; + make_substrait_like_expr( + producer, + *case_insensitive, + *negated, + expr, + pattern, + *escape_char, + schema, + ) +} + +fn make_substrait_like_expr( + producer: &mut impl SubstraitProducer, + ignore_case: bool, + negated: bool, + expr: &Expr, + pattern: &Expr, + escape_char: Option, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + let function_anchor = if ignore_case { + producer.register_function("ilike".to_string()) + } else { + producer.register_function("like".to_string()) + }; + let expr = producer.handle_expr(expr, schema)?; + let pattern = producer.handle_expr(pattern, schema)?; + let escape_char = to_substrait_literal_expr( + producer, + &ScalarValue::Utf8(escape_char.map(|c| c.to_string())), + )?; + let arguments = vec![ + FunctionArgument { + arg_type: Some(ArgType::Value(expr)), + }, + FunctionArgument { + arg_type: Some(ArgType::Value(pattern)), + }, + FunctionArgument { + arg_type: Some(ArgType::Value(escape_char)), + }, + ]; + + #[allow(deprecated)] + let substrait_like = Expression { + rex_type: Some(RexType::ScalarFunction(ScalarFunction { + function_reference: function_anchor, + arguments, + output_type: None, + args: vec![], + options: vec![], + })), + }; + + if negated { + let function_anchor = producer.register_function("not".to_string()); + + #[allow(deprecated)] + Ok(Expression { + rex_type: Some(RexType::ScalarFunction(ScalarFunction { + function_reference: function_anchor, + arguments: vec![FunctionArgument { + arg_type: Some(ArgType::Value(substrait_like)), + }], + output_type: None, + args: vec![], + options: vec![], + })), + }) + } else { + Ok(substrait_like) + } +} + +/// Util to generate substrait [RexType::ScalarFunction] with one argument +fn to_substrait_unary_scalar_fn( + producer: &mut impl SubstraitProducer, + fn_name: &str, + arg: &Expr, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + let function_anchor = producer.register_function(fn_name.to_string()); + let substrait_expr = producer.handle_expr(arg, schema)?; + + Ok(Expression { + rex_type: Some(RexType::ScalarFunction(ScalarFunction { + function_reference: function_anchor, + arguments: vec![FunctionArgument { + arg_type: Some(ArgType::Value(substrait_expr)), + }], + output_type: None, + options: vec![], + ..Default::default() + })), + }) +} + +/// Return Substrait scalar function with two arguments +pub fn make_binary_op_scalar_func( + producer: &mut impl SubstraitProducer, + lhs: &Expression, + rhs: &Expression, + op: Operator, +) -> Expression { + let function_anchor = producer.register_function(operator_to_name(op).to_string()); + #[allow(deprecated)] + Expression { + rex_type: Some(RexType::ScalarFunction(ScalarFunction { + function_reference: function_anchor, + arguments: vec![ + FunctionArgument { + arg_type: Some(ArgType::Value(lhs.clone())), + }, + FunctionArgument { + arg_type: Some(ArgType::Value(rhs.clone())), + }, + ], + output_type: None, + args: vec![], + options: vec![], + })), + } +} + +pub fn from_between( + producer: &mut impl SubstraitProducer, + between: &Between, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + let Between { + expr, + negated, + low, + high, + } = between; + if *negated { + // `expr NOT BETWEEN low AND high` can be translated into (expr < low OR high < expr) + let substrait_expr = producer.handle_expr(expr.as_ref(), schema)?; + let substrait_low = producer.handle_expr(low.as_ref(), schema)?; + let substrait_high = producer.handle_expr(high.as_ref(), schema)?; + + let l_expr = make_binary_op_scalar_func( + producer, + &substrait_expr, + &substrait_low, + Operator::Lt, + ); + let r_expr = make_binary_op_scalar_func( + producer, + &substrait_high, + &substrait_expr, + Operator::Lt, + ); + + Ok(make_binary_op_scalar_func( + producer, + &l_expr, + &r_expr, + Operator::Or, + )) + } else { + // `expr BETWEEN low AND high` can be translated into (low <= expr AND expr <= high) + let substrait_expr = producer.handle_expr(expr.as_ref(), schema)?; + let substrait_low = producer.handle_expr(low.as_ref(), schema)?; + let substrait_high = producer.handle_expr(high.as_ref(), schema)?; + + let l_expr = make_binary_op_scalar_func( + producer, + &substrait_low, + &substrait_expr, + Operator::LtEq, + ); + let r_expr = make_binary_op_scalar_func( + producer, + &substrait_expr, + &substrait_high, + Operator::LtEq, + ); + + Ok(make_binary_op_scalar_func( + producer, + &l_expr, + &r_expr, + Operator::And, + )) + } +} + +pub fn operator_to_name(op: Operator) -> &'static str { + match op { + Operator::Eq => "equal", + Operator::NotEq => "not_equal", + Operator::Lt => "lt", + Operator::LtEq => "lte", + Operator::Gt => "gt", + Operator::GtEq => "gte", + Operator::Plus => "add", + Operator::Minus => "subtract", + Operator::Multiply => "multiply", + Operator::Divide => "divide", + Operator::Modulo => "modulus", + Operator::And => "and", + Operator::Or => "or", + Operator::IsDistinctFrom => "is_distinct_from", + Operator::IsNotDistinctFrom => "is_not_distinct_from", + Operator::RegexMatch => "regex_match", + Operator::RegexIMatch => "regex_imatch", + Operator::RegexNotMatch => "regex_not_match", + Operator::RegexNotIMatch => "regex_not_imatch", + Operator::LikeMatch => "like_match", + Operator::ILikeMatch => "like_imatch", + Operator::NotLikeMatch => "like_not_match", + Operator::NotILikeMatch => "like_not_imatch", + Operator::BitwiseAnd => "bitwise_and", + Operator::BitwiseOr => "bitwise_or", + Operator::StringConcat => "str_concat", + Operator::AtArrow => "at_arrow", + Operator::ArrowAt => "arrow_at", + Operator::Arrow => "arrow", + Operator::LongArrow => "long_arrow", + Operator::HashArrow => "hash_arrow", + Operator::HashLongArrow => "hash_long_arrow", + Operator::AtAt => "at_at", + Operator::IntegerDivide => "integer_divide", + Operator::HashMinus => "hash_minus", + Operator::AtQuestion => "at_question", + Operator::Question => "question", + Operator::QuestionAnd => "question_and", + Operator::QuestionPipe => "question_pipe", + Operator::BitwiseXor => "bitwise_xor", + Operator::BitwiseShiftRight => "bitwise_shift_right", + Operator::BitwiseShiftLeft => "bitwise_shift_left", + } +} diff --git a/datafusion/substrait/src/logical_plan/producer/expr/singular_or_list.rs b/datafusion/substrait/src/logical_plan/producer/expr/singular_or_list.rs new file mode 100644 index 000000000000..1c0b6dcc154b --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/expr/singular_or_list.rs @@ -0,0 +1,66 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::producer::SubstraitProducer; +use datafusion::common::DFSchemaRef; +use datafusion::logical_expr::expr::InList; +use substrait::proto::expression::{RexType, ScalarFunction, SingularOrList}; +use substrait::proto::function_argument::ArgType; +use substrait::proto::{Expression, FunctionArgument}; + +pub fn from_in_list( + producer: &mut impl SubstraitProducer, + in_list: &InList, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + let InList { + expr, + list, + negated, + } = in_list; + let substrait_list = list + .iter() + .map(|x| producer.handle_expr(x, schema)) + .collect::>>()?; + let substrait_expr = producer.handle_expr(expr, schema)?; + + let substrait_or_list = Expression { + rex_type: Some(RexType::SingularOrList(Box::new(SingularOrList { + value: Some(Box::new(substrait_expr)), + options: substrait_list, + }))), + }; + + if *negated { + let function_anchor = producer.register_function("not".to_string()); + + #[allow(deprecated)] + Ok(Expression { + rex_type: Some(RexType::ScalarFunction(ScalarFunction { + function_reference: function_anchor, + arguments: vec![FunctionArgument { + arg_type: Some(ArgType::Value(substrait_or_list)), + }], + output_type: None, + args: vec![], + options: vec![], + })), + }) + } else { + Ok(substrait_or_list) + } +} diff --git a/datafusion/substrait/src/logical_plan/producer/expr/subquery.rs b/datafusion/substrait/src/logical_plan/producer/expr/subquery.rs new file mode 100644 index 000000000000..c1ee78c68c25 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/expr/subquery.rs @@ -0,0 +1,72 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::producer::SubstraitProducer; +use datafusion::common::DFSchemaRef; +use datafusion::logical_expr::expr::InSubquery; +use substrait::proto::expression::subquery::InPredicate; +use substrait::proto::expression::{RexType, ScalarFunction}; +use substrait::proto::function_argument::ArgType; +use substrait::proto::{Expression, FunctionArgument}; + +pub fn from_in_subquery( + producer: &mut impl SubstraitProducer, + subquery: &InSubquery, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + let InSubquery { + expr, + subquery, + negated, + } = subquery; + let substrait_expr = producer.handle_expr(expr, schema)?; + + let subquery_plan = producer.handle_plan(subquery.subquery.as_ref())?; + + let substrait_subquery = Expression { + rex_type: Some(RexType::Subquery(Box::new( + substrait::proto::expression::Subquery { + subquery_type: Some( + substrait::proto::expression::subquery::SubqueryType::InPredicate( + Box::new(InPredicate { + needles: (vec![substrait_expr]), + haystack: Some(subquery_plan), + }), + ), + ), + }, + ))), + }; + if *negated { + let function_anchor = producer.register_function("not".to_string()); + + #[allow(deprecated)] + Ok(Expression { + rex_type: Some(RexType::ScalarFunction(ScalarFunction { + function_reference: function_anchor, + arguments: vec![FunctionArgument { + arg_type: Some(ArgType::Value(substrait_subquery)), + }], + output_type: None, + args: vec![], + options: vec![], + })), + }) + } else { + Ok(substrait_subquery) + } +} diff --git a/datafusion/substrait/src/logical_plan/producer/expr/window_function.rs b/datafusion/substrait/src/logical_plan/producer/expr/window_function.rs new file mode 100644 index 000000000000..17e71f2d7c14 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/expr/window_function.rs @@ -0,0 +1,162 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::producer::utils::substrait_sort_field; +use crate::logical_plan::producer::SubstraitProducer; +use datafusion::common::{not_impl_err, DFSchemaRef, ScalarValue}; +use datafusion::logical_expr::expr::{WindowFunction, WindowFunctionParams}; +use datafusion::logical_expr::{WindowFrame, WindowFrameBound, WindowFrameUnits}; +use substrait::proto::expression::window_function::bound as SubstraitBound; +use substrait::proto::expression::window_function::bound::Kind as BoundKind; +use substrait::proto::expression::window_function::{Bound, BoundsType}; +use substrait::proto::expression::RexType; +use substrait::proto::expression::WindowFunction as SubstraitWindowFunction; +use substrait::proto::function_argument::ArgType; +use substrait::proto::{Expression, FunctionArgument, SortField}; + +pub fn from_window_function( + producer: &mut impl SubstraitProducer, + window_fn: &WindowFunction, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + let WindowFunction { + fun, + params: + WindowFunctionParams { + args, + partition_by, + order_by, + window_frame, + null_treatment: _, + }, + } = window_fn; + // function reference + let function_anchor = producer.register_function(fun.to_string()); + // arguments + let mut arguments: Vec = vec![]; + for arg in args { + arguments.push(FunctionArgument { + arg_type: Some(ArgType::Value(producer.handle_expr(arg, schema)?)), + }); + } + // partition by expressions + let partition_by = partition_by + .iter() + .map(|e| producer.handle_expr(e, schema)) + .collect::>>()?; + // order by expressions + let order_by = order_by + .iter() + .map(|e| substrait_sort_field(producer, e, schema)) + .collect::>>()?; + // window frame + let bounds = to_substrait_bounds(window_frame)?; + let bound_type = to_substrait_bound_type(window_frame)?; + Ok(make_substrait_window_function( + function_anchor, + arguments, + partition_by, + order_by, + bounds, + bound_type, + )) +} + +fn make_substrait_window_function( + function_reference: u32, + arguments: Vec, + partitions: Vec, + sorts: Vec, + bounds: (Bound, Bound), + bounds_type: BoundsType, +) -> Expression { + #[allow(deprecated)] + Expression { + rex_type: Some(RexType::WindowFunction(SubstraitWindowFunction { + function_reference, + arguments, + partitions, + sorts, + options: vec![], + output_type: None, + phase: 0, // default to AGGREGATION_PHASE_UNSPECIFIED + invocation: 0, // TODO: fix + lower_bound: Some(bounds.0), + upper_bound: Some(bounds.1), + args: vec![], + bounds_type: bounds_type as i32, + })), + } +} + +fn to_substrait_bound_type( + window_frame: &WindowFrame, +) -> datafusion::common::Result { + match window_frame.units { + WindowFrameUnits::Rows => Ok(BoundsType::Rows), // ROWS + WindowFrameUnits::Range => Ok(BoundsType::Range), // RANGE + // TODO: Support GROUPS + unit => not_impl_err!("Unsupported window frame unit: {unit:?}"), + } +} + +fn to_substrait_bounds( + window_frame: &WindowFrame, +) -> datafusion::common::Result<(Bound, Bound)> { + Ok(( + to_substrait_bound(&window_frame.start_bound), + to_substrait_bound(&window_frame.end_bound), + )) +} + +fn to_substrait_bound(bound: &WindowFrameBound) -> Bound { + match bound { + WindowFrameBound::CurrentRow => Bound { + kind: Some(BoundKind::CurrentRow(SubstraitBound::CurrentRow {})), + }, + WindowFrameBound::Preceding(s) => match to_substrait_bound_offset(s) { + Some(offset) => Bound { + kind: Some(BoundKind::Preceding(SubstraitBound::Preceding { offset })), + }, + None => Bound { + kind: Some(BoundKind::Unbounded(SubstraitBound::Unbounded {})), + }, + }, + WindowFrameBound::Following(s) => match to_substrait_bound_offset(s) { + Some(offset) => Bound { + kind: Some(BoundKind::Following(SubstraitBound::Following { offset })), + }, + None => Bound { + kind: Some(BoundKind::Unbounded(SubstraitBound::Unbounded {})), + }, + }, + } +} + +fn to_substrait_bound_offset(value: &ScalarValue) -> Option { + match value { + ScalarValue::UInt8(Some(v)) => Some(*v as i64), + ScalarValue::UInt16(Some(v)) => Some(*v as i64), + ScalarValue::UInt32(Some(v)) => Some(*v as i64), + ScalarValue::UInt64(Some(v)) => Some(*v as i64), + ScalarValue::Int8(Some(v)) => Some(*v as i64), + ScalarValue::Int16(Some(v)) => Some(*v as i64), + ScalarValue::Int32(Some(v)) => Some(*v as i64), + ScalarValue::Int64(Some(v)) => Some(*v), + _ => None, + } +} diff --git a/datafusion/substrait/src/logical_plan/producer/mod.rs b/datafusion/substrait/src/logical_plan/producer/mod.rs new file mode 100644 index 000000000000..fc4af94a25fe --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/mod.rs @@ -0,0 +1,30 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod expr; +mod plan; +mod rel; +mod substrait_producer; +mod types; +mod utils; + +pub use expr::*; +pub use plan::*; +pub use rel::*; +pub use substrait_producer::*; +pub(crate) use types::*; +pub(crate) use utils::*; diff --git a/datafusion/substrait/src/logical_plan/producer/plan.rs b/datafusion/substrait/src/logical_plan/producer/plan.rs new file mode 100644 index 000000000000..7d5b7754122d --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/plan.rs @@ -0,0 +1,63 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::producer::{ + to_substrait_named_struct, DefaultSubstraitProducer, SubstraitProducer, +}; +use datafusion::execution::SessionState; +use datafusion::logical_expr::{LogicalPlan, SubqueryAlias}; +use substrait::proto::{plan_rel, Plan, PlanRel, Rel, RelRoot}; +use substrait::version; + +/// Convert DataFusion LogicalPlan to Substrait Plan +pub fn to_substrait_plan( + plan: &LogicalPlan, + state: &SessionState, +) -> datafusion::common::Result> { + // Parse relation nodes + // Generate PlanRel(s) + // Note: Only 1 relation tree is currently supported + + let mut producer: DefaultSubstraitProducer = DefaultSubstraitProducer::new(state); + let plan_rels = vec![PlanRel { + rel_type: Some(plan_rel::RelType::Root(RelRoot { + input: Some(*producer.handle_plan(plan)?), + names: to_substrait_named_struct(plan.schema())?.names, + })), + }]; + + // Return parsed plan + let extensions = producer.get_extensions(); + Ok(Box::new(Plan { + version: Some(version::version_with_producer("datafusion")), + extension_uris: vec![], + extensions: extensions.into(), + relations: plan_rels, + advanced_extensions: None, + expected_type_urls: vec![], + parameter_bindings: vec![], + })) +} + +pub fn from_subquery_alias( + producer: &mut impl SubstraitProducer, + alias: &SubqueryAlias, +) -> datafusion::common::Result> { + // Do nothing if encounters SubqueryAlias + // since there is no corresponding relation type in Substrait + producer.handle_plan(alias.input.as_ref()) +} diff --git a/datafusion/substrait/src/logical_plan/producer/rel/aggregate_rel.rs b/datafusion/substrait/src/logical_plan/producer/rel/aggregate_rel.rs new file mode 100644 index 000000000000..4abd283a7ee0 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/rel/aggregate_rel.rs @@ -0,0 +1,182 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::producer::{ + from_aggregate_function, substrait_field_ref, SubstraitProducer, +}; +use datafusion::common::{internal_err, not_impl_err, DFSchemaRef, DataFusionError}; +use datafusion::logical_expr::expr::Alias; +use datafusion::logical_expr::{Aggregate, Distinct, Expr, GroupingSet}; +use substrait::proto::aggregate_rel::{Grouping, Measure}; +use substrait::proto::rel::RelType; +use substrait::proto::{AggregateRel, Expression, Rel}; + +pub fn from_aggregate( + producer: &mut impl SubstraitProducer, + agg: &Aggregate, +) -> datafusion::common::Result> { + let input = producer.handle_plan(agg.input.as_ref())?; + let (grouping_expressions, groupings) = + to_substrait_groupings(producer, &agg.group_expr, agg.input.schema())?; + let measures = agg + .aggr_expr + .iter() + .map(|e| to_substrait_agg_measure(producer, e, agg.input.schema())) + .collect::>>()?; + + Ok(Box::new(Rel { + rel_type: Some(RelType::Aggregate(Box::new(AggregateRel { + common: None, + input: Some(input), + grouping_expressions, + groupings, + measures, + advanced_extension: None, + }))), + })) +} + +pub fn from_distinct( + producer: &mut impl SubstraitProducer, + distinct: &Distinct, +) -> datafusion::common::Result> { + match distinct { + Distinct::All(plan) => { + // Use Substrait's AggregateRel with empty measures to represent `select distinct` + let input = producer.handle_plan(plan.as_ref())?; + // Get grouping keys from the input relation's number of output fields + let grouping = (0..plan.schema().fields().len()) + .map(substrait_field_ref) + .collect::>>()?; + + #[allow(deprecated)] + Ok(Box::new(Rel { + rel_type: Some(RelType::Aggregate(Box::new(AggregateRel { + common: None, + input: Some(input), + grouping_expressions: vec![], + groupings: vec![Grouping { + grouping_expressions: grouping, + expression_references: vec![], + }], + measures: vec![], + advanced_extension: None, + }))), + })) + } + Distinct::On(_) => not_impl_err!("Cannot convert Distinct::On"), + } +} + +pub fn to_substrait_groupings( + producer: &mut impl SubstraitProducer, + exprs: &[Expr], + schema: &DFSchemaRef, +) -> datafusion::common::Result<(Vec, Vec)> { + let mut ref_group_exprs = vec![]; + let groupings = match exprs.len() { + 1 => match &exprs[0] { + Expr::GroupingSet(gs) => match gs { + GroupingSet::Cube(_) => Err(DataFusionError::NotImplemented( + "GroupingSet CUBE is not yet supported".to_string(), + )), + GroupingSet::GroupingSets(sets) => Ok(sets + .iter() + .map(|set| { + parse_flat_grouping_exprs( + producer, + set, + schema, + &mut ref_group_exprs, + ) + }) + .collect::>>()?), + GroupingSet::Rollup(set) => { + let mut sets: Vec> = vec![vec![]]; + for i in 0..set.len() { + sets.push(set[..=i].to_vec()); + } + Ok(sets + .iter() + .rev() + .map(|set| { + parse_flat_grouping_exprs( + producer, + set, + schema, + &mut ref_group_exprs, + ) + }) + .collect::>>()?) + } + }, + _ => Ok(vec![parse_flat_grouping_exprs( + producer, + exprs, + schema, + &mut ref_group_exprs, + )?]), + }, + _ => Ok(vec![parse_flat_grouping_exprs( + producer, + exprs, + schema, + &mut ref_group_exprs, + )?]), + }?; + Ok((ref_group_exprs, groupings)) +} + +pub fn parse_flat_grouping_exprs( + producer: &mut impl SubstraitProducer, + exprs: &[Expr], + schema: &DFSchemaRef, + ref_group_exprs: &mut Vec, +) -> datafusion::common::Result { + let mut expression_references = vec![]; + let mut grouping_expressions = vec![]; + + for e in exprs { + let rex = producer.handle_expr(e, schema)?; + grouping_expressions.push(rex.clone()); + ref_group_exprs.push(rex); + expression_references.push((ref_group_exprs.len() - 1) as u32); + } + #[allow(deprecated)] + Ok(Grouping { + grouping_expressions, + expression_references, + }) +} + +pub fn to_substrait_agg_measure( + producer: &mut impl SubstraitProducer, + expr: &Expr, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + match expr { + Expr::AggregateFunction(agg_fn) => from_aggregate_function(producer, agg_fn, schema), + Expr::Alias(Alias { expr, .. }) => { + to_substrait_agg_measure(producer, expr, schema) + } + _ => internal_err!( + "Expression must be compatible with aggregation. Unsupported expression: {:?}. ExpressionType: {:?}", + expr, + expr.variant_name() + ), + } +} diff --git a/datafusion/substrait/src/logical_plan/producer/rel/exchange_rel.rs b/datafusion/substrait/src/logical_plan/producer/rel/exchange_rel.rs new file mode 100644 index 000000000000..9e0ef8905f43 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/rel/exchange_rel.rs @@ -0,0 +1,70 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::producer::{ + try_to_substrait_field_reference, SubstraitProducer, +}; +use datafusion::common::not_impl_err; +use datafusion::logical_expr::{Partitioning, Repartition}; +use substrait::proto::exchange_rel::{ExchangeKind, RoundRobin, ScatterFields}; +use substrait::proto::rel::RelType; +use substrait::proto::{ExchangeRel, Rel}; + +pub fn from_repartition( + producer: &mut impl SubstraitProducer, + repartition: &Repartition, +) -> datafusion::common::Result> { + let input = producer.handle_plan(repartition.input.as_ref())?; + let partition_count = match repartition.partitioning_scheme { + Partitioning::RoundRobinBatch(num) => num, + Partitioning::Hash(_, num) => num, + Partitioning::DistributeBy(_) => { + return not_impl_err!( + "Physical plan does not support DistributeBy partitioning" + ) + } + }; + // ref: https://substrait.io/relations/physical_relations/#exchange-types + let exchange_kind = match &repartition.partitioning_scheme { + Partitioning::RoundRobinBatch(_) => { + ExchangeKind::RoundRobin(RoundRobin::default()) + } + Partitioning::Hash(exprs, _) => { + let fields = exprs + .iter() + .map(|e| try_to_substrait_field_reference(e, repartition.input.schema())) + .collect::>>()?; + ExchangeKind::ScatterByFields(ScatterFields { fields }) + } + Partitioning::DistributeBy(_) => { + return not_impl_err!( + "Physical plan does not support DistributeBy partitioning" + ) + } + }; + let exchange_rel = ExchangeRel { + common: None, + input: Some(input), + exchange_kind: Some(exchange_kind), + advanced_extension: None, + partition_count: partition_count as i32, + targets: vec![], + }; + Ok(Box::new(Rel { + rel_type: Some(RelType::Exchange(Box::new(exchange_rel))), + })) +} diff --git a/datafusion/substrait/src/logical_plan/producer/rel/fetch_rel.rs b/datafusion/substrait/src/logical_plan/producer/rel/fetch_rel.rs new file mode 100644 index 000000000000..4706401d558e --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/rel/fetch_rel.rs @@ -0,0 +1,54 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::producer::SubstraitProducer; +use datafusion::common::DFSchema; +use datafusion::logical_expr::Limit; +use std::sync::Arc; +use substrait::proto::rel::RelType; +use substrait::proto::{fetch_rel, FetchRel, Rel}; + +pub fn from_limit( + producer: &mut impl SubstraitProducer, + limit: &Limit, +) -> datafusion::common::Result> { + let input = producer.handle_plan(limit.input.as_ref())?; + let empty_schema = Arc::new(DFSchema::empty()); + let offset_mode = limit + .skip + .as_ref() + .map(|expr| producer.handle_expr(expr.as_ref(), &empty_schema)) + .transpose()? + .map(Box::new) + .map(fetch_rel::OffsetMode::OffsetExpr); + let count_mode = limit + .fetch + .as_ref() + .map(|expr| producer.handle_expr(expr.as_ref(), &empty_schema)) + .transpose()? + .map(Box::new) + .map(fetch_rel::CountMode::CountExpr); + Ok(Box::new(Rel { + rel_type: Some(RelType::Fetch(Box::new(FetchRel { + common: None, + input: Some(input), + offset_mode, + count_mode, + advanced_extension: None, + }))), + })) +} diff --git a/datafusion/substrait/src/logical_plan/producer/rel/filter_rel.rs b/datafusion/substrait/src/logical_plan/producer/rel/filter_rel.rs new file mode 100644 index 000000000000..770696dfe1a9 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/rel/filter_rel.rs @@ -0,0 +1,37 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::producer::SubstraitProducer; +use datafusion::logical_expr::Filter; +use substrait::proto::rel::RelType; +use substrait::proto::{FilterRel, Rel}; + +pub fn from_filter( + producer: &mut impl SubstraitProducer, + filter: &Filter, +) -> datafusion::common::Result> { + let input = producer.handle_plan(filter.input.as_ref())?; + let filter_expr = producer.handle_expr(&filter.predicate, filter.input.schema())?; + Ok(Box::new(Rel { + rel_type: Some(RelType::Filter(Box::new(FilterRel { + common: None, + input: Some(input), + condition: Some(Box::new(filter_expr)), + advanced_extension: None, + }))), + })) +} diff --git a/datafusion/substrait/src/logical_plan/producer/rel/join.rs b/datafusion/substrait/src/logical_plan/producer/rel/join.rs new file mode 100644 index 000000000000..79564ad5daf1 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/rel/join.rs @@ -0,0 +1,120 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::producer::{make_binary_op_scalar_func, SubstraitProducer}; +use datafusion::common::{not_impl_err, DFSchemaRef, JoinConstraint, JoinType}; +use datafusion::logical_expr::{Expr, Join, Operator}; +use std::sync::Arc; +use substrait::proto::rel::RelType; +use substrait::proto::{join_rel, Expression, JoinRel, Rel}; + +pub fn from_join( + producer: &mut impl SubstraitProducer, + join: &Join, +) -> datafusion::common::Result> { + let left = producer.handle_plan(join.left.as_ref())?; + let right = producer.handle_plan(join.right.as_ref())?; + let join_type = to_substrait_jointype(join.join_type); + // we only support basic joins so return an error for anything not yet supported + match join.join_constraint { + JoinConstraint::On => {} + JoinConstraint::Using => return not_impl_err!("join constraint: `using`"), + } + let in_join_schema = Arc::new(join.left.schema().join(join.right.schema())?); + + // convert filter if present + let join_filter = match &join.filter { + Some(filter) => Some(producer.handle_expr(filter, &in_join_schema)?), + None => None, + }; + + // map the left and right columns to binary expressions in the form `l = r` + // build a single expression for the ON condition, such as `l.a = r.a AND l.b = r.b` + let eq_op = if join.null_equals_null { + Operator::IsNotDistinctFrom + } else { + Operator::Eq + }; + let join_on = to_substrait_join_expr(producer, &join.on, eq_op, &in_join_schema)?; + + // create conjunction between `join_on` and `join_filter` to embed all join conditions, + // whether equal or non-equal in a single expression + let join_expr = match &join_on { + Some(on_expr) => match &join_filter { + Some(filter) => Some(Box::new(make_binary_op_scalar_func( + producer, + on_expr, + filter, + Operator::And, + ))), + None => join_on.map(Box::new), // the join expression will only contain `join_on` if filter doesn't exist + }, + None => match &join_filter { + Some(_) => join_filter.map(Box::new), // the join expression will only contain `join_filter` if the `on` condition doesn't exist + None => None, + }, + }; + + Ok(Box::new(Rel { + rel_type: Some(RelType::Join(Box::new(JoinRel { + common: None, + left: Some(left), + right: Some(right), + r#type: join_type as i32, + expression: join_expr, + post_join_filter: None, + advanced_extension: None, + }))), + })) +} + +fn to_substrait_join_expr( + producer: &mut impl SubstraitProducer, + join_conditions: &Vec<(Expr, Expr)>, + eq_op: Operator, + join_schema: &DFSchemaRef, +) -> datafusion::common::Result> { + // Only support AND conjunction for each binary expression in join conditions + let mut exprs: Vec = vec![]; + for (left, right) in join_conditions { + let l = producer.handle_expr(left, join_schema)?; + let r = producer.handle_expr(right, join_schema)?; + // AND with existing expression + exprs.push(make_binary_op_scalar_func(producer, &l, &r, eq_op)); + } + + let join_expr: Option = + exprs.into_iter().reduce(|acc: Expression, e: Expression| { + make_binary_op_scalar_func(producer, &acc, &e, Operator::And) + }); + Ok(join_expr) +} + +fn to_substrait_jointype(join_type: JoinType) -> join_rel::JoinType { + match join_type { + JoinType::Inner => join_rel::JoinType::Inner, + JoinType::Left => join_rel::JoinType::Left, + JoinType::Right => join_rel::JoinType::Right, + JoinType::Full => join_rel::JoinType::Outer, + JoinType::LeftAnti => join_rel::JoinType::LeftAnti, + JoinType::LeftSemi => join_rel::JoinType::LeftSemi, + JoinType::LeftMark => join_rel::JoinType::LeftMark, + JoinType::RightAnti | JoinType::RightSemi => { + unimplemented!() + } + } +} diff --git a/datafusion/substrait/src/logical_plan/producer/rel/mod.rs b/datafusion/substrait/src/logical_plan/producer/rel/mod.rs new file mode 100644 index 000000000000..c3599a2635ff --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/rel/mod.rs @@ -0,0 +1,78 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod aggregate_rel; +mod exchange_rel; +mod fetch_rel; +mod filter_rel; +mod join; +mod project_rel; +mod read_rel; +mod set_rel; +mod sort_rel; + +pub use aggregate_rel::*; +pub use exchange_rel::*; +pub use fetch_rel::*; +pub use filter_rel::*; +pub use join::*; +pub use project_rel::*; +pub use read_rel::*; +pub use set_rel::*; +pub use sort_rel::*; + +use crate::logical_plan::producer::SubstraitProducer; +use datafusion::common::not_impl_err; +use datafusion::logical_expr::LogicalPlan; +use substrait::proto::Rel; + +pub fn to_substrait_rel( + producer: &mut impl SubstraitProducer, + plan: &LogicalPlan, +) -> datafusion::common::Result> { + match plan { + LogicalPlan::Projection(plan) => producer.handle_projection(plan), + LogicalPlan::Filter(plan) => producer.handle_filter(plan), + LogicalPlan::Window(plan) => producer.handle_window(plan), + LogicalPlan::Aggregate(plan) => producer.handle_aggregate(plan), + LogicalPlan::Sort(plan) => producer.handle_sort(plan), + LogicalPlan::Join(plan) => producer.handle_join(plan), + LogicalPlan::Repartition(plan) => producer.handle_repartition(plan), + LogicalPlan::Union(plan) => producer.handle_union(plan), + LogicalPlan::TableScan(plan) => producer.handle_table_scan(plan), + LogicalPlan::EmptyRelation(plan) => producer.handle_empty_relation(plan), + LogicalPlan::Subquery(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, + LogicalPlan::SubqueryAlias(plan) => producer.handle_subquery_alias(plan), + LogicalPlan::Limit(plan) => producer.handle_limit(plan), + LogicalPlan::Statement(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, + LogicalPlan::Values(plan) => producer.handle_values(plan), + LogicalPlan::Explain(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, + LogicalPlan::Analyze(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, + LogicalPlan::Extension(plan) => producer.handle_extension(plan), + LogicalPlan::Distinct(plan) => producer.handle_distinct(plan), + LogicalPlan::Dml(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, + LogicalPlan::Ddl(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, + LogicalPlan::Copy(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, + LogicalPlan::DescribeTable(plan) => { + not_impl_err!("Unsupported plan type: {plan:?}")? + } + LogicalPlan::Unnest(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, + LogicalPlan::RecursiveQuery(plan) => { + not_impl_err!("Unsupported plan type: {plan:?}")? + } + } +} diff --git a/datafusion/substrait/src/logical_plan/producer/rel/project_rel.rs b/datafusion/substrait/src/logical_plan/producer/rel/project_rel.rs new file mode 100644 index 000000000000..0190dca12bf5 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/rel/project_rel.rs @@ -0,0 +1,101 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::producer::{substrait_field_ref, SubstraitProducer}; +use datafusion::logical_expr::{Projection, Window}; +use substrait::proto::rel::RelType; +use substrait::proto::rel_common::EmitKind; +use substrait::proto::rel_common::EmitKind::Emit; +use substrait::proto::{rel_common, ProjectRel, Rel, RelCommon}; + +pub fn from_projection( + producer: &mut impl SubstraitProducer, + p: &Projection, +) -> datafusion::common::Result> { + let expressions = p + .expr + .iter() + .map(|e| producer.handle_expr(e, p.input.schema())) + .collect::>>()?; + + let emit_kind = create_project_remapping( + expressions.len(), + p.input.as_ref().schema().fields().len(), + ); + let common = RelCommon { + emit_kind: Some(emit_kind), + hint: None, + advanced_extension: None, + }; + + Ok(Box::new(Rel { + rel_type: Some(RelType::Project(Box::new(ProjectRel { + common: Some(common), + input: Some(producer.handle_plan(p.input.as_ref())?), + expressions, + advanced_extension: None, + }))), + })) +} + +pub fn from_window( + producer: &mut impl SubstraitProducer, + window: &Window, +) -> datafusion::common::Result> { + let input = producer.handle_plan(window.input.as_ref())?; + + // create a field reference for each input field + let mut expressions = (0..window.input.schema().fields().len()) + .map(substrait_field_ref) + .collect::>>()?; + + // process and add each window function expression + for expr in &window.window_expr { + expressions.push(producer.handle_expr(expr, window.input.schema())?); + } + + let emit_kind = + create_project_remapping(expressions.len(), window.input.schema().fields().len()); + let common = RelCommon { + emit_kind: Some(emit_kind), + hint: None, + advanced_extension: None, + }; + let project_rel = Box::new(ProjectRel { + common: Some(common), + input: Some(input), + expressions, + advanced_extension: None, + }); + + Ok(Box::new(Rel { + rel_type: Some(RelType::Project(project_rel)), + })) +} + +/// By default, a Substrait Project outputs all input fields followed by all expressions. +/// A DataFusion Projection only outputs expressions. In order to keep the Substrait +/// plan consistent with DataFusion, we must apply an output mapping that skips the input +/// fields so that the Substrait Project will only output the expression fields. +fn create_project_remapping(expr_count: usize, input_field_count: usize) -> EmitKind { + let expression_field_start = input_field_count; + let expression_field_end = expression_field_start + expr_count; + let output_mapping = (expression_field_start..expression_field_end) + .map(|i| i as i32) + .collect(); + Emit(rel_common::Emit { output_mapping }) +} diff --git a/datafusion/substrait/src/logical_plan/producer/rel/read_rel.rs b/datafusion/substrait/src/logical_plan/producer/rel/read_rel.rs new file mode 100644 index 000000000000..212874e7913b --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/rel/read_rel.rs @@ -0,0 +1,149 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::producer::{ + to_substrait_literal, to_substrait_named_struct, SubstraitProducer, +}; +use datafusion::common::{not_impl_err, substrait_datafusion_err, DFSchema, ToDFSchema}; +use datafusion::logical_expr::utils::conjunction; +use datafusion::logical_expr::{EmptyRelation, Expr, TableScan, Values}; +use std::sync::Arc; +use substrait::proto::expression::literal::Struct; +use substrait::proto::expression::mask_expression::{StructItem, StructSelect}; +use substrait::proto::expression::MaskExpression; +use substrait::proto::read_rel::{NamedTable, ReadType, VirtualTable}; +use substrait::proto::rel::RelType; +use substrait::proto::{ReadRel, Rel}; + +pub fn from_table_scan( + producer: &mut impl SubstraitProducer, + scan: &TableScan, +) -> datafusion::common::Result> { + let projection = scan.projection.as_ref().map(|p| { + p.iter() + .map(|i| StructItem { + field: *i as i32, + child: None, + }) + .collect() + }); + + let projection = projection.map(|struct_items| MaskExpression { + select: Some(StructSelect { struct_items }), + maintain_singular_struct: false, + }); + + let table_schema = scan.source.schema().to_dfschema_ref()?; + let base_schema = to_substrait_named_struct(&table_schema)?; + + let filter_option = if scan.filters.is_empty() { + None + } else { + let table_schema_qualified = Arc::new( + DFSchema::try_from_qualified_schema( + scan.table_name.clone(), + &(scan.source.schema()), + ) + .unwrap(), + ); + + let combined_expr = conjunction(scan.filters.clone()).unwrap(); + let filter_expr = + producer.handle_expr(&combined_expr, &table_schema_qualified)?; + Some(Box::new(filter_expr)) + }; + + Ok(Box::new(Rel { + rel_type: Some(RelType::Read(Box::new(ReadRel { + common: None, + base_schema: Some(base_schema), + filter: filter_option, + best_effort_filter: None, + projection, + advanced_extension: None, + read_type: Some(ReadType::NamedTable(NamedTable { + names: scan.table_name.to_vec(), + advanced_extension: None, + })), + }))), + })) +} + +pub fn from_empty_relation(e: &EmptyRelation) -> datafusion::common::Result> { + if e.produce_one_row { + return not_impl_err!("Producing a row from empty relation is unsupported"); + } + #[allow(deprecated)] + Ok(Box::new(Rel { + rel_type: Some(RelType::Read(Box::new(ReadRel { + common: None, + base_schema: Some(to_substrait_named_struct(&e.schema)?), + filter: None, + best_effort_filter: None, + projection: None, + advanced_extension: None, + read_type: Some(ReadType::VirtualTable(VirtualTable { + values: vec![], + expressions: vec![], + })), + }))), + })) +} + +pub fn from_values( + producer: &mut impl SubstraitProducer, + v: &Values, +) -> datafusion::common::Result> { + let values = v + .values + .iter() + .map(|row| { + let fields = row + .iter() + .map(|v| match v { + Expr::Literal(sv, _) => to_substrait_literal(producer, sv), + Expr::Alias(alias) => match alias.expr.as_ref() { + // The schema gives us the names, so we can skip aliases + Expr::Literal(sv, _) => to_substrait_literal(producer, sv), + _ => Err(substrait_datafusion_err!( + "Only literal types can be aliased in Virtual Tables, got: {}", alias.expr.variant_name() + )), + }, + _ => Err(substrait_datafusion_err!( + "Only literal types and aliases are supported in Virtual Tables, got: {}", v.variant_name() + )), + }) + .collect::>()?; + Ok(Struct { fields }) + }) + .collect::>()?; + #[allow(deprecated)] + Ok(Box::new(Rel { + rel_type: Some(RelType::Read(Box::new(ReadRel { + common: None, + base_schema: Some(to_substrait_named_struct(&v.schema)?), + filter: None, + best_effort_filter: None, + projection: None, + advanced_extension: None, + read_type: Some(ReadType::VirtualTable(VirtualTable { + values, + expressions: vec![], + })), + }))), + })) +} diff --git a/datafusion/substrait/src/logical_plan/producer/rel/set_rel.rs b/datafusion/substrait/src/logical_plan/producer/rel/set_rel.rs new file mode 100644 index 000000000000..58ddfca3617a --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/rel/set_rel.rs @@ -0,0 +1,43 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::producer::SubstraitProducer; +use datafusion::logical_expr::Union; +use substrait::proto::rel::RelType; +use substrait::proto::{set_rel, Rel, SetRel}; + +pub fn from_union( + producer: &mut impl SubstraitProducer, + union: &Union, +) -> datafusion::common::Result> { + let input_rels = union + .inputs + .iter() + .map(|input| producer.handle_plan(input.as_ref())) + .collect::>>()? + .into_iter() + .map(|ptr| *ptr) + .collect(); + Ok(Box::new(Rel { + rel_type: Some(RelType::Set(SetRel { + common: None, + inputs: input_rels, + op: set_rel::SetOp::UnionAll as i32, // UNION DISTINCT gets translated to AGGREGATION + UNION ALL + advanced_extension: None, + })), + })) +} diff --git a/datafusion/substrait/src/logical_plan/producer/rel/sort_rel.rs b/datafusion/substrait/src/logical_plan/producer/rel/sort_rel.rs new file mode 100644 index 000000000000..aaa8be163560 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/rel/sort_rel.rs @@ -0,0 +1,69 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::producer::{substrait_sort_field, SubstraitProducer}; +use crate::variation_const::DEFAULT_TYPE_VARIATION_REF; +use datafusion::logical_expr::Sort; +use substrait::proto::expression::literal::LiteralType; +use substrait::proto::expression::{Literal, RexType}; +use substrait::proto::rel::RelType; +use substrait::proto::{fetch_rel, Expression, FetchRel, Rel, SortRel}; + +pub fn from_sort( + producer: &mut impl SubstraitProducer, + sort: &Sort, +) -> datafusion::common::Result> { + let Sort { expr, input, fetch } = sort; + let sort_fields = expr + .iter() + .map(|e| substrait_sort_field(producer, e, input.schema())) + .collect::>>()?; + + let input = producer.handle_plan(input.as_ref())?; + + let sort_rel = Box::new(Rel { + rel_type: Some(RelType::Sort(Box::new(SortRel { + common: None, + input: Some(input), + sorts: sort_fields, + advanced_extension: None, + }))), + }); + + match fetch { + Some(amount) => { + let count_mode = + Some(fetch_rel::CountMode::CountExpr(Box::new(Expression { + rex_type: Some(RexType::Literal(Literal { + nullable: false, + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + literal_type: Some(LiteralType::I64(*amount as i64)), + })), + }))); + Ok(Box::new(Rel { + rel_type: Some(RelType::Fetch(Box::new(FetchRel { + common: None, + input: Some(sort_rel), + offset_mode: None, + count_mode, + advanced_extension: None, + }))), + })) + } + None => Ok(sort_rel), + } +} diff --git a/datafusion/substrait/src/logical_plan/producer/substrait_producer.rs b/datafusion/substrait/src/logical_plan/producer/substrait_producer.rs new file mode 100644 index 000000000000..56edfac5769c --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/substrait_producer.rs @@ -0,0 +1,411 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::extensions::Extensions; +use crate::logical_plan::producer::{ + from_aggregate, from_aggregate_function, from_alias, from_between, from_binary_expr, + from_case, from_cast, from_column, from_distinct, from_empty_relation, from_filter, + from_in_list, from_in_subquery, from_join, from_like, from_limit, from_literal, + from_projection, from_repartition, from_scalar_function, from_sort, + from_subquery_alias, from_table_scan, from_try_cast, from_unary_expr, from_union, + from_values, from_window, from_window_function, to_substrait_rel, to_substrait_rex, +}; +use datafusion::common::{substrait_err, Column, DFSchemaRef, ScalarValue}; +use datafusion::execution::registry::SerializerRegistry; +use datafusion::execution::SessionState; +use datafusion::logical_expr::expr::{Alias, InList, InSubquery, WindowFunction}; +use datafusion::logical_expr::{ + expr, Aggregate, Between, BinaryExpr, Case, Cast, Distinct, EmptyRelation, Expr, + Extension, Filter, Join, Like, Limit, LogicalPlan, Projection, Repartition, Sort, + SubqueryAlias, TableScan, TryCast, Union, Values, Window, +}; +use pbjson_types::Any as ProtoAny; +use substrait::proto::aggregate_rel::Measure; +use substrait::proto::rel::RelType; +use substrait::proto::{ + Expression, ExtensionLeafRel, ExtensionMultiRel, ExtensionSingleRel, Rel, +}; + +/// This trait is used to produce Substrait plans, converting them from DataFusion Logical Plans. +/// It can be implemented by users to allow for custom handling of relations, expressions, etc. +/// +/// Combined with the [crate::logical_plan::consumer::SubstraitConsumer] this allows for fully +/// customizable Substrait serde. +/// +/// # Example Usage +/// +/// ``` +/// # use std::sync::Arc; +/// # use substrait::proto::{Expression, Rel}; +/// # use substrait::proto::rel::RelType; +/// # use datafusion::common::DFSchemaRef; +/// # use datafusion::error::Result; +/// # use datafusion::execution::SessionState; +/// # use datafusion::logical_expr::{Between, Extension, Projection}; +/// # use datafusion_substrait::extensions::Extensions; +/// # use datafusion_substrait::logical_plan::producer::{from_projection, SubstraitProducer}; +/// +/// struct CustomSubstraitProducer { +/// extensions: Extensions, +/// state: Arc, +/// } +/// +/// impl SubstraitProducer for CustomSubstraitProducer { +/// +/// fn register_function(&mut self, signature: String) -> u32 { +/// self.extensions.register_function(signature) +/// } +/// +/// fn get_extensions(self) -> Extensions { +/// self.extensions +/// } +/// +/// // You can set additional metadata on the Rels you produce +/// fn handle_projection(&mut self, plan: &Projection) -> Result> { +/// let mut rel = from_projection(self, plan)?; +/// match rel.rel_type { +/// Some(RelType::Project(mut project)) => { +/// let mut project = project.clone(); +/// // set common metadata or advanced extension +/// project.common = None; +/// project.advanced_extension = None; +/// Ok(Box::new(Rel { +/// rel_type: Some(RelType::Project(project)), +/// })) +/// } +/// rel_type => Ok(Box::new(Rel { rel_type })), +/// } +/// } +/// +/// // You can tweak how you convert expressions for your target system +/// fn handle_between(&mut self, between: &Between, schema: &DFSchemaRef) -> Result { +/// // add your own encoding for Between +/// todo!() +/// } +/// +/// // You can fully control how you convert UserDefinedLogicalNodes into Substrait +/// fn handle_extension(&mut self, _plan: &Extension) -> Result> { +/// // implement your own serializer into Substrait +/// todo!() +/// } +/// } +/// ``` +pub trait SubstraitProducer: Send + Sync + Sized { + /// Within a Substrait plan, functions are referenced using function anchors that are stored at + /// the top level of the [Plan](substrait::proto::Plan) within + /// [ExtensionFunction](substrait::proto::extensions::simple_extension_declaration::ExtensionFunction) + /// messages. + /// + /// When given a function signature, this method should return the existing anchor for it if + /// there is one. Otherwise, it should generate a new anchor. + fn register_function(&mut self, signature: String) -> u32; + + /// Consume the producer to generate the [Extensions] for the Substrait plan based on the + /// functions that have been registered + fn get_extensions(self) -> Extensions; + + // Logical Plan Methods + // There is one method per LogicalPlan to allow for easy overriding of producer behaviour. + // These methods have default implementations calling the common handler code, to allow for users + // to re-use common handling logic. + + fn handle_plan( + &mut self, + plan: &LogicalPlan, + ) -> datafusion::common::Result> { + to_substrait_rel(self, plan) + } + + fn handle_projection( + &mut self, + plan: &Projection, + ) -> datafusion::common::Result> { + from_projection(self, plan) + } + + fn handle_filter(&mut self, plan: &Filter) -> datafusion::common::Result> { + from_filter(self, plan) + } + + fn handle_window(&mut self, plan: &Window) -> datafusion::common::Result> { + from_window(self, plan) + } + + fn handle_aggregate( + &mut self, + plan: &Aggregate, + ) -> datafusion::common::Result> { + from_aggregate(self, plan) + } + + fn handle_sort(&mut self, plan: &Sort) -> datafusion::common::Result> { + from_sort(self, plan) + } + + fn handle_join(&mut self, plan: &Join) -> datafusion::common::Result> { + from_join(self, plan) + } + + fn handle_repartition( + &mut self, + plan: &Repartition, + ) -> datafusion::common::Result> { + from_repartition(self, plan) + } + + fn handle_union(&mut self, plan: &Union) -> datafusion::common::Result> { + from_union(self, plan) + } + + fn handle_table_scan( + &mut self, + plan: &TableScan, + ) -> datafusion::common::Result> { + from_table_scan(self, plan) + } + + fn handle_empty_relation( + &mut self, + plan: &EmptyRelation, + ) -> datafusion::common::Result> { + from_empty_relation(plan) + } + + fn handle_subquery_alias( + &mut self, + plan: &SubqueryAlias, + ) -> datafusion::common::Result> { + from_subquery_alias(self, plan) + } + + fn handle_limit(&mut self, plan: &Limit) -> datafusion::common::Result> { + from_limit(self, plan) + } + + fn handle_values(&mut self, plan: &Values) -> datafusion::common::Result> { + from_values(self, plan) + } + + fn handle_distinct( + &mut self, + plan: &Distinct, + ) -> datafusion::common::Result> { + from_distinct(self, plan) + } + + fn handle_extension( + &mut self, + _plan: &Extension, + ) -> datafusion::common::Result> { + substrait_err!("Specify handling for LogicalPlan::Extension by implementing the SubstraitProducer trait") + } + + // Expression Methods + // There is one method per DataFusion Expr to allow for easy overriding of producer behaviour + // These methods have default implementations calling the common handler code, to allow for users + // to re-use common handling logic. + + fn handle_expr( + &mut self, + expr: &Expr, + schema: &DFSchemaRef, + ) -> datafusion::common::Result { + to_substrait_rex(self, expr, schema) + } + + fn handle_alias( + &mut self, + alias: &Alias, + schema: &DFSchemaRef, + ) -> datafusion::common::Result { + from_alias(self, alias, schema) + } + + fn handle_column( + &mut self, + column: &Column, + schema: &DFSchemaRef, + ) -> datafusion::common::Result { + from_column(column, schema) + } + + fn handle_literal( + &mut self, + value: &ScalarValue, + ) -> datafusion::common::Result { + from_literal(self, value) + } + + fn handle_binary_expr( + &mut self, + expr: &BinaryExpr, + schema: &DFSchemaRef, + ) -> datafusion::common::Result { + from_binary_expr(self, expr, schema) + } + + fn handle_like( + &mut self, + like: &Like, + schema: &DFSchemaRef, + ) -> datafusion::common::Result { + from_like(self, like, schema) + } + + /// For handling Not, IsNotNull, IsNull, IsTrue, IsFalse, IsUnknown, IsNotTrue, IsNotFalse, IsNotUnknown, Negative + fn handle_unary_expr( + &mut self, + expr: &Expr, + schema: &DFSchemaRef, + ) -> datafusion::common::Result { + from_unary_expr(self, expr, schema) + } + + fn handle_between( + &mut self, + between: &Between, + schema: &DFSchemaRef, + ) -> datafusion::common::Result { + from_between(self, between, schema) + } + + fn handle_case( + &mut self, + case: &Case, + schema: &DFSchemaRef, + ) -> datafusion::common::Result { + from_case(self, case, schema) + } + + fn handle_cast( + &mut self, + cast: &Cast, + schema: &DFSchemaRef, + ) -> datafusion::common::Result { + from_cast(self, cast, schema) + } + + fn handle_try_cast( + &mut self, + cast: &TryCast, + schema: &DFSchemaRef, + ) -> datafusion::common::Result { + from_try_cast(self, cast, schema) + } + + fn handle_scalar_function( + &mut self, + scalar_fn: &expr::ScalarFunction, + schema: &DFSchemaRef, + ) -> datafusion::common::Result { + from_scalar_function(self, scalar_fn, schema) + } + + fn handle_aggregate_function( + &mut self, + agg_fn: &expr::AggregateFunction, + schema: &DFSchemaRef, + ) -> datafusion::common::Result { + from_aggregate_function(self, agg_fn, schema) + } + + fn handle_window_function( + &mut self, + window_fn: &WindowFunction, + schema: &DFSchemaRef, + ) -> datafusion::common::Result { + from_window_function(self, window_fn, schema) + } + + fn handle_in_list( + &mut self, + in_list: &InList, + schema: &DFSchemaRef, + ) -> datafusion::common::Result { + from_in_list(self, in_list, schema) + } + + fn handle_in_subquery( + &mut self, + in_subquery: &InSubquery, + schema: &DFSchemaRef, + ) -> datafusion::common::Result { + from_in_subquery(self, in_subquery, schema) + } +} + +pub struct DefaultSubstraitProducer<'a> { + extensions: Extensions, + serializer_registry: &'a dyn SerializerRegistry, +} + +impl<'a> DefaultSubstraitProducer<'a> { + pub fn new(state: &'a SessionState) -> Self { + DefaultSubstraitProducer { + extensions: Extensions::default(), + serializer_registry: state.serializer_registry().as_ref(), + } + } +} + +impl SubstraitProducer for DefaultSubstraitProducer<'_> { + fn register_function(&mut self, fn_name: String) -> u32 { + self.extensions.register_function(fn_name) + } + + fn get_extensions(self) -> Extensions { + self.extensions + } + + fn handle_extension( + &mut self, + plan: &Extension, + ) -> datafusion::common::Result> { + let extension_bytes = self + .serializer_registry + .serialize_logical_plan(plan.node.as_ref())?; + let detail = ProtoAny { + type_url: plan.node.name().to_string(), + value: extension_bytes.into(), + }; + let mut inputs_rel = plan + .node + .inputs() + .into_iter() + .map(|plan| self.handle_plan(plan)) + .collect::>>()?; + let rel_type = match inputs_rel.len() { + 0 => RelType::ExtensionLeaf(ExtensionLeafRel { + common: None, + detail: Some(detail), + }), + 1 => RelType::ExtensionSingle(Box::new(ExtensionSingleRel { + common: None, + detail: Some(detail), + input: Some(inputs_rel.pop().unwrap()), + })), + _ => RelType::ExtensionMulti(ExtensionMultiRel { + common: None, + detail: Some(detail), + inputs: inputs_rel.into_iter().map(|r| *r).collect(), + }), + }; + Ok(Box::new(Rel { + rel_type: Some(rel_type), + })) + } +} diff --git a/datafusion/substrait/src/logical_plan/producer/types.rs b/datafusion/substrait/src/logical_plan/producer/types.rs new file mode 100644 index 000000000000..61b7a79095d5 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/types.rs @@ -0,0 +1,436 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::producer::utils::flatten_names; +use crate::variation_const::{ + DATE_32_TYPE_VARIATION_REF, DATE_64_TYPE_VARIATION_REF, + DECIMAL_128_TYPE_VARIATION_REF, DECIMAL_256_TYPE_VARIATION_REF, + DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_TYPE_VARIATION_REF, + LARGE_CONTAINER_TYPE_VARIATION_REF, UNSIGNED_INTEGER_TYPE_VARIATION_REF, + VIEW_CONTAINER_TYPE_VARIATION_REF, +}; +use datafusion::arrow::datatypes::{DataType, IntervalUnit, TimeUnit}; +use datafusion::common::{internal_err, not_impl_err, plan_err, DFSchemaRef}; +use substrait::proto::{r#type, NamedStruct}; + +pub(crate) fn to_substrait_type( + dt: &DataType, + nullable: bool, +) -> datafusion::common::Result { + let nullability = if nullable { + r#type::Nullability::Nullable as i32 + } else { + r#type::Nullability::Required as i32 + }; + match dt { + DataType::Null => internal_err!("Null cast is not valid"), + DataType::Boolean => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::Bool(r#type::Boolean { + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + nullability, + })), + }), + DataType::Int8 => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::I8(r#type::I8 { + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + nullability, + })), + }), + DataType::UInt8 => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::I8(r#type::I8 { + type_variation_reference: UNSIGNED_INTEGER_TYPE_VARIATION_REF, + nullability, + })), + }), + DataType::Int16 => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::I16(r#type::I16 { + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + nullability, + })), + }), + DataType::UInt16 => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::I16(r#type::I16 { + type_variation_reference: UNSIGNED_INTEGER_TYPE_VARIATION_REF, + nullability, + })), + }), + DataType::Int32 => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::I32(r#type::I32 { + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + nullability, + })), + }), + DataType::UInt32 => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::I32(r#type::I32 { + type_variation_reference: UNSIGNED_INTEGER_TYPE_VARIATION_REF, + nullability, + })), + }), + DataType::Int64 => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::I64(r#type::I64 { + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + nullability, + })), + }), + DataType::UInt64 => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::I64(r#type::I64 { + type_variation_reference: UNSIGNED_INTEGER_TYPE_VARIATION_REF, + nullability, + })), + }), + // Float16 is not supported in Substrait + DataType::Float32 => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::Fp32(r#type::Fp32 { + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + nullability, + })), + }), + DataType::Float64 => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::Fp64(r#type::Fp64 { + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + nullability, + })), + }), + DataType::Timestamp(unit, tz) => { + let precision = match unit { + TimeUnit::Second => 0, + TimeUnit::Millisecond => 3, + TimeUnit::Microsecond => 6, + TimeUnit::Nanosecond => 9, + }; + let kind = match tz { + None => r#type::Kind::PrecisionTimestamp(r#type::PrecisionTimestamp { + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + nullability, + precision, + }), + Some(_) => { + // If timezone is present, no matter what the actual tz value is, it indicates the + // value of the timestamp is tied to UTC epoch. That's all that Substrait cares about. + // As the timezone is lost, this conversion may be lossy for downstream use of the value. + r#type::Kind::PrecisionTimestampTz(r#type::PrecisionTimestampTz { + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + nullability, + precision, + }) + } + }; + Ok(substrait::proto::Type { kind: Some(kind) }) + } + DataType::Date32 => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::Date(r#type::Date { + type_variation_reference: DATE_32_TYPE_VARIATION_REF, + nullability, + })), + }), + DataType::Date64 => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::Date(r#type::Date { + type_variation_reference: DATE_64_TYPE_VARIATION_REF, + nullability, + })), + }), + DataType::Interval(interval_unit) => { + match interval_unit { + IntervalUnit::YearMonth => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::IntervalYear(r#type::IntervalYear { + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + nullability, + })), + }), + IntervalUnit::DayTime => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::IntervalDay(r#type::IntervalDay { + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + nullability, + precision: Some(3), // DayTime precision is always milliseconds + })), + }), + IntervalUnit::MonthDayNano => { + Ok(substrait::proto::Type { + kind: Some(r#type::Kind::IntervalCompound( + r#type::IntervalCompound { + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + nullability, + precision: 9, // nanos + }, + )), + }) + } + } + } + DataType::Binary => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::Binary(r#type::Binary { + type_variation_reference: DEFAULT_CONTAINER_TYPE_VARIATION_REF, + nullability, + })), + }), + DataType::FixedSizeBinary(length) => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::FixedBinary(r#type::FixedBinary { + length: *length, + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + nullability, + })), + }), + DataType::LargeBinary => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::Binary(r#type::Binary { + type_variation_reference: LARGE_CONTAINER_TYPE_VARIATION_REF, + nullability, + })), + }), + DataType::BinaryView => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::Binary(r#type::Binary { + type_variation_reference: VIEW_CONTAINER_TYPE_VARIATION_REF, + nullability, + })), + }), + DataType::Utf8 => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::String(r#type::String { + type_variation_reference: DEFAULT_CONTAINER_TYPE_VARIATION_REF, + nullability, + })), + }), + DataType::LargeUtf8 => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::String(r#type::String { + type_variation_reference: LARGE_CONTAINER_TYPE_VARIATION_REF, + nullability, + })), + }), + DataType::Utf8View => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::String(r#type::String { + type_variation_reference: VIEW_CONTAINER_TYPE_VARIATION_REF, + nullability, + })), + }), + DataType::List(inner) => { + let inner_type = to_substrait_type(inner.data_type(), inner.is_nullable())?; + Ok(substrait::proto::Type { + kind: Some(r#type::Kind::List(Box::new(r#type::List { + r#type: Some(Box::new(inner_type)), + type_variation_reference: DEFAULT_CONTAINER_TYPE_VARIATION_REF, + nullability, + }))), + }) + } + DataType::LargeList(inner) => { + let inner_type = to_substrait_type(inner.data_type(), inner.is_nullable())?; + Ok(substrait::proto::Type { + kind: Some(r#type::Kind::List(Box::new(r#type::List { + r#type: Some(Box::new(inner_type)), + type_variation_reference: LARGE_CONTAINER_TYPE_VARIATION_REF, + nullability, + }))), + }) + } + DataType::Map(inner, _) => match inner.data_type() { + DataType::Struct(key_and_value) if key_and_value.len() == 2 => { + let key_type = to_substrait_type( + key_and_value[0].data_type(), + key_and_value[0].is_nullable(), + )?; + let value_type = to_substrait_type( + key_and_value[1].data_type(), + key_and_value[1].is_nullable(), + )?; + Ok(substrait::proto::Type { + kind: Some(r#type::Kind::Map(Box::new(r#type::Map { + key: Some(Box::new(key_type)), + value: Some(Box::new(value_type)), + type_variation_reference: DEFAULT_CONTAINER_TYPE_VARIATION_REF, + nullability, + }))), + }) + } + _ => plan_err!("Map fields must contain a Struct with exactly 2 fields"), + }, + DataType::Struct(fields) => { + let field_types = fields + .iter() + .map(|field| to_substrait_type(field.data_type(), field.is_nullable())) + .collect::>>()?; + Ok(substrait::proto::Type { + kind: Some(r#type::Kind::Struct(r#type::Struct { + types: field_types, + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + nullability, + })), + }) + } + DataType::Decimal128(p, s) => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::Decimal(r#type::Decimal { + type_variation_reference: DECIMAL_128_TYPE_VARIATION_REF, + nullability, + scale: *s as i32, + precision: *p as i32, + })), + }), + DataType::Decimal256(p, s) => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::Decimal(r#type::Decimal { + type_variation_reference: DECIMAL_256_TYPE_VARIATION_REF, + nullability, + scale: *s as i32, + precision: *p as i32, + })), + }), + _ => not_impl_err!("Unsupported cast type: {dt:?}"), + } +} + +pub(crate) fn to_substrait_named_struct( + schema: &DFSchemaRef, +) -> datafusion::common::Result { + let mut names = Vec::with_capacity(schema.fields().len()); + for field in schema.fields() { + flatten_names(field, false, &mut names)?; + } + + let field_types = r#type::Struct { + types: schema + .fields() + .iter() + .map(|f| to_substrait_type(f.data_type(), f.is_nullable())) + .collect::>()?, + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + nullability: r#type::Nullability::Required as i32, + }; + + Ok(NamedStruct { + names, + r#struct: Some(field_types), + }) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::logical_plan::consumer::tests::test_consumer; + use crate::logical_plan::consumer::{ + from_substrait_named_struct, from_substrait_type_without_names, + }; + use datafusion::arrow::datatypes::{Field, Fields, Schema}; + use datafusion::common::{DFSchema, Result}; + use std::sync::Arc; + + #[test] + fn round_trip_types() -> Result<()> { + round_trip_type(DataType::Boolean)?; + round_trip_type(DataType::Int8)?; + round_trip_type(DataType::UInt8)?; + round_trip_type(DataType::Int16)?; + round_trip_type(DataType::UInt16)?; + round_trip_type(DataType::Int32)?; + round_trip_type(DataType::UInt32)?; + round_trip_type(DataType::Int64)?; + round_trip_type(DataType::UInt64)?; + round_trip_type(DataType::Float32)?; + round_trip_type(DataType::Float64)?; + + for tz in [None, Some("UTC".into())] { + round_trip_type(DataType::Timestamp(TimeUnit::Second, tz.clone()))?; + round_trip_type(DataType::Timestamp(TimeUnit::Millisecond, tz.clone()))?; + round_trip_type(DataType::Timestamp(TimeUnit::Microsecond, tz.clone()))?; + round_trip_type(DataType::Timestamp(TimeUnit::Nanosecond, tz))?; + } + + round_trip_type(DataType::Date32)?; + round_trip_type(DataType::Date64)?; + round_trip_type(DataType::Binary)?; + round_trip_type(DataType::FixedSizeBinary(10))?; + round_trip_type(DataType::LargeBinary)?; + round_trip_type(DataType::BinaryView)?; + round_trip_type(DataType::Utf8)?; + round_trip_type(DataType::LargeUtf8)?; + round_trip_type(DataType::Utf8View)?; + round_trip_type(DataType::Decimal128(10, 2))?; + round_trip_type(DataType::Decimal256(30, 2))?; + + round_trip_type(DataType::List( + Field::new_list_field(DataType::Int32, true).into(), + ))?; + round_trip_type(DataType::LargeList( + Field::new_list_field(DataType::Int32, true).into(), + ))?; + + round_trip_type(DataType::Map( + Field::new_struct( + "entries", + [ + Field::new("key", DataType::Utf8, false).into(), + Field::new("value", DataType::Int32, true).into(), + ], + false, + ) + .into(), + false, + ))?; + + round_trip_type(DataType::Struct( + vec![ + Field::new("c0", DataType::Int32, true), + Field::new("c1", DataType::Utf8, true), + ] + .into(), + ))?; + + round_trip_type(DataType::Interval(IntervalUnit::YearMonth))?; + round_trip_type(DataType::Interval(IntervalUnit::MonthDayNano))?; + round_trip_type(DataType::Interval(IntervalUnit::DayTime))?; + + Ok(()) + } + + fn round_trip_type(dt: DataType) -> Result<()> { + println!("Checking round trip of {dt:?}"); + + // As DataFusion doesn't consider nullability as a property of the type, but field, + // it doesn't matter if we set nullability to true or false here. + let substrait = to_substrait_type(&dt, true)?; + let consumer = test_consumer(); + let roundtrip_dt = from_substrait_type_without_names(&consumer, &substrait)?; + assert_eq!(dt, roundtrip_dt); + Ok(()) + } + + #[test] + fn named_struct_names() -> Result<()> { + let schema = DFSchemaRef::new(DFSchema::try_from(Schema::new(vec![ + Field::new("int", DataType::Int32, true), + Field::new( + "struct", + DataType::Struct(Fields::from(vec![Field::new( + "inner", + DataType::List(Arc::new(Field::new_list_field(DataType::Utf8, true))), + true, + )])), + true, + ), + Field::new("trailer", DataType::Float64, true), + ]))?); + + let named_struct = to_substrait_named_struct(&schema)?; + + // Struct field names should be flattened DFS style + // List field names should be omitted + assert_eq!( + named_struct.names, + vec!["int", "struct", "inner", "trailer"] + ); + + let roundtrip_schema = + from_substrait_named_struct(&test_consumer(), &named_struct)?; + assert_eq!(schema.as_ref(), &roundtrip_schema); + Ok(()) + } +} diff --git a/datafusion/substrait/src/logical_plan/producer/utils.rs b/datafusion/substrait/src/logical_plan/producer/utils.rs new file mode 100644 index 000000000000..5429e4a1ad88 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/utils.rs @@ -0,0 +1,78 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::producer::SubstraitProducer; +use datafusion::arrow::datatypes::{DataType, Field}; +use datafusion::common::{plan_err, DFSchemaRef}; +use datafusion::logical_expr::SortExpr; +use substrait::proto::sort_field::{SortDirection, SortKind}; +use substrait::proto::SortField; + +// Substrait wants a list of all field names, including nested fields from structs, +// also from within e.g. lists and maps. However, it does not want the list and map field names +// themselves - only proper structs fields are considered to have useful names. +pub(crate) fn flatten_names( + field: &Field, + skip_self: bool, + names: &mut Vec, +) -> datafusion::common::Result<()> { + if !skip_self { + names.push(field.name().to_string()); + } + match field.data_type() { + DataType::Struct(fields) => { + for field in fields { + flatten_names(field, false, names)?; + } + Ok(()) + } + DataType::List(l) => flatten_names(l, true, names), + DataType::LargeList(l) => flatten_names(l, true, names), + DataType::Map(m, _) => match m.data_type() { + DataType::Struct(key_and_value) if key_and_value.len() == 2 => { + flatten_names(&key_and_value[0], true, names)?; + flatten_names(&key_and_value[1], true, names) + } + _ => plan_err!("Map fields must contain a Struct with exactly 2 fields"), + }, + _ => Ok(()), + }?; + Ok(()) +} + +pub(crate) fn substrait_sort_field( + producer: &mut impl SubstraitProducer, + sort: &SortExpr, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + let SortExpr { + expr, + asc, + nulls_first, + } = sort; + let e = producer.handle_expr(expr, schema)?; + let d = match (asc, nulls_first) { + (true, true) => SortDirection::AscNullsFirst, + (true, false) => SortDirection::AscNullsLast, + (false, true) => SortDirection::DescNullsFirst, + (false, false) => SortDirection::DescNullsLast, + }; + Ok(SortField { + expr: Some(e), + sort_kind: Some(SortKind::Direction(d as i32)), + }) +} diff --git a/datafusion/substrait/tests/cases/consumer_integration.rs b/datafusion/substrait/tests/cases/consumer_integration.rs index bdeeeb585c0c..4a121e41d27e 100644 --- a/datafusion/substrait/tests/cases/consumer_integration.rs +++ b/datafusion/substrait/tests/cases/consumer_integration.rs @@ -44,7 +44,7 @@ mod tests { let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto)?; let plan = from_substrait_plan(&ctx.state(), &proto).await?; ctx.state().create_physical_plan(&plan).await?; - Ok(format!("{}", plan)) + Ok(format!("{plan}")) } #[tokio::test] @@ -501,7 +501,7 @@ mod tests { let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto)?; let plan = from_substrait_plan(&ctx.state(), &proto).await?; ctx.state().create_physical_plan(&plan).await?; - Ok(format!("{}", plan)) + Ok(format!("{plan}")) } #[tokio::test] @@ -560,4 +560,28 @@ mod tests { ); Ok(()) } + + #[tokio::test] + async fn test_multiple_unions() -> Result<()> { + let plan_str = test_plan_to_string("multiple_unions.json").await?; + assert_snapshot!( + plan_str, + @r#" + Projection: Utf8("people") AS product_category, Utf8("people")__temp__0 AS product_type, product_key + Union + Projection: Utf8("people"), Utf8("people") AS Utf8("people")__temp__0, sales.product_key + Left Join: sales.product_key = food.@food_id + TableScan: sales + TableScan: food + Union + Projection: people.$f3, people.$f5, people.product_key0 + Left Join: people.product_key0 = food.@food_id + TableScan: people + TableScan: food + TableScan: more_products + "# + ); + + Ok(()) + } } diff --git a/datafusion/substrait/tests/cases/emit_kind_tests.rs b/datafusion/substrait/tests/cases/emit_kind_tests.rs index 88db2bc34d7f..e916b4cb0e1a 100644 --- a/datafusion/substrait/tests/cases/emit_kind_tests.rs +++ b/datafusion/substrait/tests/cases/emit_kind_tests.rs @@ -126,8 +126,8 @@ mod tests { let plan1str = format!("{plan}"); let plan2str = format!("{plan2}"); - println!("{}", plan1str); - println!("{}", plan2str); + println!("{plan1str}"); + println!("{plan2str}"); assert_eq!(plan1str, plan2str); Ok(()) diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 9a85f3e6c4dc..7a5cfeb39836 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -763,7 +763,7 @@ async fn simple_intersect() -> Result<()> { let expected_plan_str = format!( "Projection: count(Int64(1)) AS {syntax}\ \n Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]]\ - \n Projection: \ + \n Projection:\ \n LeftSemi Join: data.a = data2.a\ \n Aggregate: groupBy=[[data.a]], aggr=[[]]\ \n TableScan: data projection=[a]\ @@ -780,7 +780,7 @@ async fn simple_intersect() -> Result<()> { async fn check_constant(sql_syntax: &str, plan_expr: &str) -> Result<()> { let expected_plan_str = format!( "Aggregate: groupBy=[[]], aggr=[[{plan_expr}]]\ - \n Projection: \ + \n Projection:\ \n LeftSemi Join: data.a = data2.a\ \n Aggregate: groupBy=[[data.a]], aggr=[[]]\ \n TableScan: data projection=[a]\ @@ -854,6 +854,22 @@ async fn aggregate_wo_projection_sorted_consume() -> Result<()> { Ok(()) } +#[tokio::test] +async fn aggregate_identical_grouping_expressions() -> Result<()> { + let proto_plan = + read_json("tests/testdata/test_plans/aggregate_identical_grouping_expressions.substrait.json"); + + let plan = generate_plan_from_substrait(proto_plan).await?; + assert_snapshot!( + plan, + @r#" + Aggregate: groupBy=[[Int32(1) AS grouping_col_1, Int32(1) AS grouping_col_2]], aggr=[[]] + TableScan: data projection=[] + "# + ); + Ok(()) +} + #[tokio::test] async fn simple_intersect_consume() -> Result<()> { let proto_plan = read_json("tests/testdata/test_plans/intersect.substrait.json"); @@ -942,7 +958,7 @@ async fn simple_intersect_table_reuse() -> Result<()> { let expected_plan_str = format!( "Projection: count(Int64(1)) AS {syntax}\ \n Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]]\ - \n Projection: \ + \n Projection:\ \n LeftSemi Join: left.a = right.a\ \n SubqueryAlias: left\ \n Aggregate: groupBy=[[data.a]], aggr=[[]]\ @@ -961,7 +977,7 @@ async fn simple_intersect_table_reuse() -> Result<()> { async fn check_constant(sql_syntax: &str, plan_expr: &str) -> Result<()> { let expected_plan_str = format!( "Aggregate: groupBy=[[]], aggr=[[{plan_expr}]]\ - \n Projection: \ + \n Projection:\ \n LeftSemi Join: left.a = right.a\ \n SubqueryAlias: left\ \n Aggregate: groupBy=[[data.a]], aggr=[[]]\ diff --git a/datafusion/substrait/tests/testdata/test_plans/aggregate_identical_grouping_expressions.substrait.json b/datafusion/substrait/tests/testdata/test_plans/aggregate_identical_grouping_expressions.substrait.json new file mode 100644 index 000000000000..15c0b0505fa6 --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/aggregate_identical_grouping_expressions.substrait.json @@ -0,0 +1,53 @@ +{ + "extensionUris": [], + "extensions": [], + "relations": [ + { + "root": { + "input": { + "aggregate": { + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [], + "struct": { + "types": [], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": ["data"] + } + } + }, + "groupings": [ + { + "groupingExpressions": [ + { + "literal": { + "i32": 1 + } + }, + { + "literal": { + "i32": 1 + } + } + ] + } + ], + "measures": [] + } + }, + "names": ["grouping_col_1", "grouping_col_2"] + } + } + ], + "version": { + "minorNumber": 54, + "producer": "manual" + } +} diff --git a/datafusion/substrait/tests/testdata/test_plans/multiple_unions.json b/datafusion/substrait/tests/testdata/test_plans/multiple_unions.json new file mode 100644 index 000000000000..8b82d6eec755 --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/multiple_unions.json @@ -0,0 +1,328 @@ +{ + "extensionUris": [{ + "extensionUriAnchor": 1, + "uri": "/functions_comparison.yaml" + }], + "extensions": [{ + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 0, + "name": "equal:any_any" + } + }], + "relations": [{ + "root": { + "input": { + "set": { + "common": { + "direct": { + } + }, + "inputs": [{ + "project": { + "common": { + "emit": { + "outputMapping": [2, 3, 4] + } + }, + "input": { + "join": { + "common": { + "direct": { + } + }, + "left": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["product_key"], + "struct": { + "types": [{ + "string": { + "nullability": "NULLABILITY_NULLABLE" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": [ + "sales" + ] + } + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["@food_id"], + "struct": { + "types": [{ + "string": { + "nullability": "NULLABILITY_NULLABLE" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": [ + "food" + ] + } + } + }, + "expression": { + "scalarFunction": { + "functionReference": 0, + "outputType": { + "bool": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + } + }] + } + }, + "type": "JOIN_TYPE_LEFT" + } + }, + "expressions": [{ + "literal": { + "string": "people" + } + }, { + "literal": { + "string": "people" + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + }] + } + }, { + "set": { + "common": { + "direct": { + } + }, + "inputs": [{ + "project": { + "common": { + "emit": { + "outputMapping": [4, 5, 6] + } + }, + "input": { + "join": { + "common": { + "direct": { + } + }, + "left": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["$f3", "$f5", "product_key0"], + "struct": { + "types": [{ + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_NULLABLE" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": [ + "people" + ] + } + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["@food_id"], + "struct": { + "types": [{ + "string": { + "nullability": "NULLABILITY_NULLABLE" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": [ + "food" + ] + } + + } + }, + "expression": { + "scalarFunction": { + "functionReference": 0, + "outputType": { + "bool": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 3 + } + }, + "rootReference": { + } + } + } + }] + } + }, + "type": "JOIN_TYPE_LEFT" + } + }, + "expressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + }] + } + }, { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["$f1000", "$f2000", "more_products_key0000"], + "struct": { + "types": [{ + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_NULLABLE" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": [ + "more_products" + ] + } + + } + }], + "op": "SET_OP_UNION_ALL" + } + }], + "op": "SET_OP_UNION_ALL" + } + }, + "names": ["product_category", "product_type", "product_key"] + } + }] +} \ No newline at end of file diff --git a/datafusion/wasmtest/Cargo.toml b/datafusion/wasmtest/Cargo.toml index 10eab025734c..b43c34f19760 100644 --- a/datafusion/wasmtest/Cargo.toml +++ b/datafusion/wasmtest/Cargo.toml @@ -52,14 +52,13 @@ datafusion-expr = { workspace = true } datafusion-optimizer = { workspace = true, default-features = true } datafusion-physical-plan = { workspace = true } datafusion-sql = { workspace = true } -# getrandom must be compiled with js feature -getrandom = { version = "0.2.8", features = ["js"] } - +getrandom = { version = "0.3", features = ["wasm_js"] } wasm-bindgen = "0.2.99" [dev-dependencies] insta = { workspace = true } object_store = { workspace = true } +# needs to be compiled tokio = { workspace = true } url = { workspace = true } wasm-bindgen-test = "0.3.49" diff --git a/datafusion/wasmtest/datafusion-wasm-app/package-lock.json b/datafusion/wasmtest/datafusion-wasm-app/package-lock.json index 65d8bdbb5e93..c018e779fcbf 100644 --- a/datafusion/wasmtest/datafusion-wasm-app/package-lock.json +++ b/datafusion/wasmtest/datafusion-wasm-app/package-lock.json @@ -2007,10 +2007,11 @@ } }, "node_modules/http-proxy-middleware": { - "version": "2.0.6", - "resolved": "https://registry.npmjs.org/http-proxy-middleware/-/http-proxy-middleware-2.0.6.tgz", - "integrity": "sha512-ya/UeJ6HVBYxrgYotAZo1KvPWlgB48kUJLDePFeneHsVujFaW5WNj2NgWCAE//B1Dl02BIfYlpNgBy8Kf8Rjmw==", + "version": "2.0.9", + "resolved": "https://registry.npmjs.org/http-proxy-middleware/-/http-proxy-middleware-2.0.9.tgz", + "integrity": "sha512-c1IyJYLYppU574+YI7R4QyX2ystMtVXZwIdzazUIPIJsHuWNd+mho2j+bKoHftndicGj9yh+xjd+l0yj7VeT1Q==", "dev": true, + "license": "MIT", "dependencies": { "@types/http-proxy": "^1.17.8", "http-proxy": "^1.18.1", @@ -5562,9 +5563,9 @@ } }, "http-proxy-middleware": { - "version": "2.0.6", - "resolved": "https://registry.npmjs.org/http-proxy-middleware/-/http-proxy-middleware-2.0.6.tgz", - "integrity": "sha512-ya/UeJ6HVBYxrgYotAZo1KvPWlgB48kUJLDePFeneHsVujFaW5WNj2NgWCAE//B1Dl02BIfYlpNgBy8Kf8Rjmw==", + "version": "2.0.9", + "resolved": "https://registry.npmjs.org/http-proxy-middleware/-/http-proxy-middleware-2.0.9.tgz", + "integrity": "sha512-c1IyJYLYppU574+YI7R4QyX2ystMtVXZwIdzazUIPIJsHuWNd+mho2j+bKoHftndicGj9yh+xjd+l0yj7VeT1Q==", "dev": true, "requires": { "@types/http-proxy": "^1.17.8", diff --git a/datafusion/wasmtest/src/lib.rs b/datafusion/wasmtest/src/lib.rs index 0a7e546b4b18..e30a1046ab27 100644 --- a/datafusion/wasmtest/src/lib.rs +++ b/datafusion/wasmtest/src/lib.rs @@ -92,7 +92,8 @@ mod test { }; use datafusion_common::test_util::batches_to_string; use datafusion_execution::{ - config::SessionConfig, disk_manager::DiskManagerConfig, + config::SessionConfig, + disk_manager::{DiskManagerBuilder, DiskManagerMode}, runtime_env::RuntimeEnvBuilder, }; use datafusion_physical_plan::collect; @@ -112,7 +113,9 @@ mod test { fn get_ctx() -> Arc { let rt = RuntimeEnvBuilder::new() - .with_disk_manager(DiskManagerConfig::Disabled) + .with_disk_manager_builder( + DiskManagerBuilder::default().with_mode(DiskManagerMode::Disabled), + ) .build_arc() .unwrap(); let session_config = SessionConfig::new().with_target_partitions(1); diff --git a/dev/changelog/48.0.0.md b/dev/changelog/48.0.0.md new file mode 100644 index 000000000000..42f128bcb7b5 --- /dev/null +++ b/dev/changelog/48.0.0.md @@ -0,0 +1,407 @@ + + +# Apache DataFusion 48.0.0 Changelog + +This release consists of 269 commits from 89 contributors. See credits at the end of this changelog for more information. + +**Breaking changes:** + +- Attach Diagnostic to syntax errors [#15680](https://github.com/apache/datafusion/pull/15680) (logan-keede) +- Change `flatten` so it does only a level, not recursively [#15160](https://github.com/apache/datafusion/pull/15160) (delamarch3) +- Improve `simplify_expressions` rule [#15735](https://github.com/apache/datafusion/pull/15735) (xudong963) +- Support WITHIN GROUP syntax to standardize certain existing aggregate functions [#13511](https://github.com/apache/datafusion/pull/13511) (Garamda) +- Add Extension Type / Metadata support for Scalar UDFs [#15646](https://github.com/apache/datafusion/pull/15646) (timsaucer) +- chore: fix clippy::large_enum_variant for DataFusionError [#15861](https://github.com/apache/datafusion/pull/15861) (rroelke) +- Feat: introduce `ExecutionPlan::partition_statistics` API [#15852](https://github.com/apache/datafusion/pull/15852) (xudong963) +- refactor: remove deprecated `ParquetExec` [#15973](https://github.com/apache/datafusion/pull/15973) (miroim) +- refactor: remove deprecated `ArrowExec` [#16006](https://github.com/apache/datafusion/pull/16006) (miroim) +- refactor: remove deprecated `MemoryExec` [#16007](https://github.com/apache/datafusion/pull/16007) (miroim) +- refactor: remove deprecated `JsonExec` [#16005](https://github.com/apache/datafusion/pull/16005) (miroim) +- feat: metadata handling for aggregates and window functions [#15911](https://github.com/apache/datafusion/pull/15911) (timsaucer) +- Remove `Filter::having` field [#16154](https://github.com/apache/datafusion/pull/16154) (findepi) +- Shift from Field to FieldRef for all user defined functions [#16122](https://github.com/apache/datafusion/pull/16122) (timsaucer) +- Change default SQL mapping for `VARCAHR` from `Utf8` to `Utf8View` [#16142](https://github.com/apache/datafusion/pull/16142) (zhuqi-lucas) +- Minor: remove unused IPCWriter [#16215](https://github.com/apache/datafusion/pull/16215) (alamb) +- Reduce size of `Expr` struct [#16207](https://github.com/apache/datafusion/pull/16207) (hendrikmakait) + +**Performance related:** + +- Apply pre-selection and computation skipping to short-circuit optimization [#15694](https://github.com/apache/datafusion/pull/15694) (acking-you) +- Add a fast path for `optimize_projection` [#15746](https://github.com/apache/datafusion/pull/15746) (xudong963) +- Speed up `optimize_projection` by improving `is_projection_unnecessary` [#15761](https://github.com/apache/datafusion/pull/15761) (xudong963) +- Speed up `optimize_projection` [#15787](https://github.com/apache/datafusion/pull/15787) (xudong963) +- Support `GroupsAccumulator` for Avg duration [#15748](https://github.com/apache/datafusion/pull/15748) (shruti2522) +- Optimize performance of `string::ascii` function [#16087](https://github.com/apache/datafusion/pull/16087) (tlm365) + +**Implemented enhancements:** + +- Set DataFusion runtime configurations through SQL interface [#15594](https://github.com/apache/datafusion/pull/15594) (kumarlokesh) +- feat: Add option to adjust writer buffer size for query output [#15747](https://github.com/apache/datafusion/pull/15747) (m09526) +- feat: Add `datafusion-spark` crate [#15168](https://github.com/apache/datafusion/pull/15168) (shehabgamin) +- feat: create helpers to set the max_temp_directory_size [#15919](https://github.com/apache/datafusion/pull/15919) (jdrouet) +- feat: ORDER BY ALL [#15772](https://github.com/apache/datafusion/pull/15772) (PokIsemaine) +- feat: support min/max for struct [#15667](https://github.com/apache/datafusion/pull/15667) (chenkovsky) +- feat(proto): udf decoding fallback [#15997](https://github.com/apache/datafusion/pull/15997) (leoyvens) +- feat: make error handling in indent explain consistent with that in tree [#16097](https://github.com/apache/datafusion/pull/16097) (chenkovsky) +- feat: coerce to/from fixed size binary to binary view [#16110](https://github.com/apache/datafusion/pull/16110) (chenkovsky) +- feat: array_length for fixed size list [#16167](https://github.com/apache/datafusion/pull/16167) (chenkovsky) +- feat: ADD sha2 spark function [#16168](https://github.com/apache/datafusion/pull/16168) (getChan) +- feat: create builder for disk manager [#16191](https://github.com/apache/datafusion/pull/16191) (jdrouet) +- feat: Add Aggregate UDF to FFI crate [#14775](https://github.com/apache/datafusion/pull/14775) (timsaucer) +- feat(small): Add `BaselineMetrics` to `generate_series()` table function [#16255](https://github.com/apache/datafusion/pull/16255) (2010YOUY01) +- feat: Add Window UDFs to FFI Crate [#16261](https://github.com/apache/datafusion/pull/16261) (timsaucer) + +**Fixed bugs:** + +- fix: serialize listing table without partition column [#15737](https://github.com/apache/datafusion/pull/15737) (chenkovsky) +- fix: describe Parquet schema with coerce_int96 [#15750](https://github.com/apache/datafusion/pull/15750) (chenkovsky) +- fix: clickbench type err [#15773](https://github.com/apache/datafusion/pull/15773) (chenkovsky) +- Fix: fetch is missing in `replace_order_preserving_variants` method during `EnforceDistribution` optimizer [#15808](https://github.com/apache/datafusion/pull/15808) (xudong963) +- Fix: fetch is missing in `EnforceSorting` optimizer (two places) [#15822](https://github.com/apache/datafusion/pull/15822) (xudong963) +- fix: Avoid mistaken ILike to string equality optimization [#15836](https://github.com/apache/datafusion/pull/15836) (srh) +- Map file-level column statistics to the table-level [#15865](https://github.com/apache/datafusion/pull/15865) (xudong963) +- fix(avro): Respect projection order in Avro reader [#15840](https://github.com/apache/datafusion/pull/15840) (nantunes) +- fix: correctly specify the nullability of `map_values` return type [#15901](https://github.com/apache/datafusion/pull/15901) (rluvaton) +- Fix CI in main [#15917](https://github.com/apache/datafusion/pull/15917) (blaginin) +- fix: sqllogictest on Windows [#15932](https://github.com/apache/datafusion/pull/15932) (nuno-faria) +- fix: fold cast null to substrait typed null [#15854](https://github.com/apache/datafusion/pull/15854) (discord9) +- Fix: `build_predicate_expression` method doesn't process `false` expr correctly [#15995](https://github.com/apache/datafusion/pull/15995) (xudong963) +- fix: add an "expr_planners" method to SessionState [#15119](https://github.com/apache/datafusion/pull/15119) (niebayes) +- fix: overcounting of memory in first/last. [#15924](https://github.com/apache/datafusion/pull/15924) (ashdnazg) +- fix: track timing for coalescer's in execution time [#16048](https://github.com/apache/datafusion/pull/16048) (waynexia) +- fix: stack overflow for substrait functions with large argument lists that translate to DataFusion binary operators [#16031](https://github.com/apache/datafusion/pull/16031) (fmonjalet) +- fix: coerce int96 resolution inside of list, struct, and map types [#16058](https://github.com/apache/datafusion/pull/16058) (mbutrovich) +- fix: Add coercion rules for Float16 types [#15816](https://github.com/apache/datafusion/pull/15816) (etseidl) +- fix: describe escaped quoted identifiers [#16082](https://github.com/apache/datafusion/pull/16082) (jfahne) +- fix: Remove trailing whitespace in `Display` for `LogicalPlan::Projection` [#16164](https://github.com/apache/datafusion/pull/16164) (atahanyorganci) +- fix: metadata of join schema [#16221](https://github.com/apache/datafusion/pull/16221) (chenkovsky) +- fix: add missing row count limits to TPC-H queries [#16230](https://github.com/apache/datafusion/pull/16230) (0ax1) +- fix: NaN semantics in GROUP BY [#16256](https://github.com/apache/datafusion/pull/16256) (chenkovsky) +- fix: [branch-48] Revert "Improve performance of constant aggregate window expression" [#16307](https://github.com/apache/datafusion/pull/16307) (andygrove) + +**Documentation updates:** + +- Add DataFusion 47.0.0 Upgrade Guide [#15749](https://github.com/apache/datafusion/pull/15749) (alamb) +- Improve documentation for format `OPTIONS` clause [#15708](https://github.com/apache/datafusion/pull/15708) (marvelshan) +- doc: Adding Feldera as known user [#15799](https://github.com/apache/datafusion/pull/15799) (comphead) +- docs: add ArkFlow [#15826](https://github.com/apache/datafusion/pull/15826) (chenquan) +- Fix `from_unixtime` function documentation [#15844](https://github.com/apache/datafusion/pull/15844) (Viicos) +- Upgrade-guide: Downgrade "FileScanConfig –> FileScanConfigBuilder" headline [#15883](https://github.com/apache/datafusion/pull/15883) (simonvandel) +- doc: Update known users docs [#15895](https://github.com/apache/datafusion/pull/15895) (comphead) +- Add `union_tag` scalar function [#14687](https://github.com/apache/datafusion/pull/14687) (gstvg) +- Fix typo in introduction.md [#15910](https://github.com/apache/datafusion/pull/15910) (tom-mont) +- Add `FormatOptions` to Config [#15793](https://github.com/apache/datafusion/pull/15793) (blaginin) +- docs: Label `bloom_filter_on_read` as a reading config [#15933](https://github.com/apache/datafusion/pull/15933) (nuno-faria) +- Implement Parquet filter pushdown via new filter pushdown APIs [#15769](https://github.com/apache/datafusion/pull/15769) (adriangb) +- Enable repartitioning on MemTable. [#15409](https://github.com/apache/datafusion/pull/15409) (wiedld) +- Updated extending operators documentation [#15612](https://github.com/apache/datafusion/pull/15612) (the0ninjas) +- chore: Replace MSRV link on main page with Github badge [#16020](https://github.com/apache/datafusion/pull/16020) (comphead) +- Add note to upgrade guide for removal of `ParquetExec`, `AvroExec`, `CsvExec`, `JsonExec` [#16034](https://github.com/apache/datafusion/pull/16034) (alamb) +- docs: Clarify that it is only the name of the field that is ignored [#16052](https://github.com/apache/datafusion/pull/16052) (alamb) +- [Docs]: Added SQL example for all window functions [#16074](https://github.com/apache/datafusion/pull/16074) (Adez017) +- Fix CI on main: Add window function examples in code [#16102](https://github.com/apache/datafusion/pull/16102) (alamb) +- chore: Remove SMJ experimental status in docs [#16072](https://github.com/apache/datafusion/pull/16072) (comphead) +- doc: fix indent format explain [#16085](https://github.com/apache/datafusion/pull/16085) (chenkovsky) +- Update documentation for `datafusion.execution.collect_statistics` [#16100](https://github.com/apache/datafusion/pull/16100) (alamb) +- Make `SessionContext::register_parquet` obey `collect_statistics` config [#16080](https://github.com/apache/datafusion/pull/16080) (adriangb) +- Improve the DML / DDL Documentation [#16115](https://github.com/apache/datafusion/pull/16115) (alamb) +- docs: Fix typos and minor grammatical issues in Architecture docs [#16119](https://github.com/apache/datafusion/pull/16119) (patrickcsullivan) +- Set `TrackConsumersPool` as default in datafusion-cli [#16081](https://github.com/apache/datafusion/pull/16081) (ding-young) +- Minor: Fix links in substrait readme [#16156](https://github.com/apache/datafusion/pull/16156) (alamb) +- Add macro for creating DataFrame (#16090) [#16104](https://github.com/apache/datafusion/pull/16104) (cj-zhukov) +- doc: Move `dataframe!` example into dedicated example [#16197](https://github.com/apache/datafusion/pull/16197) (comphead) +- doc: add diagram to describe how DataSource, FileSource, and DataSourceExec are related [#16181](https://github.com/apache/datafusion/pull/16181) (onlyjackfrost) +- Clarify documentation about gathering statistics for parquet files [#16157](https://github.com/apache/datafusion/pull/16157) (alamb) +- Add change to VARCHAR in the upgrade guide [#16216](https://github.com/apache/datafusion/pull/16216) (alamb) +- Add iceberg-rust to user list [#16246](https://github.com/apache/datafusion/pull/16246) (jonathanc-n) +- Prepare for 48.0.0 release: Version and Changelog [#16238](https://github.com/apache/datafusion/pull/16238) (xudong963) + +**Other:** + +- Enable setting default values for target_partitions and planning_concurrency [#15712](https://github.com/apache/datafusion/pull/15712) (nuno-faria) +- minor: fix doc comment [#15733](https://github.com/apache/datafusion/pull/15733) (niebayes) +- chore(deps-dev): bump http-proxy-middleware from 2.0.6 to 2.0.9 in /datafusion/wasmtest/datafusion-wasm-app [#15738](https://github.com/apache/datafusion/pull/15738) (dependabot[bot]) +- Avoid computing unnecessary statstics [#15729](https://github.com/apache/datafusion/pull/15729) (xudong963) +- chore(deps): bump libc from 0.2.171 to 0.2.172 [#15745](https://github.com/apache/datafusion/pull/15745) (dependabot[bot]) +- Final release note touchups [#15741](https://github.com/apache/datafusion/pull/15741) (alamb) +- Refactor regexp slt tests [#15709](https://github.com/apache/datafusion/pull/15709) (kumarlokesh) +- ExecutionPlan: add APIs for filter pushdown & optimizer rule to apply them [#15566](https://github.com/apache/datafusion/pull/15566) (adriangb) +- Coerce and simplify FixedSizeBinary equality to literal binary [#15726](https://github.com/apache/datafusion/pull/15726) (leoyvens) +- Minor: simplify code in datafusion-proto [#15752](https://github.com/apache/datafusion/pull/15752) (alamb) +- chore(deps): bump clap from 4.5.35 to 4.5.36 [#15759](https://github.com/apache/datafusion/pull/15759) (dependabot[bot]) +- Support `Accumulator` for avg duration [#15468](https://github.com/apache/datafusion/pull/15468) (shruti2522) +- Show current SQL recursion limit in RecursionLimitExceeded error message [#15644](https://github.com/apache/datafusion/pull/15644) (kumarlokesh) +- Minor: fix flaky test in `aggregate.slt` [#15786](https://github.com/apache/datafusion/pull/15786) (xudong963) +- Minor: remove unused logic for limit pushdown [#15730](https://github.com/apache/datafusion/pull/15730) (zhuqi-lucas) +- chore(deps): bump sqllogictest from 0.28.0 to 0.28.1 [#15788](https://github.com/apache/datafusion/pull/15788) (dependabot[bot]) +- Add try_new for LogicalPlan::Join [#15757](https://github.com/apache/datafusion/pull/15757) (kumarlokesh) +- Minor: eliminate unnecessary struct creation in session state build [#15800](https://github.com/apache/datafusion/pull/15800) (Rachelint) +- chore(deps): bump half from 2.5.0 to 2.6.0 [#15806](https://github.com/apache/datafusion/pull/15806) (dependabot[bot]) +- Add `or_fun_call` and `unnecessary_lazy_evaluations` lints on `core` [#15807](https://github.com/apache/datafusion/pull/15807) (Rachelint) +- chore(deps): bump env_logger from 0.11.7 to 0.11.8 [#15823](https://github.com/apache/datafusion/pull/15823) (dependabot[bot]) +- Support unparsing `UNION` for distinct results [#15814](https://github.com/apache/datafusion/pull/15814) (phillipleblanc) +- Add `MemoryPool::memory_limit` to expose setting memory usage limit [#15828](https://github.com/apache/datafusion/pull/15828) (Rachelint) +- Preserve projection for inline scan [#15825](https://github.com/apache/datafusion/pull/15825) (jayzhan211) +- Minor: cleanup hash table after emit all [#15834](https://github.com/apache/datafusion/pull/15834) (jayzhan211) +- chore(deps): bump pyo3 from 0.24.1 to 0.24.2 [#15838](https://github.com/apache/datafusion/pull/15838) (dependabot[bot]) +- Minor: fix potential flaky test in aggregate.slt [#15829](https://github.com/apache/datafusion/pull/15829) (bikbov) +- Fix `ILIKE` expression support in SQL unparser [#15820](https://github.com/apache/datafusion/pull/15820) (ewgenius) +- Make `Diagnostic` easy/convinient to attach by using macro and avoiding `map_err` [#15796](https://github.com/apache/datafusion/pull/15796) (logan-keede) +- Feature/benchmark config from env [#15782](https://github.com/apache/datafusion/pull/15782) (ctsk) +- predicate pruning: support cast and try_cast for more types [#15764](https://github.com/apache/datafusion/pull/15764) (adriangb) +- Fix: fetch is missing in `plan_with_order_breaking_variants` method [#15842](https://github.com/apache/datafusion/pull/15842) (xudong963) +- Fix `CoalescePartitionsExec` proto serialization [#15824](https://github.com/apache/datafusion/pull/15824) (lewiszlw) +- Fix build failure caused by new `CoalescePartitionsExec::with_fetch` method [#15849](https://github.com/apache/datafusion/pull/15849) (lewiszlw) +- Fix ScalarValue::List comparison when the compared lists have different lengths [#15856](https://github.com/apache/datafusion/pull/15856) (gabotechs) +- chore: More details to `No UDF registered` error [#15843](https://github.com/apache/datafusion/pull/15843) (comphead) +- chore(deps): bump clap from 4.5.36 to 4.5.37 [#15853](https://github.com/apache/datafusion/pull/15853) (dependabot[bot]) +- Remove usage of `dbg!` [#15858](https://github.com/apache/datafusion/pull/15858) (phillipleblanc) +- Minor: Interval singleton [#15859](https://github.com/apache/datafusion/pull/15859) (jayzhan211) +- Make aggr fuzzer query builder more configurable [#15851](https://github.com/apache/datafusion/pull/15851) (Rachelint) +- chore(deps): bump aws-config from 1.6.1 to 1.6.2 [#15874](https://github.com/apache/datafusion/pull/15874) (dependabot[bot]) +- Add slt tests for `datafusion.execution.parquet.coerce_int96` setting [#15723](https://github.com/apache/datafusion/pull/15723) (alamb) +- Improve `ListingTable` / `ListingTableOptions` docs [#15767](https://github.com/apache/datafusion/pull/15767) (alamb) +- Migrate Optimizer tests to insta, part2 [#15884](https://github.com/apache/datafusion/pull/15884) (qstommyshu) +- Improve documentation for `FileSource`, `DataSource` and `DataSourceExec` [#15766](https://github.com/apache/datafusion/pull/15766) (alamb) +- Implement min max for dictionary types [#15827](https://github.com/apache/datafusion/pull/15827) (XiangpengHao) +- chore(deps): bump blake3 from 1.8.1 to 1.8.2 [#15890](https://github.com/apache/datafusion/pull/15890) (dependabot[bot]) +- Respect ignore_nulls in array_agg [#15544](https://github.com/apache/datafusion/pull/15544) (joroKr21) +- Set HashJoin seed [#15783](https://github.com/apache/datafusion/pull/15783) (ctsk) +- Saner handling of nulls inside arrays [#15149](https://github.com/apache/datafusion/pull/15149) (joroKr21) +- Keeping pull request in sync with the base branch [#15894](https://github.com/apache/datafusion/pull/15894) (xudong963) +- Fix `flatten` scalar function when inner list is `FixedSizeList` [#15898](https://github.com/apache/datafusion/pull/15898) (gstvg) +- support OR operator in binary `evaluate_bounds` [#15716](https://github.com/apache/datafusion/pull/15716) (davidhewitt) +- infer placeholder datatype for IN lists [#15864](https://github.com/apache/datafusion/pull/15864) (kczimm) +- Fix allow_update_branch [#15904](https://github.com/apache/datafusion/pull/15904) (xudong963) +- chore(deps): bump tokio from 1.44.1 to 1.44.2 [#15900](https://github.com/apache/datafusion/pull/15900) (dependabot[bot]) +- chore(deps): bump assert_cmd from 2.0.16 to 2.0.17 [#15909](https://github.com/apache/datafusion/pull/15909) (dependabot[bot]) +- Factor out Substrait consumers into separate files [#15794](https://github.com/apache/datafusion/pull/15794) (gabotechs) +- Unparse `UNNEST` projection with the table column alias [#15879](https://github.com/apache/datafusion/pull/15879) (goldmedal) +- Migrate Optimizer tests to insta, part3 [#15893](https://github.com/apache/datafusion/pull/15893) (qstommyshu) +- Minor: cleanup datafusion-spark scalar functions [#15921](https://github.com/apache/datafusion/pull/15921) (alamb) +- Fix ClickBench extended queries after update to APPROX_PERCENTILE_CONT [#15929](https://github.com/apache/datafusion/pull/15929) (alamb) +- Add extended query for checking improvement for blocked groups optimization [#15936](https://github.com/apache/datafusion/pull/15936) (Rachelint) +- Speedup `character_length` [#15931](https://github.com/apache/datafusion/pull/15931) (Dandandan) +- chore(deps): bump tokio-util from 0.7.14 to 0.7.15 [#15918](https://github.com/apache/datafusion/pull/15918) (dependabot[bot]) +- Migrate Optimizer tests to insta, part4 [#15937](https://github.com/apache/datafusion/pull/15937) (qstommyshu) +- fix query results for predicates referencing partition columns and data columns [#15935](https://github.com/apache/datafusion/pull/15935) (adriangb) +- chore(deps): bump substrait from 0.55.0 to 0.55.1 [#15941](https://github.com/apache/datafusion/pull/15941) (dependabot[bot]) +- Fix main CI by adding `rowsort` to slt test [#15942](https://github.com/apache/datafusion/pull/15942) (xudong963) +- Improve sqllogictest error reporting [#15905](https://github.com/apache/datafusion/pull/15905) (gabotechs) +- refactor filter pushdown apis [#15801](https://github.com/apache/datafusion/pull/15801) (adriangb) +- Add additional tests for filter pushdown apis [#15955](https://github.com/apache/datafusion/pull/15955) (adriangb) +- Improve filter pushdown optimizer rule performance [#15959](https://github.com/apache/datafusion/pull/15959) (adriangb) +- Reduce rehashing cost for primitive grouping by also reusing hash value [#15962](https://github.com/apache/datafusion/pull/15962) (Rachelint) +- chore(deps): bump chrono from 0.4.40 to 0.4.41 [#15956](https://github.com/apache/datafusion/pull/15956) (dependabot[bot]) +- refactor: replace `unwrap_or` with `unwrap_or_else` for improved lazy… [#15841](https://github.com/apache/datafusion/pull/15841) (NevroHelios) +- add benchmark code for `Reuse rows in row cursor stream` [#15913](https://github.com/apache/datafusion/pull/15913) (acking-you) +- [Update] : Removal of duplicate CI jobs [#15966](https://github.com/apache/datafusion/pull/15966) (Adez017) +- Segfault in ByteGroupValueBuilder [#15968](https://github.com/apache/datafusion/pull/15968) (thinkharderdev) +- make can_expr_be_pushed_down_with_schemas public again [#15971](https://github.com/apache/datafusion/pull/15971) (adriangb) +- re-export can_expr_be_pushed_down_with_schemas to be public [#15974](https://github.com/apache/datafusion/pull/15974) (adriangb) +- Migrate Optimizer tests to insta, part5 [#15945](https://github.com/apache/datafusion/pull/15945) (qstommyshu) +- Show LogicalType name for `INFORMATION_SCHEMA` [#15965](https://github.com/apache/datafusion/pull/15965) (goldmedal) +- chore(deps): bump sha2 from 0.10.8 to 0.10.9 [#15970](https://github.com/apache/datafusion/pull/15970) (dependabot[bot]) +- chore(deps): bump insta from 1.42.2 to 1.43.1 [#15988](https://github.com/apache/datafusion/pull/15988) (dependabot[bot]) +- [datafusion-spark] Add Spark-compatible hex function [#15947](https://github.com/apache/datafusion/pull/15947) (andygrove) +- refactor: remove deprecated `AvroExec` [#15987](https://github.com/apache/datafusion/pull/15987) (miroim) +- Substrait: Handle inner map fields in schema renaming [#15869](https://github.com/apache/datafusion/pull/15869) (cht42) +- refactor: remove deprecated `CsvExec` [#15991](https://github.com/apache/datafusion/pull/15991) (miroim) +- Migrate Optimizer tests to insta, part6 [#15984](https://github.com/apache/datafusion/pull/15984) (qstommyshu) +- chore(deps): bump nix from 0.29.0 to 0.30.1 [#16002](https://github.com/apache/datafusion/pull/16002) (dependabot[bot]) +- Implement RightSemi join for SortMergeJoin [#15972](https://github.com/apache/datafusion/pull/15972) (irenjj) +- Migrate Optimizer tests to insta, part7 [#16010](https://github.com/apache/datafusion/pull/16010) (qstommyshu) +- chore(deps): bump sysinfo from 0.34.2 to 0.35.1 [#16027](https://github.com/apache/datafusion/pull/16027) (dependabot[bot]) +- refactor: move `should_enable_page_index` from `mod.rs` to `opener.rs` [#16026](https://github.com/apache/datafusion/pull/16026) (miroim) +- chore(deps): bump sqllogictest from 0.28.1 to 0.28.2 [#16037](https://github.com/apache/datafusion/pull/16037) (dependabot[bot]) +- chores: Add lint rule to enforce string formatting style [#16024](https://github.com/apache/datafusion/pull/16024) (Lordworms) +- Use human-readable byte sizes in `EXPLAIN` [#16043](https://github.com/apache/datafusion/pull/16043) (tlm365) +- Docs: Add example of creating a field in `return_field_from_args` [#16039](https://github.com/apache/datafusion/pull/16039) (alamb) +- Support `MIN` and `MAX` for `DataType::List` [#16025](https://github.com/apache/datafusion/pull/16025) (gabotechs) +- Improve docs for Exprs and scalar functions [#16036](https://github.com/apache/datafusion/pull/16036) (alamb) +- Add h2o window benchmark [#16003](https://github.com/apache/datafusion/pull/16003) (2010YOUY01) +- Fix Infer prepare statement type tests [#15743](https://github.com/apache/datafusion/pull/15743) (brayanjuls) +- style: simplify some strings for readability [#15999](https://github.com/apache/datafusion/pull/15999) (hamirmahal) +- support simple/cross lateral joins [#16015](https://github.com/apache/datafusion/pull/16015) (jayzhan211) +- Improve error message on Out of Memory [#16050](https://github.com/apache/datafusion/pull/16050) (ding-young) +- chore(deps): bump the arrow-parquet group with 7 updates [#16047](https://github.com/apache/datafusion/pull/16047) (dependabot[bot]) +- chore(deps): bump petgraph from 0.7.1 to 0.8.1 [#15669](https://github.com/apache/datafusion/pull/15669) (dependabot[bot]) +- [datafusion-spark] Add Spark-compatible `char` expression [#15994](https://github.com/apache/datafusion/pull/15994) (andygrove) +- chore(deps): bump substrait from 0.55.1 to 0.56.0 [#16091](https://github.com/apache/datafusion/pull/16091) (dependabot[bot]) +- Add test that demonstrate behavior for `collect_statistics` [#16098](https://github.com/apache/datafusion/pull/16098) (alamb) +- Refactor substrait producer into multiple files [#16089](https://github.com/apache/datafusion/pull/16089) (gabotechs) +- Fix temp dir leak in tests [#16094](https://github.com/apache/datafusion/pull/16094) (findepi) +- Label Spark functions PRs with spark label [#16095](https://github.com/apache/datafusion/pull/16095) (findepi) +- Added SLT tests for IMDB benchmark queries [#16067](https://github.com/apache/datafusion/pull/16067) (kumarlokesh) +- chore(CI) Upgrade toolchain to Rust-1.87 [#16068](https://github.com/apache/datafusion/pull/16068) (kadai0308) +- minor: Add benchmark query and corresponding documentation for Average Duration [#16105](https://github.com/apache/datafusion/pull/16105) (logan-keede) +- Use qualified names on DELETE selections [#16033](https://github.com/apache/datafusion/pull/16033) (nuno-faria) +- chore(deps): bump testcontainers from 0.23.3 to 0.24.0 [#15989](https://github.com/apache/datafusion/pull/15989) (dependabot[bot]) +- Clean up ExternalSorter and use upstream kernel [#16109](https://github.com/apache/datafusion/pull/16109) (alamb) +- Test Duration in aggregation `fuzz` tests [#16111](https://github.com/apache/datafusion/pull/16111) (alamb) +- Move PruningStatistics into datafusion::common [#16069](https://github.com/apache/datafusion/pull/16069) (adriangb) +- Revert use file schema in parquet pruning [#16086](https://github.com/apache/datafusion/pull/16086) (adriangb) +- Minor: Add `ScalarFunctionArgs::return_type` method [#16113](https://github.com/apache/datafusion/pull/16113) (alamb) +- Fix `contains` function expression [#16046](https://github.com/apache/datafusion/pull/16046) (liamzwbao) +- chore: Use materialized data for filter pushdown tests [#16123](https://github.com/apache/datafusion/pull/16123) (comphead) +- chore: Upgrade rand crate and some other minor crates [#16062](https://github.com/apache/datafusion/pull/16062) (comphead) +- Include data types in logical plans of inferred prepare statements [#16019](https://github.com/apache/datafusion/pull/16019) (brayanjuls) +- CI: Fix extended test failure [#16144](https://github.com/apache/datafusion/pull/16144) (2010YOUY01) +- Fix: handle column name collisions when combining UNION logical inputs & nested Column expressions in maybe_fix_physical_column_name [#16064](https://github.com/apache/datafusion/pull/16064) (LiaCastaneda) +- adding support for Min/Max over LargeList and FixedSizeList [#16071](https://github.com/apache/datafusion/pull/16071) (logan-keede) +- Move prepare/parameter handling tests into `params.rs` [#16141](https://github.com/apache/datafusion/pull/16141) (liamzwbao) +- Minor: Add `Accumulator::return_type` and `StateFieldsArgs::return_type` to help with upgrade to 48 [#16112](https://github.com/apache/datafusion/pull/16112) (alamb) +- Support filtering specific sqllogictests identified by line number [#16029](https://github.com/apache/datafusion/pull/16029) (gabotechs) +- Enrich GroupedHashAggregateStream name to ease debugging Resources exhausted errors [#16152](https://github.com/apache/datafusion/pull/16152) (ahmed-mez) +- chore(deps): bump uuid from 1.16.0 to 1.17.0 [#16162](https://github.com/apache/datafusion/pull/16162) (dependabot[bot]) +- Clarify docs and names in parquet predicate pushdown tests [#16155](https://github.com/apache/datafusion/pull/16155) (alamb) +- Minor: Fix name() for FilterPushdown physical optimizer rule [#16175](https://github.com/apache/datafusion/pull/16175) (adriangb) +- migrate tests in `pool.rs` to use insta [#16145](https://github.com/apache/datafusion/pull/16145) (lifan-ake) +- refactor(optimizer): Add support for dynamically adding test tables [#16138](https://github.com/apache/datafusion/pull/16138) (atahanyorganci) +- [Minor] Speedup TPC-H benchmark run with memtable option [#16159](https://github.com/apache/datafusion/pull/16159) (Dandandan) +- Fast path for joins with distinct values in build side [#16153](https://github.com/apache/datafusion/pull/16153) (Dandandan) +- chore: Reduce repetition in the parameter type inference tests [#16079](https://github.com/apache/datafusion/pull/16079) (jsai28) +- chore(deps): bump tokio from 1.45.0 to 1.45.1 [#16190](https://github.com/apache/datafusion/pull/16190) (dependabot[bot]) +- Improve `unproject_sort_expr` to handle arbitrary expressions [#16127](https://github.com/apache/datafusion/pull/16127) (phillipleblanc) +- chore(deps): bump rustyline from 15.0.0 to 16.0.0 [#16194](https://github.com/apache/datafusion/pull/16194) (dependabot[bot]) +- migrate `logical_plan` tests to insta [#16184](https://github.com/apache/datafusion/pull/16184) (lifan-ake) +- chore(deps): bump clap from 4.5.38 to 4.5.39 [#16204](https://github.com/apache/datafusion/pull/16204) (dependabot[bot]) +- implement `AggregateExec.partition_statistics` [#15954](https://github.com/apache/datafusion/pull/15954) (UBarney) +- Propagate .execute() calls immediately in `RepartitionExec` [#16093](https://github.com/apache/datafusion/pull/16093) (gabotechs) +- Set aggregation hash seed [#16165](https://github.com/apache/datafusion/pull/16165) (ctsk) +- Fix ScalarStructBuilder::build() for an empty struct [#16205](https://github.com/apache/datafusion/pull/16205) (Blizzara) +- Return an error on overflow in `do_append_val_inner` [#16201](https://github.com/apache/datafusion/pull/16201) (liamzwbao) +- chore(deps): bump testcontainers-modules from 0.12.0 to 0.12.1 [#16212](https://github.com/apache/datafusion/pull/16212) (dependabot[bot]) +- Substrait: handle identical grouping expressions [#16189](https://github.com/apache/datafusion/pull/16189) (cht42) +- Add new stats pruning helpers to allow combining partition values in file level stats [#16139](https://github.com/apache/datafusion/pull/16139) (adriangb) +- Implement schema adapter support for FileSource and add integration tests [#16148](https://github.com/apache/datafusion/pull/16148) (kosiew) +- Minor: update documentation for PrunableStatistics [#16213](https://github.com/apache/datafusion/pull/16213) (alamb) +- Remove use of deprecated dict_ordered in datafusion-proto (#16218) [#16220](https://github.com/apache/datafusion/pull/16220) (cj-zhukov) +- Minor: Print cargo command in bench script [#16236](https://github.com/apache/datafusion/pull/16236) (2010YOUY01) +- Simplify FileSource / SchemaAdapterFactory API [#16214](https://github.com/apache/datafusion/pull/16214) (alamb) +- Add dicts to aggregation fuzz testing [#16232](https://github.com/apache/datafusion/pull/16232) (blaginin) +- chore(deps): bump sysinfo from 0.35.1 to 0.35.2 [#16247](https://github.com/apache/datafusion/pull/16247) (dependabot[bot]) +- Improve performance of constant aggregate window expression [#16234](https://github.com/apache/datafusion/pull/16234) (suibianwanwank) +- Support compound identifier when parsing tuples [#16225](https://github.com/apache/datafusion/pull/16225) (hozan23) +- Schema adapter helper [#16108](https://github.com/apache/datafusion/pull/16108) (kosiew) +- Update tpch, clickbench, sort_tpch to mark failed queries [#16182](https://github.com/apache/datafusion/pull/16182) (ding-young) +- Adjust slttest to pass without RUST_BACKTRACE enabled [#16251](https://github.com/apache/datafusion/pull/16251) (alamb) +- Handle dicts for distinct count [#15871](https://github.com/apache/datafusion/pull/15871) (blaginin) +- Add `--substrait-round-trip` option in sqllogictests [#16183](https://github.com/apache/datafusion/pull/16183) (gabotechs) +- Minor: fix upgrade papercut `pub use PruningStatistics` [#16264](https://github.com/apache/datafusion/pull/16264) (alamb) +- chore: update DF48 changelog [#16269](https://github.com/apache/datafusion/pull/16269) (xudong963) + +## Credits + +Thank you to everyone who contributed to this release. Here is a breakdown of commits (PRs merged) per contributor. + +``` + 30 dependabot[bot] + 29 Andrew Lamb + 17 xudong.w + 14 Adrian Garcia Badaracco + 10 Chen Chongchen + 8 Gabriel + 8 Oleks V + 7 miro + 6 Tommy shu + 6 kamille + 5 Lokesh + 5 Tim Saucer + 4 Dmitrii Blaginin + 4 Jay Zhan + 4 Nuno Faria + 4 Yongting You + 4 logan-keede + 3 Andy Grove + 3 Christian + 3 Daniël Heres + 3 Liam Bao + 3 Phillip LeBlanc + 3 Piotr Findeisen + 3 ding-young + 2 Atahan Yorgancı + 2 Brayan Jules + 2 Georgi Krastev + 2 Jax Liu + 2 Jérémie Drouet + 2 LB7666 + 2 Leonardo Yvens + 2 Qi Zhu + 2 Sergey Zhukov + 2 Shruti Sharma + 2 Tai Le Manh + 2 aditya singh rathore + 2 ake + 2 cht42 + 2 gstvg + 2 kosiew + 2 niebayes + 2 张林伟 + 1 Ahmed Mezghani + 1 Alexander Droste + 1 Andy Yen + 1 Arka Dash + 1 Arttu + 1 Dan Harris + 1 David Hewitt + 1 Davy + 1 Ed Seidl + 1 Eshed Schacham + 1 Evgenii Khramkov + 1 Florent Monjalet + 1 Galim Bikbov + 1 Garam Choi + 1 Hamir Mahal + 1 Hendrik Makait + 1 Jonathan Chen + 1 Joseph Fahnestock + 1 Kevin Zimmerman + 1 Lordworms + 1 Lía Adriana + 1 Matt Butrovich + 1 Namgung Chan + 1 Nelson Antunes + 1 Patrick Sullivan + 1 Raz Luvaton + 1 Ruihang Xia + 1 Ryan Roelke + 1 Sam Hughes + 1 Shehab Amin + 1 Sile Zhou + 1 Simon Vandel Sillesen + 1 Tom Montgomery + 1 UBarney + 1 Victorien + 1 Xiangpeng Hao + 1 Zaki + 1 chen quan + 1 delamarch3 + 1 discord9 + 1 hozan23 + 1 irenjj + 1 jsai28 + 1 m09526 + 1 suibianwanwan + 1 the0ninjas + 1 wiedld +``` + +Thank you also to everyone who contributed in other ways such as filing issues, reviewing PRs, and providing feedback on this release. diff --git a/dev/changelog/48.0.1.md b/dev/changelog/48.0.1.md new file mode 100644 index 000000000000..dcd4cc9c1547 --- /dev/null +++ b/dev/changelog/48.0.1.md @@ -0,0 +1,41 @@ + + +# Apache DataFusion 48.0.1 Changelog + +This release consists of 3 commits from 2 contributors. See credits at the end of this changelog for more information. + +See the [upgrade guide](https://datafusion.apache.org/library-user-guide/upgrading.html) for information on how to upgrade from previous versions. + +**Bug Fixes:** + +- [branch-48] Set the default value of datafusion.execution.collect_statistics to true #16447 [#16659](https://github.com/apache/datafusion/pull/16659) (blaginin) +- [branch-48] Fix parquet filter_pushdown: respect parquet filter pushdown config i… [#16656](https://github.com/apache/datafusion/pull/16656) (alamb) +- [branch-48] fix: column indices in FFI partition evaluator (#16480) [#16657](https://github.com/apache/datafusion/pull/16657) (alamb) + +## Credits + +Thank you to everyone who contributed to this release. Here is a breakdown of commits (PRs merged) per contributor. + +``` + 2 Andrew Lamb + 1 Dmitrii Blaginin +``` + +Thank you also to everyone who contributed in other ways such as filing issues, reviewing PRs, and providing feedback on this release. diff --git a/dev/release/generate-changelog.py b/dev/release/generate-changelog.py index 1349416bcaa5..830d329f73c4 100755 --- a/dev/release/generate-changelog.py +++ b/dev/release/generate-changelog.py @@ -124,6 +124,9 @@ def generate_changelog(repo, repo_name, tag1, tag2, version): print(f"This release consists of {commit_count} commits from {contributor_count} contributors. " f"See credits at the end of this changelog for more information.\n") + print("See the [upgrade guide](https://datafusion.apache.org/library-user-guide/upgrading.html) " + "for information on how to upgrade from previous versions.\n") + print_pulls(repo_name, "Breaking changes", breaking) print_pulls(repo_name, "Performance related", performance) print_pulls(repo_name, "Implemented enhancements", enhancements) diff --git a/dev/update_runtime_config_docs.sh b/dev/update_runtime_config_docs.sh new file mode 100755 index 000000000000..0d9d0f103323 --- /dev/null +++ b/dev/update_runtime_config_docs.sh @@ -0,0 +1,76 @@ +#!/bin/bash +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +set -e + +SOURCE_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "${SOURCE_DIR}/../" && pwd + +TARGET_FILE="docs/source/user-guide/runtime_configs.md" +PRINT_CONFIG_DOCS_COMMAND="cargo run --manifest-path datafusion/core/Cargo.toml --bin print_runtime_config_docs" + +echo "Inserting header" +cat <<'EOF' > "$TARGET_FILE" + + + + +# Runtime Environment Configurations + +DataFusion runtime configurations can be set via SQL using the `SET` command. + +For example, to configure `datafusion.runtime.memory_limit`: + +```sql +SET datafusion.runtime.memory_limit = '2G'; +``` + +The following runtime configuration settings are available: + +EOF + +echo "Running CLI and inserting runtime config docs table" +$PRINT_CONFIG_DOCS_COMMAND >> "$TARGET_FILE" + +echo "Running prettier" +npx prettier@2.3.2 --write "$TARGET_FILE" + +echo "'$TARGET_FILE' successfully updated!" diff --git a/docs/source/index.rst b/docs/source/index.rst index 0dc947fdea57..e920a0f036cb 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -116,6 +116,7 @@ To get started, see user-guide/expressions user-guide/sql/index user-guide/configs + user-guide/runtime_configs user-guide/explain-usage user-guide/faq diff --git a/docs/source/library-user-guide/adding-udfs.md b/docs/source/library-user-guide/adding-udfs.md index 8fb8a59fb860..cd40e664239a 100644 --- a/docs/source/library-user-guide/adding-udfs.md +++ b/docs/source/library-user-guide/adding-udfs.md @@ -1076,7 +1076,7 @@ pub struct EchoFunction {} impl TableFunctionImpl for EchoFunction { fn call(&self, exprs: &[Expr]) -> Result> { - let Some(Expr::Literal(ScalarValue::Int64(Some(value)))) = exprs.get(0) else { + let Some(Expr::Literal(ScalarValue::Int64(Some(value)), _)) = exprs.get(0) else { return plan_err!("First argument must be an integer"); }; @@ -1117,7 +1117,7 @@ With the UDTF implemented, you can register it with the `SessionContext`: # # impl TableFunctionImpl for EchoFunction { # fn call(&self, exprs: &[Expr]) -> Result> { -# let Some(Expr::Literal(ScalarValue::Int64(Some(value)))) = exprs.get(0) else { +# let Some(Expr::Literal(ScalarValue::Int64(Some(value)), _)) = exprs.get(0) else { # return plan_err!("First argument must be an integer"); # }; # diff --git a/docs/source/library-user-guide/catalogs.md b/docs/source/library-user-guide/catalogs.md index 906039ba2300..d4e6633d40ba 100644 --- a/docs/source/library-user-guide/catalogs.md +++ b/docs/source/library-user-guide/catalogs.md @@ -23,11 +23,14 @@ This section describes how to create and manage catalogs, schemas, and tables in ## General Concepts -CatalogProviderList, Catalogs, schemas, and tables are organized in a hierarchy. A CatalogProviderList contains catalog providers, a catalog provider contains schemas and a schema contains tables. +Catalog providers, catalogs, schemas, and tables are organized in a hierarchy. A `CatalogProviderList` contains `CatalogProvider`s, a `CatalogProvider` contains `SchemaProviders` and a `SchemaProvider` contains `TableProvider`s. DataFusion comes with a basic in memory catalog functionality in the [`catalog` module]. You can use these in memory implementations as is, or extend DataFusion with your own catalog implementations, for example based on local files or files on remote object storage. +DataFusion supports DDL queries (e.g. `CREATE TABLE`) using the catalog API described in this section. See the [TableProvider] section for information on DML queries (e.g. `INSERT INTO`). + [`catalog` module]: https://docs.rs/datafusion/latest/datafusion/catalog/index.html +[tableprovider]: ./custom-table-providers.md Similarly to other concepts in DataFusion, you'll implement various traits to create your own catalogs, schemas, and tables. The following sections describe the traits you'll need to implement. diff --git a/docs/source/library-user-guide/custom-table-providers.md b/docs/source/library-user-guide/custom-table-providers.md index 886ac9629566..54f79a421823 100644 --- a/docs/source/library-user-guide/custom-table-providers.md +++ b/docs/source/library-user-guide/custom-table-providers.md @@ -19,17 +19,22 @@ # Custom Table Provider -Like other areas of DataFusion, you extend DataFusion's functionality by implementing a trait. The `TableProvider` and associated traits, have methods that allow you to implement a custom table provider, i.e. use DataFusion's other functionality with your custom data source. +Like other areas of DataFusion, you extend DataFusion's functionality by implementing a trait. The [`TableProvider`] and associated traits allow you to implement a custom table provider, i.e. use DataFusion's other functionality with your custom data source. -This section will also touch on how to have DataFusion use the new `TableProvider` implementation. +This section describes how to create a [`TableProvider`] and how to configure DataFusion to use it for reading. ## Table Provider and Scan -The `scan` method on the `TableProvider` is likely its most important. It returns an `ExecutionPlan` that DataFusion will use to read the actual data during execution of the query. +The [`TableProvider::scan`] method reads data from the table and is likely the most important. It returns an [`ExecutionPlan`] that DataFusion will use to read the actual data during execution of the query. The [`TableProvider::insert_into`] method is used to `INSERT` data into the table. ### Scan -As mentioned, `scan` returns an execution plan, and in particular a `Result>`. The core of this is returning something that can be dynamically dispatched to an `ExecutionPlan`. And as per the general DataFusion idea, we'll need to implement it. +As mentioned, [`TableProvider::scan`] returns an execution plan, and in particular a `Result>`. The core of this is returning something that can be dynamically dispatched to an `ExecutionPlan`. And as per the general DataFusion idea, we'll need to implement it. + +[`tableprovider`]: https://docs.rs/datafusion/latest/datafusion/datasource/trait.TableProvider.html +[`tableprovider::scan`]: https://docs.rs/datafusion/latest/datafusion/datasource/trait.TableProvider.html#tymethod.scan +[`tableprovider::insert_into`]: https://docs.rs/datafusion/latest/datafusion/datasource/trait.TableProvider.html#tymethod.insert_into +[`executionplan`]: https://docs.rs/datafusion/latest/datafusion/physical_plan/trait.ExecutionPlan.html #### Execution Plan diff --git a/docs/source/library-user-guide/extending-operators.md b/docs/source/library-user-guide/extending-operators.md index 631bdc67975a..3d491806a4e6 100644 --- a/docs/source/library-user-guide/extending-operators.md +++ b/docs/source/library-user-guide/extending-operators.md @@ -19,4 +19,41 @@ # Extending DataFusion's operators: custom LogicalPlan and Execution Plans -Coming soon +DataFusion supports extension of operators by transforming logical plan and execution plan through customized [optimizer rules](https://docs.rs/datafusion/latest/datafusion/optimizer/trait.OptimizerRule.html). This section will use the µWheel project to illustrate such capabilities. + +## About DataFusion µWheel + +[DataFusion µWheel](https://github.com/uwheel/datafusion-uwheel/tree/main) is a native DataFusion optimizer which improves query performance for time-based analytics through fast temporal aggregation and pruning using custom indices. The integration of µWheel into DataFusion is a joint effort with the DataFusion community. + +### Optimizing Logical Plan + +The `rewrite` function transforms logical plans by identifying temporal patterns and aggregation functions that match the stored wheel indices. When match is found, it queries the corresponding index to retrieve pre-computed aggregate values, stores these results in a [MemTable](https://docs.rs/datafusion/latest/datafusion/datasource/memory/struct.MemTable.html), and returns as a new `LogicalPlan::TableScan`. If no match is found, the original plan proceeds unchanged through DataFusion's standard execution path. + +```rust,ignore +fn rewrite( + &self, + plan: LogicalPlan, + _config: &dyn OptimizerConfig, +) -> Result> { + // Attemps to rewrite a logical plan to a uwheel-based plan that either provides + // plan-time aggregates or skips execution based on min/max pruning. + if let Some(rewritten) = self.try_rewrite(&plan) { + Ok(Transformed::yes(rewritten)) + } else { + Ok(Transformed::no(plan)) + } +} +``` + +```rust,ignore +// Converts a uwheel aggregate result to a TableScan with a MemTable as source +fn agg_to_table_scan(result: f64, schema: SchemaRef) -> Result { + let data = Float64Array::from(vec![result]); + let record_batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(data)])?; + let df_schema = Arc::new(DFSchema::try_from(schema.clone())?); + let mem_table = MemTable::try_new(schema, vec![vec![record_batch]])?; + mem_table_as_table_scan(mem_table, df_schema) +} +``` + +To get a deeper dive into the usage of the µWheel project, visit the [blog post](https://uwheel.rs/post/datafusion_uwheel/) by Max Meldrum. diff --git a/docs/source/library-user-guide/query-optimizer.md b/docs/source/library-user-guide/query-optimizer.md index 03cd7b5bbbbe..a1ccd0a15a7e 100644 --- a/docs/source/library-user-guide/query-optimizer.md +++ b/docs/source/library-user-guide/query-optimizer.md @@ -193,7 +193,7 @@ Looking at the `EXPLAIN` output we can see that the optimizer has effectively re `3 as "1 + 2"`: ```text -> explain select 1 + 2; +> explain format indent select 1 + 2; +---------------+-------------------------------------------------+ | plan_type | plan | +---------------+-------------------------------------------------+ diff --git a/docs/source/library-user-guide/upgrading.md b/docs/source/library-user-guide/upgrading.md index 11fd49566522..3922e0d45d88 100644 --- a/docs/source/library-user-guide/upgrading.md +++ b/docs/source/library-user-guide/upgrading.md @@ -19,6 +19,249 @@ # Upgrade Guides +## DataFusion `48.0.0` + +### The `VARCHAR` SQL type is now represented as `Utf8View` in Arrow. + +The mapping of the SQL `VARCHAR` type has been changed from `Utf8` to `Utf8View` +which improves performance for many string operations. You can read more about +`Utf8View` in the [DataFusion blog post on German-style strings] + +[datafusion blog post on german-style strings]: https://datafusion.apache.org/blog/2024/09/13/string-view-german-style-strings-part-1/ + +This means that when you create a table with a `VARCHAR` column, it will now use +`Utf8View` as the underlying data type. For example: + +```sql +> CREATE TABLE my_table (my_column VARCHAR); +0 row(s) fetched. +Elapsed 0.001 seconds. + +> DESCRIBE my_table; ++-------------+-----------+-------------+ +| column_name | data_type | is_nullable | ++-------------+-----------+-------------+ +| my_column | Utf8View | YES | ++-------------+-----------+-------------+ +1 row(s) fetched. +Elapsed 0.000 seconds. +``` + +You can restore the old behavior of using `Utf8` by changing the +`datafusion.sql_parser.map_varchar_to_utf8view` configuration setting. For +example + +```sql +> set datafusion.sql_parser.map_varchar_to_utf8view = false; +0 row(s) fetched. +Elapsed 0.001 seconds. + +> CREATE TABLE my_table (my_column VARCHAR); +0 row(s) fetched. +Elapsed 0.014 seconds. + +> DESCRIBE my_table; ++-------------+-----------+-------------+ +| column_name | data_type | is_nullable | ++-------------+-----------+-------------+ +| my_column | Utf8 | YES | ++-------------+-----------+-------------+ +1 row(s) fetched. +Elapsed 0.004 seconds. +``` + +### `ListingOptions` default for `collect_stat` changed from `true` to `false` + +This makes it agree with the default for `SessionConfig`. +Most users won't be impacted by this change but if you were using `ListingOptions` directly +and relied on the default value of `collect_stat` being `true`, you will need to +explicitly set it to `true` in your code. + +```rust +# /* comment to avoid running +ListingOptions::new(Arc::new(ParquetFormat::default())) + .with_collect_stat(true) + // other options +# */ +``` + +### Processing `FieldRef` instead of `DataType` for user defined functions + +In order to support metadata handling and extension types, user defined functions are +now switching to traits which use `FieldRef` rather than a `DataType` and nullability. +This gives a single interface to both of these parameters and additionally allows +access to metadata fields, which can be used for extension types. + +To upgrade structs which implement `ScalarUDFImpl`, if you have implemented +`return_type_from_args` you need instead to implement `return_field_from_args`. +If your functions do not need to handle metadata, this should be straightforward +repackaging of the output data into a `FieldRef`. The name you specify on the +field is not important. It will be overwritten during planning. `ReturnInfo` +has been removed, so you will need to remove all references to it. + +`ScalarFunctionArgs` now contains a field called `arg_fields`. You can use this +to access the metadata associated with the columnar values during invocation. + +To upgrade user defined aggregate functions, there is now a function +`return_field` that will allow you to specify both metadata and nullability of +your function. You are not required to implement this if you do not need to +handle metatdata. + +The largest change to aggregate functions happens in the accumulator arguments. +Both the `AccumulatorArgs` and `StateFieldsArgs` now contain `FieldRef` rather +than `DataType`. + +To upgrade window functions, `ExpressionArgs` now contains input fields instead +of input data types. When setting these fields, the name of the field is +not important since this gets overwritten during the planning stage. All you +should need to do is wrap your existing data types in fields with nullability +set depending on your use case. + +### Physical Expression return `Field` + +To support the changes to user defined functions processing metadata, the +`PhysicalExpr` trait, which now must specify a return `Field` based on the input +schema. To upgrade structs which implement `PhysicalExpr` you need to implement +the `return_field` function. There are numerous examples in the `physical-expr` +crate. + +### `FileFormat::supports_filters_pushdown` replaced with `FileSource::try_pushdown_filters` + +To support more general filter pushdown, the `FileFormat::supports_filters_pushdown` was replaced with +`FileSource::try_pushdown_filters`. +If you implemented a custom `FileFormat` that uses a custom `FileSource` you will need to implement +`FileSource::try_pushdown_filters`. +See `ParquetSource::try_pushdown_filters` for an example of how to implement this. + +`FileFormat::supports_filters_pushdown` has been removed. + +### `ParquetExec`, `AvroExec`, `CsvExec`, `JsonExec` Removed + +`ParquetExec`, `AvroExec`, `CsvExec`, and `JsonExec` were deprecated in +DataFusion 46 and are removed in DataFusion 48. This is sooner than the normal +process described in the [API Deprecation Guidelines] because all the tests +cover the new `DataSourceExec` rather than the older structures. As we evolve +`DataSource`, the old structures began to show signs of "bit rotting" (not +working but no one knows due to lack of test coverage). + +[api deprecation guidelines]: https://datafusion.apache.org/contributor-guide/api-health.html#deprecation-guidelines + +## DataFusion `47.0.0` + +This section calls out some of the major changes in the `47.0.0` release of DataFusion. + +Here are some example upgrade PRs that demonstrate changes required when upgrading from DataFusion 46.0.0: + +- [delta-rs Upgrade to `47.0.0`](https://github.com/delta-io/delta-rs/pull/3378) +- [DataFusion Comet Upgrade to `47.0.0`](https://github.com/apache/datafusion-comet/pull/1563) +- [Sail Upgrade to `47.0.0`](https://github.com/lakehq/sail/pull/434) + +### Upgrades to `arrow-rs` and `arrow-parquet` 55.0.0 and `object_store` 0.12.0 + +Several APIs are changed in the underlying arrow and parquet libraries to use a +`u64` instead of `usize` to better support WASM (See [#7371] and [#6961]) + +Additionally `ObjectStore::list` and `ObjectStore::list_with_offset` have been changed to return `static` lifetimes (See [#6619]) + +[#6619]: https://github.com/apache/arrow-rs/pull/6619 +[#7371]: https://github.com/apache/arrow-rs/pull/7371 +[#7328]: https://github.com/apache/arrow-rs/pull/6961 + +This requires converting from `usize` to `u64` occasionally as well as changes to `ObjectStore` implementations such as + +```rust +# /* comment to avoid running +impl Objectstore { + ... + // The range is now a u64 instead of usize + async fn get_range(&self, location: &Path, range: Range) -> ObjectStoreResult { + self.inner.get_range(location, range).await + } + ... + // the lifetime is now 'static instead of `_ (meaning the captured closure can't contain references) + // (this also applies to list_with_offset) + fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, ObjectStoreResult> { + self.inner.list(prefix) + } +} +# */ +``` + +The `ParquetObjectReader` has been updated to no longer require the object size +(it can be fetched using a single suffix request). See [#7334] for details + +[#7334]: https://github.com/apache/arrow-rs/pull/7334 + +Pattern in DataFusion `46.0.0`: + +```rust +# /* comment to avoid running +let meta: ObjectMeta = ...; +let reader = ParquetObjectReader::new(store, meta); +# */ +``` + +Pattern in DataFusion `47.0.0`: + +```rust +# /* comment to avoid running +let meta: ObjectMeta = ...; +let reader = ParquetObjectReader::new(store, location) + .with_file_size(meta.size); +# */ +``` + +### `DisplayFormatType::TreeRender` + +DataFusion now supports [`tree` style explain plans]. Implementations of +`Executionplan` must also provide a description in the +`DisplayFormatType::TreeRender` format. This can be the same as the existing +`DisplayFormatType::Default`. + +[`tree` style explain plans]: https://datafusion.apache.org/user-guide/sql/explain.html#tree-format-default + +### Removed Deprecated APIs + +Several APIs have been removed in this release. These were either deprecated +previously or were hard to use correctly such as the multiple different +`ScalarUDFImpl::invoke*` APIs. See [#15130], [#15123], and [#15027] for more +details. + +[#15130]: https://github.com/apache/datafusion/pull/15130 +[#15123]: https://github.com/apache/datafusion/pull/15123 +[#15027]: https://github.com/apache/datafusion/pull/15027 + +### `FileScanConfig` --> `FileScanConfigBuilder` + +Previously, `FileScanConfig::build()` directly created ExecutionPlans. In +DataFusion 47.0.0 this has been changed to use `FileScanConfigBuilder`. See +[#15352] for details. + +[#15352]: https://github.com/apache/datafusion/pull/15352 + +Pattern in DataFusion `46.0.0`: + +```rust +# /* comment to avoid running +let plan = FileScanConfig::new(url, schema, Arc::new(file_source)) + .with_statistics(stats) + ... + .build() +# */ +``` + +Pattern in DataFusion `47.0.0`: + +```rust +# /* comment to avoid running +let config = FileScanConfigBuilder::new(url, schema, Arc::new(file_source)) + .with_statistics(stats) + ... + .build(); +let scan = DataSourceExec::from_data_source(config); +# */ +``` + ## DataFusion `46.0.0` ### Use `invoke_with_args` instead of `invoke()` and `invoke_batch()` @@ -39,7 +282,7 @@ below. See [PR 14876] for an example. Given existing code like this: ```rust -# /* +# /* comment to avoid running impl ScalarUDFImpl for SparkConcat { ... fn invoke_batch(&self, args: &[ColumnarValue], number_rows: usize) -> Result { @@ -59,7 +302,7 @@ impl ScalarUDFImpl for SparkConcat { To ```rust -# /* comment out so they don't run +# /* comment to avoid running impl ScalarUDFImpl for SparkConcat { ... fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { @@ -164,7 +407,7 @@ let mut file_source = ParquetSource::new(parquet_options) // Add filter if let Some(predicate) = logical_filter { if config.enable_parquet_pushdown { - file_source = file_source.with_predicate(Arc::clone(&file_schema), predicate); + file_source = file_source.with_predicate(predicate); } }; diff --git a/docs/source/user-guide/cli/usage.md b/docs/source/user-guide/cli/usage.md index 68b09d319984..13f0e7cff175 100644 --- a/docs/source/user-guide/cli/usage.md +++ b/docs/source/user-guide/cli/usage.md @@ -57,6 +57,9 @@ OPTIONS: --mem-pool-type Specify the memory pool type 'greedy' or 'fair', default to 'greedy' + --top-memory-consumers + The number of top memory consumers to display when query fails due to memory exhaustion. To disable memory consumer tracking, set this value to 0 [default: 3] + -d, --disk-limit Available disk space for spilling queries (e.g. '10g'), default to None (uses DataFusion's default value of '100g') diff --git a/docs/source/user-guide/configs.md b/docs/source/user-guide/configs.md index 7a46d59d893e..05cc36651a1a 100644 --- a/docs/source/user-guide/configs.md +++ b/docs/source/user-guide/configs.md @@ -35,100 +35,110 @@ Values are parsed according to the [same rules used in casts from Utf8](https:// If the value in the environment variable cannot be cast to the type of the configuration option, the default value will be used instead and a warning emitted. Environment variables are read during `SessionConfig` initialisation so they must be set beforehand and will not affect running sessions. -| key | default | description | -| ----------------------------------------------------------------------- | ------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | -| datafusion.catalog.create_default_catalog_and_schema | true | Whether the default catalog and schema should be created automatically. | -| datafusion.catalog.default_catalog | datafusion | The default catalog name - this impacts what SQL queries use if not specified | -| datafusion.catalog.default_schema | public | The default schema name - this impacts what SQL queries use if not specified | -| datafusion.catalog.information_schema | false | Should DataFusion provide access to `information_schema` virtual tables for displaying schema information | -| datafusion.catalog.location | NULL | Location scanned to load tables for `default` schema | -| datafusion.catalog.format | NULL | Type of `TableProvider` to use when loading `default` schema | -| datafusion.catalog.has_header | true | Default value for `format.has_header` for `CREATE EXTERNAL TABLE` if not specified explicitly in the statement. | -| datafusion.catalog.newlines_in_values | false | Specifies whether newlines in (quoted) CSV values are supported. This is the default value for `format.newlines_in_values` for `CREATE EXTERNAL TABLE` if not specified explicitly in the statement. Parsing newlines in quoted values may be affected by execution behaviour such as parallel file scanning. Setting this to `true` ensures that newlines in values are parsed successfully, which may reduce performance. | -| datafusion.execution.batch_size | 8192 | Default batch size while creating new batches, it's especially useful for buffer-in-memory batches since creating tiny batches would result in too much metadata memory consumption | -| datafusion.execution.coalesce_batches | true | When set to true, record batches will be examined between each operator and small batches will be coalesced into larger batches. This is helpful when there are highly selective filters or joins that could produce tiny output batches. The target batch size is determined by the configuration setting | -| datafusion.execution.collect_statistics | false | Should DataFusion collect statistics after listing files | -| datafusion.execution.target_partitions | 0 | Number of partitions for query execution. Increasing partitions can increase concurrency. Defaults to the number of CPU cores on the system | -| datafusion.execution.time_zone | +00:00 | The default time zone Some functions, e.g. `EXTRACT(HOUR from SOME_TIME)`, shift the underlying datetime according to this time zone, and then extract the hour | -| datafusion.execution.parquet.enable_page_index | true | (reading) If true, reads the Parquet data page level metadata (the Page Index), if present, to reduce the I/O and number of rows decoded. | -| datafusion.execution.parquet.pruning | true | (reading) If true, the parquet reader attempts to skip entire row groups based on the predicate in the query and the metadata (min/max values) stored in the parquet file | -| datafusion.execution.parquet.skip_metadata | true | (reading) If true, the parquet reader skip the optional embedded metadata that may be in the file Schema. This setting can help avoid schema conflicts when querying multiple parquet files with schemas containing compatible types but different metadata | -| datafusion.execution.parquet.metadata_size_hint | NULL | (reading) If specified, the parquet reader will try and fetch the last `size_hint` bytes of the parquet file optimistically. If not specified, two reads are required: One read to fetch the 8-byte parquet footer and another to fetch the metadata length encoded in the footer | -| datafusion.execution.parquet.pushdown_filters | false | (reading) If true, filter expressions are be applied during the parquet decoding operation to reduce the number of rows decoded. This optimization is sometimes called "late materialization". | -| datafusion.execution.parquet.reorder_filters | false | (reading) If true, filter expressions evaluated during the parquet decoding operation will be reordered heuristically to minimize the cost of evaluation. If false, the filters are applied in the same order as written in the query | -| datafusion.execution.parquet.schema_force_view_types | true | (reading) If true, parquet reader will read columns of `Utf8/Utf8Large` with `Utf8View`, and `Binary/BinaryLarge` with `BinaryView`. | -| datafusion.execution.parquet.binary_as_string | false | (reading) If true, parquet reader will read columns of `Binary/LargeBinary` with `Utf8`, and `BinaryView` with `Utf8View`. Parquet files generated by some legacy writers do not correctly set the UTF8 flag for strings, causing string columns to be loaded as BLOB instead. | -| datafusion.execution.parquet.coerce_int96 | NULL | (reading) If true, parquet reader will read columns of physical type int96 as originating from a different resolution than nanosecond. This is useful for reading data from systems like Spark which stores microsecond resolution timestamps in an int96 allowing it to write values with a larger date range than 64-bit timestamps with nanosecond resolution. | -| datafusion.execution.parquet.data_pagesize_limit | 1048576 | (writing) Sets best effort maximum size of data page in bytes | -| datafusion.execution.parquet.write_batch_size | 1024 | (writing) Sets write_batch_size in bytes | -| datafusion.execution.parquet.writer_version | 1.0 | (writing) Sets parquet writer version valid values are "1.0" and "2.0" | -| datafusion.execution.parquet.skip_arrow_metadata | false | (writing) Skip encoding the embedded arrow metadata in the KV_meta This is analogous to the `ArrowWriterOptions::with_skip_arrow_metadata`. Refer to | -| datafusion.execution.parquet.compression | zstd(3) | (writing) Sets default parquet compression codec. Valid values are: uncompressed, snappy, gzip(level), lzo, brotli(level), lz4, zstd(level), and lz4_raw. These values are not case sensitive. If NULL, uses default parquet writer setting Note that this default setting is not the same as the default parquet writer setting. | -| datafusion.execution.parquet.dictionary_enabled | true | (writing) Sets if dictionary encoding is enabled. If NULL, uses default parquet writer setting | -| datafusion.execution.parquet.dictionary_page_size_limit | 1048576 | (writing) Sets best effort maximum dictionary page size, in bytes | -| datafusion.execution.parquet.statistics_enabled | page | (writing) Sets if statistics are enabled for any column Valid values are: "none", "chunk", and "page" These values are not case sensitive. If NULL, uses default parquet writer setting | -| datafusion.execution.parquet.max_statistics_size | 4096 | (writing) Sets max statistics size for any column. If NULL, uses default parquet writer setting max_statistics_size is deprecated, currently it is not being used | -| datafusion.execution.parquet.max_row_group_size | 1048576 | (writing) Target maximum number of rows in each row group (defaults to 1M rows). Writing larger row groups requires more memory to write, but can get better compression and be faster to read. | -| datafusion.execution.parquet.created_by | datafusion version 47.0.0 | (writing) Sets "created by" property | -| datafusion.execution.parquet.column_index_truncate_length | 64 | (writing) Sets column index truncate length | -| datafusion.execution.parquet.statistics_truncate_length | NULL | (writing) Sets statictics truncate length. If NULL, uses default parquet writer setting | -| datafusion.execution.parquet.data_page_row_count_limit | 20000 | (writing) Sets best effort maximum number of rows in data page | -| datafusion.execution.parquet.encoding | NULL | (writing) Sets default encoding for any column. Valid values are: plain, plain_dictionary, rle, bit_packed, delta_binary_packed, delta_length_byte_array, delta_byte_array, rle_dictionary, and byte_stream_split. These values are not case sensitive. If NULL, uses default parquet writer setting | -| datafusion.execution.parquet.bloom_filter_on_read | true | (writing) Use any available bloom filters when reading parquet files | -| datafusion.execution.parquet.bloom_filter_on_write | false | (writing) Write bloom filters for all columns when creating parquet files | -| datafusion.execution.parquet.bloom_filter_fpp | NULL | (writing) Sets bloom filter false positive probability. If NULL, uses default parquet writer setting | -| datafusion.execution.parquet.bloom_filter_ndv | NULL | (writing) Sets bloom filter number of distinct values. If NULL, uses default parquet writer setting | -| datafusion.execution.parquet.allow_single_file_parallelism | true | (writing) Controls whether DataFusion will attempt to speed up writing parquet files by serializing them in parallel. Each column in each row group in each output file are serialized in parallel leveraging a maximum possible core count of n_files*n_row_groups*n_columns. | -| datafusion.execution.parquet.maximum_parallel_row_group_writers | 1 | (writing) By default parallel parquet writer is tuned for minimum memory usage in a streaming execution plan. You may see a performance benefit when writing large parquet files by increasing maximum_parallel_row_group_writers and maximum_buffered_record_batches_per_stream if your system has idle cores and can tolerate additional memory usage. Boosting these values is likely worthwhile when writing out already in-memory data, such as from a cached data frame. | -| datafusion.execution.parquet.maximum_buffered_record_batches_per_stream | 2 | (writing) By default parallel parquet writer is tuned for minimum memory usage in a streaming execution plan. You may see a performance benefit when writing large parquet files by increasing maximum_parallel_row_group_writers and maximum_buffered_record_batches_per_stream if your system has idle cores and can tolerate additional memory usage. Boosting these values is likely worthwhile when writing out already in-memory data, such as from a cached data frame. | -| datafusion.execution.planning_concurrency | 0 | Fan-out during initial physical planning. This is mostly use to plan `UNION` children in parallel. Defaults to the number of CPU cores on the system | -| datafusion.execution.skip_physical_aggregate_schema_check | false | When set to true, skips verifying that the schema produced by planning the input of `LogicalPlan::Aggregate` exactly matches the schema of the input plan. When set to false, if the schema does not match exactly (including nullability and metadata), a planning error will be raised. This is used to workaround bugs in the planner that are now caught by the new schema verification step. | -| datafusion.execution.sort_spill_reservation_bytes | 10485760 | Specifies the reserved memory for each spillable sort operation to facilitate an in-memory merge. When a sort operation spills to disk, the in-memory data must be sorted and merged before being written to a file. This setting reserves a specific amount of memory for that in-memory sort/merge process. Note: This setting is irrelevant if the sort operation cannot spill (i.e., if there's no `DiskManager` configured). | -| datafusion.execution.sort_in_place_threshold_bytes | 1048576 | When sorting, below what size should data be concatenated and sorted in a single RecordBatch rather than sorted in batches and merged. | -| datafusion.execution.meta_fetch_concurrency | 32 | Number of files to read in parallel when inferring schema and statistics | -| datafusion.execution.minimum_parallel_output_files | 4 | Guarantees a minimum level of output files running in parallel. RecordBatches will be distributed in round robin fashion to each parallel writer. Each writer is closed and a new file opened once soft_max_rows_per_output_file is reached. | -| datafusion.execution.soft_max_rows_per_output_file | 50000000 | Target number of rows in output files when writing multiple. This is a soft max, so it can be exceeded slightly. There also will be one file smaller than the limit if the total number of rows written is not roughly divisible by the soft max | -| datafusion.execution.max_buffered_batches_per_output_file | 2 | This is the maximum number of RecordBatches buffered for each output file being worked. Higher values can potentially give faster write performance at the cost of higher peak memory consumption | -| datafusion.execution.listing_table_ignore_subdirectory | true | Should sub directories be ignored when scanning directories for data files. Defaults to true (ignores subdirectories), consistent with Hive. Note that this setting does not affect reading partitioned tables (e.g. `/table/year=2021/month=01/data.parquet`). | -| datafusion.execution.enable_recursive_ctes | true | Should DataFusion support recursive CTEs | -| datafusion.execution.split_file_groups_by_statistics | false | Attempt to eliminate sorts by packing & sorting files with non-overlapping statistics into the same file groups. Currently experimental | -| datafusion.execution.keep_partition_by_columns | false | Should DataFusion keep the columns used for partition_by in the output RecordBatches | -| datafusion.execution.skip_partial_aggregation_probe_ratio_threshold | 0.8 | Aggregation ratio (number of distinct groups / number of input rows) threshold for skipping partial aggregation. If the value is greater then partial aggregation will skip aggregation for further input | -| datafusion.execution.skip_partial_aggregation_probe_rows_threshold | 100000 | Number of input rows partial aggregation partition should process, before aggregation ratio check and trying to switch to skipping aggregation mode | -| datafusion.execution.use_row_number_estimates_to_optimize_partitioning | false | Should DataFusion use row number estimates at the input to decide whether increasing parallelism is beneficial or not. By default, only exact row numbers (not estimates) are used for this decision. Setting this flag to `true` will likely produce better plans. if the source of statistics is accurate. We plan to make this the default in the future. | -| datafusion.execution.enforce_batch_size_in_joins | false | Should DataFusion enforce batch size in joins or not. By default, DataFusion will not enforce batch size in joins. Enforcing batch size in joins can reduce memory usage when joining large tables with a highly-selective join filter, but is also slightly slower. | -| datafusion.optimizer.enable_distinct_aggregation_soft_limit | true | When set to true, the optimizer will push a limit operation into grouped aggregations which have no aggregate expressions, as a soft limit, emitting groups once the limit is reached, before all rows in the group are read. | -| datafusion.optimizer.enable_round_robin_repartition | true | When set to true, the physical plan optimizer will try to add round robin repartitioning to increase parallelism to leverage more CPU cores | -| datafusion.optimizer.enable_topk_aggregation | true | When set to true, the optimizer will attempt to perform limit operations during aggregations, if possible | -| datafusion.optimizer.filter_null_join_keys | false | When set to true, the optimizer will insert filters before a join between a nullable and non-nullable column to filter out nulls on the nullable side. This filter can add additional overhead when the file format does not fully support predicate push down. | -| datafusion.optimizer.repartition_aggregations | true | Should DataFusion repartition data using the aggregate keys to execute aggregates in parallel using the provided `target_partitions` level | -| datafusion.optimizer.repartition_file_min_size | 10485760 | Minimum total files size in bytes to perform file scan repartitioning. | -| datafusion.optimizer.repartition_joins | true | Should DataFusion repartition data using the join keys to execute joins in parallel using the provided `target_partitions` level | -| datafusion.optimizer.allow_symmetric_joins_without_pruning | true | Should DataFusion allow symmetric hash joins for unbounded data sources even when its inputs do not have any ordering or filtering If the flag is not enabled, the SymmetricHashJoin operator will be unable to prune its internal buffers, resulting in certain join types - such as Full, Left, LeftAnti, LeftSemi, Right, RightAnti, and RightSemi - being produced only at the end of the execution. This is not typical in stream processing. Additionally, without proper design for long runner execution, all types of joins may encounter out-of-memory errors. | -| datafusion.optimizer.repartition_file_scans | true | When set to `true`, file groups will be repartitioned to achieve maximum parallelism. Currently Parquet and CSV formats are supported. If set to `true`, all files will be repartitioned evenly (i.e., a single large file might be partitioned into smaller chunks) for parallel scanning. If set to `false`, different files will be read in parallel, but repartitioning won't happen within a single file. | -| datafusion.optimizer.repartition_windows | true | Should DataFusion repartition data using the partitions keys to execute window functions in parallel using the provided `target_partitions` level | -| datafusion.optimizer.repartition_sorts | true | Should DataFusion execute sorts in a per-partition fashion and merge afterwards instead of coalescing first and sorting globally. With this flag is enabled, plans in the form below `text "SortExec: [a@0 ASC]", " CoalescePartitionsExec", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", ` would turn into the plan below which performs better in multithreaded environments `text "SortPreservingMergeExec: [a@0 ASC]", " SortExec: [a@0 ASC]", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", ` | -| datafusion.optimizer.prefer_existing_sort | false | When true, DataFusion will opportunistically remove sorts when the data is already sorted, (i.e. setting `preserve_order` to true on `RepartitionExec` and using `SortPreservingMergeExec`) When false, DataFusion will maximize plan parallelism using `RepartitionExec` even if this requires subsequently resorting data using a `SortExec`. | -| datafusion.optimizer.skip_failed_rules | false | When set to true, the logical plan optimizer will produce warning messages if any optimization rules produce errors and then proceed to the next rule. When set to false, any rules that produce errors will cause the query to fail | -| datafusion.optimizer.max_passes | 3 | Number of times that the optimizer will attempt to optimize the plan | -| datafusion.optimizer.top_down_join_key_reordering | true | When set to true, the physical plan optimizer will run a top down process to reorder the join keys | -| datafusion.optimizer.prefer_hash_join | true | When set to true, the physical plan optimizer will prefer HashJoin over SortMergeJoin. HashJoin can work more efficiently than SortMergeJoin but consumes more memory | -| datafusion.optimizer.hash_join_single_partition_threshold | 1048576 | The maximum estimated size in bytes for one input side of a HashJoin will be collected into a single partition | -| datafusion.optimizer.hash_join_single_partition_threshold_rows | 131072 | The maximum estimated size in rows for one input side of a HashJoin will be collected into a single partition | -| datafusion.optimizer.default_filter_selectivity | 20 | The default filter selectivity used by Filter Statistics when an exact selectivity cannot be determined. Valid values are between 0 (no selectivity) and 100 (all rows are selected). | -| datafusion.optimizer.prefer_existing_union | false | When set to true, the optimizer will not attempt to convert Union to Interleave | -| datafusion.optimizer.expand_views_at_output | false | When set to true, if the returned type is a view type then the output will be coerced to a non-view. Coerces `Utf8View` to `LargeUtf8`, and `BinaryView` to `LargeBinary`. | -| datafusion.explain.logical_plan_only | false | When set to true, the explain statement will only print logical plans | -| datafusion.explain.physical_plan_only | false | When set to true, the explain statement will only print physical plans | -| datafusion.explain.show_statistics | false | When set to true, the explain statement will print operator statistics for physical plans | -| datafusion.explain.show_sizes | true | When set to true, the explain statement will print the partition sizes | -| datafusion.explain.show_schema | false | When set to true, the explain statement will print schema information | -| datafusion.explain.format | indent | Display format of explain. Default is "indent". When set to "tree", it will print the plan in a tree-rendered format. | -| datafusion.sql_parser.parse_float_as_decimal | false | When set to true, SQL parser will parse float as decimal type | -| datafusion.sql_parser.enable_ident_normalization | true | When set to true, SQL parser will normalize ident (convert ident to lowercase when not quoted) | -| datafusion.sql_parser.enable_options_value_normalization | false | When set to true, SQL parser will normalize options value (convert value to lowercase). Note that this option is ignored and will be removed in the future. All case-insensitive values are normalized automatically. | -| datafusion.sql_parser.dialect | generic | Configure the SQL dialect used by DataFusion's parser; supported values include: Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, MsSQL, ClickHouse, BigQuery, Ansi, DuckDB and Databricks. | -| datafusion.sql_parser.support_varchar_with_length | true | If true, permit lengths for `VARCHAR` such as `VARCHAR(20)`, but ignore the length. If false, error if a `VARCHAR` with a length is specified. The Arrow type system does not have a notion of maximum string length and thus DataFusion can not enforce such limits. | -| datafusion.sql_parser.map_varchar_to_utf8view | false | If true, `VARCHAR` is mapped to `Utf8View` during SQL planning. If false, `VARCHAR` is mapped to `Utf8` during SQL planning. Default is false. | -| datafusion.sql_parser.collect_spans | false | When set to true, the source locations relative to the original SQL query (i.e. [`Span`](https://docs.rs/sqlparser/latest/sqlparser/tokenizer/struct.Span.html)) will be collected and recorded in the logical plan nodes. | -| datafusion.sql_parser.recursion_limit | 50 | Specifies the recursion depth limit when parsing complex SQL Queries | +| key | default | description | +| ----------------------------------------------------------------------- | ------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| datafusion.catalog.create_default_catalog_and_schema | true | Whether the default catalog and schema should be created automatically. | +| datafusion.catalog.default_catalog | datafusion | The default catalog name - this impacts what SQL queries use if not specified | +| datafusion.catalog.default_schema | public | The default schema name - this impacts what SQL queries use if not specified | +| datafusion.catalog.information_schema | false | Should DataFusion provide access to `information_schema` virtual tables for displaying schema information | +| datafusion.catalog.location | NULL | Location scanned to load tables for `default` schema | +| datafusion.catalog.format | NULL | Type of `TableProvider` to use when loading `default` schema | +| datafusion.catalog.has_header | true | Default value for `format.has_header` for `CREATE EXTERNAL TABLE` if not specified explicitly in the statement. | +| datafusion.catalog.newlines_in_values | false | Specifies whether newlines in (quoted) CSV values are supported. This is the default value for `format.newlines_in_values` for `CREATE EXTERNAL TABLE` if not specified explicitly in the statement. Parsing newlines in quoted values may be affected by execution behaviour such as parallel file scanning. Setting this to `true` ensures that newlines in values are parsed successfully, which may reduce performance. | +| datafusion.execution.batch_size | 8192 | Default batch size while creating new batches, it's especially useful for buffer-in-memory batches since creating tiny batches would result in too much metadata memory consumption | +| datafusion.execution.coalesce_batches | true | When set to true, record batches will be examined between each operator and small batches will be coalesced into larger batches. This is helpful when there are highly selective filters or joins that could produce tiny output batches. The target batch size is determined by the configuration setting | +| datafusion.execution.collect_statistics | true | Should DataFusion collect statistics when first creating a table. Has no effect after the table is created. Applies to the default `ListingTableProvider` in DataFusion. Defaults to true. | +| datafusion.execution.target_partitions | 0 | Number of partitions for query execution. Increasing partitions can increase concurrency. Defaults to the number of CPU cores on the system | +| datafusion.execution.time_zone | +00:00 | The default time zone Some functions, e.g. `EXTRACT(HOUR from SOME_TIME)`, shift the underlying datetime according to this time zone, and then extract the hour | +| datafusion.execution.parquet.enable_page_index | true | (reading) If true, reads the Parquet data page level metadata (the Page Index), if present, to reduce the I/O and number of rows decoded. | +| datafusion.execution.parquet.pruning | true | (reading) If true, the parquet reader attempts to skip entire row groups based on the predicate in the query and the metadata (min/max values) stored in the parquet file | +| datafusion.execution.parquet.skip_metadata | true | (reading) If true, the parquet reader skip the optional embedded metadata that may be in the file Schema. This setting can help avoid schema conflicts when querying multiple parquet files with schemas containing compatible types but different metadata | +| datafusion.execution.parquet.metadata_size_hint | NULL | (reading) If specified, the parquet reader will try and fetch the last `size_hint` bytes of the parquet file optimistically. If not specified, two reads are required: One read to fetch the 8-byte parquet footer and another to fetch the metadata length encoded in the footer | +| datafusion.execution.parquet.pushdown_filters | false | (reading) If true, filter expressions are be applied during the parquet decoding operation to reduce the number of rows decoded. This optimization is sometimes called "late materialization". | +| datafusion.execution.parquet.reorder_filters | false | (reading) If true, filter expressions evaluated during the parquet decoding operation will be reordered heuristically to minimize the cost of evaluation. If false, the filters are applied in the same order as written in the query | +| datafusion.execution.parquet.schema_force_view_types | true | (reading) If true, parquet reader will read columns of `Utf8/Utf8Large` with `Utf8View`, and `Binary/BinaryLarge` with `BinaryView`. | +| datafusion.execution.parquet.binary_as_string | false | (reading) If true, parquet reader will read columns of `Binary/LargeBinary` with `Utf8`, and `BinaryView` with `Utf8View`. Parquet files generated by some legacy writers do not correctly set the UTF8 flag for strings, causing string columns to be loaded as BLOB instead. | +| datafusion.execution.parquet.coerce_int96 | NULL | (reading) If true, parquet reader will read columns of physical type int96 as originating from a different resolution than nanosecond. This is useful for reading data from systems like Spark which stores microsecond resolution timestamps in an int96 allowing it to write values with a larger date range than 64-bit timestamps with nanosecond resolution. | +| datafusion.execution.parquet.bloom_filter_on_read | true | (reading) Use any available bloom filters when reading parquet files | +| datafusion.execution.parquet.data_pagesize_limit | 1048576 | (writing) Sets best effort maximum size of data page in bytes | +| datafusion.execution.parquet.write_batch_size | 1024 | (writing) Sets write_batch_size in bytes | +| datafusion.execution.parquet.writer_version | 1.0 | (writing) Sets parquet writer version valid values are "1.0" and "2.0" | +| datafusion.execution.parquet.skip_arrow_metadata | false | (writing) Skip encoding the embedded arrow metadata in the KV_meta This is analogous to the `ArrowWriterOptions::with_skip_arrow_metadata`. Refer to | +| datafusion.execution.parquet.compression | zstd(3) | (writing) Sets default parquet compression codec. Valid values are: uncompressed, snappy, gzip(level), lzo, brotli(level), lz4, zstd(level), and lz4_raw. These values are not case sensitive. If NULL, uses default parquet writer setting Note that this default setting is not the same as the default parquet writer setting. | +| datafusion.execution.parquet.dictionary_enabled | true | (writing) Sets if dictionary encoding is enabled. If NULL, uses default parquet writer setting | +| datafusion.execution.parquet.dictionary_page_size_limit | 1048576 | (writing) Sets best effort maximum dictionary page size, in bytes | +| datafusion.execution.parquet.statistics_enabled | page | (writing) Sets if statistics are enabled for any column Valid values are: "none", "chunk", and "page" These values are not case sensitive. If NULL, uses default parquet writer setting | +| datafusion.execution.parquet.max_statistics_size | 4096 | (writing) Sets max statistics size for any column. If NULL, uses default parquet writer setting max_statistics_size is deprecated, currently it is not being used | +| datafusion.execution.parquet.max_row_group_size | 1048576 | (writing) Target maximum number of rows in each row group (defaults to 1M rows). Writing larger row groups requires more memory to write, but can get better compression and be faster to read. | +| datafusion.execution.parquet.created_by | datafusion version 48.0.1 | (writing) Sets "created by" property | +| datafusion.execution.parquet.column_index_truncate_length | 64 | (writing) Sets column index truncate length | +| datafusion.execution.parquet.statistics_truncate_length | NULL | (writing) Sets statictics truncate length. If NULL, uses default parquet writer setting | +| datafusion.execution.parquet.data_page_row_count_limit | 20000 | (writing) Sets best effort maximum number of rows in data page | +| datafusion.execution.parquet.encoding | NULL | (writing) Sets default encoding for any column. Valid values are: plain, plain_dictionary, rle, bit_packed, delta_binary_packed, delta_length_byte_array, delta_byte_array, rle_dictionary, and byte_stream_split. These values are not case sensitive. If NULL, uses default parquet writer setting | +| datafusion.execution.parquet.bloom_filter_on_write | false | (writing) Write bloom filters for all columns when creating parquet files | +| datafusion.execution.parquet.bloom_filter_fpp | NULL | (writing) Sets bloom filter false positive probability. If NULL, uses default parquet writer setting | +| datafusion.execution.parquet.bloom_filter_ndv | NULL | (writing) Sets bloom filter number of distinct values. If NULL, uses default parquet writer setting | +| datafusion.execution.parquet.allow_single_file_parallelism | true | (writing) Controls whether DataFusion will attempt to speed up writing parquet files by serializing them in parallel. Each column in each row group in each output file are serialized in parallel leveraging a maximum possible core count of n_files*n_row_groups*n_columns. | +| datafusion.execution.parquet.maximum_parallel_row_group_writers | 1 | (writing) By default parallel parquet writer is tuned for minimum memory usage in a streaming execution plan. You may see a performance benefit when writing large parquet files by increasing maximum_parallel_row_group_writers and maximum_buffered_record_batches_per_stream if your system has idle cores and can tolerate additional memory usage. Boosting these values is likely worthwhile when writing out already in-memory data, such as from a cached data frame. | +| datafusion.execution.parquet.maximum_buffered_record_batches_per_stream | 2 | (writing) By default parallel parquet writer is tuned for minimum memory usage in a streaming execution plan. You may see a performance benefit when writing large parquet files by increasing maximum_parallel_row_group_writers and maximum_buffered_record_batches_per_stream if your system has idle cores and can tolerate additional memory usage. Boosting these values is likely worthwhile when writing out already in-memory data, such as from a cached data frame. | +| datafusion.execution.planning_concurrency | 0 | Fan-out during initial physical planning. This is mostly use to plan `UNION` children in parallel. Defaults to the number of CPU cores on the system | +| datafusion.execution.skip_physical_aggregate_schema_check | false | When set to true, skips verifying that the schema produced by planning the input of `LogicalPlan::Aggregate` exactly matches the schema of the input plan. When set to false, if the schema does not match exactly (including nullability and metadata), a planning error will be raised. This is used to workaround bugs in the planner that are now caught by the new schema verification step. | +| datafusion.execution.sort_spill_reservation_bytes | 10485760 | Specifies the reserved memory for each spillable sort operation to facilitate an in-memory merge. When a sort operation spills to disk, the in-memory data must be sorted and merged before being written to a file. This setting reserves a specific amount of memory for that in-memory sort/merge process. Note: This setting is irrelevant if the sort operation cannot spill (i.e., if there's no `DiskManager` configured). | +| datafusion.execution.sort_in_place_threshold_bytes | 1048576 | When sorting, below what size should data be concatenated and sorted in a single RecordBatch rather than sorted in batches and merged. | +| datafusion.execution.meta_fetch_concurrency | 32 | Number of files to read in parallel when inferring schema and statistics | +| datafusion.execution.minimum_parallel_output_files | 4 | Guarantees a minimum level of output files running in parallel. RecordBatches will be distributed in round robin fashion to each parallel writer. Each writer is closed and a new file opened once soft_max_rows_per_output_file is reached. | +| datafusion.execution.soft_max_rows_per_output_file | 50000000 | Target number of rows in output files when writing multiple. This is a soft max, so it can be exceeded slightly. There also will be one file smaller than the limit if the total number of rows written is not roughly divisible by the soft max | +| datafusion.execution.max_buffered_batches_per_output_file | 2 | This is the maximum number of RecordBatches buffered for each output file being worked. Higher values can potentially give faster write performance at the cost of higher peak memory consumption | +| datafusion.execution.listing_table_ignore_subdirectory | true | Should sub directories be ignored when scanning directories for data files. Defaults to true (ignores subdirectories), consistent with Hive. Note that this setting does not affect reading partitioned tables (e.g. `/table/year=2021/month=01/data.parquet`). | +| datafusion.execution.enable_recursive_ctes | true | Should DataFusion support recursive CTEs | +| datafusion.execution.split_file_groups_by_statistics | false | Attempt to eliminate sorts by packing & sorting files with non-overlapping statistics into the same file groups. Currently experimental | +| datafusion.execution.keep_partition_by_columns | false | Should DataFusion keep the columns used for partition_by in the output RecordBatches | +| datafusion.execution.skip_partial_aggregation_probe_ratio_threshold | 0.8 | Aggregation ratio (number of distinct groups / number of input rows) threshold for skipping partial aggregation. If the value is greater then partial aggregation will skip aggregation for further input | +| datafusion.execution.skip_partial_aggregation_probe_rows_threshold | 100000 | Number of input rows partial aggregation partition should process, before aggregation ratio check and trying to switch to skipping aggregation mode | +| datafusion.execution.use_row_number_estimates_to_optimize_partitioning | false | Should DataFusion use row number estimates at the input to decide whether increasing parallelism is beneficial or not. By default, only exact row numbers (not estimates) are used for this decision. Setting this flag to `true` will likely produce better plans. if the source of statistics is accurate. We plan to make this the default in the future. | +| datafusion.execution.enforce_batch_size_in_joins | false | Should DataFusion enforce batch size in joins or not. By default, DataFusion will not enforce batch size in joins. Enforcing batch size in joins can reduce memory usage when joining large tables with a highly-selective join filter, but is also slightly slower. | +| datafusion.execution.objectstore_writer_buffer_size | 10485760 | Size (bytes) of data buffer DataFusion uses when writing output files. This affects the size of the data chunks that are uploaded to remote object stores (e.g. AWS S3). If very large (>= 100 GiB) output files are being written, it may be necessary to increase this size to avoid errors from the remote end point. | +| datafusion.optimizer.enable_distinct_aggregation_soft_limit | true | When set to true, the optimizer will push a limit operation into grouped aggregations which have no aggregate expressions, as a soft limit, emitting groups once the limit is reached, before all rows in the group are read. | +| datafusion.optimizer.enable_round_robin_repartition | true | When set to true, the physical plan optimizer will try to add round robin repartitioning to increase parallelism to leverage more CPU cores | +| datafusion.optimizer.enable_topk_aggregation | true | When set to true, the optimizer will attempt to perform limit operations during aggregations, if possible | +| datafusion.optimizer.filter_null_join_keys | false | When set to true, the optimizer will insert filters before a join between a nullable and non-nullable column to filter out nulls on the nullable side. This filter can add additional overhead when the file format does not fully support predicate push down. | +| datafusion.optimizer.repartition_aggregations | true | Should DataFusion repartition data using the aggregate keys to execute aggregates in parallel using the provided `target_partitions` level | +| datafusion.optimizer.repartition_file_min_size | 10485760 | Minimum total files size in bytes to perform file scan repartitioning. | +| datafusion.optimizer.repartition_joins | true | Should DataFusion repartition data using the join keys to execute joins in parallel using the provided `target_partitions` level | +| datafusion.optimizer.allow_symmetric_joins_without_pruning | true | Should DataFusion allow symmetric hash joins for unbounded data sources even when its inputs do not have any ordering or filtering If the flag is not enabled, the SymmetricHashJoin operator will be unable to prune its internal buffers, resulting in certain join types - such as Full, Left, LeftAnti, LeftSemi, Right, RightAnti, and RightSemi - being produced only at the end of the execution. This is not typical in stream processing. Additionally, without proper design for long runner execution, all types of joins may encounter out-of-memory errors. | +| datafusion.optimizer.repartition_file_scans | true | When set to `true`, datasource partitions will be repartitioned to achieve maximum parallelism. This applies to both in-memory partitions and FileSource's file groups (1 group is 1 partition). For FileSources, only Parquet and CSV formats are currently supported. If set to `true` for a FileSource, all files will be repartitioned evenly (i.e., a single large file might be partitioned into smaller chunks) for parallel scanning. If set to `false` for a FileSource, different files will be read in parallel, but repartitioning won't happen within a single file. If set to `true` for an in-memory source, all memtable's partitions will have their batches repartitioned evenly to the desired number of `target_partitions`. Repartitioning can change the total number of partitions and batches per partition, but does not slice the initial record tables provided to the MemTable on creation. | +| datafusion.optimizer.repartition_windows | true | Should DataFusion repartition data using the partitions keys to execute window functions in parallel using the provided `target_partitions` level | +| datafusion.optimizer.repartition_sorts | true | Should DataFusion execute sorts in a per-partition fashion and merge afterwards instead of coalescing first and sorting globally. With this flag is enabled, plans in the form below `text "SortExec: [a@0 ASC]", " CoalescePartitionsExec", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", ` would turn into the plan below which performs better in multithreaded environments `text "SortPreservingMergeExec: [a@0 ASC]", " SortExec: [a@0 ASC]", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", ` | +| datafusion.optimizer.prefer_existing_sort | false | When true, DataFusion will opportunistically remove sorts when the data is already sorted, (i.e. setting `preserve_order` to true on `RepartitionExec` and using `SortPreservingMergeExec`) When false, DataFusion will maximize plan parallelism using `RepartitionExec` even if this requires subsequently resorting data using a `SortExec`. | +| datafusion.optimizer.skip_failed_rules | false | When set to true, the logical plan optimizer will produce warning messages if any optimization rules produce errors and then proceed to the next rule. When set to false, any rules that produce errors will cause the query to fail | +| datafusion.optimizer.max_passes | 3 | Number of times that the optimizer will attempt to optimize the plan | +| datafusion.optimizer.top_down_join_key_reordering | true | When set to true, the physical plan optimizer will run a top down process to reorder the join keys | +| datafusion.optimizer.prefer_hash_join | true | When set to true, the physical plan optimizer will prefer HashJoin over SortMergeJoin. HashJoin can work more efficiently than SortMergeJoin but consumes more memory | +| datafusion.optimizer.hash_join_single_partition_threshold | 1048576 | The maximum estimated size in bytes for one input side of a HashJoin will be collected into a single partition | +| datafusion.optimizer.hash_join_single_partition_threshold_rows | 131072 | The maximum estimated size in rows for one input side of a HashJoin will be collected into a single partition | +| datafusion.optimizer.default_filter_selectivity | 20 | The default filter selectivity used by Filter Statistics when an exact selectivity cannot be determined. Valid values are between 0 (no selectivity) and 100 (all rows are selected). | +| datafusion.optimizer.prefer_existing_union | false | When set to true, the optimizer will not attempt to convert Union to Interleave | +| datafusion.optimizer.expand_views_at_output | false | When set to true, if the returned type is a view type then the output will be coerced to a non-view. Coerces `Utf8View` to `LargeUtf8`, and `BinaryView` to `LargeBinary`. | +| datafusion.explain.logical_plan_only | false | When set to true, the explain statement will only print logical plans | +| datafusion.explain.physical_plan_only | false | When set to true, the explain statement will only print physical plans | +| datafusion.explain.show_statistics | false | When set to true, the explain statement will print operator statistics for physical plans | +| datafusion.explain.show_sizes | true | When set to true, the explain statement will print the partition sizes | +| datafusion.explain.show_schema | false | When set to true, the explain statement will print schema information | +| datafusion.explain.format | indent | Display format of explain. Default is "indent". When set to "tree", it will print the plan in a tree-rendered format. | +| datafusion.sql_parser.parse_float_as_decimal | false | When set to true, SQL parser will parse float as decimal type | +| datafusion.sql_parser.enable_ident_normalization | true | When set to true, SQL parser will normalize ident (convert ident to lowercase when not quoted) | +| datafusion.sql_parser.enable_options_value_normalization | false | When set to true, SQL parser will normalize options value (convert value to lowercase). Note that this option is ignored and will be removed in the future. All case-insensitive values are normalized automatically. | +| datafusion.sql_parser.dialect | generic | Configure the SQL dialect used by DataFusion's parser; supported values include: Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, MsSQL, ClickHouse, BigQuery, Ansi, DuckDB and Databricks. | +| datafusion.sql_parser.support_varchar_with_length | true | If true, permit lengths for `VARCHAR` such as `VARCHAR(20)`, but ignore the length. If false, error if a `VARCHAR` with a length is specified. The Arrow type system does not have a notion of maximum string length and thus DataFusion can not enforce such limits. | +| datafusion.sql_parser.map_varchar_to_utf8view | true | If true, `VARCHAR` is mapped to `Utf8View` during SQL planning. If false, `VARCHAR` is mapped to `Utf8` during SQL planning. Default is false. | +| datafusion.sql_parser.collect_spans | false | When set to true, the source locations relative to the original SQL query (i.e. [`Span`](https://docs.rs/sqlparser/latest/sqlparser/tokenizer/struct.Span.html)) will be collected and recorded in the logical plan nodes. | +| datafusion.sql_parser.recursion_limit | 50 | Specifies the recursion depth limit when parsing complex SQL Queries | +| datafusion.format.safe | true | If set to `true` any formatting errors will be written to the output instead of being converted into a [`std::fmt::Error`] | +| datafusion.format.null | | Format string for nulls | +| datafusion.format.date_format | %Y-%m-%d | Date format for date arrays | +| datafusion.format.datetime_format | %Y-%m-%dT%H:%M:%S%.f | Format for DateTime arrays | +| datafusion.format.timestamp_format | %Y-%m-%dT%H:%M:%S%.f | Timestamp format for timestamp arrays | +| datafusion.format.timestamp_tz_format | NULL | Timestamp format for timestamp with timezone arrays. When `None`, ISO 8601 format is used. | +| datafusion.format.time_format | %H:%M:%S%.f | Time format for time arrays | +| datafusion.format.duration_format | pretty | Duration format. Can be either `"pretty"` or `"ISO8601"` | +| datafusion.format.types_info | false | Show types in visual representation batches | diff --git a/docs/source/user-guide/dataframe.md b/docs/source/user-guide/dataframe.md index 96be1bb9e256..82f1eeb2823d 100644 --- a/docs/source/user-guide/dataframe.md +++ b/docs/source/user-guide/dataframe.md @@ -50,13 +50,38 @@ use datafusion::prelude::*; Here is a minimal example showing the execution of a query using the DataFrame API. +Create DataFrame using macro API from in memory rows + ```rust use datafusion::prelude::*; use datafusion::error::Result; + +#[tokio::main] +async fn main() -> Result<()> { + // Create a new dataframe with in-memory data using macro + let df = dataframe!( + "a" => [1, 2, 3], + "b" => [true, true, false], + "c" => [Some("foo"), Some("bar"), None] + )?; + df.show().await?; + Ok(()) +} +``` + +Create DataFrame from file or in memory rows using standard API + +```rust +use datafusion::arrow::array::{Int32Array, RecordBatch, StringArray}; +use datafusion::arrow::datatypes::{DataType, Field, Schema}; +use datafusion::error::Result; use datafusion::functions_aggregate::expr_fn::min; +use datafusion::prelude::*; +use std::sync::Arc; #[tokio::main] async fn main() -> Result<()> { + // Read the data from a csv file let ctx = SessionContext::new(); let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; let df = df.filter(col("a").lt_eq(col("b")))? @@ -64,6 +89,22 @@ async fn main() -> Result<()> { .limit(0, Some(100))?; // Print results df.show().await?; + + // Create a new dataframe with in-memory data + let schema = Schema::new(vec![ + Field::new("id", DataType::Int32, true), + Field::new("name", DataType::Utf8, true), + ]); + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(StringArray::from(vec!["foo", "bar", "baz"])), + ], + )?; + let df = ctx.read_batch(batch)?; + df.show().await?; + Ok(()) } ``` diff --git a/docs/source/user-guide/explain-usage.md b/docs/source/user-guide/explain-usage.md index d89ed5f0e7ea..68712012f43f 100644 --- a/docs/source/user-guide/explain-usage.md +++ b/docs/source/user-guide/explain-usage.md @@ -40,7 +40,7 @@ Let's see how DataFusion runs a query that selects the top 5 watch lists for the site `http://domcheloveplanet.ru/`: ```sql -EXPLAIN SELECT "WatchID" AS wid, "hits.parquet"."ClientIP" AS ip +EXPLAIN FORMAT INDENT SELECT "WatchID" AS wid, "hits.parquet"."ClientIP" AS ip FROM 'hits.parquet' WHERE starts_with("URL", 'http://domcheloveplanet.ru/') ORDER BY wid ASC, ip DESC @@ -268,7 +268,7 @@ LIMIT 10; We can again see the query plan by using `EXPLAIN`: ```sql -> EXPLAIN SELECT "UserID", COUNT(*) FROM 'hits.parquet' GROUP BY "UserID" ORDER BY COUNT(*) DESC LIMIT 10; +> EXPLAIN FORMAT INDENT SELECT "UserID", COUNT(*) FROM 'hits.parquet' GROUP BY "UserID" ORDER BY COUNT(*) DESC LIMIT 10; +---------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | plan_type | plan | +---------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ diff --git a/docs/source/user-guide/introduction.md b/docs/source/user-guide/introduction.md index 14d6ab177dc3..040405f8f63e 100644 --- a/docs/source/user-guide/introduction.md +++ b/docs/source/user-guide/introduction.md @@ -40,9 +40,9 @@ Arrow](https://arrow.apache.org/). ## Features - Feature-rich [SQL support](https://datafusion.apache.org/user-guide/sql/index.html) and [DataFrame API](https://datafusion.apache.org/user-guide/dataframe.html) -- Blazingly fast, vectorized, multi-threaded, streaming execution engine. +- Blazingly fast, vectorized, multithreaded, streaming execution engine. - Native support for Parquet, CSV, JSON, and Avro file formats. Support - for custom file formats and non file datasources via the `TableProvider` trait. + for custom file formats and non-file datasources via the `TableProvider` trait. - Many extension points: user defined scalar/aggregate/window functions, DataSources, SQL, other query languages, custom plan and execution nodes, optimizer passes, and more. - Streaming, asynchronous IO directly from popular object stores, including AWS S3, @@ -68,14 +68,14 @@ DataFusion can be used without modification as an embedded SQL engine or can be customized and used as a foundation for building new systems. -While most current usecases are "analytic" or (throughput) some +While most current use cases are "analytic" or (throughput) some components of DataFusion such as the plan representations, are suitable for "streaming" and "transaction" style systems (low latency). Here are some example systems built using DataFusion: -- Specialized Analytical Database systems such as [HoraeDB] and more general Apache Spark like system such a [Ballista]. +- Specialized Analytical Database systems such as [HoraeDB] and more general Apache Spark like system such as [Ballista] - New query language engines such as [prql-query] and accelerators such as [VegaFusion] - Research platform for new Database Systems, such as [Flock] - SQL support to another library, such as [dask sql] @@ -95,19 +95,22 @@ Here are some active projects using DataFusion: - [Arroyo](https://github.com/ArroyoSystems/arroyo) Distributed stream processing engine in Rust +- [ArkFlow](https://github.com/arkflow-rs/arkflow) High-performance Rust stream processing engine - [Ballista](https://github.com/apache/datafusion-ballista) Distributed SQL Query Engine - [Blaze](https://github.com/kwai/blaze) The Blaze accelerator for Apache Spark leverages native vectorized execution to accelerate query processing - [CnosDB](https://github.com/cnosdb/cnosdb) Open Source Distributed Time Series Database - [Comet](https://github.com/apache/datafusion-comet) Apache Spark native query execution plugin -- [Cube Store](https://github.com/cube-js/cube.js/tree/master/rust) +- [Cube Store](https://github.com/cube-js/cube.js/tree/master/rust) Cube’s universal semantic layer platform is the next evolution of OLAP technology for AI, BI, spreadsheets, and embedded analytics - [Dask SQL](https://github.com/dask-contrib/dask-sql) Distributed SQL query engine in Python - [datafusion-dft](https://github.com/datafusion-contrib/datafusion-dft) Batteries included CLI, TUI, and server implementations for DataFusion. - [delta-rs](https://github.com/delta-io/delta-rs) Native Rust implementation of Delta Lake - [Exon](https://github.com/wheretrue/exon) Analysis toolkit for life-science applications +- [Feldera](https://github.com/feldera/feldera) Fast query engine for incremental computation - [Funnel](https://funnel.io/) Data Platform powering Marketing Intelligence applications. - [GlareDB](https://github.com/GlareDB/glaredb) Fast SQL database for querying and analyzing distributed data. - [GreptimeDB](https://github.com/GreptimeTeam/greptimedb) Open Source & Cloud Native Distributed Time Series Database - [HoraeDB](https://github.com/apache/incubator-horaedb) Distributed Time-Series Database +- [Iceberg-rust](https://github.com/apache/iceberg-rust) Rust implementation of Apache Iceberg - [InfluxDB](https://github.com/influxdata/influxdb) Time Series Database - [Kamu](https://github.com/kamu-data/kamu-cli/) Planet-scale streaming data pipeline - [LakeSoul](https://github.com/lakesoul-io/LakeSoul) Open source LakeHouse framework with native IO in Rust. @@ -118,11 +121,11 @@ Here are some active projects using DataFusion: - [Polygon.io](https://polygon.io/) Stock Market API - [qv](https://github.com/timvw/qv) Quickly view your data - [Restate](https://github.com/restatedev) Easily build resilient applications using distributed durable async/await -- [ROAPI](https://github.com/roapi/roapi) -- [Sail](https://github.com/lakehq/sail) Unifying stream, batch, and AI workloads with Apache Spark compatibility +- [ROAPI](https://github.com/roapi/roapi) Create full-fledged APIs for slowly moving datasets without writing a single line of code +- [Sail](https://github.com/lakehq/sail) Unifying stream, batch and AI workloads with Apache Spark compatibility - [Seafowl](https://github.com/splitgraph/seafowl) CDN-friendly analytical database - [Sleeper](https://github.com/gchq/sleeper) Serverless, cloud-native, log-structured merge tree based, scalable key-value store -- [Spice.ai](https://github.com/spiceai/spiceai) Unified SQL query interface & materialization engine +- [Spice.ai](https://github.com/spiceai/spiceai) Building blocks for data-driven AI applications - [Synnada](https://synnada.ai/) Streaming-first framework for data products - [VegaFusion](https://vegafusion.io/) Server-side acceleration for the [Vega](https://vega.github.io/) visualization grammar - [Telemetry](https://telemetry.sh/) Structured logging made easy @@ -179,6 +182,20 @@ provide integrations with other systems, some of which are described below: ## Why DataFusion? - _High Performance_: Leveraging Rust and Arrow's memory model, DataFusion is very fast. -- _Easy to Connect_: Being part of the Apache Arrow ecosystem (Arrow, Parquet and Flight), DataFusion works well with the rest of the big data ecosystem +- _Easy to Connect_: Being part of the Apache Arrow ecosystem (Arrow, Parquet, and Flight), DataFusion works well with the rest of the big data ecosystem - _Easy to Embed_: Allowing extension at almost any point in its design, and published regularly as a crate on [crates.io](http://crates.io), DataFusion can be integrated and tailored for your specific usecase. - _High Quality_: Extensively tested, both by itself and with the rest of the Arrow ecosystem, DataFusion can and is used as the foundation for production systems. + +## Rust Version Compatibility Policy + +The Rust toolchain releases are tracked at [Rust Versions](https://releases.rs) and follow +[semantic versioning](https://semver.org/). A Rust toolchain release can be identified +by a version string like `1.80.0`, or more generally `major.minor.patch`. + +DataFusion supports the last 4 stable Rust minor versions released and any such versions released within the last 4 months. + +For example, given the releases `1.78.0`, `1.79.0`, `1.80.0`, `1.80.1` and `1.81.0` DataFusion will support 1.78.0, which is 3 minor versions prior to the most minor recent `1.81`. + +Note: If a Rust hotfix is released for the current MSRV, the MSRV will be updated to the specific minor version that includes all applicable hotfixes preceding other policies. + +DataFusion enforces MSRV policy using a [MSRV CI Check](https://github.com/search?q=repo%3Aapache%2Fdatafusion+rust-version+language%3ATOML+path%3A%2F%5ECargo.toml%2F&type=code) diff --git a/docs/source/user-guide/runtime_configs.md b/docs/source/user-guide/runtime_configs.md new file mode 100644 index 000000000000..feef709db992 --- /dev/null +++ b/docs/source/user-guide/runtime_configs.md @@ -0,0 +1,40 @@ + + + + +# Runtime Environment Configurations + +DataFusion runtime configurations can be set via SQL using the `SET` command. + +For example, to configure `datafusion.runtime.memory_limit`: + +```sql +SET datafusion.runtime.memory_limit = '2G'; +``` + +The following runtime configuration settings are available: + +| key | default | description | +| ------------------------------- | ------- | ------------------------------------------------------------------------------------------------------------------------------------------- | +| datafusion.runtime.memory_limit | NULL | Maximum memory limit for query execution. Supports suffixes K (kilobytes), M (megabytes), and G (gigabytes). Example: '2G' for 2 gigabytes. | diff --git a/docs/source/user-guide/sql/aggregate_functions.md b/docs/source/user-guide/sql/aggregate_functions.md index 684db52e6323..774a4fae6bf3 100644 --- a/docs/source/user-guide/sql/aggregate_functions.md +++ b/docs/source/user-guide/sql/aggregate_functions.md @@ -808,7 +808,7 @@ approx_distinct(expression) ### `approx_median` -Returns the approximate median (50th percentile) of input values. It is an alias of `approx_percentile_cont(x, 0.5)`. +Returns the approximate median (50th percentile) of input values. It is an alias of `approx_percentile_cont(0.5) WITHIN GROUP (ORDER BY x)`. ```sql approx_median(expression) @@ -834,7 +834,7 @@ approx_median(expression) Returns the approximate percentile of input values using the t-digest algorithm. ```sql -approx_percentile_cont(expression, percentile, centroids) +approx_percentile_cont(percentile, centroids) WITHIN GROUP (ORDER BY expression) ``` #### Arguments @@ -846,12 +846,12 @@ approx_percentile_cont(expression, percentile, centroids) #### Example ```sql -> SELECT approx_percentile_cont(column_name, 0.75, 100) FROM table_name; -+-------------------------------------------------+ -| approx_percentile_cont(column_name, 0.75, 100) | -+-------------------------------------------------+ -| 65.0 | -+-------------------------------------------------+ +> SELECT approx_percentile_cont(0.75, 100) WITHIN GROUP (ORDER BY column_name) FROM table_name; ++-----------------------------------------------------------------------+ +| approx_percentile_cont(0.75, 100) WITHIN GROUP (ORDER BY column_name) | ++-----------------------------------------------------------------------+ +| 65.0 | ++-----------------------------------------------------------------------+ ``` ### `approx_percentile_cont_with_weight` @@ -859,7 +859,7 @@ approx_percentile_cont(expression, percentile, centroids) Returns the weighted approximate percentile of input values using the t-digest algorithm. ```sql -approx_percentile_cont_with_weight(expression, weight, percentile) +approx_percentile_cont_with_weight(weight, percentile) WITHIN GROUP (ORDER BY expression) ``` #### Arguments @@ -871,10 +871,10 @@ approx_percentile_cont_with_weight(expression, weight, percentile) #### Example ```sql -> SELECT approx_percentile_cont_with_weight(column_name, weight_column, 0.90) FROM table_name; -+----------------------------------------------------------------------+ -| approx_percentile_cont_with_weight(column_name, weight_column, 0.90) | -+----------------------------------------------------------------------+ -| 78.5 | -+----------------------------------------------------------------------+ +> SELECT approx_percentile_cont_with_weight(weight_column, 0.90) WITHIN GROUP (ORDER BY column_name) FROM table_name; ++---------------------------------------------------------------------------------------------+ +| approx_percentile_cont_with_weight(weight_column, 0.90) WITHIN GROUP (ORDER BY column_name) | ++---------------------------------------------------------------------------------------------+ +| 78.5 | ++---------------------------------------------------------------------------------------------+ ``` diff --git a/docs/source/user-guide/sql/ddl.md b/docs/source/user-guide/sql/ddl.md index 71475cff9a39..1d971594ada9 100644 --- a/docs/source/user-guide/sql/ddl.md +++ b/docs/source/user-guide/sql/ddl.md @@ -74,7 +74,7 @@ LOCATION := ( , ...) ``` -For a detailed list of write related options which can be passed in the OPTIONS key_value_list, see [Write Options](write_options). +For a comprehensive list of format-specific options that can be specified in the `OPTIONS` clause, see [Format Options](format_options.md). `file_type` is one of `CSV`, `ARROW`, `PARQUET`, `AVRO` or `JSON` @@ -82,6 +82,8 @@ For a detailed list of write related options which can be passed in the OPTIONS a path to a file or directory of partitioned files locally or on an object store. +### Example: Parquet + Parquet data sources can be registered by executing a `CREATE EXTERNAL TABLE` SQL statement such as the following. It is not necessary to provide schema information for Parquet files. @@ -91,6 +93,23 @@ STORED AS PARQUET LOCATION '/mnt/nyctaxi/tripdata.parquet'; ``` +:::{note} +Statistics +: By default, when a table is created, DataFusion will read the files +to gather statistics, which can be expensive but can accelerate subsequent +queries substantially. If you don't want to gather statistics +when creating a table, set the `datafusion.execution.collect_statistics` +configuration option to `false` before creating the table. For example: + +```sql +SET datafusion.execution.collect_statistics = false; +``` + +See the [config settings docs](../configs.md) for more details. +::: + +### Example: Comma Separated Value (CSV) + CSV data sources can also be registered by executing a `CREATE EXTERNAL TABLE` SQL statement. The schema will be inferred based on scanning a subset of the file. @@ -101,6 +120,8 @@ LOCATION '/path/to/aggregate_simple.csv' OPTIONS ('has_header' 'true'); ``` +### Example: Compression + It is also possible to use compressed files, such as `.csv.gz`: ```sql @@ -111,6 +132,8 @@ LOCATION '/path/to/aggregate_simple.csv.gz' OPTIONS ('has_header' 'true'); ``` +### Example: Specifying Schema + It is also possible to specify the schema manually. ```sql @@ -134,6 +157,8 @@ LOCATION '/path/to/aggregate_test_100.csv' OPTIONS ('has_header' 'true'); ``` +### Example: Partitioned Tables + It is also possible to specify a directory that contains a partitioned table (multiple files with the same schema) @@ -144,7 +169,9 @@ LOCATION '/path/to/directory/of/files' OPTIONS ('has_header' 'true'); ``` -With `CREATE UNBOUNDED EXTERNAL TABLE` SQL statement. We can create unbounded data sources such as following: +### Example: Unbounded Data Sources + +We can create unbounded data sources using the `CREATE UNBOUNDED EXTERNAL TABLE` SQL statement. ```sql CREATE UNBOUNDED EXTERNAL TABLE taxi @@ -154,6 +181,8 @@ LOCATION '/mnt/nyctaxi/tripdata.parquet'; Note that this statement actually reads data from a fixed-size file, so a better example would involve reading from a FIFO file. Nevertheless, once Datafusion sees the `UNBOUNDED` keyword in a data source, it tries to execute queries that refer to this unbounded source in streaming fashion. If this is not possible according to query specifications, plan generation fails stating it is not possible to execute given query in streaming fashion. Note that queries that can run with unbounded sources (i.e. in streaming mode) are a subset of those that can with bounded sources. A query that fails with unbounded source(s) may work with bounded source(s). +### Example: `WITH ORDER` Clause + When creating an output from a data source that is already ordered by an expression, you can pre-specify the order of the data using the `WITH ORDER` clause. This applies even if the expression used for @@ -190,7 +219,7 @@ WITH ORDER (sort_expression1 [ASC | DESC] [NULLS { FIRST | LAST }] [, sort_expression2 [ASC | DESC] [NULLS { FIRST | LAST }] ...]) ``` -### Cautions when using the WITH ORDER Clause +#### Cautions when using the WITH ORDER Clause - It's important to understand that using the `WITH ORDER` clause in the `CREATE EXTERNAL TABLE` statement only specifies the order in which the data should be read from the external file. If the data in the file is not already sorted according to the specified order, then the results may not be correct. diff --git a/docs/source/user-guide/sql/dml.md b/docs/source/user-guide/sql/dml.md index 4eda59d6dea1..c29447f23cd9 100644 --- a/docs/source/user-guide/sql/dml.md +++ b/docs/source/user-guide/sql/dml.md @@ -49,7 +49,7 @@ The output format is determined by the first match of the following rules: 1. Value of `STORED AS` 2. Filename extension (e.g. `foo.parquet` implies `PARQUET` format) -For a detailed list of valid OPTIONS, see [Write Options](write_options). +For a detailed list of valid OPTIONS, see [Format Options](format_options.md). ### Examples diff --git a/docs/source/user-guide/sql/explain.md b/docs/source/user-guide/sql/explain.md index 9984de147ecc..c5e2e215a6b6 100644 --- a/docs/source/user-guide/sql/explain.md +++ b/docs/source/user-guide/sql/explain.md @@ -118,7 +118,7 @@ See [Reading Explain Plans](../explain-usage.md) for more information on how to 0 row(s) fetched. Elapsed 0.004 seconds. -> EXPLAIN SELECT SUM(x) FROM t GROUP BY b; +> EXPLAIN FORMAT INDENT SELECT SUM(x) FROM t GROUP BY b; +---------------+-------------------------------------------------------------------------------+ | plan_type | plan | +---------------+-------------------------------------------------------------------------------+ diff --git a/docs/source/user-guide/sql/format_options.md b/docs/source/user-guide/sql/format_options.md new file mode 100644 index 000000000000..e8008eafb166 --- /dev/null +++ b/docs/source/user-guide/sql/format_options.md @@ -0,0 +1,180 @@ + + +# Format Options + +DataFusion supports customizing how data is read from or written to disk as a result of a `COPY`, `INSERT INTO`, or `CREATE EXTERNAL TABLE` statements. There are a few special options, file format (e.g., CSV or Parquet) specific options, and Parquet column-specific options. In some cases, Options can be specified in multiple ways with a set order of precedence. + +## Specifying Options and Order of Precedence + +Format-related options can be specified in three ways, in decreasing order of precedence: + +- `CREATE EXTERNAL TABLE` syntax +- `COPY` option tuples +- Session-level config defaults + +For a list of supported session-level config defaults, see [Configuration Settings](../configs). These defaults apply to all operations but have the lowest level of precedence. + +If creating an external table, table-specific format options can be specified when the table is created using the `OPTIONS` clause: + +```sql +CREATE EXTERNAL TABLE + my_table(a bigint, b bigint) + STORED AS csv + LOCATION '/tmp/my_csv_table/' + OPTIONS( + NULL_VALUE 'NAN', + 'has_header' 'true', + 'format.delimiter' ';' + ); +``` + +When running `INSERT INTO my_table ...`, the options from the `CREATE TABLE` will be respected (e.g., gzip compression, special delimiter, and header row included). Note that compression, header, and delimiter settings can also be specified within the `OPTIONS` tuple list. Dedicated syntax within the SQL statement always takes precedence over arbitrary option tuples, so if both are specified, the `OPTIONS` setting will be ignored. + +For example, with the table defined above, running the following command: + +```sql +INSERT INTO my_table VALUES(1,2); +``` + +Results in a new CSV file with the specified options: + +```shell +$ cat /tmp/my_csv_table/bmC8zWFvLMtWX68R_0.csv +a;b +1;2 +``` + +Finally, options can be passed when running a `COPY` command. + +```sql +COPY source_table + TO 'test/table_with_options' + PARTITIONED BY (column3, column4) + OPTIONS ( + format parquet, + compression snappy, + 'compression::column1' 'zstd(5)', + ) +``` + +In this example, we write the entire `source_table` out to a folder of Parquet files. One Parquet file will be written in parallel to the folder for each partition in the query. The next option `compression` set to `snappy` indicates that unless otherwise specified, all columns should use the snappy compression codec. The option `compression::col1` sets an override, so that the column `col1` in the Parquet file will use the ZSTD compression codec with compression level `5`. In general, Parquet options that support column-specific settings can be specified with the syntax `OPTION::COLUMN.NESTED.PATH`. + +# Available Options + +## JSON Format Options + +The following options are available when reading or writing JSON files. Note: If any unsupported option is specified, an error will be raised and the query will fail. + +| Option | Description | Default Value | +| ----------- | ---------------------------------------------------------------------------------------------------------------------------------- | ------------- | +| COMPRESSION | Sets the compression that should be applied to the entire JSON file. Supported values are GZIP, BZIP2, XZ, ZSTD, and UNCOMPRESSED. | UNCOMPRESSED | + +**Example:** + +```sql +CREATE EXTERNAL TABLE t(a int) +STORED AS JSON +LOCATION '/tmp/foo/' +OPTIONS('COMPRESSION' 'gzip'); +``` + +## CSV Format Options + +The following options are available when reading or writing CSV files. Note: If any unsupported option is specified, an error will be raised and the query will fail. + +| Option | Description | Default Value | +| -------------------- | --------------------------------------------------------------------------------------------------------------------------------- | ------------------ | +| COMPRESSION | Sets the compression that should be applied to the entire CSV file. Supported values are GZIP, BZIP2, XZ, ZSTD, and UNCOMPRESSED. | UNCOMPRESSED | +| HAS_HEADER | Sets if the CSV file should include column headers. If not set, uses session or system default. | None | +| DELIMITER | Sets the character which should be used as the column delimiter within the CSV file. | `,` (comma) | +| QUOTE | Sets the character which should be used for quoting values within the CSV file. | `"` (double quote) | +| TERMINATOR | Sets the character which should be used as the line terminator within the CSV file. | None | +| ESCAPE | Sets the character which should be used for escaping special characters within the CSV file. | None | +| DOUBLE_QUOTE | Sets if quotes within quoted fields should be escaped by doubling them (e.g., `"aaa""bbb"`). | None | +| NEWLINES_IN_VALUES | Sets if newlines in quoted values are supported. If not set, uses session or system default. | None | +| DATE_FORMAT | Sets the format that dates should be encoded in within the CSV file. | None | +| DATETIME_FORMAT | Sets the format that datetimes should be encoded in within the CSV file. | None | +| TIMESTAMP_FORMAT | Sets the format that timestamps should be encoded in within the CSV file. | None | +| TIMESTAMP_TZ_FORMAT | Sets the format that timestamps with timezone should be encoded in within the CSV file. | None | +| TIME_FORMAT | Sets the format that times should be encoded in within the CSV file. | None | +| NULL_VALUE | Sets the string which should be used to indicate null values within the CSV file. | None | +| NULL_REGEX | Sets the regex pattern to match null values when loading CSVs. | None | +| SCHEMA_INFER_MAX_REC | Sets the maximum number of records to scan to infer the schema. | None | +| COMMENT | Sets the character which should be used to indicate comment lines in the CSV file. | None | + +**Example:** + +```sql +CREATE EXTERNAL TABLE t (col1 varchar, col2 int, col3 boolean) +STORED AS CSV +LOCATION '/tmp/foo/' +OPTIONS('DELIMITER' '|', 'HAS_HEADER' 'true', 'NEWLINES_IN_VALUES' 'true'); +``` + +## Parquet Format Options + +The following options are available when reading or writing Parquet files. If any unsupported option is specified, an error will be raised and the query will fail. If a column-specific option is specified for a column that does not exist, the option will be ignored without error. + +| Option | Can be Column Specific? | Description | OPTIONS Key | Default Value | +| ------------------------------------------ | ----------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------- | ------------------------ | +| COMPRESSION | Yes | Sets the internal Parquet **compression codec** for data pages, optionally including the compression level. Applies globally if set without `::col`, or specifically to a column if set using `'compression::column_name'`. Valid values: `uncompressed`, `snappy`, `gzip(level)`, `lzo`, `brotli(level)`, `lz4`, `zstd(level)`, `lz4_raw`. | `'compression'` or `'compression::col'` | zstd(3) | +| ENCODING | Yes | Sets the **encoding** scheme for data pages. Valid values: `plain`, `plain_dictionary`, `rle`, `bit_packed`, `delta_binary_packed`, `delta_length_byte_array`, `delta_byte_array`, `rle_dictionary`, `byte_stream_split`. Use key `'encoding'` or `'encoding::col'` in OPTIONS. | `'encoding'` or `'encoding::col'` | None | +| DICTIONARY_ENABLED | Yes | Sets whether dictionary encoding should be enabled globally or for a specific column. | `'dictionary_enabled'` or `'dictionary_enabled::col'` | true | +| STATISTICS_ENABLED | Yes | Sets the level of statistics to write (`none`, `chunk`, `page`). | `'statistics_enabled'` or `'statistics_enabled::col'` | page | +| BLOOM_FILTER_ENABLED | Yes | Sets whether a bloom filter should be written for a specific column. | `'bloom_filter_enabled::column_name'` | None | +| BLOOM_FILTER_FPP | Yes | Sets bloom filter false positive probability (global or per column). | `'bloom_filter_fpp'` or `'bloom_filter_fpp::col'` | None | +| BLOOM_FILTER_NDV | Yes | Sets bloom filter number of distinct values (global or per column). | `'bloom_filter_ndv'` or `'bloom_filter_ndv::col'` | None | +| MAX_ROW_GROUP_SIZE | No | Sets the maximum number of rows per row group. Larger groups require more memory but can improve compression and scan efficiency. | `'max_row_group_size'` | 1048576 | +| ENABLE_PAGE_INDEX | No | If true, reads the Parquet data page level metadata (the Page Index), if present, to reduce I/O and decoding. | `'enable_page_index'` | true | +| PRUNING | No | If true, enables row group pruning based on min/max statistics. | `'pruning'` | true | +| SKIP_METADATA | No | If true, skips optional embedded metadata in the file schema. | `'skip_metadata'` | true | +| METADATA_SIZE_HINT | No | Sets the size hint (in bytes) for fetching Parquet file metadata. | `'metadata_size_hint'` | None | +| PUSHDOWN_FILTERS | No | If true, enables filter pushdown during Parquet decoding. | `'pushdown_filters'` | false | +| REORDER_FILTERS | No | If true, enables heuristic reordering of filters during Parquet decoding. | `'reorder_filters'` | false | +| SCHEMA_FORCE_VIEW_TYPES | No | If true, reads Utf8/Binary columns as view types. | `'schema_force_view_types'` | true | +| BINARY_AS_STRING | No | If true, reads Binary columns as strings. | `'binary_as_string'` | false | +| DATA_PAGESIZE_LIMIT | No | Sets best effort maximum size of data page in bytes. | `'data_pagesize_limit'` | 1048576 | +| DATA_PAGE_ROW_COUNT_LIMIT | No | Sets best effort maximum number of rows in data page. | `'data_page_row_count_limit'` | 20000 | +| DICTIONARY_PAGE_SIZE_LIMIT | No | Sets best effort maximum dictionary page size, in bytes. | `'dictionary_page_size_limit'` | 1048576 | +| WRITE_BATCH_SIZE | No | Sets write_batch_size in bytes. | `'write_batch_size'` | 1024 | +| WRITER_VERSION | No | Sets the Parquet writer version (`1.0` or `2.0`). | `'writer_version'` | 1.0 | +| SKIP_ARROW_METADATA | No | If true, skips writing Arrow schema information into the Parquet file metadata. | `'skip_arrow_metadata'` | false | +| CREATED_BY | No | Sets the "created by" string in the Parquet file metadata. | `'created_by'` | datafusion version X.Y.Z | +| COLUMN_INDEX_TRUNCATE_LENGTH | No | Sets the length (in bytes) to truncate min/max values in column indexes. | `'column_index_truncate_length'` | 64 | +| STATISTICS_TRUNCATE_LENGTH | No | Sets statistics truncate length. | `'statistics_truncate_length'` | None | +| BLOOM_FILTER_ON_WRITE | No | Sets whether bloom filters should be written for all columns by default (can be overridden per column). | `'bloom_filter_on_write'` | false | +| ALLOW_SINGLE_FILE_PARALLELISM | No | Enables parallel serialization of columns in a single file. | `'allow_single_file_parallelism'` | true | +| MAXIMUM_PARALLEL_ROW_GROUP_WRITERS | No | Maximum number of parallel row group writers. | `'maximum_parallel_row_group_writers'` | 1 | +| MAXIMUM_BUFFERED_RECORD_BATCHES_PER_STREAM | No | Maximum number of buffered record batches per stream. | `'maximum_buffered_record_batches_per_stream'` | 2 | +| KEY_VALUE_METADATA | No (Key is specific) | Adds custom key-value pairs to the file metadata. Use the format `'metadata::your_key_name' 'your_value'`. Multiple entries allowed. | `'metadata::key_name'` | None | + +**Example:** + +```sql +CREATE EXTERNAL TABLE t (id bigint, value double, category varchar) +STORED AS PARQUET +LOCATION '/tmp/parquet_data/' +OPTIONS( + 'COMPRESSION::user_id' 'snappy', + 'ENCODING::col_a' 'delta_binary_packed', + 'MAX_ROW_GROUP_SIZE' '1000000', + 'BLOOM_FILTER_ENABLED::id' 'true' +); +``` diff --git a/docs/source/user-guide/sql/index.rst b/docs/source/user-guide/sql/index.rst index 8e3f51bf8b0b..a13d40334b63 100644 --- a/docs/source/user-guide/sql/index.rst +++ b/docs/source/user-guide/sql/index.rst @@ -33,5 +33,5 @@ SQL Reference window_functions scalar_functions special_functions - write_options + format_options prepared_statements diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 0f08934c8a9c..cbcec710e267 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -2133,7 +2133,7 @@ _Alias of [date_trunc](#date_trunc)._ ### `from_unixtime` -Converts an integer to RFC3339 timestamp format (`YYYY-MM-DDT00:00:00.000000000Z`). Integers and unsigned integers are interpreted as nanoseconds since the unix epoch (`1970-01-01T00:00:00Z`) return the corresponding timestamp. +Converts an integer to RFC3339 timestamp format (`YYYY-MM-DDT00:00:00.000000000Z`). Integers and unsigned integers are interpreted as seconds since the unix epoch (`1970-01-01T00:00:00Z`) return the corresponding timestamp. ```sql from_unixtime(expression[, timezone]) @@ -4404,6 +4404,7 @@ sha512(expression) Functions to work with the union data type, also know as tagged unions, variant types, enums or sum types. Note: Not related to the SQL UNION operator - [union_extract](#union_extract) +- [union_tag](#union_tag) ### `union_extract` @@ -4433,6 +4434,33 @@ union_extract(union, field_name) +--------------+----------------------------------+----------------------------------+ ``` +### `union_tag` + +Returns the name of the currently selected field in the union + +```sql +union_tag(union_expression) +``` + +#### Arguments + +- **union**: Union expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +❯ select union_column, union_tag(union_column) from table_with_union; ++--------------+-------------------------+ +| union_column | union_tag(union_column) | ++--------------+-------------------------+ +| {a=1} | a | +| {b=3.0} | b | +| {a=4} | a | +| {b=} | b | +| {a=} | a | ++--------------+-------------------------+ +``` + ## Other Functions - [arrow_cast](#arrow_cast) diff --git a/docs/source/user-guide/sql/window_functions.md b/docs/source/user-guide/sql/window_functions.md index 68a700380312..bcb33bad7fb5 100644 --- a/docs/source/user-guide/sql/window_functions.md +++ b/docs/source/user-guide/sql/window_functions.md @@ -193,6 +193,29 @@ Returns the rank of the current row without gaps. This function ranks rows in a dense_rank() ``` +#### Example + +```sql + --Example usage of the dense_rank window function: + SELECT department, + salary, + dense_rank() OVER (PARTITION BY department ORDER BY salary DESC) AS dense_rank + FROM employees; +``` + +```sql ++-------------+--------+------------+ +| department | salary | dense_rank | ++-------------+--------+------------+ +| Sales | 70000 | 1 | +| Sales | 50000 | 2 | +| Sales | 50000 | 2 | +| Sales | 30000 | 3 | +| Engineering | 90000 | 1 | +| Engineering | 80000 | 2 | ++-------------+--------+------------+ +``` + ### `ntile` Integer ranging from 1 to the argument value, dividing the partition as equally as possible @@ -205,6 +228,31 @@ ntile(expression) - **expression**: An integer describing the number groups the partition should be split into +#### Example + +```sql + --Example usage of the ntile window function: + SELECT employee_id, + salary, + ntile(4) OVER (ORDER BY salary DESC) AS quartile + FROM employees; +``` + +```sql ++-------------+--------+----------+ +| employee_id | salary | quartile | ++-------------+--------+----------+ +| 1 | 90000 | 1 | +| 2 | 85000 | 1 | +| 3 | 80000 | 2 | +| 4 | 70000 | 2 | +| 5 | 60000 | 3 | +| 6 | 50000 | 3 | +| 7 | 40000 | 4 | +| 8 | 30000 | 4 | ++-------------+--------+----------+ +``` + ### `percent_rank` Returns the percentage rank of the current row within its partition. The value ranges from 0 to 1 and is computed as `(rank - 1) / (total_rows - 1)`. @@ -213,6 +261,26 @@ Returns the percentage rank of the current row within its partition. The value r percent_rank() ``` +#### Example + +```sql + --Example usage of the percent_rank window function: + SELECT employee_id, + salary, + percent_rank() OVER (ORDER BY salary) AS percent_rank + FROM employees; +``` + +```sql ++-------------+--------+---------------+ +| employee_id | salary | percent_rank | ++-------------+--------+---------------+ +| 1 | 30000 | 0.00 | +| 2 | 50000 | 0.50 | +| 3 | 70000 | 1.00 | ++-------------+--------+---------------+ +``` + ### `rank` Returns the rank of the current row within its partition, allowing gaps between ranks. This function provides a ranking similar to `row_number`, but skips ranks for identical values. @@ -221,6 +289,29 @@ Returns the rank of the current row within its partition, allowing gaps between rank() ``` +#### Example + +```sql + --Example usage of the rank window function: + SELECT department, + salary, + rank() OVER (PARTITION BY department ORDER BY salary DESC) AS rank + FROM employees; +``` + +```sql ++-------------+--------+------+ +| department | salary | rank | ++-------------+--------+------+ +| Sales | 70000 | 1 | +| Sales | 50000 | 2 | +| Sales | 50000 | 2 | +| Sales | 30000 | 4 | +| Engineering | 90000 | 1 | +| Engineering | 80000 | 2 | ++-------------+--------+------+ +``` + ### `row_number` Number of the current row within its partition, counting from 1. @@ -229,6 +320,30 @@ Number of the current row within its partition, counting from 1. row_number() ``` +#### Example + +```sql + --Example usage of the row_number window function: + SELECT department, + salary, + row_number() OVER (PARTITION BY department ORDER BY salary DESC) AS row_num + FROM employees; +``` + +````sql ++-------------+--------+---------+ +| department | salary | row_num | ++-------------+--------+---------+ +| Sales | 70000 | 1 | +| Sales | 50000 | 2 | +| Sales | 50000 | 3 | +| Sales | 30000 | 4 | +| Engineering | 90000 | 1 | +| Engineering | 80000 | 2 | ++-------------+--------+---------+ +```# + + ## Analytical Functions - [first_value](#first_value) @@ -243,12 +358,35 @@ Returns value evaluated at the row that is the first row of the window frame. ```sql first_value(expression) -``` +```` #### Arguments - **expression**: Expression to operate on +#### Example + +```sql + --Example usage of the first_value window function: + SELECT department, + employee_id, + salary, + first_value(salary) OVER (PARTITION BY department ORDER BY salary DESC) AS top_salary + FROM employees; +``` + +```sql ++-------------+-------------+--------+------------+ +| department | employee_id | salary | top_salary | ++-------------+-------------+--------+------------+ +| Sales | 1 | 70000 | 70000 | +| Sales | 2 | 50000 | 70000 | +| Sales | 3 | 30000 | 70000 | +| Engineering | 4 | 90000 | 90000 | +| Engineering | 5 | 80000 | 90000 | ++-------------+-------------+--------+------------+ +``` + ### `lag` Returns value evaluated at the row that is offset rows before the current row within the partition; if there is no such row, instead return default (which must be of the same type as value). @@ -263,6 +401,27 @@ lag(expression, offset, default) - **offset**: Integer. Specifies how many rows back the value of expression should be retrieved. Defaults to 1. - **default**: The default value if the offset is not within the partition. Must be of the same type as expression. +#### Example + +```sql + --Example usage of the lag window function: + SELECT employee_id, + salary, + lag(salary, 1, 0) OVER (ORDER BY employee_id) AS prev_salary + FROM employees; +``` + +```sql ++-------------+--------+-------------+ +| employee_id | salary | prev_salary | ++-------------+--------+-------------+ +| 1 | 30000 | 0 | +| 2 | 50000 | 30000 | +| 3 | 70000 | 50000 | +| 4 | 60000 | 70000 | ++-------------+--------+-------------+ +``` + ### `last_value` Returns value evaluated at the row that is the last row of the window frame. @@ -275,6 +434,29 @@ last_value(expression) - **expression**: Expression to operate on +#### Example + +```sql +-- SQL example of last_value: +SELECT department, + employee_id, + salary, + last_value(salary) OVER (PARTITION BY department ORDER BY salary) AS running_last_salary +FROM employees; +``` + +```sql ++-------------+-------------+--------+---------------------+ +| department | employee_id | salary | running_last_salary | ++-------------+-------------+--------+---------------------+ +| Sales | 1 | 30000 | 30000 | +| Sales | 2 | 50000 | 50000 | +| Sales | 3 | 70000 | 70000 | +| Engineering | 4 | 40000 | 40000 | +| Engineering | 5 | 60000 | 60000 | ++-------------+-------------+--------+---------------------+ +``` + ### `lead` Returns value evaluated at the row that is offset rows after the current row within the partition; if there is no such row, instead return default (which must be of the same type as value). @@ -289,6 +471,30 @@ lead(expression, offset, default) - **offset**: Integer. Specifies how many rows forward the value of expression should be retrieved. Defaults to 1. - **default**: The default value if the offset is not within the partition. Must be of the same type as expression. +#### Example + +```sql +-- Example usage of lead() : +SELECT + employee_id, + department, + salary, + lead(salary, 1, 0) OVER (PARTITION BY department ORDER BY salary) AS next_salary +FROM employees; +``` + +```sql ++-------------+-------------+--------+--------------+ +| employee_id | department | salary | next_salary | ++-------------+-------------+--------+--------------+ +| 1 | Sales | 30000 | 50000 | +| 2 | Sales | 50000 | 70000 | +| 3 | Sales | 70000 | 0 | +| 4 | Engineering | 40000 | 60000 | +| 5 | Engineering | 60000 | 0 | ++-------------+-------------+--------+--------------+ +``` + ### `nth_value` Returns the value evaluated at the nth row of the window frame (counting from 1). Returns NULL if no such row exists. diff --git a/docs/source/user-guide/sql/write_options.md b/docs/source/user-guide/sql/write_options.md deleted file mode 100644 index 521e29436212..000000000000 --- a/docs/source/user-guide/sql/write_options.md +++ /dev/null @@ -1,127 +0,0 @@ - - -# Write Options - -DataFusion supports customizing how data is written out to disk as a result of a `COPY` or `INSERT INTO` query. There are a few special options, file format (e.g. CSV or parquet) specific options, and parquet column specific options. Options can also in some cases be specified in multiple ways with a set order of precedence. - -## Specifying Options and Order of Precedence - -Write related options can be specified in the following ways: - -- Session level config defaults -- `CREATE EXTERNAL TABLE` options -- `COPY` option tuples - -For a list of supported session level config defaults see [Configuration Settings](../configs). These defaults apply to all write operations but have the lowest level of precedence. - -If inserting to an external table, table specific write options can be specified when the table is created using the `OPTIONS` clause: - -```sql -CREATE EXTERNAL TABLE - my_table(a bigint, b bigint) - STORED AS csv - COMPRESSION TYPE gzip - LOCATION '/test/location/my_csv_table/' - OPTIONS( - NULL_VALUE 'NAN', - 'has_header' 'true', - 'format.delimiter' ';' - ) -``` - -When running `INSERT INTO my_table ...`, the options from the `CREATE TABLE` will be respected (gzip compression, special delimiter, and header row included). There will be a single output file if the output path doesn't have folder format, i.e. ending with a `\`. Note that compression, header, and delimiter settings can also be specified within the `OPTIONS` tuple list. Dedicated syntax within the SQL statement always takes precedence over arbitrary option tuples, so if both are specified the `OPTIONS` setting will be ignored. NULL_VALUE is a CSV format specific option that determines how null values should be encoded within the CSV file. - -Finally, options can be passed when running a `COPY` command. - - - -```sql -COPY source_table - TO 'test/table_with_options' - PARTITIONED BY (column3, column4) - OPTIONS ( - format parquet, - compression snappy, - 'compression::column1' 'zstd(5)', - ) -``` - -In this example, we write the entirety of `source_table` out to a folder of parquet files. One parquet file will be written in parallel to the folder for each partition in the query. The next option `compression` set to `snappy` indicates that unless otherwise specified all columns should use the snappy compression codec. The option `compression::col1` sets an override, so that the column `col1` in the parquet file will use `ZSTD` compression codec with compression level `5`. In general, parquet options which support column specific settings can be specified with the syntax `OPTION::COLUMN.NESTED.PATH`. - -## Available Options - -### Execution Specific Options - -The following options are available when executing a `COPY` query. - -| Option | Description | Default Value | -| ----------------------------------- | ---------------------------------------------------------------------------------- | ------------- | -| execution.keep_partition_by_columns | Flag to retain the columns in the output data when using `PARTITIONED BY` queries. | false | - -Note: `execution.keep_partition_by_columns` flag can also be enabled through `ExecutionOptions` within `SessionConfig`. - -### JSON Format Specific Options - -The following options are available when writing JSON files. Note: If any unsupported option is specified, an error will be raised and the query will fail. - -| Option | Description | Default Value | -| ----------- | ---------------------------------------------------------------------------------------------------------------------------------- | ------------- | -| COMPRESSION | Sets the compression that should be applied to the entire JSON file. Supported values are GZIP, BZIP2, XZ, ZSTD, and UNCOMPRESSED. | UNCOMPRESSED | - -### CSV Format Specific Options - -The following options are available when writing CSV files. Note: if any unsupported options is specified an error will be raised and the query will fail. - -| Option | Description | Default Value | -| --------------- | --------------------------------------------------------------------------------------------------------------------------------- | ---------------- | -| COMPRESSION | Sets the compression that should be applied to the entire CSV file. Supported values are GZIP, BZIP2, XZ, ZSTD, and UNCOMPRESSED. | UNCOMPRESSED | -| HEADER | Sets if the CSV file should include column headers | false | -| DATE_FORMAT | Sets the format that dates should be encoded in within the CSV file | arrow-rs default | -| DATETIME_FORMAT | Sets the format that datetimes should be encoded in within the CSV file | arrow-rs default | -| TIME_FORMAT | Sets the format that times should be encoded in within the CSV file | arrow-rs default | -| RFC3339 | If true, uses RFC339 format for date and time encodings | arrow-rs default | -| NULL_VALUE | Sets the string which should be used to indicate null values within the CSV file. | arrow-rs default | -| DELIMITER | Sets the character which should be used as the column delimiter within the CSV file. | arrow-rs default | - -### Parquet Format Specific Options - -The following options are available when writing parquet files. If any unsupported option is specified an error will be raised and the query will fail. If a column specific option is specified for a column which does not exist, the option will be ignored without error. For default values, see: [Configuration Settings](https://datafusion.apache.org/user-guide/configs.html). - -| Option | Can be Column Specific? | Description | -| ---------------------------- | ----------------------- | ----------------------------------------------------------------------------------------------------------------------------------- | -| COMPRESSION | Yes | Sets the compression codec and if applicable compression level to use | -| MAX_ROW_GROUP_SIZE | No | Sets the maximum number of rows that can be encoded in a single row group. Larger row groups require more memory to write and read. | -| DATA_PAGESIZE_LIMIT | No | Sets the best effort maximum page size in bytes | -| WRITE_BATCH_SIZE | No | Maximum number of rows written for each column in a single batch | -| WRITER_VERSION | No | Parquet writer version (1.0 or 2.0) | -| DICTIONARY_PAGE_SIZE_LIMIT | No | Sets best effort maximum dictionary page size in bytes | -| CREATED_BY | No | Sets the "created by" property in the parquet file | -| COLUMN_INDEX_TRUNCATE_LENGTH | No | Sets the max length of min/max value fields in the column index. | -| DATA_PAGE_ROW_COUNT_LIMIT | No | Sets best effort maximum number of rows in a data page. | -| BLOOM_FILTER_ENABLED | Yes | Sets whether a bloom filter should be written into the file. | -| ENCODING | Yes | Sets the encoding that should be used (e.g. PLAIN or RLE) | -| DICTIONARY_ENABLED | Yes | Sets if dictionary encoding is enabled. Use this instead of ENCODING to set dictionary encoding. | -| STATISTICS_ENABLED | Yes | Sets if statistics are enabled at PAGE or ROW_GROUP level. | -| MAX_STATISTICS_SIZE | Yes | Sets the maximum size in bytes that statistics can take up. | -| BLOOM_FILTER_FPP | Yes | Sets the false positive probability (fpp) for the bloom filter. Implicitly sets BLOOM_FILTER_ENABLED to true. | -| BLOOM_FILTER_NDV | Yes | Sets the number of distinct values (ndv) for the bloom filter. Implicitly sets bloom_filter_enabled to true. | diff --git a/parquet-testing b/parquet-testing index 6e851ddd768d..107b36603e05 160000 --- a/parquet-testing +++ b/parquet-testing @@ -1 +1 @@ -Subproject commit 6e851ddd768d6af741c7b15dc594874399fc3cff +Subproject commit 107b36603e051aee26bd93e04b871034f6c756c0 diff --git a/rust-toolchain.toml b/rust-toolchain.toml index a85e6fa54299..c52dd7322d9a 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -19,5 +19,5 @@ # to compile this workspace and run CI jobs. [toolchain] -channel = "1.86.0" +channel = "1.87.0" components = ["rustfmt", "clippy"] diff --git a/test-utils/src/array_gen/binary.rs b/test-utils/src/array_gen/binary.rs index d342118fa85d..9740eeae5e7f 100644 --- a/test-utils/src/array_gen/binary.rs +++ b/test-utils/src/array_gen/binary.rs @@ -46,11 +46,11 @@ impl BinaryArrayGenerator { // Pick num_binaries randomly from the distinct binary table let indices: UInt32Array = (0..self.num_binaries) .map(|_| { - if self.rng.gen::() < self.null_pct { + if self.rng.random::() < self.null_pct { None } else if self.num_distinct_binaries > 1 { let range = 0..(self.num_distinct_binaries as u32); - Some(self.rng.gen_range(range)) + Some(self.rng.random_range(range)) } else { Some(0) } @@ -68,11 +68,11 @@ impl BinaryArrayGenerator { let indices: UInt32Array = (0..self.num_binaries) .map(|_| { - if self.rng.gen::() < self.null_pct { + if self.rng.random::() < self.null_pct { None } else if self.num_distinct_binaries > 1 { let range = 0..(self.num_distinct_binaries as u32); - Some(self.rng.gen_range(range)) + Some(self.rng.random_range(range)) } else { Some(0) } @@ -88,7 +88,7 @@ fn random_binary(rng: &mut StdRng, max_len: usize) -> Vec { if max_len == 0 { Vec::new() } else { - let len = rng.gen_range(1..=max_len); - (0..len).map(|_| rng.gen()).collect() + let len = rng.random_range(1..=max_len); + (0..len).map(|_| rng.random()).collect() } } diff --git a/test-utils/src/array_gen/boolean.rs b/test-utils/src/array_gen/boolean.rs index f3b83dd245f7..004d615b4caa 100644 --- a/test-utils/src/array_gen/boolean.rs +++ b/test-utils/src/array_gen/boolean.rs @@ -34,7 +34,7 @@ impl BooleanArrayGenerator { // Table of booleans from which to draw (distinct means 1 or 2) let distinct_booleans: BooleanArray = match self.num_distinct_booleans { 1 => { - let value = self.rng.gen::(); + let value = self.rng.random::(); let mut builder = BooleanBuilder::with_capacity(1); builder.append_value(value); builder.finish() @@ -51,10 +51,10 @@ impl BooleanArrayGenerator { // Generate indices to select from the distinct booleans let indices: UInt32Array = (0..self.num_booleans) .map(|_| { - if self.rng.gen::() < self.null_pct { + if self.rng.random::() < self.null_pct { None } else if self.num_distinct_booleans > 1 { - Some(self.rng.gen_range(0..self.num_distinct_booleans as u32)) + Some(self.rng.random_range(0..self.num_distinct_booleans as u32)) } else { Some(0) } diff --git a/test-utils/src/array_gen/decimal.rs b/test-utils/src/array_gen/decimal.rs index d46ea9fe5457..c5ec8ac5e893 100644 --- a/test-utils/src/array_gen/decimal.rs +++ b/test-utils/src/array_gen/decimal.rs @@ -62,11 +62,11 @@ impl DecimalArrayGenerator { // pick num_decimals randomly from the distinct decimal table let indices: UInt32Array = (0..self.num_decimals) .map(|_| { - if self.rng.gen::() < self.null_pct { + if self.rng.random::() < self.null_pct { None } else if self.num_distinct_decimals > 1 { let range = 1..(self.num_distinct_decimals as u32); - Some(self.rng.gen_range(range)) + Some(self.rng.random_range(range)) } else { Some(0) } diff --git a/test-utils/src/array_gen/primitive.rs b/test-utils/src/array_gen/primitive.rs index 58d39c14e65d..62a38a1b4ce1 100644 --- a/test-utils/src/array_gen/primitive.rs +++ b/test-utils/src/array_gen/primitive.rs @@ -18,7 +18,8 @@ use arrow::array::{ArrayRef, ArrowPrimitiveType, PrimitiveArray, UInt32Array}; use arrow::datatypes::DataType; use chrono_tz::{Tz, TZ_VARIANTS}; -use rand::{rngs::StdRng, seq::SliceRandom, thread_rng, Rng}; +use rand::prelude::IndexedRandom; +use rand::{rng, rngs::StdRng, Rng}; use std::sync::Arc; use super::random_data::RandomNativeData; @@ -66,6 +67,7 @@ impl PrimitiveArrayGenerator { | DataType::Time32(_) | DataType::Time64(_) | DataType::Interval(_) + | DataType::Duration(_) | DataType::Binary | DataType::LargeBinary | DataType::BinaryView @@ -81,11 +83,11 @@ impl PrimitiveArrayGenerator { // pick num_primitives randomly from the distinct string table let indices: UInt32Array = (0..self.num_primitives) .map(|_| { - if self.rng.gen::() < self.null_pct { + if self.rng.random::() < self.null_pct { None } else if self.num_distinct_primitives > 1 { let range = 1..(self.num_distinct_primitives as u32); - Some(self.rng.gen_range(range)) + Some(self.rng.random_range(range)) } else { Some(0) } @@ -102,7 +104,7 @@ impl PrimitiveArrayGenerator { /// - `Some(Arc)` containing the timezone name. /// - `None` if no timezone is selected. fn generate_timezone() -> Option> { - let mut rng = thread_rng(); + let mut rng = rng(); // Allows for timezones + None let mut timezone_options: Vec> = vec![None]; diff --git a/test-utils/src/array_gen/random_data.rs b/test-utils/src/array_gen/random_data.rs index a7297d45fdf0..78518b7bf9dc 100644 --- a/test-utils/src/array_gen/random_data.rs +++ b/test-utils/src/array_gen/random_data.rs @@ -17,15 +17,16 @@ use arrow::array::ArrowPrimitiveType; use arrow::datatypes::{ - i256, Date32Type, Date64Type, Decimal128Type, Decimal256Type, Float32Type, - Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, IntervalDayTime, - IntervalDayTimeType, IntervalMonthDayNano, IntervalMonthDayNanoType, - IntervalYearMonthType, Time32MillisecondType, Time32SecondType, - Time64MicrosecondType, Time64NanosecondType, TimestampMicrosecondType, - TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt16Type, - UInt32Type, UInt64Type, UInt8Type, + i256, Date32Type, Date64Type, Decimal128Type, Decimal256Type, + DurationMicrosecondType, DurationMillisecondType, DurationNanosecondType, + DurationSecondType, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, + Int8Type, IntervalDayTime, IntervalDayTimeType, IntervalMonthDayNano, + IntervalMonthDayNanoType, IntervalYearMonthType, Time32MillisecondType, + Time32SecondType, Time64MicrosecondType, Time64NanosecondType, + TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, + TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, }; -use rand::distributions::Standard; +use rand::distr::StandardUniform; use rand::prelude::Distribution; use rand::rngs::StdRng; use rand::Rng; @@ -40,11 +41,11 @@ macro_rules! basic_random_data { ($ARROW_TYPE: ty) => { impl RandomNativeData for $ARROW_TYPE where - Standard: Distribution, + StandardUniform: Distribution, { #[inline] fn generate_random_native_data(rng: &mut StdRng) -> Self::Native { - rng.gen::() + rng.random::() } } }; @@ -71,11 +72,16 @@ basic_random_data!(TimestampSecondType); basic_random_data!(TimestampMillisecondType); basic_random_data!(TimestampMicrosecondType); basic_random_data!(TimestampNanosecondType); +// Note DurationSecondType is restricted to i64::MIN / 1000 to i64::MAX / 1000 +// due to https://github.com/apache/arrow-rs/issues/7533 so handle it specially below +basic_random_data!(DurationMillisecondType); +basic_random_data!(DurationMicrosecondType); +basic_random_data!(DurationNanosecondType); impl RandomNativeData for Date64Type { fn generate_random_native_data(rng: &mut StdRng) -> Self::Native { // TODO: constrain this range to valid dates if necessary - let date_value = rng.gen_range(i64::MIN..=i64::MAX); + let date_value = rng.random_range(i64::MIN..=i64::MAX); let millis_per_day = 86_400_000; date_value - (date_value % millis_per_day) } @@ -84,8 +90,8 @@ impl RandomNativeData for Date64Type { impl RandomNativeData for IntervalDayTimeType { fn generate_random_native_data(rng: &mut StdRng) -> Self::Native { IntervalDayTime { - days: rng.gen::(), - milliseconds: rng.gen::(), + days: rng.random::(), + milliseconds: rng.random::(), } } } @@ -93,15 +99,24 @@ impl RandomNativeData for IntervalDayTimeType { impl RandomNativeData for IntervalMonthDayNanoType { fn generate_random_native_data(rng: &mut StdRng) -> Self::Native { IntervalMonthDayNano { - months: rng.gen::(), - days: rng.gen::(), - nanoseconds: rng.gen::(), + months: rng.random::(), + days: rng.random::(), + nanoseconds: rng.random::(), } } } +// Restrict Duration(Seconds) to i64::MIN / 1000 to i64::MAX / 1000 to +// avoid panics on pretty printing. See +// https://github.com/apache/arrow-rs/issues/7533 +impl RandomNativeData for DurationSecondType { + fn generate_random_native_data(rng: &mut StdRng) -> Self::Native { + rng.random::() / 1000 + } +} + impl RandomNativeData for Decimal256Type { fn generate_random_native_data(rng: &mut StdRng) -> Self::Native { - i256::from_parts(rng.gen::(), rng.gen::()) + i256::from_parts(rng.random::(), rng.random::()) } } diff --git a/test-utils/src/array_gen/string.rs b/test-utils/src/array_gen/string.rs index ac659ae67bc0..546485fd8dc1 100644 --- a/test-utils/src/array_gen/string.rs +++ b/test-utils/src/array_gen/string.rs @@ -18,6 +18,7 @@ use arrow::array::{ ArrayRef, GenericStringArray, OffsetSizeTrait, StringViewArray, UInt32Array, }; +use rand::distr::StandardUniform; use rand::rngs::StdRng; use rand::Rng; @@ -47,11 +48,11 @@ impl StringArrayGenerator { // pick num_strings randomly from the distinct string table let indices: UInt32Array = (0..self.num_strings) .map(|_| { - if self.rng.gen::() < self.null_pct { + if self.rng.random::() < self.null_pct { None } else if self.num_distinct_strings > 1 { let range = 1..(self.num_distinct_strings as u32); - Some(self.rng.gen_range(range)) + Some(self.rng.random_range(range)) } else { Some(0) } @@ -71,11 +72,11 @@ impl StringArrayGenerator { // pick num_strings randomly from the distinct string table let indices: UInt32Array = (0..self.num_strings) .map(|_| { - if self.rng.gen::() < self.null_pct { + if self.rng.random::() < self.null_pct { None } else if self.num_distinct_strings > 1 { let range = 1..(self.num_distinct_strings as u32); - Some(self.rng.gen_range(range)) + Some(self.rng.random_range(range)) } else { Some(0) } @@ -92,10 +93,10 @@ fn random_string(rng: &mut StdRng, max_len: usize) -> String { // pick characters at random (not just ascii) match max_len { 0 => "".to_string(), - 1 => String::from(rng.gen::()), + 1 => String::from(rng.random::()), _ => { - let len = rng.gen_range(1..=max_len); - rng.sample_iter::(rand::distributions::Standard) + let len = rng.random_range(1..=max_len); + rng.sample_iter::(StandardUniform) .take(len) .collect() } diff --git a/test-utils/src/data_gen.rs b/test-utils/src/data_gen.rs index 7ac6f3d3e255..2228010b28dd 100644 --- a/test-utils/src/data_gen.rs +++ b/test-utils/src/data_gen.rs @@ -104,10 +104,11 @@ impl BatchBuilder { } fn append(&mut self, rng: &mut StdRng, host: &str, service: &str) { - let num_pods = rng.gen_range(self.options.pods_per_host.clone()); + let num_pods = rng.random_range(self.options.pods_per_host.clone()); let pods = generate_sorted_strings(rng, num_pods, 30..40); for pod in pods { - let num_containers = rng.gen_range(self.options.containers_per_pod.clone()); + let num_containers = + rng.random_range(self.options.containers_per_pod.clone()); for container_idx in 0..num_containers { let container = format!("{service}_container_{container_idx}"); let image = format!( @@ -115,7 +116,7 @@ impl BatchBuilder { ); let num_entries = - rng.gen_range(self.options.entries_per_container.clone()); + rng.random_range(self.options.entries_per_container.clone()); for i in 0..num_entries { if self.is_finished() { return; @@ -154,7 +155,7 @@ impl BatchBuilder { if self.options.include_nulls { // Append a null value if the option is set // Use both "NULL" as a string and a null value - if rng.gen_bool(0.5) { + if rng.random_bool(0.5) { self.client_addr.append_null(); } else { self.client_addr.append_value("NULL"); @@ -162,26 +163,26 @@ impl BatchBuilder { } else { self.client_addr.append_value(format!( "{}.{}.{}.{}", - rng.gen::(), - rng.gen::(), - rng.gen::(), - rng.gen::() + rng.random::(), + rng.random::(), + rng.random::(), + rng.random::() )); } - self.request_duration.append_value(rng.gen()); + self.request_duration.append_value(rng.random()); self.request_user_agent .append_value(random_string(rng, 20..100)); self.request_method - .append_value(methods[rng.gen_range(0..methods.len())]); + .append_value(methods[rng.random_range(0..methods.len())]); self.request_host .append_value(format!("https://{service}.mydomain.com")); self.request_bytes - .append_option(rng.gen_bool(0.9).then(|| rng.gen())); + .append_option(rng.random_bool(0.9).then(|| rng.random())); self.response_bytes - .append_option(rng.gen_bool(0.9).then(|| rng.gen())); + .append_option(rng.random_bool(0.9).then(|| rng.random())); self.response_status - .append_value(status[rng.gen_range(0..status.len())]); + .append_value(status[rng.random_range(0..status.len())]); self.prices_status.append_value(self.row_count as i128); } @@ -216,9 +217,9 @@ impl BatchBuilder { } fn random_string(rng: &mut StdRng, len_range: Range) -> String { - let len = rng.gen_range(len_range); + let len = rng.random_range(len_range); (0..len) - .map(|_| rng.gen_range(b'a'..=b'z') as char) + .map(|_| rng.random_range(b'a'..=b'z') as char) .collect::() } @@ -364,7 +365,7 @@ impl Iterator for AccessLogGenerator { self.host_idx += 1; for service in &["frontend", "backend", "database", "cache"] { - if self.rng.gen_bool(0.5) { + if self.rng.random_bool(0.5) { continue; } if builder.is_finished() { diff --git a/test-utils/src/lib.rs b/test-utils/src/lib.rs index 47f23de4951e..be2bc0712afb 100644 --- a/test-utils/src/lib.rs +++ b/test-utils/src/lib.rs @@ -67,9 +67,9 @@ pub fn add_empty_batches( .flat_map(|batch| { // insert 0, or 1 empty batches before and after the current batch let empty_batch = RecordBatch::new_empty(schema.clone()); - std::iter::repeat_n(empty_batch.clone(), rng.gen_range(0..2)) + std::iter::repeat_n(empty_batch.clone(), rng.random_range(0..2)) .chain(std::iter::once(batch)) - .chain(std::iter::repeat_n(empty_batch, rng.gen_range(0..2))) + .chain(std::iter::repeat_n(empty_batch, rng.random_range(0..2))) }) .collect() } @@ -100,7 +100,7 @@ pub fn stagger_batch_with_seed(batch: RecordBatch, seed: u64) -> Vec 0 { - let batch_size = rng.gen_range(0..remainder.num_rows() + 1); + let batch_size = rng.random_range(0..remainder.num_rows() + 1); batches.push(remainder.slice(0, batch_size)); remainder = remainder.slice(batch_size, remainder.num_rows() - batch_size); diff --git a/test-utils/src/string_gen.rs b/test-utils/src/string_gen.rs index b598241db1e9..75ed03898a27 100644 --- a/test-utils/src/string_gen.rs +++ b/test-utils/src/string_gen.rs @@ -19,7 +19,7 @@ use crate::array_gen::StringArrayGenerator; use crate::stagger_batch; use arrow::record_batch::RecordBatch; use rand::rngs::StdRng; -use rand::{thread_rng, Rng, SeedableRng}; +use rand::{rng, Rng, SeedableRng}; /// Randomly generate strings pub struct StringBatchGenerator(StringArrayGenerator); @@ -56,18 +56,18 @@ impl StringBatchGenerator { stagger_batch(batch) } - /// Return an set of `BatchGenerator`s that cover a range of interesting + /// Return a set of `BatchGenerator`s that cover a range of interesting /// cases pub fn interesting_cases() -> Vec { let mut cases = vec![]; - let mut rng = thread_rng(); + let mut rng = rng(); for null_pct in [0.0, 0.01, 0.1, 0.5] { for _ in 0..10 { // max length of generated strings - let max_len = rng.gen_range(1..50); - let num_strings = rng.gen_range(1..100); + let max_len = rng.random_range(1..50); + let num_strings = rng.random_range(1..100); let num_distinct_strings = if num_strings > 1 { - rng.gen_range(1..num_strings) + rng.random_range(1..num_strings) } else { num_strings }; @@ -76,7 +76,7 @@ impl StringBatchGenerator { num_strings, num_distinct_strings, null_pct, - rng: StdRng::from_seed(rng.gen()), + rng: StdRng::from_seed(rng.random()), })) } }